aboutsummaryrefslogtreecommitdiff
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h69
1 files changed, 69 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
index 29e50a3b2..96fa46c86 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h
@@ -28,6 +28,8 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
public:
TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
+ EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorDevice)
+
template<typename OtherDerived>
EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
@@ -63,6 +65,73 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
ExpressionType& m_expression;
};
+/** \class TensorAsyncDevice
+ * \ingroup CXX11_Tensor_Module
+ *
+ * \brief Pseudo expression providing an operator = that will evaluate its
+ * argument asynchronously on the specified device. Currently only
+ * ThreadPoolDevice implements proper asynchronous execution, while the default
+ * and GPU devices just run the expression synchronously and call m_done() on
+ * completion..
+ *
+ * Example:
+ * auto done = []() { ... expression evaluation done ... };
+ * C.device(thread_pool_device, std::move(done)) = A + B;
+ */
+
+template <typename ExpressionType, typename DeviceType, typename DoneCallback>
+class TensorAsyncDevice {
+ public:
+ TensorAsyncDevice(const DeviceType& device, ExpressionType& expression,
+ DoneCallback done)
+ : m_device(device), m_expression(expression), m_done(std::move(done)) {}
+
+ template <typename OtherDerived>
+ EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
+ typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
+ typedef internal::TensorExecutor<const Assign, DeviceType> Executor;
+
+ Assign assign(m_expression, other);
+ Executor::run(assign, m_device);
+ m_done();
+
+ return *this;
+ }
+
+ protected:
+ const DeviceType& m_device;
+ ExpressionType& m_expression;
+ DoneCallback m_done;
+};
+
+
+#ifdef EIGEN_USE_THREADS
+template <typename ExpressionType, typename DoneCallback>
+class TensorAsyncDevice<ExpressionType, ThreadPoolDevice, DoneCallback> {
+ public:
+ TensorAsyncDevice(const ThreadPoolDevice& device, ExpressionType& expression,
+ DoneCallback done)
+ : m_device(device), m_expression(expression), m_done(std::move(done)) {}
+
+ template <typename OtherDerived>
+ EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
+ typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
+ typedef internal::TensorAsyncExecutor<const Assign, ThreadPoolDevice, DoneCallback> Executor;
+
+ // WARNING: After assignment 'm_done' callback will be in undefined state.
+ Assign assign(m_expression, other);
+ Executor::runAsync(assign, m_device, std::move(m_done));
+
+ return *this;
+ }
+
+ protected:
+ const ThreadPoolDevice& m_device;
+ ExpressionType& m_expression;
+ DoneCallback m_done;
+};
+#endif
+
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H