summaryrefslogtreecommitdiff
path: root/nn/runtime/Callbacks.h
blob: 75370254ee1cb495af259361521d76782f3ac82f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_CALLBACKS_H
#define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_CALLBACKS_H

#include "HalInterfaces.h"

#include <android-base/thread_annotations.h>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <thread>
#include <vector>

/*
 * The Callback classes are used internally by the NeuralNetworks runtime to
 * synchronize between different threads. An asynchronous task is launched
 * paired with a callback object. When a client thread requires the output being
 * generated by the asynchronous task, the client thread can wait for the result
 * and be blocked until it has completed. Any wait may safely be called
 * concurrently, even on the same callback object. When the asynchronous task
 * has finished its workload, it must immediately call "notify*". If the
 * asynchronous task has failed to launch, the function that tried to launch the
 * asynchronous task must immediately call "notify*". This "notify*" call
 * awakens any client threads waiting on the callback object.
 *
 * These classes exist to enable synchronization across HIDL. When
 * synchronization is only required in the same process, consider using
 * std::future, std::mutex, std::condition_variable, or std::experimental::latch
 * instead.
 */

namespace android::nn {

/**
 * The PreparedModelCallback class is used to receive the error status of
 * preparing a model as well as the prepared model from a task executing
 * asynchronously with respect to the runtime. If a calling thread calls wait
 * or get* on a PreparedModelCallback object and the corresponding asynchronous
 * task has not finished preparing the model, the calling thread will block
 * until the asynchronous task has called notify*.
 *
 * If the callback object is notified more than once, only the results of the
 * first call to notify* are used, and the results from subsequent calls are
 * discarded.
 *
 * This callback object is passed as an argument to IDevice::prepareModel*.
 */
class PreparedModelCallback : public hal::IPreparedModelCallback {
   public:
    /**
     * IPreparedModelCallback::notify marks the callback object with the return
     * status of the asynchronous model preparation along with the prepared
     * model, and allows all prior and future wait calls on the
     * PreparedModelCallback object to proceed.
     *
     * One of IPreparedModelCallback::notify, IPreparedModelCallback::notify_1_2,
     * or IPreparedModelCallback::notify_1_3 must be called on a given
     * PreparedModelCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify* are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from asynchronously preparing the
     *     model; will be:
     *     - NONE if the asynchronous preparation was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if there is an unspecified error
     *     - INVALID_ARGUMENT if the input model is invalid
     * @param preparedModel Returned model that has been prepared for execution,
     *     nullptr if the model was unable to be prepared.
     */
    hal::Return<void> notify(hal::V1_0::ErrorStatus status,
                             const sp<hal::V1_0::IPreparedModel>& preparedModel) override;

    /**
     * IPreparedModelCallback::notify_1_2 marks the callback object with the
     * return status of the asynchronous model preparation along with the
     * prepared model, and allows all prior and future wait calls on the
     * PreparedModelCallback object to proceed.
     *
     * One of IPreparedModelCallback::notify, IPreparedModelCallback::notify_1_2,
     * or IPreparedModelCallback::notify_1_3 must be called on a given
     * PreparedModelCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify* are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from asynchronously preparing the
     *     model; will be:
     *     - NONE if the asynchronous preparation was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if there is an unspecified error
     *     - INVALID_ARGUMENT if the input model is invalid
     * @param preparedModel Returned model that has been prepared for execution,
     *     nullptr if the model was unable to be prepared.
     */
    hal::Return<void> notify_1_2(hal::V1_0::ErrorStatus status,
                                 const sp<hal::V1_2::IPreparedModel>& preparedModel) override;

    /**
     * IPreparedModelCallback::notify_1_3 marks the callback object with the
     * return status of the asynchronous model preparation along with the
     * prepared model, and allows all prior and future wait calls on the
     * PreparedModelCallback object to proceed.
     *
     * One of IPreparedModelCallback::notify, IPreparedModelCallback::notify_1_2,
     * or IPreparedModelCallback::notify_1_3 must be called on a given
     * PreparedModelCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify* are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from asynchronously preparing the
     *     model; will be:
     *     - NONE if the asynchronous preparation was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if there is an unspecified error
     *     - INVALID_ARGUMENT if the input model is invalid
     *     - MISSED_DEADLINE_* if the deadline could not be met
     *     - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
     * @param preparedModel Returned model that has been prepared for execution,
     *     nullptr if the model was unable to be prepared.
     */
    hal::Return<void> notify_1_3(hal::V1_3::ErrorStatus status,
                                 const sp<hal::V1_3::IPreparedModel>& preparedModel) override;

