diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h new file mode 100644 index 000000000..7ed3a3a56 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h @@ -0,0 +1,239 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Mehdi Goli Codeplay Software Ltd. +// Ralph Potter Codeplay Software Ltd. +// Luke Iwanski Codeplay Software Ltd. +// Contact: <eigen@codeplay.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/***************************************************************** + * TensorSyclExprConstructor.h + * + * \brief: + * This file re-create an expression on the SYCL device in order + * to use the original tensor evaluator. + * +*****************************************************************/ + +#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP +#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP + +namespace Eigen { +namespace TensorSycl { +namespace internal { +/// this class is used by EvalToOp in order to create an lhs expression which is +/// a pointer from an accessor on device-only buffer +template <typename PtrType, size_t N, typename... Params> +struct EvalToLHSConstructor { + PtrType expr; + EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t): expr((&(*(utility::tuple::get<N>(t).get_pointer())))) {} +}; + +/// \struct ExprConstructor is used to reconstruct the expression on the device and +/// recreate the expression with MakeGlobalPointer containing the device address +/// space for the TensorMap pointers used in eval function. +/// It receives the original expression type, the functor of the node, the tuple +/// of accessors, and the device expression type to re-instantiate the +/// expression tree for the device +template <typename OrigExpr, typename IndexExpr, typename... Params> +struct ExprConstructor; + +/// specialisation of the \ref ExprConstructor struct when the node type is +/// TensorMap +#define TENSORMAP(CVQual)\ +template <typename Scalar_, int Options_, int Options2_, int Options3_, int NumIndices_, typename IndexType_,\ +template <class> class MakePointer_, size_t N, typename... Params>\ +struct ExprConstructor< CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakeGlobalPointer>,\ +CVQual PlaceHolder<CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options3_, MakePointer_>, N>, Params...>{\ + typedef CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakeGlobalPointer> Type;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ + : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ +}; + +TENSORMAP(const) +TENSORMAP() +#undef TENSORMAP + +#define UNARYCATEGORY(CVQual)\ +template <template<class, class> class UnaryCategory, typename OP, typename OrigRHSExpr, typename RHSExpr, typename... Params>\ +struct ExprConstructor<CVQual UnaryCategory<OP, OrigRHSExpr>, CVQual UnaryCategory<OP, RHSExpr>, Params...> {\ + typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_type;\ + my_type rhsExpr;\ + typedef CVQual UnaryCategory<OP, typename my_type::Type> Type;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ + : rhsExpr(funcD.rhsExpr, t), expr(rhsExpr.expr, funcD.func) {}\ +}; + +UNARYCATEGORY(const) +UNARYCATEGORY() +#undef UNARYCATEGORY + +/// specialisation of the \ref ExprConstructor struct when the node type is +/// TensorBinaryOp +#define BINARYCATEGORY(CVQual)\ +template <template<class, class, class> class BinaryCategory, typename OP, typename OrigLHSExpr, typename OrigRHSExpr, typename LHSExpr,\ +typename RHSExpr, typename... Params>\ +struct ExprConstructor<CVQual BinaryCategory<OP, OrigLHSExpr, OrigRHSExpr>, CVQual BinaryCategory<OP, LHSExpr, RHSExpr>, Params...> {\ + typedef ExprConstructor<OrigLHSExpr, LHSExpr, Params...> my_left_type;\ + typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_right_type;\ + typedef CVQual BinaryCategory<OP, typename my_left_type::Type, typename my_right_type::Type> Type;\ + my_left_type lhsExpr;\ + my_right_type rhsExpr;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ + : lhsExpr(funcD.lhsExpr, t),rhsExpr(funcD.rhsExpr, t), expr(lhsExpr.expr, rhsExpr.expr, funcD.func) {}\ +}; + +BINARYCATEGORY(const) +BINARYCATEGORY() +#undef BINARYCATEGORY + +/// specialisation of the \ref ExprConstructor struct when the node type is +/// TensorCwiseTernaryOp +#define TERNARYCATEGORY(CVQual)\ +template <template <class, class, class, class> class TernaryCategory, typename OP, typename OrigArg1Expr, typename OrigArg2Expr,typename OrigArg3Expr,\ +typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename... Params>\ +struct ExprConstructor<CVQual TernaryCategory<OP, OrigArg1Expr, OrigArg2Expr, OrigArg3Expr>, CVQual TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Params...> {\ + typedef ExprConstructor<OrigArg1Expr, Arg1Expr, Params...> my_arg1_type;\ + typedef ExprConstructor<OrigArg2Expr, Arg2Expr, Params...> my_arg2_type;\ + typedef ExprConstructor<OrigArg3Expr, Arg3Expr, Params...> my_arg3_type;\ + typedef CVQual TernaryCategory<OP, typename my_arg1_type::Type, typename my_arg2_type::Type, typename my_arg3_type::Type> Type;\ + my_arg1_type arg1Expr;\ + my_arg2_type arg2Expr;\ + my_arg3_type arg3Expr;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD,const utility::tuple::Tuple<Params...> &t)\ + : arg1Expr(funcD.arg1Expr, t), arg2Expr(funcD.arg2Expr, t), arg3Expr(funcD.arg3Expr, t), expr(arg1Expr.expr, arg2Expr.expr, arg3Expr.expr, funcD.func) {}\ +}; + +TERNARYCATEGORY(const) +TERNARYCATEGORY() +#undef TERNARYCATEGORY + +/// specialisation of the \ref ExprConstructor struct when the node type is +/// TensorCwiseSelectOp +#define SELECTOP(CVQual)\ +template <typename OrigIfExpr, typename OrigThenExpr, typename OrigElseExpr, typename IfExpr, typename ThenExpr, typename ElseExpr, typename... Params>\ +struct ExprConstructor< CVQual TensorSelectOp<OrigIfExpr, OrigThenExpr, OrigElseExpr>, CVQual TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Params...> {\ + typedef ExprConstructor<OrigIfExpr, IfExpr, Params...> my_if_type;\ + typedef ExprConstructor<OrigThenExpr, ThenExpr, Params...> my_then_type;\ + typedef ExprConstructor<OrigElseExpr, ElseExpr, Params...> my_else_type;\ + typedef CVQual TensorSelectOp<typename my_if_type::Type, typename my_then_type::Type, typename my_else_type::Type> Type;\ + my_if_type ifExpr;\ + my_then_type thenExpr;\ + my_else_type elseExpr;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ + : ifExpr(funcD.ifExpr, t), thenExpr(funcD.thenExpr, t), elseExpr(funcD.elseExpr, t), expr(ifExpr.expr, thenExpr.expr, elseExpr.expr) {}\ +}; + +SELECTOP(const) +SELECTOP() +#undef SELECTOP + +/// specialisation of the \ref ExprConstructor struct when the node type is +/// const TensorAssignOp +#define ASSIGN(CVQual)\ +template <typename OrigLHSExpr, typename OrigRHSExpr, typename LHSExpr, typename RHSExpr, typename... Params>\ +struct ExprConstructor<CVQual TensorAssignOp<OrigLHSExpr, OrigRHSExpr>, CVQual TensorAssignOp<LHSExpr, RHSExpr>, Params...> {\ + typedef ExprConstructor<OrigLHSExpr, LHSExpr, Params...> my_left_type;\ + typedef ExprConstructor<OrigRHSExpr, RHSExpr, Params...> my_right_type;\ + typedef CVQual TensorAssignOp<typename my_left_type::Type, typename my_right_type::Type> Type;\ + my_left_type lhsExpr;\ + my_right_type rhsExpr;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ + : lhsExpr(funcD.lhsExpr, t), rhsExpr(funcD.rhsExpr, t), expr(lhsExpr.expr, rhsExpr.expr) {}\ + }; + + ASSIGN(const) + ASSIGN() + #undef ASSIGN +/// specialisation of the \ref ExprConstructor struct when the node type is +/// TensorEvalToOp +#define EVALTO(CVQual)\ +template <typename OrigExpr, typename Expr, typename... Params>\ +struct ExprConstructor<CVQual TensorEvalToOp<OrigExpr, MakeGlobalPointer>, CVQual TensorEvalToOp<Expr>, Params...> {\ + typedef ExprConstructor<OrigExpr, Expr, Params...> my_expr_type;\ + typedef typename TensorEvalToOp<OrigExpr, MakeGlobalPointer>::PointerType my_buffer_type;\ + typedef CVQual TensorEvalToOp<typename my_expr_type::Type, MakeGlobalPointer> Type;\ + my_expr_type nestedExpression;\ + EvalToLHSConstructor<my_buffer_type, 0, Params...> buffer;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ + : nestedExpression(funcD.rhsExpr, t), buffer(t), expr(buffer.expr, nestedExpression.expr) {}\ +}; + +EVALTO(const) +EVALTO() +#undef EVALTO + +/// specialisation of the \ref ExprConstructor struct when the node type is +/// TensorForcedEvalOp +#define FORCEDEVAL(CVQual)\ +template <typename OrigExpr, typename DevExpr, size_t N, typename... Params>\ +struct ExprConstructor<CVQual TensorForcedEvalOp<OrigExpr, MakeGlobalPointer>,\ +CVQual PlaceHolder<CVQual TensorForcedEvalOp<DevExpr>, N>, Params...> {\ + typedef CVQual TensorMap<Tensor<typename TensorForcedEvalOp<DevExpr, MakeGlobalPointer>::Scalar,\ + TensorForcedEvalOp<DevExpr, MakeGlobalPointer>::NumDimensions, 0, typename TensorForcedEvalOp<DevExpr>::Index>, 0, MakeGlobalPointer> Type;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ + : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ +}; + +FORCEDEVAL(const) +FORCEDEVAL() +#undef FORCEDEVAL + +template <bool Conds, size_t X , size_t Y > struct ValueCondition { + static const size_t Res =X; +}; +template<size_t X, size_t Y> struct ValueCondition<false, X , Y> { + static const size_t Res =Y; +}; + +/// specialisation of the \ref ExprConstructor struct when the node type is TensorReductionOp +#define SYCLREDUCTIONEXPR(CVQual)\ +template <typename OP, typename Dim, typename OrigExpr, typename DevExpr, size_t N, typename... Params>\ +struct ExprConstructor<CVQual TensorReductionOp<OP, Dim, OrigExpr, MakeGlobalPointer>,\ +CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dim, DevExpr>, N>, Params...> {\ + static const size_t NumIndices= ValueCondition< TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::NumDimensions==0, 1, TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::NumDimensions >::Res;\ + typedef CVQual TensorMap<Tensor<typename TensorReductionOp<OP, Dim, DevExpr, MakeGlobalPointer>::Scalar,\ + NumIndices, 0, typename TensorReductionOp<OP, Dim, DevExpr>::Index>, 0, MakeGlobalPointer> Type;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ + : expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}\ +}; + +SYCLREDUCTIONEXPR(const) +SYCLREDUCTIONEXPR() +#undef SYCLREDUCTIONEXPR + +/// template deduction for \ref ExprConstructor struct +template <typename OrigExpr, typename IndexExpr, typename FuncD, typename... Params> +auto createDeviceExpression(FuncD &funcD, const utility::tuple::Tuple<Params...> &t) + -> decltype(ExprConstructor<OrigExpr, IndexExpr, Params...>(funcD, t)) { + return ExprConstructor<OrigExpr, IndexExpr, Params...>(funcD, t); +} + +} /// namespace TensorSycl +} /// namespace internal +} /// namespace Eigen + + +#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXPR_CONSTRUCTOR_HPP |