summaryrefslogtreecommitdiff
path: root/nn/common/include/ExecutionBurstController.h
blob: 6328096b0ea466224c6d9f522be2b6ca10dfcd51 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H

#include "HalInterfaces.h"

#include <android-base/macros.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>

#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <stack>
#include <tuple>

namespace android::nn {

/**
 * Number of elements in the FMQ.
 */
constexpr const size_t kExecutionBurstChannelLength = 1024;

/**
 * Function to serialize a request.
 *
 * Prefer calling RequestChannelSender::send.
 *
 * @param request Request object without the pool information.
 * @param measure Whether to collect timing information for the execution.
 * @param memoryIds Slot identifiers corresponding to memory resources for the
 *     request.
 * @return Serialized FMQ request data.
 */
std::vector<hal::FmqRequestDatum> serialize(const hal::Request& request, hal::MeasureTiming measure,
                                            const std::vector<int32_t>& slots);

/**
 * Deserialize the FMQ result data.
 *
 * The three resulting fields are the status of the execution, the dynamic
 * shapes of the output tensors, and the timing information of the execution.
 *
 * @param data Serialized FMQ result data.
 * @return Result object if successfully deserialized, std::nullopt otherwise.
 */
std::optional<std::tuple<hal::ErrorStatus, std::vector<hal::OutputShape>, hal::Timing>> deserialize(
        const std::vector<hal::FmqResultDatum>& data);

/**
 * ResultChannelReceiver is responsible for waiting on the channel until the
 * packet is available, extracting the packet from the channel, and
 * deserializing the packet.
 *
 * Because the receiver can wait on a packet that may never come (e.g., because
 * the sending side of the packet has been closed), this object can be
 * invalidating, unblocking the receiver.
 */
class ResultChannelReceiver {
    using FmqResultDescriptor = ::android::hardware::MQDescriptorSync<hal::FmqResultDatum>;
    using FmqResultChannel =
            hardware::MessageQueue<hal::FmqResultDatum, hardware::kSynchronizedReadWrite>;

   public:
    /**
     * Create the receiving end of a result channel.
     *
     * Prefer this call over the constructor.
     *
     * @param channelLength Number of elements in the FMQ.
     * @param blocking 'true' if FMQ should use futex, 'false' if it should
     *     spin-wait.
     * @return A pair of ResultChannelReceiver and the FMQ descriptor on
     *     successful creation, both nullptr otherwise.
     */
    static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create(
            size_t channelLength, bool blocking);

    /**
     * Get the result from the channel.
     *
     * This method will block until either:
     * 1) The packet has been retrieved, or
     * 2) The receiver has been invalidated
     *
     * @return Result object if successfully received, std::nullopt if error or
     *     if the receiver object was invalidated.
     */
    std::optional<std::tuple<hal::ErrorStatus, std::vector<hal::OutputShape>, hal::Timing>>
    getBlocking();

    /**
     * Method to mark the channel as invalid, unblocking any current or future
     * calls to ResultChannelReceiver::getBlocking.
     */
    void invalidate();

    // prefer calling ResultChannelReceiver::getBlocking
    std::optional<std::vector<hal::FmqResultDatum>> getPacketBlocking();

    ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);

   private:
    const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
    std::atomic<bool> mValid{true};
    const bool mBlocking;
};

/**
 * RequestChannelSender is responsible for serializing the result packet of
 * information, sending it on the result channel, and signaling that the data is
 * available.
 */
class RequestChannelSender {
    using FmqRequestDescriptor = ::android::hardware::MQDescriptorSync<hal::FmqRequestDatum>;
    using FmqRequestChannel =
            hardware::MessageQueue<hal::FmqRequestDatum, hardware::kSynchronizedReadWrite>;

   public:
    /**
     * Create the sending end of a request channel.
     *
     * Prefer this call over the constructor.
     *
     * @param channelLength Number of elements in the FMQ.
     * @param blocking 'true' if FMQ should use futex, 'false' if it should
     *     spin-wait.
     * @return A pair of ResultChannelReceiver and the FMQ descriptor on
     *     successful creation, both nullptr otherwise.
     */
    static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create(
            size_t channelLength, bool blocking);

    /**
     * Send the request to the channel.
     *
     * @param request Request object without the pool information.
     * @param measure Whether to collect timing information for the execution.
     * @param memoryIds Slot identifiers corresponding to memory resources for
     *     the request.
     * @return 'true' on successful send, 'false' otherwise.
     */
    bool send(const hal::Request& request, hal::MeasureTiming measure,
              const std::vector<int32_t>& slots);

    /**
     * Method to mark the channel as invalid, causing all future calls to
     * RequestChannelSender::send to immediately return false without attempting
     * to send a message across the FMQ.
     */
    void invalidate();

    // prefer calling RequestChannelSender::send
    bool sendPacket(const std::vector<hal::FmqRequestDatum>& packet);

    RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);

   private:
    const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
    std::atomic<bool> mValid{true};
    const bool mBlocking;
};

/**
 * The ExecutionBurstController class manages both the serialization and
 * deserialization of data across FMQ, making it appear to the runtime as a
 * regular synchronous inference. Additionally, this class manages the burst's
 * memory cache.
 */
class ExecutionBurstController {
    DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);

   public:
    /**
     * NN runtime burst callback object and memory cache.
     *
     * ExecutionBurstCallback associates a hidl_memory object with a slot number
     * to be passed across FMQ. The ExecutionBurstServer can use this callback
     * to retrieve this hidl_memory corresponding to the slot via HIDL.
     *
     * Whenever a hidl_memory object is copied, it will duplicate the underlying
     * file descriptor. Because the NN runtime currently copies the hidl_memory
     * on each execution, it is difficult to associate hidl_memory objects with
     * previously cached hidl_memory objects. For this reason, callers of this
     * class must pair each hidl_memory object with an associated key. For
     * efficiency, if two hidl_memory objects represent the same underlying
     * buffer, they must use the same key.
     */
    class ExecutionBurstCallback : public hal::IBurstCallback {
        DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);

       public:
        ExecutionBurstCallback() = default;

        hal::Return<void> getMemories(const hal::hidl_vec<int32_t>& slots,
                                      getMemories_cb cb) override;

        /**
         * This function performs one of two different actions:
         * 1) If a key corresponding to a memory resource is unrecognized by the
         *    ExecutionBurstCallback object, the ExecutionBurstCallback object
         *    will allocate a slot, bind the memory to the slot, and return the
         *    slot identifier.
         * 2) If a key corresponding to a memory resource is recognized by the
         *    ExecutionBurstCallback object, the ExecutionBurstCallback object
         *    will return the existing slot identifier.
         *
         * @param memories Memory resources used in an inference.
         * @param keys Unique identifiers where each element corresponds to a
         *     memory resource element in "memories".
         * @return Unique slot identifiers where each returned slot element
         *     corresponds to a memory resource element in "memories".
         */
        std::vector<int32_t> getSlots(const hal::hidl_vec<hal::hidl_memory>& memories,
                                      const std::vector<intptr_t>& keys);

        /*
         * This function performs two different actions:
         * 1) Removes an entry from the cache (if present), including the local
         *    storage of the hidl_memory object. Note that this call does not
         *    free any corresponding hidl_memory object in ExecutionBurstServer,
         *    which is separately freed via IBurstContext::freeMemory.
         * 2) Return whether a cache entry was removed and which slot was removed if
         *    found. If the key did not to correspond to any entry in the cache, a
         *    slot number of 0 is returned. The slot number and whether the entry
         *    existed is useful so the same slot can be freed in the
         *    ExecutionBurstServer's cache via IBurstContext::freeMemory.
         */
        std::pair<bool, int32_t> freeMemory(intptr_t key);

       private:
        int32_t getSlotLocked(const hal::hidl_memory& memory, intptr_t key);
        int32_t allocateSlotLocked();

        std::mutex mMutex;
        std::stack<int32_t, std::vector<int32_t>> mFreeSlots;
        std::map<intptr_t, int32_t> mMemoryIdToSlot;
        std::vector<hal::hidl_memory> mMemoryCache;
    };

    /**
     * Creates a burst controller on a prepared model.
     *
     * Prefer this over ExecutionBurstController's constructor.
     *
     * @param preparedModel Model prepared for execution to execute on.
     * @param blocking 'true' if the FMQ should use a futex to perform blocking
     *     until data is available in a less responsive, but more energy
     *     efficient manner. 'false' if the FMQ should use spin-looping to
     *     wait until data is available in a more responsive, but less energy
     *     efficient manner.
     * @return ExecutionBurstController Execution burst controller object.
     */
    static std::unique_ptr<ExecutionBurstController> create(
            const sp<hal::IPreparedModel>& preparedModel, bool blocking);

    // prefer calling ExecutionBurstController::create
    ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender,
                             const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
                             const sp<hal::IBurstContext>& burstContext,
                             const sp<ExecutionBurstCallback>& callback,
                             const sp<hal::hidl_death_recipient>& deathHandler = nullptr);

    // explicit destructor to unregister the death recipient
    ~ExecutionBurstController();

    /**
     * Execute a request on a model.
     *
     * @param request Arguments to be executed on a model.
     * @param measure Whether to collect timing measurements, either YES or NO
     * @param memoryIds Identifiers corresponding to each memory object in the
     *     request's pools.
     * @return A tuple of:
     *     - status of the execution
     *     - dynamic output shapes from the execution
     *     - any execution time measurements of the execution
     */
    std::tuple<hal::ErrorStatus, std::vector<hal::OutputShape>, hal::Timing> compute(
            const hal::Request& request, hal::MeasureTiming measure,
            const std::vector<intptr_t>& memoryIds);

    // TODO: combine "compute" and "tryCompute" back into a single function.
    // "tryCompute" was created later to return the "fallback" boolean. This
    // could not be done directly in "compute" because the VTS test cases (which
    // test burst using "compute") had already been locked down and could not be
    // changed.
    /**
     * Execute a request on a model.
     *
     * @param request Arguments to be executed on a model.
     * @param measure Whether to collect timing measurements, either YES or NO
     * @param memoryIds Identifiers corresponding to each memory object in the
     *     request's pools.
     * @return A tuple of:
     *     - result code of the execution
     *     - dynamic output shapes from the execution
     *     - any execution time measurements of the execution
     *     - whether or not a failed burst execution should be re-run using a
     *       different path (e.g., IPreparedModel::executeSynchronously)
     */
    std::tuple<int, std::vector<hal::OutputShape>, hal::Timing, bool> tryCompute(
            const hal::Request& request, hal::MeasureTiming measure,
            const std::vector<intptr_t>& memoryIds);

    /**
     * Propagate a user's freeing of memory to the service.
     *
     * @param key Key corresponding to the memory object.
     */
    void freeMemory(intptr_t key);

   private:
    std::mutex mMutex;
    const std::shared_ptr<RequestChannelSender> mRequestChannelSender;
    const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver;
    const sp<hal::IBurstContext> mBurstContext;
    const sp<ExecutionBurstCallback> mMemoryCache;
    const sp<hal::hidl_death_recipient> mDeathHandler;
};

}  // namespace android::nn

#endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H