    /**
     * Mark the callback object as a dead object. This acts as a call to notify.
     */
    void notifyAsDeadObject();

    /**
     * PreparedModelCallback::wait blocks until notify* has been called on the
     * callback object.
     */
    void wait() const;

    /**
     * Retrieves the error status returned from the asynchronous task launched
     * by IDevice::prepareModel*. If IDevice::prepareModel* has not finished
     * asynchronously preparing the model, this call will block until the
     * asynchronous task notifies the object.
     *
     * @return status Error status returned from asynchronously preparing the
     *     model; will be:
     *     - NONE if the asynchronous preparation was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if there is an unspecified error
     *     - INVALID_ARGUMENT if the input model is invalid
     *     - MISSED_DEADLINE_* if the deadline could not be met
     *     - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
     *     - DEAD_OBJECT if the driver crashed without returning a result
     */
    hal::V1_3::ErrorStatus getStatus() const;

    /**
     * Retrieves the model that has been prepared for execution from the
     * asynchronous task launched by IDevice::prepareModel*. If
     * IDevice::prepareModel* has not finished asynchronously preparing the
     * model, this call will block until the asynchronous task notifies the
     * object.
     *
     * @return preparedModel Returned model that has been prepared for
     *     execution, nullptr if the model was unable to be prepared.
     */
    sp<hal::V1_0::IPreparedModel> getPreparedModel() const;

    /**
     * Queries whether the object is dead.
     *
     * @return 'true' if dead, 'false' otherwise.
     */
    bool isDeadObject() const;

   private:
    hal::Return<void> notifyInternal(bool deadObject, hal::ErrorStatus errorStatus,
                                     const sp<hal::V1_0::IPreparedModel>& preparedModel);

    mutable std::mutex mMutex;
    mutable std::condition_variable mCondition;
    bool mNotified GUARDED_BY(mMutex) = false;
    bool mDeadObject = false;
    hal::ErrorStatus mErrorStatus = hal::ErrorStatus::GENERAL_FAILURE;
    sp<hal::V1_0::IPreparedModel> mPreparedModel;
};

/**
 * The ExecutionCallback class is used to receive the results of the execution
 * from a task executing asynchronously with respect to the runtime. If a
 * calling thread calls wait or get* on a ExecutionCallback object and the
 * corresponding asynchronous task has not finished the execution, the calling
 * thread will block until the asynchronous task has called one of the notify*
 * methods.
 *
 * If the callback object is notified more than once, only the results of the
 * first call to notify* are used, and the results from subsequent calls are
 * discarded.
 *
 * This callback object is passed as an argument to IPreparedModel::execute*.
 */
class ExecutionCallback : public hal::IExecutionCallback {
    using ExecutionFinish =
            std::function<hal::ErrorStatus(hal::ErrorStatus, const std::vector<hal::OutputShape>&)>;

   public:
    /**
     * IExecutionCallback::notify marks the callback object with the return
     * status of the asynchronous execution that held this callback and enables
     * all prior and future wait calls on the ExecutionCallback object to
     * proceed.
     *
     * One of the IExecutionCallback::notify* methods must be called on a given
     * ExecutionCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify* are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from launching the asynchronous task
     *     (if the launch fails) or from the asynchronous task itself (if the
     *     launch succeeds). Must be:
     *     - NONE if the asynchronous execution was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if there is an unspecified error
     *     - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is not large
     *         enough to store the resultant values
     *     - INVALID_ARGUMENT if the input request is invalid
     */
    hal::Return<void> notify(hal::V1_0::ErrorStatus status) override;

