diff options
Diffstat (limited to 'unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h')
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h index afd88ec4d..e363e779d 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -17,7 +17,7 @@ namespace internal { // pre: T.block(i,i,2,2) has complex conjugate eigenvalues // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) template <typename MatrixType, typename ResultType> -void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, typename MatrixType::Index i, ResultType& sqrtT) +void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, Index i, ResultType& sqrtT) { // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere // in EigenSolver. If we expose it, we could call it directly from here. @@ -32,7 +32,7 @@ void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, typena // all blocks of sqrtT to left of and below (i,j) are correct // post: sqrtT(i,j) has the correct value template <typename MatrixType, typename ResultType> -void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) +void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT) { typedef typename traits<MatrixType>::Scalar Scalar; Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); @@ -41,7 +41,7 @@ void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, ty // similar to compute1x1offDiagonalBlock() template <typename MatrixType, typename ResultType> -void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) +void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT) { typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); @@ -54,7 +54,7 @@ void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, ty // similar to compute1x1offDiagonalBlock() template <typename MatrixType, typename ResultType> -void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) +void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT) { typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); @@ -101,7 +101,7 @@ void matrix_sqrt_quasi_triangular_solve_auxiliary_equation(MatrixType& X, const // similar to compute1x1offDiagonalBlock() template <typename MatrixType, typename ResultType> -void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) +void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, Index i, Index j, ResultType& sqrtT) { typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); @@ -120,7 +120,6 @@ template <typename MatrixType, typename ResultType> void matrix_sqrt_quasi_triangular_diagonal(const MatrixType& T, ResultType& sqrtT) { using std::sqrt; - typedef typename MatrixType::Index Index; const Index size = T.rows(); for (Index i = 0; i < size; i++) { if (i == size - 1 || T.coeff(i+1, i) == 0) { @@ -139,7 +138,6 @@ void matrix_sqrt_quasi_triangular_diagonal(const MatrixType& T, ResultType& sqrt template <typename MatrixType, typename ResultType> void matrix_sqrt_quasi_triangular_off_diagonal(const MatrixType& T, ResultType& sqrtT) { - typedef typename MatrixType::Index Index; const Index size = T.rows(); for (Index j = 1; j < size; j++) { if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block @@ -206,8 +204,7 @@ template <typename MatrixType, typename ResultType> void matrix_sqrt_triangular(const MatrixType &arg, ResultType &result) { using std::sqrt; - typedef typename MatrixType::Index Index; - typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Scalar Scalar; eigen_assert(arg.rows() == arg.cols()); @@ -256,18 +253,19 @@ struct matrix_sqrt_compute template <typename MatrixType> struct matrix_sqrt_compute<MatrixType, 0> { + typedef typename MatrixType::PlainObject PlainType; template <typename ResultType> static void run(const MatrixType &arg, ResultType &result) { eigen_assert(arg.rows() == arg.cols()); // Compute Schur decomposition of arg - const RealSchur<MatrixType> schurOfA(arg); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); + const RealSchur<PlainType> schurOfA(arg); + const PlainType& T = schurOfA.matrixT(); + const PlainType& U = schurOfA.matrixU(); // Compute square root of T - MatrixType sqrtT = MatrixType::Zero(arg.rows(), arg.cols()); + PlainType sqrtT = PlainType::Zero(arg.rows(), arg.cols()); matrix_sqrt_quasi_triangular(T, sqrtT); // Compute square root of arg @@ -281,18 +279,19 @@ struct matrix_sqrt_compute<MatrixType, 0> template <typename MatrixType> struct matrix_sqrt_compute<MatrixType, 1> { + typedef typename MatrixType::PlainObject PlainType; template <typename ResultType> static void run(const MatrixType &arg, ResultType &result) { eigen_assert(arg.rows() == arg.cols()); // Compute Schur decomposition of arg - const ComplexSchur<MatrixType> schurOfA(arg); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); + const ComplexSchur<PlainType> schurOfA(arg); + const PlainType& T = schurOfA.matrixT(); + const PlainType& U = schurOfA.matrixU(); // Compute square root of T - MatrixType sqrtT; + PlainType sqrtT; matrix_sqrt_triangular(T, sqrtT); // Compute square root of arg @@ -318,7 +317,6 @@ template<typename Derived> class MatrixSquareRootReturnValue : public ReturnByValue<MatrixSquareRootReturnValue<Derived> > { protected: - typedef typename Derived::Index Index; typedef typename internal::ref_selector<Derived>::type DerivedNested; public: |