aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/Core/products/TriangularMatrixVector.h
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-05-10 06:53:32 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-05-10 06:53:32 +0000
commit4ac63e5b0d908975847fc698cd8dcecd4d4b0569 (patch)
treefb979fb4cf4f8052c8cc66b1ec9516d91fcd859b /Eigen/src/Core/products/TriangularMatrixVector.h
parent72bcb7722ab4e1a8b57ff4c474da04ebeecd0254 (diff)
parentedb0ad5bb04b48aab7dd0978f0475edd3550de7c (diff)
downloadeigen-4ac63e5b0d908975847fc698cd8dcecd4d4b0569.tar.gz
Change-Id: Ia95263c756e823d68c32190fe0aa6d9d10f3d474
Diffstat (limited to 'Eigen/src/Core/products/TriangularMatrixVector.h')
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector.h22
1 files changed, 18 insertions, 4 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index 4b292e74d..76bfa159c 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -221,8 +221,9 @@ template<int Mode> struct trmv_selector<Mode,ColMajor>
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
- ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
- * RhsBlasTraits::extractScalarFactor(rhs);
+ LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
+ RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
+ ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
@@ -274,6 +275,12 @@ template<int Mode> struct trmv_selector<Mode,ColMajor>
else
dest = MappedDest(actualDestPtr, dest.size());
}
+
+ if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
+ {
+ Index diagSize = (std::min)(lhs.rows(),lhs.cols());
+ dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
+ }
}
};
@@ -295,8 +302,9 @@ template<int Mode> struct trmv_selector<Mode,RowMajor>
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
- ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
- * RhsBlasTraits::extractScalarFactor(rhs);
+ LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
+ RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
+ ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
enum {
DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
@@ -326,6 +334,12 @@ template<int Mode> struct trmv_selector<Mode,RowMajor>
actualRhsPtr,1,
dest.data(),dest.innerStride(),
actualAlpha);
+
+ if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
+ {
+ Index diagSize = (std::min)(lhs.rows(),lhs.cols());
+ dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
+ }
}
};