aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h78
1 files changed, 54 insertions, 24 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
index d06f40cd8..8b8fb9235 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
@@ -37,7 +37,7 @@ struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
template<typename XprType>
struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
{
- typedef const TensorIndexTupleOp<XprType>& type;
+ typedef const TensorIndexTupleOp<XprType>EIGEN_DEVICE_REF type;
};
template<typename XprType>
@@ -82,28 +82,35 @@ struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
static const int NumDims = internal::array_size<Dimensions>::value;
+ typedef StorageMemory<CoeffReturnType, Device> Storage;
+ typedef typename Storage::Type EvaluatorPointerType;
enum {
IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
BlockAccess = false,
+ PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = false, // to be implemented
RawAccess = false
};
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
+ //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
+ typedef internal::TensorBlockNotImplemented TensorBlock;
+ //===--------------------------------------------------------------------===//
+
+ EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device) { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
return m_impl.dimensions();
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
+ EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
m_impl.evalSubExprsIfNeeded(NULL);
return true;
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
+ EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
}
@@ -117,7 +124,13 @@ struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
}
- EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
+ EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
+
+#ifdef EIGEN_USE_SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
+ m_impl.bind(cgh);
+ }
+#endif
protected:
TensorEvaluator<ArgType, Device> m_impl;
@@ -147,7 +160,7 @@ struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<Xp
template<typename ReduceOp, typename Dims, typename XprType>
struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
{
- typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type;
+ typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>EIGEN_DEVICE_REF type;
};
template<typename ReduceOp, typename Dims, typename XprType>
@@ -172,7 +185,7 @@ class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Di
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
const ReduceOp& reduce_op,
- const int return_dim,
+ const Index return_dim,
const Dims& reduce_dims)
: m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
@@ -187,12 +200,12 @@ class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Di
const Dims& reduce_dims() const { return m_reduce_dims; }
EIGEN_DEVICE_FUNC
- int return_dim() const { return m_return_dim; }
+ Index return_dim() const { return m_return_dim; }
protected:
typename XprType::Nested m_xpr;
const ReduceOp m_reduce_op;
- const int m_return_dim;
+ const Index m_return_dim;
const Dims m_reduce_dims;
};
@@ -209,21 +222,29 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
static const int NumDims = internal::array_size<InputDimensions>::value;
typedef array<Index, NumDims> StrideDims;
+ typedef StorageMemory<CoeffReturnType, Device> Storage;
+ typedef typename Storage::Type EvaluatorPointerType;
+ typedef StorageMemory<TupleType, Device> TupleStorageMem;
enum {
- IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
- PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
- BlockAccess = false,
- Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
- CoordAccess = false, // to be implemented
- RawAccess = false
+ IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
+ PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
+ BlockAccess = false,
+ PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
+ Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = false
};
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
+ //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
+ typedef internal::TensorBlockNotImplemented TensorBlock;
+ //===--------------------------------------------------------------------===//
+
+ EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_orig_impl(op.expression(), device),
m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
- m_return_dim(op.return_dim()) {
-
+ m_return_dim(op.return_dim())
+ {
gen_strides(m_orig_impl.dimensions(), m_strides);
if (Layout == static_cast<int>(ColMajor)) {
const Index total_size = internal::array_prod(m_orig_impl.dimensions());
@@ -231,19 +252,22 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
} else {
const Index total_size = internal::array_prod(m_orig_impl.dimensions());
m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
- }
- m_stride_div = m_strides[m_return_dim];
+ }
+ // If m_return_dim is not a valid index, returns 1 or this can crash on Windows.
+ m_stride_div = ((m_return_dim >= 0) &&
+ (m_return_dim < static_cast<Index>(m_strides.size())))
+ ? m_strides[m_return_dim] : 1;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
return m_impl.dimensions();
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
+ EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
m_impl.evalSubExprsIfNeeded(NULL);
return true;
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
+ EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
}
@@ -252,7 +276,13 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
}
- EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
+ EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
+#ifdef EIGEN_USE_SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
+ m_impl.bind(cgh);
+ m_orig_impl.bind(cgh);
+ }
+#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
costPerCoeff(bool vectorized) const {
@@ -288,7 +318,7 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
protected:
TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
- const int m_return_dim;
+ const Index m_return_dim;
StrideDims m_strides;
Index m_stride_mod;
Index m_stride_div;