    /**
     * IExecutionCallback::notify_1_2 marks the callback object with the results
     * (error status, dynamic output shapes, and timing information) of the
     * asynchronous execution that held this callback and enables all prior and
     * future wait calls on the ExecutionCallback object to proceed.
     *
     * One of the IExecutionCallback::notify* methods must be called on a given
     * ExecutionCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify* are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from launching the asynchronous task
     *     (if the launch fails) or from the asynchronous task itself (if the
     *     launch succeeds). Must be:
     *     - NONE if the asynchronous execution was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if the asynchronous task resulted in an unspecified
     *         error
     *     - OUTPUT_INSUFFICIENT_SIZE if at least one output operand buffer is
     *         not large enough to store the corresponding output
     *     - INVALID_ARGUMENT if one of the input arguments to prepareModel is
     *         invalid
     * @param outputShapes A list of shape information of model output operands.
     *     The index into "outputShapes" corresponds to the index of the output
     *     operand in the Request outputs vector. outputShapes must be empty
     *     unless the status is either NONE or OUTPUT_INSUFFICIENT_SIZE.
     * @param Timing Duration of execution. Unless MeasureTiming::YES was passed
     *     when launching the execution and status is NONE, all times must be
     *     reported as UINT64_MAX. A driver may choose to report any time as
     *     UINT64_MAX, indicating that particular measurement is not available.
     */
    hal::Return<void> notify_1_2(hal::V1_0::ErrorStatus status,
                                 const hal::hidl_vec<hal::OutputShape>& outputShapes,
                                 const hal::Timing& timing) override;

    /**
     * IExecutionCallback::notify_1_3 marks the callback object with the results
     * (error status, dynamic output shapes, and timing information) of the
     * asynchronous execution that held this callback and enables all prior and
     * future wait calls on the ExecutionCallback object to proceed.
     *
     * One of the IExecutionCallback::notify* methods must be called on a given
     * ExecutionCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify* are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from launching the asynchronous task
     *     (if the launch fails) or from the asynchronous task itself (if the
     *     launch succeeds). Must be:
     *     - NONE if the asynchronous execution was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if the asynchronous task resulted in an unspecified
     *         error
     *     - OUTPUT_INSUFFICIENT_SIZE if at least one output operand buffer is
     *         not large enough to store the corresponding output
     *     - INVALID_ARGUMENT if one of the input arguments to prepareModel is
     *         invalid
     *     - MISSED_DEADLINE_* if the deadline could not be met
     *     - RESOURCE_EXHAUSTED_* if the execution was aborted by the driver
     * @param outputShapes A list of shape information of model output operands.
     *     The index into "outputShapes" corresponds to the index of the output
     *     operand in the Request outputs vector. outputShapes must be empty
     *     unless the status is either NONE or OUTPUT_INSUFFICIENT_SIZE.
     * @param Timing Duration of execution. Unless MeasureTiming::YES was passed
     *     when launching the execution and status is NONE, all times must be
     *     reported as UINT64_MAX. A driver may choose to report any time as
     *     UINT64_MAX, indicating that particular measurement is not available.
     */
    hal::Return<void> notify_1_3(hal::V1_3::ErrorStatus status,
                                 const hal::hidl_vec<hal::OutputShape>& outputShapes,
                                 const hal::Timing& timing) override;

    // An overload of the latest notify interface to hide the version from ExecutionBuilder.
    hal::Return<void> notify(hal::V1_3::ErrorStatus status,
                             const hal::hidl_vec<hal::OutputShape>& outputShapes,
                             const hal::Timing& timing) {
        return notify_1_3(status, outputShapes, timing);
    }

    /**
     * Mark the callback object as a dead object. This acts as a call to notify.
     */
    void notifyAsDeadObject();

    /**
     * ExecutionCallback::wait blocks until notify* has been called on the
     * callback object.
     */
    void wait() const;

    /**
     * Retrieves the error status returned from the asynchronous task launched
     * by IPreparedModel::execute* (but not by
     * IPreparedModel::executeSynchronously*). If IPreparedModel::execute* has
     * not finished asynchronously executing, this call will block until the
     * asynchronous task notifies the object.
     *
     * @return status Error status returned from launching the asynchronous task
     *     (if the launch fails) or from the asynchronous task itself (if the
     *     launch succeeds). Must be:
     *     - NONE if the asynchronous execution was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if the asynchronous task resulted in an unspecified
     *         error
     *     - OUTPUT_INSUFFICIENT_SIZE if at least one output operand buffer is
     *         not large enough to store the corresponding output
     *     - INVALID_ARGUMENT if one of the input arguments to prepareModel is
     *         invalid
     *     - MISSED_DEADLINE_* if the deadline could not be met
     *     - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
     *     - DEAD_OBJECT if the driver crashed without returning a result
     */
    hal::V1_3::ErrorStatus getStatus() const;

    /**
     * Retrieves the output shapes returned from the asynchronous task launched
     * by either IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3. If
     * IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3 has not
     * finished asynchronously executing, this call will block until the
     * asynchronous task notifies the object.
     *
     * If the asynchronous task was launched by IPreparedModel::execute, an
     * empty vector will be returned.
     *
     * @return outputShapes A list of shape information of model output
     *     operands. The index into "outputShapes" corresponds to the index of
     *     the output operand in the Request outputs vector. outputShapes must
     *     be empty unless the status is either NONE or
     *     OUTPUT_INSUFFICIENT_SIZE. outputShaps may be empty if the status is
     *     NONE and all model output operands are fully-specified at execution
     *     time. outputShapes must have the same number of elements as the
     *     number of model output operands if the status is
     *     OUTPUT_INSUFFICIENT_SIZE, or if the status is NONE and the model has
     *     at least one output operand that is not fully-specified.
     */
    const std::vector<hal::OutputShape>& getOutputShapes() const;

    /**
     * Retrieves the duration of execution of the asynchronous task launched by
     * by either IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3. If
     * IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3 has not
     * finished asynchronously executing, this call will block until the
     * asynchronous task notifies the object.
     *
     * If the asynchronous task was launched by IPreparedModel::execute, every
     * time must be UINT64_MAX.
     *
     * @return timing Duration of the execution. Every time must be UINT64_MAX
     *     unless the status is NONE.
     */
    hal::Timing getTiming() const;

    /**
     * ExecutionCallback::bindThread binds a thread to the ExecutionCallback
     * object. The bound thread is later joined by ExecutionCallback::wait or
     * ExecutionCallback::get*.
     *
     * Once a thread is bound with ExecutionCallback::bindThread, the client
     * code must ensure that ExecutionCallback::wait or ExecutionCallback::get*
     * has been called before the ExecutionCallback object is destroyed.
     *
     * The bound thread must not call any ExecutionCallback method with the
     * exception of ExecutionCallback::notify*, which it must call when the
     * thread has finished its computation.
     *
     * ExecutionCallback::bindThread can be called at most once on a given
     * callback object.
     *
     * @param asyncThread Thread to be bound to the callback object. The thread
     *     object must represent a thread of execution -- i.e.,
     *     std::thread::joinable() must be true.
     * @return bool True if successful, false if thread was not properly bound.
     */
    bool bindThread(std::thread asyncThread);

    /**
     * ExecutionCallback::setOnFinish binds a callback to the ExecutionCallback
     * object that will be executed during one of the ExecutionCallback::notify*
     * calls but before any calls to wait or get* return. This provided callback
     * is provided with both the ErrorStatus and the output shapes from
     * ExecutionCallback::notify*.
     *
     * The bound function must not synchronize with or otherwise access the
     * callback object it is bound to, as this could cause a deadlock.
     *
     * This call will not bind the provided callback if any of the following
     * occur:
     * (1) the provided callback is invalid (i.e., "(bool) finish" is false)
     * (2) ExecutionCallback already contains a bound callback
     * (3) ExecutionCallback has already been notified with results
     *
     * @param finish Callback to be executed when ExecutionCallback is notified
     *     with results.
     */
    void setOnFinish(const ExecutionFinish& finish);

    /**
     * Queries whether the object is dead.
     *
     * @return 'true' if dead, 'false' otherwise.
     */
    bool isDeadObject() const;

   private:
    /*
     * ExecutionCallback::notifyInternal stores the results of the execution
     * (status, output shapes, and timing information) in the ExecutionCallback
     * object and invokes the bound callback function "mOnFinish" (if present)
     * before any call to wait or get* return. It then enables all prior and
     * future wait calls on the ExecutionCallback object to proceed.
     */
    hal::Return<void> notifyInternal(bool deadObject, hal::ErrorStatus errorStatus,
                                     std::vector<hal::OutputShape> outputShapes,
                                     hal::Timing timing);

    // members
    mutable std::mutex mMutex;
    mutable std::condition_variable mCondition;
    mutable std::thread mThread GUARDED_BY(mMutex);
    ExecutionFinish mOnFinish GUARDED_BY(mMutex);
    bool mNotified GUARDED_BY(mMutex) = false;
    bool mDeadObject = false;
    hal::ErrorStatus mErrorStatus = hal::ErrorStatus::GENERAL_FAILURE;
    std::vector<hal::OutputShape> mOutputShapes;
    hal::Timing mTiming = {};
};

}  // namespace android::nn

#endif  // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_CALLBACKS_H