aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
blob: 7c2326eb7fd98246e5589d16a541169b9195c917 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
#define EIGEN_ITERATIVE_SOLVER_BASE_H

namespace Eigen { 

namespace internal {

template<typename MatrixType>
struct is_ref_compatible_impl
{
private:
  template <typename T0>
  struct any_conversion
  {
    template <typename T> any_conversion(const volatile T&);
    template <typename T> any_conversion(T&);
  };
  struct yes {int a[1];};
  struct no  {int a[2];};

  template<typename T>
  static yes test(const Ref<const T>&, int);
  template<typename T>
  static no  test(any_conversion<T>, ...);

public:
  static MatrixType ms_from;
  enum { value = sizeof(test<MatrixType>(ms_from, 0))==sizeof(yes) };
};

template<typename MatrixType>
struct is_ref_compatible
{
  enum { value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value };
};

template<typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
class generic_matrix_wrapper;

// We have an explicit matrix at hand, compatible with Ref<>
template<typename MatrixType>
class generic_matrix_wrapper<MatrixType,false>
{
public:
  typedef Ref<const MatrixType> ActualMatrixType;
  template<int UpLo> struct ConstSelfAdjointViewReturnType {
    typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
  };

  enum {
    MatrixFree = false
  };

  generic_matrix_wrapper()
    : m_dummy(0,0), m_matrix(m_dummy)
  {}

  template<typename InputType>
  generic_matrix_wrapper(const InputType &mat)
    : m_matrix(mat)
  {}

  const ActualMatrixType& matrix() const
  {
    return m_matrix;
  }

  template<typename MatrixDerived>
  void grab(const EigenBase<MatrixDerived> &mat)
  {
    m_matrix.~Ref<const MatrixType>();
    ::new (&m_matrix) Ref<const MatrixType>(mat.derived());
  }

  void grab(const Ref<const MatrixType> &mat)
  {
    if(&(mat.derived()) != &m_matrix)
    {
      m_matrix.~Ref<const MatrixType>();
      ::new (&m_matrix) Ref<const MatrixType>(mat);
    }
  }

protected:
  MatrixType m_dummy; // used to default initialize the Ref<> object
  ActualMatrixType m_matrix;
};

// MatrixType is not compatible with Ref<> -> matrix-free wrapper
template<typename MatrixType>
class generic_matrix_wrapper<MatrixType,true>
{
public:
  typedef MatrixType ActualMatrixType;
  template<int UpLo> struct ConstSelfAdjointViewReturnType
  {
    typedef ActualMatrixType Type;
  };

  enum {
    MatrixFree = true
  };

  generic_matrix_wrapper()
    : mp_matrix(0)
  {}

  generic_matrix_wrapper(const MatrixType &mat)
    : mp_matrix(&mat)
  {}

  const ActualMatrixType& matrix() const
  {
    return *mp_matrix;
  }

  void grab(const MatrixType &mat)
  {
    mp_matrix = &mat;
  }

protected:
  const ActualMatrixType *mp_matrix;
};

}

/** \ingroup IterativeLinearSolvers_Module
  * \brief Base class for linear iterative solvers
  *
  * \sa class SimplicialCholesky, DiagonalPreconditioner, IdentityPreconditioner
  */
template< typename Derived>
class IterativeSolverBase : public SparseSolverBase<Derived>
{
protected:
  typedef SparseSolverBase<Derived> Base;
  using Base::m_isInitialized;
  
public:
  typedef typename internal::traits<Derived>::MatrixType MatrixType;
  typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
  typedef typename MatrixType::Scalar Scalar;
  typedef typename MatrixType::StorageIndex StorageIndex;
  typedef typename MatrixType::RealScalar RealScalar;

  enum {
    ColsAtCompileTime = MatrixType::ColsAtCompileTime,
    MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
  };

public:

  using Base::derived;

  /** Default constructor. */
  IterativeSolverBase()
  {
    init();
  }

  /** Initialize the solver with matrix \a A for further \c Ax=b solving.
    * 
    * This constructor is a shortcut for the default constructor followed
    * by a call to compute().
    * 
    * \warning this class stores a reference to the matrix A as well as some
    * precomputed values that depend on it. Therefore, if \a A is changed
    * this class becomes invalid. Call compute() to update it with the new
    * matrix A, or modify a copy of A.
    */
  template<typename MatrixDerived>
  explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A)
    : m_matrixWrapper(A.derived())
  {
    init();
    compute(matrix());
  }

  ~IterativeSolverBase() {}
  
  /** Initializes the iterative solver for the sparsity pattern of the matrix \a A for further solving \c Ax=b problems.
    *
    * Currently, this function mostly calls analyzePattern on the preconditioner. In the future
    * we might, for instance, implement column reordering for faster matrix vector products.
    */
  template<typename MatrixDerived>
  Derived& analyzePattern(const EigenBase<MatrixDerived>& A)
  {
    grab(A.derived());
    m_preconditioner.analyzePattern(matrix());
    m_isInitialized = true;
    m_analysisIsOk = true;
    m_info = m_preconditioner.info();
    return derived();
  }
  
