diff options
Diffstat (limited to 'include/ceres/internal/autodiff.h')
-rw-r--r-- | include/ceres/internal/autodiff.h | 181 |
1 files changed, 24 insertions, 157 deletions
diff --git a/include/ceres/internal/autodiff.h b/include/ceres/internal/autodiff.h index 581e881..cf21d7a 100644 --- a/include/ceres/internal/autodiff.h +++ b/include/ceres/internal/autodiff.h @@ -38,7 +38,7 @@ // // struct F { // template<typename T> -// bool operator(const T *x, const T *y, ..., T *z) { +// bool operator()(const T *x, const T *y, ..., T *z) { // // Compute z[] based on x[], y[], ... // // return true if computation succeeded, false otherwise. // } @@ -102,7 +102,7 @@ // // struct F { // template<typename T> -// bool operator(const T *p, const T *q, T *z) { +// bool operator()(const T *p, const T *q, T *z) { // // ... // } // }; @@ -142,10 +142,11 @@ #include <stddef.h> -#include <glog/logging.h> #include "ceres/jet.h" #include "ceres/internal/eigen.h" #include "ceres/internal/fixed_array.h" +#include "ceres/internal/variadic_evaluate.h" +#include "glog/logging.h" namespace ceres { namespace internal { @@ -164,13 +165,14 @@ namespace internal { // // is what would get put in dst if N was 3, offset was 3, and the jet type JetT // was 8-dimensional. -template <typename JetT, typename T> -inline void Make1stOrderPerturbation(int offset, int N, const T *src, - JetT *dst) { +template <typename JetT, typename T, int N> +inline void Make1stOrderPerturbation(int offset, const T* src, JetT* dst) { DCHECK(src); DCHECK(dst); for (int j = 0; j < N; ++j) { - dst[j] = JetT(src[j], offset + j); + dst[j].a = src[j]; + dst[j].v.setZero(); + dst[j].v[offset + j] = 1.0; } } @@ -191,151 +193,15 @@ inline void Take1stOrderPart(const int M, const JetT *src, T *dst) { DCHECK(src); DCHECK(dst); for (int i = 0; i < M; ++i) { - Eigen::Map<Eigen::Matrix<T, N, 1> >(dst + N * i, N) = src[i].v.template segment<N>(N0); + Eigen::Map<Eigen::Matrix<T, N, 1> >(dst + N * i, N) = + src[i].v.template segment<N>(N0); } } -// This block of quasi-repeated code calls the user-supplied functor, which may -// take a variable number of arguments. This is accomplished by specializing the -// struct based on the size of the trailing parameters; parameters with 0 size -// are assumed missing. -// -// Supporting variadic functions is the primary source of complexity in the -// autodiff implementation. - -template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, - int N5, int N6, int N7, int N8, int N9> -struct VariadicEvaluate { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - input[4], - input[5], - input[6], - input[7], - input[8], - input[9], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, - int N5, int N6, int N7, int N8> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, N7, N8, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - input[4], - input[5], - input[6], - input[7], - input[8], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, - int N5, int N6, int N7> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, N7, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - input[4], - input[5], - input[6], - input[7], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, - int N5, int N6> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - input[4], - input[5], - input[6], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, - int N5> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, 0, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - input[4], - input[5], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, 0, 0, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - input[4], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2, int N3> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, 0, 0, 0, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - input[3], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1, int N2> -struct VariadicEvaluate<Functor, T, N0, N1, N2, 0, 0, 0, 0, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - input[2], - output); - } -}; - -template<typename Functor, typename T, int N0, int N1> -struct VariadicEvaluate<Functor, T, N0, N1, 0, 0, 0, 0, 0, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - input[1], - output); - } -}; - -template<typename Functor, typename T, int N0> -struct VariadicEvaluate<Functor, T, N0, 0, 0, 0, 0, 0, 0, 0, 0, 0> { - static bool Call(const Functor& functor, T const *const *input, T* output) { - return functor(input[0], - output); - } -}; - -// This is in a struct because default template parameters on a function are not -// supported in C++03 (though it is available in C++0x). N0 through N5 are the -// dimension of the input arguments to the user supplied functor. +// This is in a struct because default template parameters on a +// function are not supported in C++03 (though it is available in +// C++0x). N0 through N5 are the dimension of the input arguments to +// the user supplied functor. template <typename Functor, typename T, int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0> @@ -347,7 +213,7 @@ struct AutoDiff { T **jacobians) { // This block breaks the 80 column rule to keep it somewhat readable. DCHECK_GT(num_outputs, 0); - CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + DCHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || @@ -390,14 +256,15 @@ struct AutoDiff { x.get() + jet8, x.get() + jet9, }; - JetT *output = x.get() + jet6; -#define CERES_MAKE_1ST_ORDER_PERTURBATION(i) \ - if (N ## i) { \ - internal::Make1stOrderPerturbation(jet ## i, \ - N ## i, \ - parameters[i], \ - x.get() + jet ## i); \ + JetT* output = x.get() + N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9; + +#define CERES_MAKE_1ST_ORDER_PERTURBATION(i) \ + if (N ## i) { \ + internal::Make1stOrderPerturbation<JetT, T, N ## i>( \ + jet ## i, \ + parameters[i], \ + x.get() + jet ## i); \ } CERES_MAKE_1ST_ORDER_PERTURBATION(0); CERES_MAKE_1ST_ORDER_PERTURBATION(1); |