diff options
Diffstat (limited to 'unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h')
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h | 51 |
1 files changed, 30 insertions, 21 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h b/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h index bb6d9e1fe..02284b0dd 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h @@ -234,12 +234,13 @@ struct matrix_exp_computeUV<MatrixType, float> template <typename MatrixType> struct matrix_exp_computeUV<MatrixType, double> { + typedef typename NumTraits<typename traits<MatrixType>::Scalar>::Real RealScalar; template <typename ArgType> static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { using std::frexp; using std::pow; - const double l1norm = arg.cwiseAbs().colwise().sum().maxCoeff(); + const RealScalar l1norm = arg.cwiseAbs().colwise().sum().maxCoeff(); squarings = 0; if (l1norm < 1.495585217958292e-002) { matrix_exp_pade3(arg, U, V); @@ -250,10 +251,10 @@ struct matrix_exp_computeUV<MatrixType, double> } else if (l1norm < 2.097847961257068e+000) { matrix_exp_pade9(arg, U, V); } else { - const double maxnorm = 5.371920351148152; + const RealScalar maxnorm = 5.371920351148152; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; - MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<double>(squarings)); + MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<RealScalar>(squarings)); matrix_exp_pade13(A, U, V); } } @@ -313,7 +314,7 @@ struct matrix_exp_computeUV<MatrixType, long double> matrix_exp_pade17(A, U, V); } -#elif LDBL_MANT_DIG <= 112 // quadruple precison +#elif LDBL_MANT_DIG <= 113 // quadruple precision if (l1norm < 1.639394610288918690547467954466970e-005L) { matrix_exp_pade3(arg, U, V); @@ -326,6 +327,7 @@ struct matrix_exp_computeUV<MatrixType, long double> } else if (l1norm < 1.125358383453143065081397882891878e+000L) { matrix_exp_pade13(arg, U, V); } else { + const long double maxnorm = 2.884233277829519311757165057717815L; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings)); @@ -342,6 +344,27 @@ struct matrix_exp_computeUV<MatrixType, long double> } }; +template<typename T> struct is_exp_known_type : false_type {}; +template<> struct is_exp_known_type<float> : true_type {}; +template<> struct is_exp_known_type<double> : true_type {}; +#if LDBL_MANT_DIG <= 113 +template<> struct is_exp_known_type<long double> : true_type {}; +#endif + +template <typename ArgType, typename ResultType> +void matrix_exp_compute(const ArgType& arg, ResultType &result, true_type) // natively supported scalar type +{ + typedef typename ArgType::PlainObject MatrixType; + MatrixType U, V; + int squarings; + matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings); // Pade approximant is (U+V) / (-U+V) + MatrixType numer = U + V; + MatrixType denom = -U + V; + result = denom.partialPivLu().solve(numer); + for (int i=0; i<squarings; i++) + result *= result; // undo scaling by repeated squaring +} + /* Computes the matrix exponential * @@ -349,26 +372,13 @@ struct matrix_exp_computeUV<MatrixType, long double> * \param result variable in which result will be stored */ template <typename ArgType, typename ResultType> -void matrix_exp_compute(const ArgType& arg, ResultType &result) +void matrix_exp_compute(const ArgType& arg, ResultType &result, false_type) // default { typedef typename ArgType::PlainObject MatrixType; -#if LDBL_MANT_DIG > 112 // rarely happens typedef typename traits<MatrixType>::Scalar Scalar; typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename std::complex<RealScalar> ComplexScalar; - if (sizeof(RealScalar) > 14) { - result = arg.matrixFunction(internal::stem_function_exp<ComplexScalar>); - return; - } -#endif - MatrixType U, V; - int squarings; - matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings); // Pade approximant is (U+V) / (-U+V) - MatrixType numer = U + V; - MatrixType denom = -U + V; - result = denom.partialPivLu().solve(numer); - for (int i=0; i<squarings; i++) - result *= result; // undo scaling by repeated squaring + result = arg.matrixFunction(internal::stem_function_exp<ComplexScalar>); } } // end namespace Eigen::internal @@ -386,7 +396,6 @@ void matrix_exp_compute(const ArgType& arg, ResultType &result) template<typename Derived> struct MatrixExponentialReturnValue : public ReturnByValue<MatrixExponentialReturnValue<Derived> > { - typedef typename Derived::Index Index; public: /** \brief Constructor. * @@ -402,7 +411,7 @@ template<typename Derived> struct MatrixExponentialReturnValue inline void evalTo(ResultType& result) const { const typename internal::nested_eval<Derived, 10>::type tmp(m_src); - internal::matrix_exp_compute(tmp, result); + internal::matrix_exp_compute(tmp, result, internal::is_exp_known_type<typename Derived::RealScalar>()); } Index rows() const { return m_src.rows(); } |