  /** Initializes the iterative solver with the numerical values of the matrix \a A for further solving \c Ax=b problems.
    *
    * Currently, this function mostly calls factorize on the preconditioner.
    *
    * \warning this class stores a reference to the matrix A as well as some
    * precomputed values that depend on it. Therefore, if \a A is changed
    * this class becomes invalid. Call compute() to update it with the new
    * matrix A, or modify a copy of A.
    */
  template<typename MatrixDerived>
  Derived& factorize(const EigenBase<MatrixDerived>& A)
  {
    eigen_assert(m_analysisIsOk && "You must first call analyzePattern()"); 
    grab(A.derived());
    m_preconditioner.factorize(matrix());
    m_factorizationIsOk = true;
    m_info = m_preconditioner.info();
    return derived();
  }

  /** Initializes the iterative solver with the matrix \a A for further solving \c Ax=b problems.
    *
    * Currently, this function mostly initializes/computes the preconditioner. In the future
    * we might, for instance, implement column reordering for faster matrix vector products.
    *
    * \warning this class stores a reference to the matrix A as well as some
    * precomputed values that depend on it. Therefore, if \a A is changed
    * this class becomes invalid. Call compute() to update it with the new
    * matrix A, or modify a copy of A.
    */
  template<typename MatrixDerived>
  Derived& compute(const EigenBase<MatrixDerived>& A)
  {
    grab(A.derived());
    m_preconditioner.compute(matrix());
    m_isInitialized = true;
    m_analysisIsOk = true;
    m_factorizationIsOk = true;
    m_info = m_preconditioner.info();
    return derived();
  }

  /** \internal */
  Index rows() const { return matrix().rows(); }

  /** \internal */
  Index cols() const { return matrix().cols(); }

  /** \returns the tolerance threshold used by the stopping criteria.
    * \sa setTolerance()
    */
  RealScalar tolerance() const { return m_tolerance; }
  
  /** Sets the tolerance threshold used by the stopping criteria.
    *
    * This value is used as an upper bound to the relative residual error: |Ax-b|/|b|.
    * The default value is the machine precision given by NumTraits<Scalar>::epsilon()
    */
  Derived& setTolerance(const RealScalar& tolerance)
  {
    m_tolerance = tolerance;
    return derived();
  }

  /** \returns a read-write reference to the preconditioner for custom configuration. */
  Preconditioner& preconditioner() { return m_preconditioner; }
  
  /** \returns a read-only reference to the preconditioner. */
  const Preconditioner& preconditioner() const { return m_preconditioner; }

  /** \returns the max number of iterations.
    * It is either the value setted by setMaxIterations or, by default,
    * twice the number of columns of the matrix.
    */
  Index maxIterations() const
  {
    return (m_maxIterations<0) ? 2*matrix().cols() : m_maxIterations;
  }
  
  /** Sets the max number of iterations.
    * Default is twice the number of columns of the matrix.
    */
  Derived& setMaxIterations(Index maxIters)
  {
    m_maxIterations = maxIters;
    return derived();
  }

  /** \returns the number of iterations performed during the last solve */
  Index iterations() const
  {
    eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
    return m_iterations;
  }

  /** \returns the tolerance error reached during the last solve.
    * It is a close approximation of the true relative residual error |Ax-b|/|b|.
    */
  RealScalar error() const
  {
    eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
    return m_error;
  }

  /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A
    * and \a x0 as an initial solution.
    *
    * \sa solve(), compute()
    */
  template<typename Rhs,typename Guess>
  inline const SolveWithGuess<Derived, Rhs, Guess>
  solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
  {
    eigen_assert(m_isInitialized && "Solver is not initialized.");
    eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
    return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
  }

  /** \returns Success if the iterations converged, and NoConvergence otherwise. */
  ComputationInfo info() const
  {
    eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
    return m_info;
  }
  
  /** \internal */
  template<typename Rhs, typename DestDerived>
  void _solve_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
  {
    eigen_assert(rows()==b.rows());
    
    Index rhsCols = b.cols();
    Index size = b.rows();
    DestDerived& dest(aDest.derived());
    typedef typename DestDerived::Scalar DestScalar;
    Eigen::Matrix<DestScalar,Dynamic,1> tb(size);
    Eigen::Matrix<DestScalar,Dynamic,1> tx(cols());
    // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
    // For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
    typename DestDerived::PlainObject tmp(cols(),rhsCols);
    for(Index k=0; k<rhsCols; ++k)
    {
      tb = b.col(k);
      tx = derived().solve(tb);
      tmp.col(k) = tx.sparseView(0);
    }
    dest.swap(tmp);
  }

protected:
  void init()
  {
    m_isInitialized = false;
    m_analysisIsOk = false;
    m_factorizationIsOk = false;
    m_maxIterations = -1;
    m_tolerance = NumTraits<Scalar>::epsilon();
  }

  typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
  typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;

  const ActualMatrixType& matrix() const
  {
    return m_matrixWrapper.matrix();
  }
  
  template<typename InputType>
  void grab(const InputType &A)
  {
    m_matrixWrapper.grab(A);
  }
  
  MatrixWrapper m_matrixWrapper;
  Preconditioner m_preconditioner;

  Index m_maxIterations;
  RealScalar m_tolerance;
  
  mutable RealScalar m_error;
  mutable Index m_iterations;
  mutable ComputationInfo m_info;
  mutable bool m_analysisIsOk, m_factorizationIsOk;
};

} // end namespace Eigen

#endif // EIGEN_ITERATIVE_SOLVER_BASE_H