aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h177
1 files changed, 177 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
new file mode 100644
index 000000000..427125343
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
@@ -0,0 +1,177 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * TensorSyclextractFunctors.h
+ *
+ * \brief:
+ * Used to extract all the functors allocated to each node of the expression
+*tree.
+ *
+*****************************************************************/
+
+#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
+#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
+
+namespace Eigen {
+namespace TensorSycl {
+namespace internal {
+/// \struct FunctorExtractor: This struct is used to extract the functors
+/// constructed on
+/// the host-side, to pack them and reuse them in reconstruction of the
+/// expression on the device.
+/// We have to do that as in Eigen the functors are not stateless so we cannot
+/// re-instantiate them on the device.
+/// We have to pass instantiated functors to the device.
+// This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval).
+template <typename Evaluator> struct FunctorExtractor{
+ typedef typename Evaluator::Dimensions Dimensions;
+ const Dimensions m_dimensions;
+ const Dimensions& dimensions() const { return m_dimensions; }
+ FunctorExtractor(const Evaluator& expr)
+ : m_dimensions(expr.dimensions()) {}
+
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp
+template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
+ OP func;
+ FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp
+template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> >
+: FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseBinaryOp
+template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
+ OP func;
+ FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseBinaryOp
+template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >
+: FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >{};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseTernaryOp
+template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev>
+struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr;
+ FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr;
+ FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr;
+ OP func;
+ FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
+ : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseTernaryOp
+template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev>
+struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >
+:FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated.
+template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
+struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr;
+ FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr;
+ FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr;
+ FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
+ : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated
+template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> >
+:FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorAssignOp. This is an specialisation without OP so it has to be separated.
+template <typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
+ FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorAssignOp. This is an specialisation without OP so it has to be separated.
+template <typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> >
+:FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{};
+
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorEvalToOp, This is an specialisation without OP so it has to be separated.
+template <typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
+ FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorEvalToOp. This is a specialisation without OP so it has to be separated.
+template <typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> >
+: FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {};
+
+template<typename Dim, size_t NumOutputDim> struct DimConstr {
+template<typename InDim>
+ static inline Dim getDim(InDim dims ) {return dims;}
+};
+
+template<typename Dim> struct DimConstr<Dim, 0> {
+ template<typename InDim>
+ static inline Dim getDim(InDim dims ) {return Dim(dims.TotalSize());}
+};
+
+template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
+struct FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{
+ typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device> Evaluator;
+ typedef typename Eigen::internal::conditional<Evaluator::NumOutputDims==0, DSizes<typename Evaluator::Index, 1>, typename Evaluator::Dimensions >::type Dimensions;
+ const Dimensions m_dimensions;
+ const Dimensions& dimensions() const { return m_dimensions; }
+ FunctorExtractor(const TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>& expr)
+ : m_dimensions(DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.dimensions())) {}
+};
+
+
+template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
+struct FunctorExtractor<TensorEvaluator<TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>
+: FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{};
+/// template deduction function for FunctorExtractor
+template <typename Evaluator>
+auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> {
+ return FunctorExtractor<Evaluator>(evaluator);
+}
+} // namespace internal
+} // namespace TensorSycl
+} // namespace Eigen
+
+#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP