aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBase.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h274
1 files changed, 220 insertions, 54 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index 7a45a5cf4..35b6458e5 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -20,9 +20,11 @@ namespace Eigen {
* \brief The tensor base class.
*
* This class is the common parent of the Tensor and TensorMap class, thus
- * making it possible to use either class interchangably in expressions.
+ * making it possible to use either class interchangeably in expressions.
*/
-
+#ifndef EIGEN_PARSED_BY_DOXYGEN
+// FIXME Doxygen does not like the inheritance with different template parameters
+// Since there is no doxygen documentation inside, we disable it for now
template<typename Derived>
class TensorBase<Derived, ReadOnlyAccessors>
{
@@ -133,6 +135,78 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_digamma_op<Scalar>());
}
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i0_op<Scalar>, const Derived>
+ bessel_i0() const {
+ return unaryExpr(internal::scalar_bessel_i0_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i0e_op<Scalar>, const Derived>
+ bessel_i0e() const {
+ return unaryExpr(internal::scalar_bessel_i0e_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i1_op<Scalar>, const Derived>
+ bessel_i1() const {
+ return unaryExpr(internal::scalar_bessel_i1_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i1e_op<Scalar>, const Derived>
+ bessel_i1e() const {
+ return unaryExpr(internal::scalar_bessel_i1e_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_j0_op<Scalar>, const Derived>
+ bessel_j0() const {
+ return unaryExpr(internal::scalar_bessel_j0_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_y0_op<Scalar>, const Derived>
+ bessel_y0() const {
+ return unaryExpr(internal::scalar_bessel_y0_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_j1_op<Scalar>, const Derived>
+ bessel_j1() const {
+ return unaryExpr(internal::scalar_bessel_j1_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_y1_op<Scalar>, const Derived>
+ bessel_y1() const {
+ return unaryExpr(internal::scalar_bessel_y1_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k0_op<Scalar>, const Derived>
+ bessel_k0() const {
+ return unaryExpr(internal::scalar_bessel_k0_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k0e_op<Scalar>, const Derived>
+ bessel_k0e() const {
+ return unaryExpr(internal::scalar_bessel_k0e_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k1_op<Scalar>, const Derived>
+ bessel_k1() const {
+ return unaryExpr(internal::scalar_bessel_k1_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k1e_op<Scalar>, const Derived>
+ bessel_k1e() const {
+ return unaryExpr(internal::scalar_bessel_k1e_op<Scalar>());
+ }
+
// igamma(a = this, x = other)
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_igamma_op<Scalar>, const Derived, const OtherDerived>
@@ -140,6 +214,20 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), internal::scalar_igamma_op<Scalar>());
}
+ // igamma_der_a(a = this, x = other)
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_igamma_der_a_op<Scalar>, const Derived, const OtherDerived>
+ igamma_der_a(const OtherDerived& other) const {
+ return binaryExpr(other.derived(), internal::scalar_igamma_der_a_op<Scalar>());
+ }
+
+ // gamma_sample_der_alpha(alpha = this, sample = other)
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_gamma_sample_der_alpha_op<Scalar>, const Derived, const OtherDerived>
+ gamma_sample_der_alpha(const OtherDerived& other) const {
+ return binaryExpr(other.derived(), internal::scalar_gamma_sample_der_alpha_op<Scalar>());
+ }
+
// igammac(a = this, x = other)
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_igammac_op<Scalar>, const Derived, const OtherDerived>
@@ -174,9 +262,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sigmoid_op<Scalar>, const Derived>
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_ndtri_op<Scalar>, const Derived>
+ ndtri() const {
+ return unaryExpr(internal::scalar_ndtri_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_logistic_op<Scalar>, const Derived>
sigmoid() const {
- return unaryExpr(internal::scalar_sigmoid_op<Scalar>());
+ return unaryExpr(internal::scalar_logistic_op<Scalar>());
}
EIGEN_DEVICE_FUNC
@@ -186,6 +280,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_expm1_op<Scalar>, const Derived>
+ expm1() const {
+ return unaryExpr(internal::scalar_expm1_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived>
log() const {
return unaryExpr(internal::scalar_log_op<Scalar>());
@@ -198,15 +298,29 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log2_op<Scalar>, const Derived>
+ log2() const {
+ return unaryExpr(internal::scalar_log2_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived>
abs() const {
return unaryExpr(internal::scalar_abs_op<Scalar>());
}
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clamp_op<Scalar>, const Derived>
+ clip(Scalar min, Scalar max) const {
+ return unaryExpr(internal::scalar_clamp_op<Scalar>(min, max));
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const typename internal::conditional<NumTraits<CoeffReturnType>::IsComplex,
+ TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>,
+ Derived>::type
conjugate() const {
- return unaryExpr(internal::scalar_conjugate_op<Scalar>());
+ return choose(Cond<NumTraits<CoeffReturnType>::IsComplex>(), unaryExpr(internal::scalar_conjugate_op<Scalar>()), derived());
}
EIGEN_DEVICE_FUNC
@@ -287,22 +401,27 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_mod_op<Scalar>(rhs));
}
+ template <int NanPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
cwiseMax(Scalar threshold) const {
- return cwiseMax(constant(threshold));
+ return cwiseMax<NanPropagation>(constant(threshold));
}
+ template <int NanPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
cwiseMin(Scalar threshold) const {
- return cwiseMin(constant(threshold));
+ return cwiseMin<NanPropagation>(constant(threshold));
}
- template <typename NewType> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorConversionOp<NewType, const Derived>
+ template<typename NewType>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const typename internal::conditional<internal::is_same<NewType, CoeffReturnType>::value,
+ Derived,
+ TensorConversionOp<NewType, const Derived> >::type
cast() const {
- return TensorConversionOp<NewType, const Derived>(derived());
+ return choose(Cond<internal::is_same<NewType, CoeffReturnType>::value>(), derived(), TensorConversionOp<NewType, const Derived>(derived()));
}
EIGEN_DEVICE_FUNC
@@ -312,6 +431,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_rint_op<Scalar>, const Derived>
+ rint() const {
+ return unaryExpr(internal::scalar_rint_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_ceil_op<Scalar>, const Derived>
ceil() const {
return unaryExpr(internal::scalar_ceil_op<Scalar>());
@@ -355,16 +480,16 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>());
}
- template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>
+ template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived>
cwiseMax(const OtherDerived& other) const {
- return binaryExpr(other.derived(), internal::scalar_max_op<Scalar>());
+ return binaryExpr(other.derived(), internal::scalar_max_op<Scalar,Scalar, NaNPropagation>());
}
- template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>
+ template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived>
cwiseMin(const OtherDerived& other) const {
- return binaryExpr(other.derived(), internal::scalar_min_op<Scalar>());
+ return binaryExpr(other.derived(), internal::scalar_min_op<Scalar,Scalar, NaNPropagation>());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -479,9 +604,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
typedef Eigen::IndexPair<Index> DimensionPair;
template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorContractionOp<const Dimensions, const Derived, const OtherDerived>
+ const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>
contract(const OtherDerived& other, const Dimensions& dims) const {
- return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims);
+ return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>(derived(), other.derived(), dims);
+ }
+
+ template<typename OtherDerived, typename Dimensions, typename OutputKernel> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>
+ contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const {
+ return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>(derived(), other.derived(), dims, output_kernel);
}
// Convolutions.
@@ -494,8 +625,8 @@ class TensorBase<Derived, ReadOnlyAccessors>
// Fourier transforms
template <int FFTDataType, int FFTDirection, typename FFT> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorFFTOp<const FFT, const Derived, FFTDataType, FFTDirection>
- fft(const FFT& fft) const {
- return TensorFFTOp<const FFT, const Derived, FFTDataType, FFTDirection>(derived(), fft);
+ fft(const FFT& dims) const {
+ return TensorFFTOp<const FFT, const Derived, FFTDataType, FFTDirection>(derived(), dims);
}
// Scan.
@@ -557,51 +688,53 @@ class TensorBase<Derived, ReadOnlyAccessors>
return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>());
}
- template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>
+ template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>
maximum(const Dims& dims) const {
- return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType,NanPropagation>());
}
- const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>
+ template <int NanPropagation=PropagateFast>
+ const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>
maximum() const {
DimensionList<Index, NumDimensions> in_dims;
- return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType,NanPropagation>());
}
- template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>
+ template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>
minimum(const Dims& dims) const {
- return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType,NanPropagation>());
}
- const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>
+ template <int NanPropagation=PropagateFast>
+ const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>
minimum() const {
DimensionList<Index, NumDimensions> in_dims;
- return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType,NanPropagation>());
}
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::AndReducer, const Dims, const TensorConversionOp<bool, const Derived> >
+ const TensorReductionOp<internal::AndReducer, const Dims, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type >
all(const Dims& dims) const {
return cast<bool>().reduce(dims, internal::AndReducer());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> >
+ const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type >
all() const {
DimensionList<Index, NumDimensions> in_dims;
return cast<bool>().reduce(in_dims, internal::AndReducer());
}
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::OrReducer, const Dims, const TensorConversionOp<bool, const Derived> >
+ const TensorReductionOp<internal::OrReducer, const Dims, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type >
any(const Dims& dims) const {
return cast<bool>().reduce(dims, internal::OrReducer());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> >
+ const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type >
any() const {
DimensionList<Index, NumDimensions> in_dims;
return cast<bool>().reduce(in_dims, internal::OrReducer());
@@ -613,7 +746,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
const array<Index, NumDimensions>, const Derived>
argmax() const {
array<Index, NumDimensions> in_dims;
- for (int d = 0; d < NumDimensions; ++d) in_dims[d] = d;
+ for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d;
return TensorTupleReducerOp<
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
const array<Index, NumDimensions>,
@@ -626,7 +759,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
const array<Index, NumDimensions>, const Derived>
argmin() const {
array<Index, NumDimensions> in_dims;
- for (int d = 0; d < NumDimensions; ++d) in_dims[d] = d;
+ for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d;
return TensorTupleReducerOp<
internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >,
const array<Index, NumDimensions>,
@@ -637,7 +770,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
const TensorTupleReducerOp<
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
const array<Index, 1>, const Derived>
- argmax(const int return_dim) const {
+ argmax(const Index return_dim) const {
array<Index, 1> in_dims;
in_dims[0] = return_dim;
return TensorTupleReducerOp<
@@ -650,7 +783,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
const TensorTupleReducerOp<
internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >,
const array<Index, 1>, const Derived>
- argmin(const int return_dim) const {
+ argmin(const Index return_dim) const {
array<Index, 1> in_dims;
in_dims[0] = return_dim;
return TensorTupleReducerOp<
@@ -665,10 +798,22 @@ class TensorBase<Derived, ReadOnlyAccessors>
return TensorReductionOp<Reducer, const Dims, const Derived>(derived(), dims, reducer);
}
+ template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorTraceOp<const Dims, const Derived>
+ trace(const Dims& dims) const {
+ return TensorTraceOp<const Dims, const Derived>(derived(), dims);
+ }
+
+ const TensorTraceOp<const DimensionList<Index, NumDimensions>, const Derived>
+ trace() const {
+ DimensionList<Index, NumDimensions> in_dims;
+ return TensorTraceOp<const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims);
+ }
+
template <typename Broadcast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorBroadcastingOp<const Broadcast, const Derived>
- broadcast(const Broadcast& broadcast) const {
- return TensorBroadcastingOp<const Broadcast, const Derived>(derived(), broadcast);
+ broadcast(const Broadcast& bcast) const {
+ return TensorBroadcastingOp<const Broadcast, const Derived>(derived(), bcast);
}
template <typename Axis, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -776,8 +921,8 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorShufflingOp<const Shuffle, const Derived>
- shuffle(const Shuffle& shuffle) const {
- return TensorShufflingOp<const Shuffle, const Derived>(derived(), shuffle);
+ shuffle(const Shuffle& shfl) const {
+ return TensorShufflingOp<const Shuffle, const Derived>(derived(), shfl);
}
template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorStridingOp<const Strides, const Derived>
@@ -818,7 +963,8 @@ class TensorBase<Derived, ReadOnlyAccessors>
protected:
template <typename Scalar, int NumIndices, int Options, typename IndexType> friend class Tensor;
template <typename Scalar, typename Dimensions, int Option, typename IndexTypes> friend class TensorFixedSize;
- template <typename OtherDerived, int AccessLevel> friend class TensorBase;
+ // the Eigen:: prefix is required to workaround a compilation issue with nvcc 9.0
+ template <typename OtherDerived, int AccessLevel> friend class Eigen::TensorBase;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
@@ -826,6 +972,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value>
class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
public:
+ typedef TensorBase<Derived, ReadOnlyAccessors> Base;
typedef internal::traits<Derived> DerivedTraits;
typedef typename DerivedTraits::Scalar Scalar;
typedef typename DerivedTraits::Index Index;
@@ -834,7 +981,8 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
template <typename Scalar, int NumIndices, int Options, typename IndexType> friend class Tensor;
template <typename Scalar, typename Dimensions, int Option, typename IndexTypes> friend class TensorFixedSize;
- template <typename OtherDerived, int OtherAccessLevel> friend class TensorBase;
+ // the Eigen:: prefix is required to workaround a compilation issue with nvcc 9.0
+ template <typename OtherDerived, int OtherAccessLevel> friend class Eigen::TensorBase;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setZero() {
@@ -972,13 +1120,13 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorShufflingOp<const Shuffle, const Derived>
- shuffle(const Shuffle& shuffle) const {
- return TensorShufflingOp<const Shuffle, const Derived>(derived(), shuffle);
+ shuffle(const Shuffle& shfl) const {
+ return TensorShufflingOp<const Shuffle, const Derived>(derived(), shfl);
}
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorShufflingOp<const Shuffle, Derived>
- shuffle(const Shuffle& shuffle) {
- return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle);
+ shuffle(const Shuffle& shfl) {
+ return TensorShufflingOp<const Shuffle, Derived>(derived(), shfl);
}
template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -994,17 +1142,35 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
// Select the device on which to evaluate the expression.
template <typename DeviceType>
- TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
- return TensorDevice<Derived, DeviceType>(device, derived());
+ TensorDevice<Derived, DeviceType> device(const DeviceType& dev) {
+ return TensorDevice<Derived, DeviceType>(dev, derived());
+ }
+
+ // Select the async device on which to evaluate the expression.
+ template <typename DeviceType, typename DoneCallback>
+ TensorAsyncDevice<Derived, DeviceType, DoneCallback> device(const DeviceType& dev, DoneCallback done) {
+ return TensorAsyncDevice<Derived, DeviceType, DoneCallback>(dev, derived(), std::move(done));
}
protected:
+ EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(TensorBase)
+ EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorBase)
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Derived& operator=(const OtherDerived& other)
+ {
+ typedef TensorAssignOp<Derived, const OtherDerived> Assign;
+ Assign assign(derived(), other.derived());
+ internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
+ return derived();
+ }
+
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& derived() { return *static_cast<Derived*>(this); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
-
+#endif // EIGEN_PARSED_BY_DOXYGEN
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H