aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2017-03-08 17:23:02 +0000
committerandroid-build-merger <android-build-merger@google.com>2017-03-08 17:23:02 +0000
commit4b1361e1aca4189cfba2bd2c0fc87d593da76728 (patch)
tree0488797fc544fe977bec6418c73445759f052482 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
parent29a3140b512228375adb9f24414fb2996c2889e6 (diff)
parentd3ea081e6f4a81b7086df9f35f08f25e54f20863 (diff)
downloadeigen-4b1361e1aca4189cfba2bd2c0fc87d593da76728.tar.gz
Merge "Rebase Eigen to 3.3.3." am: 7de1f32623 am: 6688b8b260
am: d3ea081e6f Change-Id: Id61115ffaa6ad3c134b182547652c7248e8aa5cd
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h56
1 files changed, 56 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
new file mode 100644
index 000000000..5cf7b4f71
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
@@ -0,0 +1,56 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
+#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
+
+
+namespace Eigen {
+namespace internal {
+
+enum {
+ ShardByRow = 0,
+ ShardByCol = 1
+};
+
+
+// Default Blocking Strategy
+template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
+class TensorContractionBlocking {
+ public:
+
+ typedef typename LhsMapper::Scalar LhsScalar;
+ typedef typename RhsMapper::Scalar RhsScalar;
+
+ EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
+ kc_(k), mc_(m), nc_(n)
+ {
+ if (ShardingType == ShardByCol) {
+ computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
+ }
+ else {
+ computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
+ }
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
+
+ private:
+ Index kc_;
+ Index mc_;
+ Index nc_;
+};
+
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H