aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h242
1 files changed, 175 insertions, 67 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
index 9b2cb3ff6..9ab900b4a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
@@ -22,8 +22,19 @@ enum {
/*
* Implementation of the Eigen blas_data_mapper class for tensors.
*/
-
-template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
+/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which
+/// is scalar * for CoeffLoader.
+template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
+struct CoeffLoader;
+
+template <typename Scalar, typename Index, int side, typename Tensor,
+ typename nocontract_t, typename contract_t, int packet_size,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ template <class> class MakePointer_ = MakePointer>
+class BaseTensorContractionMapper;
+
+template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
+struct CoeffLoader {
enum {
DirectOffsets = false
};
@@ -34,6 +45,12 @@ template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
eigen_assert(false && "unsupported");
}
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
+ data() const {
+ eigen_assert(false && "unsupported");
+ return NULL;
+ }
+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -42,12 +59,19 @@ template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
return m_tensor.template packet<LoadMode>(index);
}
+ #ifdef EIGEN_USE_SYCL
+ // The placeholder accessors require to be bound to a command group handler for SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
+ m_tensor.bind(cgh);
+ }
+ #endif
private:
const Tensor m_tensor;
};
-template <typename Tensor> struct CoeffLoader<Tensor, true> {
+template <typename Tensor, template <class> class MakePointer_>
+struct CoeffLoader<Tensor, true, MakePointer_> {
enum {
DirectOffsets = true
};
@@ -58,6 +82,11 @@ template <typename Tensor> struct CoeffLoader<Tensor, true> {
m_data += offset;
}
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
+ data() const {
+ return m_data;
+ }
+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -65,15 +94,23 @@ template <typename Tensor> struct CoeffLoader<Tensor, true> {
{
return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
}
+
+ #ifdef EIGEN_USE_SYCL
+ // The placeholder accessors require to be bound to a command group handler for SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
+ m_data.bind(cgh);
+ }
+ #endif
private:
typedef typename Tensor::Scalar Scalar;
- const Scalar* m_data;
+
+ typename MakePointer_<const Scalar>::Type m_data;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
- int packet_size, bool inner_dim_contiguous, int Alignment>
+ int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
class SimpleTensorContractionMapper {
public:
EIGEN_DEVICE_FUNC
@@ -89,7 +126,7 @@ class SimpleTensorContractionMapper {
m_k_strides(k_strides) { }
enum {
- DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
+ DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets
};
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
@@ -113,8 +150,10 @@ class SimpleTensorContractionMapper {
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
const bool left = (side == Lhs);
+ EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
Index nocontract_val = left ? row : col;
Index linidx = 0;
+ EIGEN_UNROLL_LOOP
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
const Index idx = nocontract_val / m_ij_strides[i];
linidx += idx * m_nocontract_strides[i];
@@ -131,6 +170,7 @@ class SimpleTensorContractionMapper {
Index contract_val = left ? col : row;
if(array_size<contract_t>::value > 0) {
+ EIGEN_UNROLL_LOOP
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
const Index idx = contract_val / m_k_strides[i];
linidx += idx * m_contract_strides[i];
@@ -151,9 +191,11 @@ class SimpleTensorContractionMapper {
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
const bool left = (side == Lhs);
+ EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
Index linidx[2] = {0, 0};
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
+ EIGEN_UNROLL_LOOP
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
@@ -174,6 +216,7 @@ class SimpleTensorContractionMapper {
Index contract_val[2] = {left ? col : row, left ? col : row + distance};
if (array_size<contract_t>::value> 0) {
+ EIGEN_UNROLL_LOOP
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
const Index idx0 = contract_val[0] / m_k_strides[i];
const Index idx1 = contract_val[1] / m_k_strides[i];
@@ -205,24 +248,41 @@ class SimpleTensorContractionMapper {
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
}
+ #ifdef EIGEN_USE_SYCL
+ // The placeholder accessors require to be bound to a command group handler for SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
+ m_tensor.bind(cgh);
+ }
+ #endif
+
+ const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const {
+ return m_tensor;
+ }
+
+ const nocontract_t& nocontract_strides() const {
+ return m_nocontract_strides;
+ }
+ const nocontract_t& ij_strides() const { return m_ij_strides; }
+ const contract_t& contract_strides() const { return m_contract_strides; }
+ const contract_t& k_strides() const { return m_k_strides; }
+
protected:
- CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
+ CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
const nocontract_t m_nocontract_strides;
const nocontract_t m_ij_strides;
const contract_t m_contract_strides;
const contract_t m_k_strides;
};
-
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size, bool inner_dim_contiguous,
- bool inner_dim_reordered, int Alignment>
-class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment>
+ bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
+class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_>
{
public:
- typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper;
+ typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
EIGEN_DEVICE_FUNC
BaseTensorContractionMapper(const Tensor& tensor,
@@ -232,12 +292,11 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
const contract_t& k_strides) :
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
- typedef typename Tensor::PacketReturnType Packet;
- typedef typename unpacket_traits<Packet>::half HalfPacket;
-
- template <int AlignmentType>
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
+ template <typename PacketT,int AlignmentType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type
+ load(Index i, Index j) const
+ {
// whole method makes column major assumption
// don't need to add offsets for now (because operator handles that)
@@ -252,7 +311,7 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
const Index first = indexPair.first;
- const Index last = indexPair.second;
+ const Index lastIdx = indexPair.second;
// We can always do optimized packet reads from left hand side right now, because
// the vertical matrix dimension on the left hand side is never contracting.
@@ -260,7 +319,7 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
// been shuffled first.
if (Tensor::PacketAccess &&
(side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
- (last - first) == (packet_size - 1)) {
+ (lastIdx - first) == (packet_size - 1)) {
return this->m_tensor.template packet<AlignmentType>(first);
}
@@ -268,31 +327,44 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
EIGEN_ALIGN_MAX Scalar data[packet_size];
data[0] = this->m_tensor.coeff(first);
+ EIGEN_UNROLL_LOOP
for (Index k = 1; k < packet_size - 1; k += 2) {
const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
data[k] = this->m_tensor.coeff(internal_pair.first);
data[k + 1] = this->m_tensor.coeff(internal_pair.second);
}
- data[packet_size - 1] = this->m_tensor.coeff(last);
+ data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
- return pload<Packet>(data);
+ return pload<PacketT>(data);
}
- template <int AlignmentType>
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
- // whole method makes column major assumption
+ template <typename PacketT,int AlignmentType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type
+ load(Index i, Index j) const
+ {
+ const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
+ EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
- // don't need to add offsets for now (because operator handles that)
- const Index half_packet_size = unpacket_traits<HalfPacket>::size;
- if (half_packet_size == packet_size) {
- return loadPacket<AlignmentType>(i, j);
- }
- EIGEN_ALIGN_MAX Scalar data[half_packet_size];
- for (Index k = 0; k < half_packet_size; k++) {
- data[k] = operator()(i + k, j);
+ const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
+ const Index first = indexPair.first;
+ const Index lastIdx = indexPair.second;
+
+ data[0] = this->m_tensor.coeff(first);
+ for (Index k = 1; k < requested_packet_size - 1; k += 2) {
+ const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
+ data[k] = this->m_tensor.coeff(internal_pair.first);
+ data[k + 1] = this->m_tensor.coeff(internal_pair.second);
}
- return pload<HalfPacket>(data);
+ data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
+
+ return pload<PacketT>(data);
+ }
+
+ template <typename PacketT,int AlignmentType>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
+ return this->load<PacketT,AlignmentType>(i,j);
}
};
@@ -301,11 +373,12 @@ template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
bool inner_dim_contiguous,
- bool inner_dim_reordered, int Alignment>
-class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
+ bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
+class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
+ : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_>
{
public:
- typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper;
+ typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
EIGEN_DEVICE_FUNC
BaseTensorContractionMapper(const Tensor& tensor,
@@ -315,16 +388,17 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con
const contract_t& k_strides) :
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
- typedef typename Tensor::PacketReturnType Packet;
- template <int> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
+ template <typename PacketT,int> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
EIGEN_ALIGN_MAX Scalar data[1];
data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
- return pload<typename Tensor::PacketReturnType>(data);
+ return pload<PacketT>(data);
}
- template <int> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
- return loadPacket(i, j);
+ template <typename PacketT,int> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
+ EIGEN_ALIGN_MAX Scalar data[1];
+ data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
+ return pload<PacketT>(data);
}
};
@@ -333,14 +407,12 @@ template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
- bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer>
class TensorContractionSubMapper {
public:
- typedef typename Tensor::PacketReturnType Packet;
- typedef typename unpacket_traits<Packet>::half HalfPacket;
- typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
- typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
+ typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper;
+ typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self;
typedef Self LinearMapper;
enum {
@@ -372,27 +444,32 @@ class TensorContractionSubMapper {
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ template <typename PacketT>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i) const {
if (UseDirectOffsets) {
- return m_base_mapper.template loadPacket<Alignment>(i, 0);
+ return m_base_mapper.template loadPacket<PacketT,Alignment>(i, 0);
}
- return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
+ return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, m_horiz_offset);
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
+
+ template <typename PacketT>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
if (UseDirectOffsets) {
- return m_base_mapper.template loadPacket<Alignment>(i, j);
+ return m_base_mapper.template loadPacket<PacketT,Alignment>(i, j);
}
- return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
+ return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
+ template <typename PacketT, int AlignmentType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
if (UseDirectOffsets) {
- return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
+ return m_base_mapper.template load<PacketT,AlignmentType>(i, j);
}
- return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
+ return m_base_mapper.template loadPacket<PacketT,AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
}
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
+ template <typename PacketT>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
if (UseDirectOffsets) {
m_base_mapper.storePacket(i, 0, p);
}
@@ -408,19 +485,30 @@ class TensorContractionSubMapper {
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
- EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
if (UseDirectOffsets) {
- return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
+ return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i, 0);
}
- return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
+ return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i + m_vert_offset, m_horiz_offset);
}
- template <typename Packet>
+ template <typename PacketT>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
return false;
}
+ #ifdef EIGEN_USE_SYCL
+ // The placeholder accessors require to be bound to a command group handler for SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
+ m_base_mapper.bind(cgh);
+ }
+ #endif
+
+ const ParentMapper& base_mapper() const { return m_base_mapper; }
+ Index vert_offset() const { return m_vert_offset; }
+ Index horiz_offset() const { return m_horiz_offset; }
+
private:
ParentMapper m_base_mapper;
const Index m_vert_offset;
@@ -432,14 +520,14 @@ template<typename Scalar_, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
- bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer>
class TensorContractionInputMapper
- : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
+ : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
public:
typedef Scalar_ Scalar;
- typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
- typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
+ typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base;
+ typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper;
typedef SubMapper VectorMapper;
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
@@ -457,9 +545,29 @@ class TensorContractionInputMapper
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor() const {
+ return Base::m_tensor;
+ }
};
+template <typename T> struct TensorContractionInputMapperTrait;
+
+template<typename Scalar_, typename Index_, int side_,
+ typename Tensor_,
+ typename nocontract_t_, typename contract_t_,
+ int packet_size_,
+ bool inner_dim_contiguous_, bool inner_dim_reordered_, int Alignment_, template <class> class MakePointer_>
+struct TensorContractionInputMapperTrait<TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_,
+ nocontract_t_, contract_t_, packet_size_, inner_dim_contiguous_,
+ inner_dim_reordered_, Alignment_, MakePointer_> > {
+
+ typedef Tensor_ XprType;
+ static const bool inner_dim_contiguous = inner_dim_contiguous_;
+ static const bool inner_dim_reordered = inner_dim_reordered_;
+ };
+
} // end namespace internal
} // end namespace Eigen