diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h index 5cf7b4f71..974feb0ad 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h @@ -21,14 +21,28 @@ enum { // Default Blocking Strategy -template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol> +template<typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex, int ShardingType = ShardByCol> class TensorContractionBlocking { public: - typedef typename LhsMapper::Scalar LhsScalar; - typedef typename RhsMapper::Scalar RhsScalar; + /* + adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h` + requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h` + which in turn, requires adding EIGEN_DEVICE_FUNC to `evaluateProductBlockingSizesHeuristic` in `GeneralBlockPanelKernel.h` + which in turn, requires adding EIGEN_DEVICE_FUNC to `manage_caching_sizes` in `GeneralBlockPanelKernel.h` + (else HIPCC will error out) - EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : + However adding EIGEN_DEVICE_FUNC to `manage_caching_sizes` in `GeneralBlockPanelKernel.h` + results in NVCC erroring out with the following error + + ../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901: + dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function + */ + + #if !defined(EIGEN_HIPCC) + EIGEN_DEVICE_FUNC + #endif + TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, StorageIndex num_threads = 1) : kc_(k), mc_(m), nc_(n) { if (ShardingType == ShardByCol) { @@ -37,19 +51,22 @@ class TensorContractionBlocking { else { computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads); } + + const int rhs_packet_size = internal::packet_traits<RhsScalar>::size; + kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ? + kc_ : (kc_ / rhs_packet_size) * rhs_packet_size; } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; } private: - Index kc_; - Index mc_; - Index nc_; + StorageIndex kc_; + StorageIndex mc_; + StorageIndex nc_; }; - } // end namespace internal } // end namespace Eigen |