aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/IterativeLinearSolvers/BiCGSTAB.h')
-rw-r--r--Eigen/src/IterativeLinearSolvers/BiCGSTAB.h32
1 files changed, 27 insertions, 5 deletions
diff --git a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
index 126341be8..2b9fb7f88 100644
--- a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
+++ b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
@@ -39,10 +39,17 @@ bool bicgstab(const MatrixType& mat, const Rhs& rhs, Dest& x,
int maxIters = iters;
int n = mat.cols();
+ x = precond.solve(x);
VectorType r = rhs - mat * x;
VectorType r0 = r;
RealScalar r0_sqnorm = r0.squaredNorm();
+ RealScalar rhs_sqnorm = rhs.squaredNorm();
+ if(rhs_sqnorm == 0)
+ {
+ x.setZero();
+ return true;
+ }
Scalar rho = 1;
Scalar alpha = 1;
Scalar w = 1;
@@ -54,14 +61,24 @@ bool bicgstab(const MatrixType& mat, const Rhs& rhs, Dest& x,
VectorType s(n), t(n);
RealScalar tol2 = tol*tol;
+ RealScalar eps2 = NumTraits<Scalar>::epsilon()*NumTraits<Scalar>::epsilon();
int i = 0;
+ int restarts = 0;
- while ( r.squaredNorm()/r0_sqnorm > tol2 && i<maxIters )
+ while ( r.squaredNorm()/rhs_sqnorm > tol2 && i<maxIters )
{
Scalar rho_old = rho;
rho = r0.dot(r);
- if (rho == Scalar(0)) return false; /* New search directions cannot be found */
+ if (abs(rho) < eps2*r0_sqnorm)
+ {
+ // The new residual vector became too orthogonal to the arbitrarily choosen direction r0
+ // Let's restart with a new r0:
+ r0 = r;
+ rho = r0_sqnorm = r.squaredNorm();
+ if(restarts++ == 0)
+ i = 0;
+ }
Scalar beta = (rho/rho_old) * (alpha / w);
p = r + beta * (p - w * v);
@@ -75,12 +92,16 @@ bool bicgstab(const MatrixType& mat, const Rhs& rhs, Dest& x,
z = precond.solve(s);
t.noalias() = mat * z;
- w = t.dot(s) / t.squaredNorm();
+ RealScalar tmp = t.squaredNorm();
+ if(tmp>RealScalar(0))
+ w = t.dot(s) / tmp;
+ else
+ w = Scalar(0);
x += alpha * y + w * z;
r = s - w * t;
++i;
}
- tol_error = sqrt(r.squaredNorm()/r0_sqnorm);
+ tol_error = sqrt(r.squaredNorm()/rhs_sqnorm);
iters = i;
return true;
}
@@ -223,7 +244,8 @@ public:
template<typename Rhs,typename Dest>
void _solve(const Rhs& b, Dest& x) const
{
- x.setZero();
+// x.setZero();
+ x = b;
_solveWithGuess(b,x);
}