aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
blob: d4c250c6dd35be5b444c0ff8b1308cb21bc1c729 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli    Codeplay Software Ltd.
// Ralph Potter  Codeplay Software Ltd.
// Luke Iwanski  Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

/*****************************************************************
 * TensorSyclPlaceHolderExpr.h
 *
 * \brief:
 *  This is the specialisation of the placeholder expression based on the
 * operation type
 *
*****************************************************************/

#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP
#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP

namespace Eigen {
namespace TensorSycl {
namespace internal {

/// \struct PlaceHolder
/// \brief PlaceHolder is used to replace the \ref TensorMap in the expression
/// tree.
/// PlaceHolder contains the order of the leaf node in the expression tree.
template <typename Scalar, size_t N>
struct PlaceHolder {
  static constexpr size_t I = N;
  typedef Scalar Type;
};

/// \sttruct PlaceHolderExpression
/// \brief it is used to create the PlaceHolder expression. The PlaceHolder
/// expression is a copy of expression type in which the TensorMap of the has
/// been replaced with PlaceHolder.
template <typename Expr, size_t N>
struct PlaceHolderExpression;

template<size_t N, typename... Args>
struct CalculateIndex;

template<size_t N, typename Arg>
struct CalculateIndex<N, Arg>{
  typedef typename PlaceHolderExpression<Arg, N>::Type ArgType;
  typedef utility::tuple::Tuple<ArgType> ArgsTuple;
};

template<size_t N, typename Arg1, typename Arg2>
struct CalculateIndex<N, Arg1, Arg2>{
  static const size_t Arg2LeafCount = LeafCount<Arg2>::Count;
  typedef typename PlaceHolderExpression<Arg1, N - Arg2LeafCount>::Type Arg1Type;
  typedef typename PlaceHolderExpression<Arg2, N>::Type Arg2Type;
  typedef utility::tuple::Tuple<Arg1Type, Arg2Type> ArgsTuple;
};

template<size_t N, typename Arg1, typename Arg2, typename Arg3>
struct CalculateIndex<N, Arg1, Arg2, Arg3> {
  static const size_t Arg3LeafCount = LeafCount<Arg3>::Count;
  static const size_t Arg2LeafCount = LeafCount<Arg2>::Count;
  typedef typename PlaceHolderExpression<Arg1, N - Arg3LeafCount - Arg2LeafCount>::Type Arg1Type;
  typedef typename PlaceHolderExpression<Arg2, N - Arg3LeafCount>::Type Arg2Type;
  typedef typename PlaceHolderExpression<Arg3, N>::Type Arg3Type;
  typedef utility::tuple::Tuple<Arg1Type, Arg2Type, Arg3Type> ArgsTuple;
};

template<template<class...> class Category , class OP, class TPL>
struct CategoryHelper;

template<template<class...> class Category , class OP, class ...T >
struct CategoryHelper<Category, OP, utility::tuple::Tuple<T...> > {
  typedef Category<OP, T... > Type;
};

template<template<class...> class Category , class ...T >
struct CategoryHelper<Category, NoOP, utility::tuple::Tuple<T...> > {
  typedef Category<T... > Type;
};

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorCwiseNullaryOp, TensorCwiseUnaryOp, TensorBroadcastingOp, TensorCwiseBinaryOp,  TensorCwiseTernaryOp
#define OPEXPRCATEGORY(CVQual)\
template <template <class, class... > class Category, typename OP, typename... SubExpr, size_t N>\
struct PlaceHolderExpression<CVQual Category<OP, SubExpr...>, N>{\
  typedef CVQual typename CategoryHelper<Category, OP, typename CalculateIndex<N, SubExpr...>::ArgsTuple>::Type Type;\
};

OPEXPRCATEGORY(const)
OPEXPRCATEGORY()
#undef OPEXPRCATEGORY

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorCwiseSelectOp
#define SELECTEXPR(CVQual)\
template <typename IfExpr, typename ThenExpr, typename ElseExpr, size_t N>\
struct PlaceHolderExpression<CVQual TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, N> {\
  typedef CVQual typename CategoryHelper<TensorSelectOp, NoOP, typename CalculateIndex<N, IfExpr, ThenExpr, ElseExpr>::ArgsTuple>::Type Type;\
};

SELECTEXPR(const)
SELECTEXPR()
#undef SELECTEXPR

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorAssignOp
#define ASSIGNEXPR(CVQual)\
template <typename LHSExpr, typename RHSExpr, size_t N>\
struct PlaceHolderExpression<CVQual TensorAssignOp<LHSExpr, RHSExpr>, N> {\
  typedef CVQual typename CategoryHelper<TensorAssignOp, NoOP, typename CalculateIndex<N, LHSExpr, RHSExpr>::ArgsTuple>::Type Type;\
};

ASSIGNEXPR(const)
ASSIGNEXPR()
#undef ASSIGNEXPR

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorMap
#define TENSORMAPEXPR(CVQual)\
template <typename Scalar_, int Options_, int Options2_, int NumIndices_, typename IndexType_, template <class> class MakePointer_, size_t N>\
struct PlaceHolderExpression< CVQual TensorMap< Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakePointer_>, N> {\
  typedef CVQual PlaceHolder<CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakePointer_>, N> Type;\
};

TENSORMAPEXPR(const)
TENSORMAPEXPR()
#undef TENSORMAPEXPR

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorForcedEvalOp
#define FORCEDEVAL(CVQual)\
template <typename Expr, size_t N>\
struct PlaceHolderExpression<CVQual TensorForcedEvalOp<Expr>, N> {\
  typedef CVQual PlaceHolder<CVQual TensorForcedEvalOp<Expr>, N> Type;\
};

FORCEDEVAL(const)
FORCEDEVAL()
#undef FORCEDEVAL

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorEvalToOp
#define EVALTO(CVQual)\
template <typename Expr, size_t N>\
struct PlaceHolderExpression<CVQual TensorEvalToOp<Expr>, N> {\
  typedef CVQual TensorEvalToOp<typename CalculateIndex <N, Expr>::ArgType> Type;\
};

EVALTO(const)
EVALTO()
#undef EVALTO


/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorReductionOp
#define SYCLREDUCTION(CVQual)\
template <typename OP, typename Dims, typename Expr, size_t N>\
struct PlaceHolderExpression<CVQual TensorReductionOp<OP, Dims, Expr>, N>{\
  typedef CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dims,Expr>, N> Type;\
};
SYCLREDUCTION(const)
SYCLREDUCTION()
#undef SYCLREDUCTION

/// template deduction for \ref PlaceHolderExpression struct
template <typename Expr>
struct createPlaceHolderExpression {
  static const size_t TotalLeaves = LeafCount<Expr>::Count;
  typedef typename PlaceHolderExpression<Expr, TotalLeaves - 1>::Type Type;
};

}  // internal
}  // TensorSycl
}  // namespace Eigen

#endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP