summaryrefslogtreecommitdiff
path: root/nn/common/include/ExecutionBurstServer.h
blob: 9da0dc7425d04d326a03ae43a2c4a92d8f6c7e19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H

#include <android-base/macros.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>

#include <atomic>
#include <chrono>
#include <memory>
#include <optional>
#include <thread>
#include <tuple>
#include <vector>

#include "HalInterfaces.h"

namespace android::nn {

using FmqRequestDescriptor = hardware::MQDescriptorSync<hal::FmqRequestDatum>;
using FmqResultDescriptor = hardware::MQDescriptorSync<hal::FmqResultDatum>;

/**
 * Function to serialize results.
 *
 * Prefer calling ResultChannelSender::send.
 *
 * @param errorStatus Status of the execution.
 * @param outputShapes Dynamic shapes of the output tensors.
 * @param timing Timing information of the execution.
 * @return Serialized FMQ result data.
 */
std::vector<hal::FmqResultDatum> serialize(hal::ErrorStatus errorStatus,
                                           const std::vector<hal::OutputShape>& outputShapes,
                                           hal::Timing timing);

/**
 * Deserialize the FMQ request data.
 *
 * The three resulting fields are the Request object (where Request::pools is
 * empty), slot identifiers (which are stand-ins for Request::pools), and
 * whether timing information must be collected for the run.
 *
 * @param data Serialized FMQ request data.
 * @return Request object if successfully deserialized, std::nullopt otherwise.
 */
std::optional<std::tuple<hal::Request, std::vector<int32_t>, hal::MeasureTiming>> deserialize(
        const std::vector<hal::FmqRequestDatum>& data);

/**
 * RequestChannelReceiver is responsible for waiting on the channel until the
 * packet is available, extracting the packet from the channel, and
 * deserializing the packet.
 *
 * Because the receiver can wait on a packet that may never come (e.g., because
 * the sending side of the packet has been closed), this object can be
 * invalidated, unblocking the receiver.
 */
class RequestChannelReceiver {
    using FmqRequestChannel =
            hardware::MessageQueue<hal::FmqRequestDatum, hardware::kSynchronizedReadWrite>;

   public:
    /**
     * Create the receiving end of a request channel.
     *
     * Prefer this call over the constructor.
     *
     * @param requestChannel Descriptor for the request channel.
     * @param pollingTimeWindow How much time (in microseconds) the
     *     RequestChannelReceiver is allowed to poll the FMQ before waiting on
     *     the blocking futex. Polling may result in lower latencies at the
     *     potential cost of more power usage.
     * @return RequestChannelReceiver on successful creation, nullptr otherwise.
     */
    static std::unique_ptr<RequestChannelReceiver> create(
            const FmqRequestDescriptor& requestChannel,
            std::chrono::microseconds pollingTimeWindow);

    /**
     * Get the request from the channel.
     *
     * This method will block until either:
     * 1) The packet has been retrieved, or
     * 2) The receiver has been invalidated
     *
     * @return Request object if successfully received, std::nullopt if error or
     *     if the receiver object was invalidated.
     */
    std::optional<std::tuple<hal::Request, std::vector<int32_t>, hal::MeasureTiming>> getBlocking();

    /**
     * Method to mark the channel as invalid, unblocking any current or future
     * calls to RequestChannelReceiver::getBlocking.
     */
    void invalidate();

    RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
                           std::chrono::microseconds pollingTimeWindow);

   private:
    std::optional<std::vector<hal::FmqRequestDatum>> getPacketBlocking();

    const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
    std::atomic<bool> mTeardown{false};
    const std::chrono::microseconds kPollingTimeWindow;
};

/**
 * ResultChannelSender is responsible for serializing the result packet of
 * information, sending it on the result channel, and signaling that the data is
 * available.
 */
class ResultChannelSender {
    using FmqResultChannel =
            hardware::MessageQueue<hal::FmqResultDatum, hardware::kSynchronizedReadWrite>;

   public:
    /**
     * Create the sending end of a result channel.
     *
     * Prefer this call over the constructor.
     *
     * @param resultChannel Descriptor for the result channel.
     * @return ResultChannelSender on successful creation, nullptr otherwise.
     */
    static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel);

    /**
     * Send the result to the channel.
     *
     * @param errorStatus Status of the execution.
     * @param outputShapes Dynamic shapes of the output tensors.
     * @param timing Timing information of the execution.
     * @return 'true' on successful send, 'false' otherwise.
     */
    bool send(hal::ErrorStatus errorStatus, const std::vector<hal::OutputShape>& outputShapes,
              hal::Timing timing);

    // prefer calling ResultChannelSender::send
    bool sendPacket(const std::vector<hal::FmqResultDatum>& packet);

    ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel);

   private:
    const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
};

/**
 * The ExecutionBurstServer class is responsible for waiting for and
 * deserializing a request object from a FMQ, performing the inference, and
 * serializing the result back across another FMQ.
 */
class ExecutionBurstServer : public hal::IBurstContext {
    DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);

   public:
    /**
     * IBurstExecutorWithCache is a callback object passed to
     * ExecutionBurstServer's factory function that is used to perform an
     * execution. Because some memory resources are needed across multiple
     * executions, this object also contains a local cache that can directly be
     * used in the execution.
     *
     * ExecutionBurstServer will never access its IBurstExecutorWithCache object
     * with concurrent calls.
     */
    class IBurstExecutorWithCache {
        DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);

       public:
        IBurstExecutorWithCache() = default;
        virtual ~IBurstExecutorWithCache() = default;

        /**
         * Checks if a cache entry specified by a slot is present in the cache.
         *
         * @param slot Identifier of the cache entry.
         * @return 'true' if the cache entry is present in the cache, 'false'
         *     otherwise.
         */
        virtual bool isCacheEntryPresent(int32_t slot) const = 0;

        /**
         * Adds an entry specified by a slot to the cache.
         *
         * The caller of this function must ensure that the cache entry that is
         * being added is not already present in the cache. This can be checked
         * via isCacheEntryPresent.
         *
         * @param memory Memory resource to be cached.
         * @param slot Slot identifier corresponding to the memory resource.
         */
        virtual void addCacheEntry(const hal::hidl_memory& memory, int32_t slot) = 0;

        /**
         * Removes an entry specified by a slot from the cache.
         *
         * If the cache entry corresponding to the slot number does not exist,
         * the call does nothing.
         *
         * @param slot Slot identifier corresponding to the memory resource.
         */
        virtual void removeCacheEntry(int32_t slot) = 0;

        /**
         * Perform an execution.
         *
         * @param request Request object with inputs and outputs specified.
         *     Request::pools is empty, and DataLocation::poolIndex instead
         *     refers to the 'slots' argument as if it were Request::pools.
         * @param slots Slots corresponding to the cached memory entries to be
         *     used.
         * @param measure Whether timing information is requested for the
         *     execution.
         * @return Result of the execution, including the status of the
         *     execution, dynamic output shapes, and any timing information.
         */
        virtual std::tuple<hal::ErrorStatus, hal::hidl_vec<hal::OutputShape>, hal::Timing> execute(
                const hal::Request& request, const std::vector<int32_t>& slots,
                hal::MeasureTiming measure) = 0;
    };

    /**
     * Create automated context to manage FMQ-based executions.
     *
     * This function is intended to be used by a service to automatically:
     * 1) Receive data from a provided FMQ
     * 2) Execute a model with the given information
     * 3) Send the result to the created FMQ
     *
     * @param callback Callback used to retrieve memories corresponding to
     *     unrecognized slots.
     * @param requestChannel Input FMQ channel through which the client passes the
     *     request to the service.
     * @param resultChannel Output FMQ channel from which the client can retrieve
     *     the result of the execution.
     * @param executorWithCache Object which maintains a local cache of the
     *     memory pools and executes using the cached memory pools.
     * @param pollingTimeWindow How much time (in microseconds) the
     *     ExecutionBurstServer is allowed to poll the FMQ before waiting on
     *     the blocking futex. Polling may result in lower latencies at the
     *     potential cost of more power usage.
     * @result IBurstContext Handle to the burst context.
     */
    static sp<ExecutionBurstServer> create(
            const sp<hal::IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
            const FmqResultDescriptor& resultChannel,
            std::shared_ptr<IBurstExecutorWithCache> executorWithCache,
            std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0});

    /**
     * Create automated context to manage FMQ-based executions.
     *
     * This function is intended to be used by a service to automatically:
     * 1) Receive data from a provided FMQ
     * 2) Execute a model with the given information
     * 3) Send the result to the created FMQ
     *
     * @param callback Callback used to retrieve memories corresponding to
     *     unrecognized slots.
     * @param requestChannel Input FMQ channel through which the client passes the
     *     request to the service.
     * @param resultChannel Output FMQ channel from which the client can retrieve
     *     the result of the execution.
     * @param preparedModel PreparedModel that the burst object was created from.
     *     IPreparedModel::executeSynchronously will be used to perform the
     *     execution.
     * @param pollingTimeWindow How much time (in microseconds) the
     *     ExecutionBurstServer is allowed to poll the FMQ before waiting on
     *     the blocking futex. Polling may result in lower latencies at the
     *     potential cost of more power usage.
     * @result IBurstContext Handle to the burst context.
     */
    static sp<ExecutionBurstServer> create(
            const sp<hal::IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
            const FmqResultDescriptor& resultChannel, hal::IPreparedModel* preparedModel,
            std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0});

    ExecutionBurstServer(const sp<hal::IBurstCallback>& callback,
                         std::unique_ptr<RequestChannelReceiver> requestChannel,
                         std::unique_ptr<ResultChannelSender> resultChannel,
                         std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
    ~ExecutionBurstServer();

    // Used by the NN runtime to preemptively remove any stored memory.
    hal::Return<void> freeMemory(int32_t slot) override;

   private:
    // Ensures all cache entries contained in mExecutorWithCache are present in
    // the cache. If they are not present, they are retrieved (via
    // IBurstCallback::getMemories) and added to mExecutorWithCache.
    //
    // This method is locked via mMutex when it is called.
    void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);

    // Work loop that will continue processing execution requests until the
    // ExecutionBurstServer object is freed.
    void task();

    std::thread mWorker;
    std::mutex mMutex;
    std::atomic<bool> mTeardown{false};
    const sp<hal::IBurstCallback> mCallback;
    const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver;
    const std::unique_ptr<ResultChannelSender> mResultChannelSender;
    const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
};

}  // namespace android::nn

#endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H