diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBase.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 274 |
1 files changed, 220 insertions, 54 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 7a45a5cf4..35b6458e5 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -20,9 +20,11 @@ namespace Eigen { * \brief The tensor base class. * * This class is the common parent of the Tensor and TensorMap class, thus - * making it possible to use either class interchangably in expressions. + * making it possible to use either class interchangeably in expressions. */ - +#ifndef EIGEN_PARSED_BY_DOXYGEN +// FIXME Doxygen does not like the inheritance with different template parameters +// Since there is no doxygen documentation inside, we disable it for now template<typename Derived> class TensorBase<Derived, ReadOnlyAccessors> { @@ -133,6 +135,78 @@ class TensorBase<Derived, ReadOnlyAccessors> return unaryExpr(internal::scalar_digamma_op<Scalar>()); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i0_op<Scalar>, const Derived> + bessel_i0() const { + return unaryExpr(internal::scalar_bessel_i0_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i0e_op<Scalar>, const Derived> + bessel_i0e() const { + return unaryExpr(internal::scalar_bessel_i0e_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i1_op<Scalar>, const Derived> + bessel_i1() const { + return unaryExpr(internal::scalar_bessel_i1_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_i1e_op<Scalar>, const Derived> + bessel_i1e() const { + return unaryExpr(internal::scalar_bessel_i1e_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_j0_op<Scalar>, const Derived> + bessel_j0() const { + return unaryExpr(internal::scalar_bessel_j0_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_y0_op<Scalar>, const Derived> + bessel_y0() const { + return unaryExpr(internal::scalar_bessel_y0_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_j1_op<Scalar>, const Derived> + bessel_j1() const { + return unaryExpr(internal::scalar_bessel_j1_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_y1_op<Scalar>, const Derived> + bessel_y1() const { + return unaryExpr(internal::scalar_bessel_y1_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k0_op<Scalar>, const Derived> + bessel_k0() const { + return unaryExpr(internal::scalar_bessel_k0_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k0e_op<Scalar>, const Derived> + bessel_k0e() const { + return unaryExpr(internal::scalar_bessel_k0e_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k1_op<Scalar>, const Derived> + bessel_k1() const { + return unaryExpr(internal::scalar_bessel_k1_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_bessel_k1e_op<Scalar>, const Derived> + bessel_k1e() const { + return unaryExpr(internal::scalar_bessel_k1e_op<Scalar>()); + } + // igamma(a = this, x = other) template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_igamma_op<Scalar>, const Derived, const OtherDerived> @@ -140,6 +214,20 @@ class TensorBase<Derived, ReadOnlyAccessors> return binaryExpr(other.derived(), internal::scalar_igamma_op<Scalar>()); } + // igamma_der_a(a = this, x = other) + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_igamma_der_a_op<Scalar>, const Derived, const OtherDerived> + igamma_der_a(const OtherDerived& other) const { + return binaryExpr(other.derived(), internal::scalar_igamma_der_a_op<Scalar>()); + } + + // gamma_sample_der_alpha(alpha = this, sample = other) + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_gamma_sample_der_alpha_op<Scalar>, const Derived, const OtherDerived> + gamma_sample_der_alpha(const OtherDerived& other) const { + return binaryExpr(other.derived(), internal::scalar_gamma_sample_der_alpha_op<Scalar>()); + } + // igammac(a = this, x = other) template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_igammac_op<Scalar>, const Derived, const OtherDerived> @@ -174,9 +262,15 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sigmoid_op<Scalar>, const Derived> + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_ndtri_op<Scalar>, const Derived> + ndtri() const { + return unaryExpr(internal::scalar_ndtri_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_logistic_op<Scalar>, const Derived> sigmoid() const { - return unaryExpr(internal::scalar_sigmoid_op<Scalar>()); + return unaryExpr(internal::scalar_logistic_op<Scalar>()); } EIGEN_DEVICE_FUNC @@ -186,6 +280,12 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_expm1_op<Scalar>, const Derived> + expm1() const { + return unaryExpr(internal::scalar_expm1_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived> log() const { return unaryExpr(internal::scalar_log_op<Scalar>()); @@ -198,15 +298,29 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log2_op<Scalar>, const Derived> + log2() const { + return unaryExpr(internal::scalar_log2_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> abs() const { return unaryExpr(internal::scalar_abs_op<Scalar>()); } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived> + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clamp_op<Scalar>, const Derived> + clip(Scalar min, Scalar max) const { + return unaryExpr(internal::scalar_clamp_op<Scalar>(min, max)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const typename internal::conditional<NumTraits<CoeffReturnType>::IsComplex, + TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>, + Derived>::type conjugate() const { - return unaryExpr(internal::scalar_conjugate_op<Scalar>()); + return choose(Cond<NumTraits<CoeffReturnType>::IsComplex>(), unaryExpr(internal::scalar_conjugate_op<Scalar>()), derived()); } EIGEN_DEVICE_FUNC @@ -287,22 +401,27 @@ class TensorBase<Derived, ReadOnlyAccessors> return unaryExpr(internal::scalar_mod_op<Scalar>(rhs)); } + template <int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > cwiseMax(Scalar threshold) const { - return cwiseMax(constant(threshold)); + return cwiseMax<NanPropagation>(constant(threshold)); } + template <int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > cwiseMin(Scalar threshold) const { - return cwiseMin(constant(threshold)); + return cwiseMin<NanPropagation>(constant(threshold)); } - template <typename NewType> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorConversionOp<NewType, const Derived> + template<typename NewType> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const typename internal::conditional<internal::is_same<NewType, CoeffReturnType>::value, + Derived, + TensorConversionOp<NewType, const Derived> >::type cast() const { - return TensorConversionOp<NewType, const Derived>(derived()); + return choose(Cond<internal::is_same<NewType, CoeffReturnType>::value>(), derived(), TensorConversionOp<NewType, const Derived>(derived())); } EIGEN_DEVICE_FUNC @@ -312,6 +431,12 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_rint_op<Scalar>, const Derived> + rint() const { + return unaryExpr(internal::scalar_rint_op<Scalar>()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_ceil_op<Scalar>, const Derived> ceil() const { return unaryExpr(internal::scalar_ceil_op<Scalar>()); @@ -355,16 +480,16 @@ class TensorBase<Derived, ReadOnlyAccessors> return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>()); } - template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived> + template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { - return binaryExpr(other.derived(), internal::scalar_max_op<Scalar>()); + return binaryExpr(other.derived(), internal::scalar_max_op<Scalar,Scalar, NaNPropagation>()); } - template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived> + template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { - return binaryExpr(other.derived(), internal::scalar_min_op<Scalar>()); + return binaryExpr(other.derived(), internal::scalar_min_op<Scalar,Scalar, NaNPropagation>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -479,9 +604,15 @@ class TensorBase<Derived, ReadOnlyAccessors> typedef Eigen::IndexPair<Index> DimensionPair; template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorContractionOp<const Dimensions, const Derived, const OtherDerived> + const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel> contract(const OtherDerived& other, const Dimensions& dims) const { - return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims); + return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>(derived(), other.derived(), dims); + } + + template<typename OtherDerived, typename Dimensions, typename OutputKernel> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel> + contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const { + return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>(derived(), other.derived(), dims, output_kernel); } // Convolutions. @@ -494,8 +625,8 @@ class TensorBase<Derived, ReadOnlyAccessors> // Fourier transforms template <int FFTDataType, int FFTDirection, typename FFT> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorFFTOp<const FFT, const Derived, FFTDataType, FFTDirection> - fft(const FFT& fft) const { - return TensorFFTOp<const FFT, const Derived, FFTDataType, FFTDirection>(derived(), fft); + fft(const FFT& dims) const { + return TensorFFTOp<const FFT, const Derived, FFTDataType, FFTDirection>(derived(), dims); } // Scan. @@ -557,51 +688,53 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>()); } - template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived> + template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived> maximum(const Dims& dims) const { - return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType>()); + return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType,NanPropagation>()); } - const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> + template <int NanPropagation=PropagateFast> + const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived> maximum() const { DimensionList<Index, NumDimensions> in_dims; - return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>()); + return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType,NanPropagation>()); } - template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived> + template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived> minimum(const Dims& dims) const { - return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType>()); + return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType,NanPropagation>()); } - const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived> + template <int NanPropagation=PropagateFast> + const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived> minimum() const { DimensionList<Index, NumDimensions> in_dims; - return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); + return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType,NanPropagation>()); } template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::AndReducer, const Dims, const TensorConversionOp<bool, const Derived> > + const TensorReductionOp<internal::AndReducer, const Dims, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type > all(const Dims& dims) const { return cast<bool>().reduce(dims, internal::AndReducer()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> > + const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type > all() const { DimensionList<Index, NumDimensions> in_dims; return cast<bool>().reduce(in_dims, internal::AndReducer()); } template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::OrReducer, const Dims, const TensorConversionOp<bool, const Derived> > + const TensorReductionOp<internal::OrReducer, const Dims, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type > any(const Dims& dims) const { return cast<bool>().reduce(dims, internal::OrReducer()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> > + const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const typename internal::conditional<internal::is_same<bool, CoeffReturnType>::value, Derived, TensorConversionOp<bool, const Derived> >::type > any() const { DimensionList<Index, NumDimensions> in_dims; return cast<bool>().reduce(in_dims, internal::OrReducer()); @@ -613,7 +746,7 @@ class TensorBase<Derived, ReadOnlyAccessors> const array<Index, NumDimensions>, const Derived> argmax() const { array<Index, NumDimensions> in_dims; - for (int d = 0; d < NumDimensions; ++d) in_dims[d] = d; + for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d; return TensorTupleReducerOp< internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, const array<Index, NumDimensions>, @@ -626,7 +759,7 @@ class TensorBase<Derived, ReadOnlyAccessors> const array<Index, NumDimensions>, const Derived> argmin() const { array<Index, NumDimensions> in_dims; - for (int d = 0; d < NumDimensions; ++d) in_dims[d] = d; + for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d; return TensorTupleReducerOp< internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >, const array<Index, NumDimensions>, @@ -637,7 +770,7 @@ class TensorBase<Derived, ReadOnlyAccessors> const TensorTupleReducerOp< internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >, const array<Index, 1>, const Derived> - argmax(const int return_dim) const { + argmax(const Index return_dim) const { array<Index, 1> in_dims; in_dims[0] = return_dim; return TensorTupleReducerOp< @@ -650,7 +783,7 @@ class TensorBase<Derived, ReadOnlyAccessors> const TensorTupleReducerOp< internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >, const array<Index, 1>, const Derived> - argmin(const int return_dim) const { + argmin(const Index return_dim) const { array<Index, 1> in_dims; in_dims[0] = return_dim; return TensorTupleReducerOp< @@ -665,10 +798,22 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorReductionOp<Reducer, const Dims, const Derived>(derived(), dims, reducer); } + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorTraceOp<const Dims, const Derived> + trace(const Dims& dims) const { + return TensorTraceOp<const Dims, const Derived>(derived(), dims); + } + + const TensorTraceOp<const DimensionList<Index, NumDimensions>, const Derived> + trace() const { + DimensionList<Index, NumDimensions> in_dims; + return TensorTraceOp<const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims); + } + template <typename Broadcast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorBroadcastingOp<const Broadcast, const Derived> - broadcast(const Broadcast& broadcast) const { - return TensorBroadcastingOp<const Broadcast, const Derived>(derived(), broadcast); + broadcast(const Broadcast& bcast) const { + return TensorBroadcastingOp<const Broadcast, const Derived>(derived(), bcast); } template <typename Axis, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -776,8 +921,8 @@ class TensorBase<Derived, ReadOnlyAccessors> } template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp<const Shuffle, const Derived> - shuffle(const Shuffle& shuffle) const { - return TensorShufflingOp<const Shuffle, const Derived>(derived(), shuffle); + shuffle(const Shuffle& shfl) const { + return TensorShufflingOp<const Shuffle, const Derived>(derived(), shfl); } template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingOp<const Strides, const Derived> @@ -818,7 +963,8 @@ class TensorBase<Derived, ReadOnlyAccessors> protected: template <typename Scalar, int NumIndices, int Options, typename IndexType> friend class Tensor; template <typename Scalar, typename Dimensions, int Option, typename IndexTypes> friend class TensorFixedSize; - template <typename OtherDerived, int AccessLevel> friend class TensorBase; + // the Eigen:: prefix is required to workaround a compilation issue with nvcc 9.0 + template <typename OtherDerived, int AccessLevel> friend class Eigen::TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); } }; @@ -826,6 +972,7 @@ class TensorBase<Derived, ReadOnlyAccessors> template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> { public: + typedef TensorBase<Derived, ReadOnlyAccessors> Base; typedef internal::traits<Derived> DerivedTraits; typedef typename DerivedTraits::Scalar Scalar; typedef typename DerivedTraits::Index Index; @@ -834,7 +981,8 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> { template <typename Scalar, int NumIndices, int Options, typename IndexType> friend class Tensor; template <typename Scalar, typename Dimensions, int Option, typename IndexTypes> friend class TensorFixedSize; - template <typename OtherDerived, int OtherAccessLevel> friend class TensorBase; + // the Eigen:: prefix is required to workaround a compilation issue with nvcc 9.0 + template <typename OtherDerived, int OtherAccessLevel> friend class Eigen::TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setZero() { @@ -972,13 +1120,13 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> { template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp<const Shuffle, const Derived> - shuffle(const Shuffle& shuffle) const { - return TensorShufflingOp<const Shuffle, const Derived>(derived(), shuffle); + shuffle(const Shuffle& shfl) const { + return TensorShufflingOp<const Shuffle, const Derived>(derived(), shfl); } template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp<const Shuffle, Derived> - shuffle(const Shuffle& shuffle) { - return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle); + shuffle(const Shuffle& shfl) { + return TensorShufflingOp<const Shuffle, Derived>(derived(), shfl); } template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -994,17 +1142,35 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> { // Select the device on which to evaluate the expression. template <typename DeviceType> - TensorDevice<Derived, DeviceType> device(const DeviceType& device) { - return TensorDevice<Derived, DeviceType>(device, derived()); + TensorDevice<Derived, DeviceType> device(const DeviceType& dev) { + return TensorDevice<Derived, DeviceType>(dev, derived()); + } + + // Select the async device on which to evaluate the expression. + template <typename DeviceType, typename DoneCallback> + TensorAsyncDevice<Derived, DeviceType, DoneCallback> device(const DeviceType& dev, DoneCallback done) { + return TensorAsyncDevice<Derived, DeviceType, DoneCallback>(dev, derived(), std::move(done)); } protected: + EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(TensorBase) + EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorBase) + + template<typename OtherDerived> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& operator=(const OtherDerived& other) + { + typedef TensorAssignOp<Derived, const OtherDerived> Assign; + Assign assign(derived(), other.derived()); + internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); + return derived(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return *static_cast<Derived*>(this); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); } }; - +#endif // EIGEN_PARSED_BY_DOXYGEN } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H |