summaryrefslogtreecommitdiff
path: root/src/main/java/org/apache/commons/math3/ode/FieldEquationsMapper.java
blob: e28f5be4695ac71411a4391aaac74b505d3ea5df (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math3.ode;

import org.apache.commons.math3.RealFieldElement;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.util.MathArrays;

import java.io.Serializable;

/**
 * Class mapping the part of a complete state or derivative that pertains to a set of differential
 * equations.
 *
 * <p>Instances of this class are guaranteed to be immutable.
 *
 * @see FieldExpandableODE
 * @param <T> the type of the field elements
 * @since 3.6
 */
public class FieldEquationsMapper<T extends RealFieldElement<T>> implements Serializable {

    /** Serializable UID. */
    private static final long serialVersionUID = 20151114L;

    /** Start indices of the components. */
    private final int[] start;

    /**
     * Create a mapper by adding a new equation to another mapper.
     *
     * <p>The new equation will have index {@code mapper.}{@link #getNumberOfEquations()}, or 0 if
     * {@code mapper} is null.
     *
     * @param mapper former mapper, with one equation less (null for first equation)
     * @param dimension dimension of the equation state vector
     */
    FieldEquationsMapper(final FieldEquationsMapper<T> mapper, final int dimension) {
        final int index = (mapper == null) ? 0 : mapper.getNumberOfEquations();
        this.start = new int[index + 2];
        if (mapper == null) {
            start[0] = 0;
        } else {
            System.arraycopy(mapper.start, 0, start, 0, index + 1);
        }
        start[index + 1] = start[index] + dimension;
    }

    /**
     * Get the number of equations mapped.
     *
     * @return number of equations mapped
     */
    public int getNumberOfEquations() {
        return start.length - 1;
    }

    /**
     * Return the dimension of the complete set of equations.
     *
     * <p>The complete set of equations correspond to the primary set plus all secondary sets.
     *
     * @return dimension of the complete set of equations
     */
    public int getTotalDimension() {
        return start[start.length - 1];
    }

    /**
     * Map a state to a complete flat array.
     *
     * @param state state to map
     * @return flat array containing the mapped state, including primary and secondary components
     */
    public T[] mapState(final FieldODEState<T> state) {
        final T[] y = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
        int index = 0;
        insertEquationData(index, state.getState(), y);
        while (++index < getNumberOfEquations()) {
            insertEquationData(index, state.getSecondaryState(index), y);
        }
        return y;
    }

    /**
     * Map a state derivative to a complete flat array.
     *
     * @param state state to map
     * @return flat array containing the mapped state derivative, including primary and secondary
     *     components
     */
    public T[] mapDerivative(final FieldODEStateAndDerivative<T> state) {
        final T[] yDot = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
        int index = 0;
        insertEquationData(index, state.getDerivative(), yDot);
        while (++index < getNumberOfEquations()) {
            insertEquationData(index, state.getSecondaryDerivative(index), yDot);
        }
        return yDot;
    }

    /**
     * Map flat arrays to a state and derivative.
     *
     * @param t time
     * @param y state array to map, including primary and secondary components
     * @param yDot state derivative array to map, including primary and secondary components
     * @return mapped state
     * @exception DimensionMismatchException if an array does not match total dimension
     */
    public FieldODEStateAndDerivative<T> mapStateAndDerivative(
            final T t, final T[] y, final T[] yDot) throws DimensionMismatchException {

        if (y.length != getTotalDimension()) {
            throw new DimensionMismatchException(y.length, getTotalDimension());
        }

        if (yDot.length != getTotalDimension()) {
            throw new DimensionMismatchException(yDot.length, getTotalDimension());
        }

        final int n = getNumberOfEquations();
        int index = 0;
        final T[] state = extractEquationData(index, y);
        final T[] derivative = extractEquationData(index, yDot);
        if (n < 2) {
            return new FieldODEStateAndDerivative<T>(t, state, derivative);
        } else {
            final T[][] secondaryState = MathArrays.buildArray(t.getField(), n - 1, -1);
            final T[][] secondaryDerivative = MathArrays.buildArray(t.getField(), n - 1, -1);
            while (++index < getNumberOfEquations()) {
                secondaryState[index - 1] = extractEquationData(index, y);
                secondaryDerivative[index - 1] = extractEquationData(index, yDot);
            }
            return new FieldODEStateAndDerivative<T>(
                    t, state, derivative, secondaryState, secondaryDerivative);
        }
    }

    /**
     * Extract equation data from a complete state or derivative array.
     *
     * @param index index of the equation, must be between 0 included and {@link
     *     #getNumberOfEquations()} (excluded)
     * @param complete complete state or derivative array from which equation data should be
     *     retrieved
     * @return equation data
     * @exception MathIllegalArgumentException if index is out of range
     * @exception DimensionMismatchException if complete state has not enough elements
     */
    public T[] extractEquationData(final int index, final T[] complete)
            throws MathIllegalArgumentException, DimensionMismatchException {
        checkIndex(index);
        final int begin = start[index];
        final int end = start[index + 1];
        if (complete.length < end) {
            throw new DimensionMismatchException(complete.length, end);
        }
        final int dimension = end - begin;
        final T[] equationData = MathArrays.buildArray(complete[0].getField(), dimension);
        System.arraycopy(complete, begin, equationData, 0, dimension);
        return equationData;
    }

    /**
     * Insert equation data into a complete state or derivative array.
     *
     * @param index index of the equation, must be between 0 included and {@link
     *     #getNumberOfEquations()} (excluded)
     * @param equationData equation data to be inserted into the complete array
     * @param complete placeholder where to put equation data (only the part corresponding to the
     *     equation will be overwritten)
     * @exception DimensionMismatchException if either array has not enough elements
     */
    public void insertEquationData(final int index, T[] equationData, T[] complete)
            throws DimensionMismatchException {
        checkIndex(index);
        final int begin = start[index];
        final int end = start[index + 1];
        final int dimension = end - begin;
        if (complete.length < end) {
            throw new DimensionMismatchException(complete.length, end);
        }
        if (equationData.length != dimension) {
            throw new DimensionMismatchException(equationData.length, dimension);
        }
        System.arraycopy(equationData, 0, complete, begin, dimension);
    }

    /**
     * Check equation index.
     *
     * @param index index of the equation, must be between 0 included and {@link
     *     #getNumberOfEquations()} (excluded)
     * @exception MathIllegalArgumentException if index is out of range
     */
    private void checkIndex(final int index) throws MathIllegalArgumentException {
        if (index < 0 || index > start.length - 2) {
            throw new MathIllegalArgumentException(
                    LocalizedFormats.ARGUMENT_OUTSIDE_DOMAIN, index, 0, start.length - 2);
        }
    }
}