summaryrefslogtreecommitdiff
path: root/nn/runtime/Callbacks.h
blob: fb9b29d17ab0ef5a803d849270faa752574b6256 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
#define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H

#include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
#include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
#include <chrono>
#include <condition_variable>
#include <functional>
#include <hidl/MQDescriptor.h>
#include <hidl/Status.h>
#include <mutex>
#include <thread>

namespace android {
namespace hardware {
namespace neuralnetworks {
namespace V1_0 {
namespace implementation {

using ::android::hardware::hidl_array;
using ::android::hardware::hidl_memory;
using ::android::hardware::hidl_string;
using ::android::hardware::hidl_vec;
using ::android::hardware::Return;
using ::android::hardware::Void;
using ::android::sp;

/**
 * The CallbackBase class is used internally by the NeuralNetworks runtime to
 * synchronize between different threads. An asynchronous task is launched
 * paired with a callback object. When a client thread requires the output being
 * generated by the asynchronous task, the client thread can wait for the result
 * and be blocked until it has completed or a timeout condition has been
 * reached. Any wait* may safely be called concurrently, even on the same
 * callback object. When the asynchronous task has finished its workload, it
 * must immediately call "notify". If the asynchronous task has failed to launch,
 * the function that tried to launch the asynchronous task must immediately call
 * "notify". This "notify" call awakens any client threads waiting on the
 * callback object.
 *
 * The CallbackBase class implements some of the base synchronization common to
 * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
 * callback class must inherit from CallbackBase as well as the HIDL callback
 * interface it implements.
 *
 * This class exists to enable synchronization across HIDL. When synchronization
 * is only required in the same process, consider using std::future, std::mutex,
 * std::condition_variable, or std::experimental::latch instead.
 */
class CallbackBase {
 public:
    CallbackBase();
    ~CallbackBase();

    /**
     * CallbackBase::wait blocks until notify has been called on the callback
     * object.
     */
    void wait();

    /**
     * CallbackBase::wait_for blocks until notify has been called on the
     * callback object or the time duration from the time the wait_for function
     * was called has expired, whichever comes first.
     *
     * @return Status std::cv_status::no_timeout if the callback was notified
     *                before the time duration expired, std::cv_status::timeout
     *                otherwise.
     */
    template<class Rep, class Period>
    std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);

    /**
     * CallbackBase::on_finish binds a function to the callback object. This
     * bound function will be executed when CallbackBase::notify is called,
     * before any calls to wait* return. (Note that CallbackBase::wait_for can
     * return std::cv_status::timeout before CallbackBase::notify is called for
     * the first time, and hence before the bound function is executed.)
     *
     * The bound function must not synchronize with or otherwise access the
     * callback object it is bound to, as this could cause a deadlock.
     *
     * CallbackBase::on_finish can be called at most once on a given callback
     * object, and the call to CallbackBase::on_finish must finish before
     * CallbackBase::notify is called.
     *
     * @param post_work Function to be invoked the first time
     *                  CallbackBase::notify is called. Must have a target --
     *                  i.e., must not compare equal to nullptr. post_work
     *                  returns true if it successfully completes, false if it
     *                  fails.
     * @return bool True if the function was successfully bound, false if
     *              unsuccessful.
     *
     * TODO: Why does the return value of the callback matter?
     */
    bool on_finish(std::function<bool(void)> post_work);

    /**
     * CallbackBase::bind_thread binds a thread to the event for later use by
     * CallbackBase::join_thread.
     *
     * The thread must be passed using std::move.
     *
     * Once a thread is bound with CallbackBase::bind_thread, the client code
     * should ensure that one of the following occurs before the event is
     * destroyed:
     * - CallbackBase::join_thread has been called.
     * - CallbackBase::wait has been called.
     * - CallbackBase::wait_for has been called and returned other than
     *   std::cv_status::no_timeout.
     *
     * The bound thread shall not call any CallbackBase method with the
     * exception of CallbackBase::notify, which it must call when the thread has
     * finished its computation.
     *
     * CallbackBase::bind_thread can be called at most once on a given callback
     * object.
     *
     * @param asyncThread Thread to be bound to the callback object. The thread
     *                    object must represent a thread of execution -- i.e.,
     *                    asyncThread.joinable() must be true.
     * @return bool True if successful, false if thread was not properly bound.
     */
    bool bind_thread(std::thread&& asyncThread);

    /**
     * CallbackBase::join_thread ensures that the thread (if any) bound to this
     * event with CallbackBase::bind_thread has fully finished and cleaned its
     * resources. It is legal to call this function multiple times, concurrently
     * or sequentially.
     */
    void join_thread();

 protected:
    /**
     * CallbackBase::notify enables all prior and future wait* calls on the
     * callback object to proceed. The call to CallbackBase::notify happens
     * before any wait* calls on this callback object return (except in the case
     * of wait_for timing out). The asynchronous call the callback object is
     * paired with must ensure that any update to state that should be visible
     * to the caller of wait* happens before the call to CallbackBase::notify.
     *
     * CallbackBase::notify must be called exactly once on a given callback
     * object.
     */
    void notify();

 private:
    // Same as CallbackBase::join_thread but assumes we already hold a lock on
    // mMutex.
    void join_thread_locked();

    bool                      mNotified;
    std::mutex                mMutex;
    std::condition_variable   mCondition;
    std::function<bool(void)> mPostWork;
    std::thread               mThread;
};

/**
 * The PreparedModelCallback class is used to receive the error status of
 * preparing a model as well as the prepared model from a task executing
 * asynchronously with respect to the runtime. If a calling thread calls wait*
 * or get* on a PreparedModelCallback object and the corresponding asynchronous
 * task has not finished preparing the model, the calling thread will block
 * until the asynchronous task has called notify. For more information on the
 * synchronization behavior, refer to the CallbackBase class.
 *
 * This class inherits the basic blocking and signaling calls from
 * CallbackBase, and implements the HIDL notify call from
 * IPreparedModelCallback. This callback object is passed as an argument to
 * IDevice::prepareModel.
 */
class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
 public:
    PreparedModelCallback();
    ~PreparedModelCallback() override;

    /**
     * IPreparedModelCallback::notify marks the callback object with the return
     * status of the asynchronous model preparation along with the prepared
     * model, and calls CallbackBase::notify, enabling all prior and future
     * wait* calls on the PreparedModelCallback object to proceed. For more
     * information on the synchronization behavior, refer to the CallbackBase
     * class.
     *
     * IPreparedModelCallback::notify must be called exactly once on a given
     * PreparedModelCallback object.
     *
     * @param status Error status returned from asynchronously preparing the
     *               model; will be:
     *               - NONE if the asynchronous preparation was successful
     *               - DEVICE_UNAVAILABLE if driver is offline or busy
     *               - GENERAL_FAILURE if there is an unspecified error
     *               - INVALID_ARGUMENT if the input model is invalid
     * @param preparedModel Returned model that has been prepared for execution,
     *                      nullptr if the model was unable to be prepared.
     */
    Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;

    /**
     * Retrieves the error status returned from the asynchronous task launched
     * by IDevice::prepareModel. If IDevice::prepareModel has not finished
     * asynchronously preparing the model, this call will block until the
     * asynchronous task notifies the object.
     *
     * @return status Error status returned from asynchronously preparing the
     *                model; will be:
     *                - NONE if the asynchronous preparation was successful
     *                - DEVICE_UNAVAILABLE if driver is offline or busy
     *                - GENERAL_FAILURE if there is an unspecified error
     *                - INVALID_ARGUMENT if the input model is invalid
     */
    ErrorStatus getStatus();

    /**
     * Retrieves the model that has been prepared for execution from the
     * asynchronous task launched by IDevice::prepareModel. If
     * IDevice::prepareModel has not finished asynchronously preparing the
     * model, this call will block until the asynchronous task notifies the
     * object.
     *
     * @return preparedModel Returned model that has been prepared for
     *                       execution, nullptr if the model was unable to be
     *                       prepared.
     */
    sp<IPreparedModel> getPreparedModel();

 private:
    ErrorStatus        mErrorStatus;
    sp<IPreparedModel> mPreparedModel;
};

/**
 * The ExecutionCallback class is used to receive the error status of the
 * execution from a task executing asynchronously with respect to the runtime.
 * If a calling thread calls wait* or get* on a PreparedModelCallback object and
 * the corresponding asynchronous task has not finished the execution, the
 * calling thread will block until the asynchronous task has called notify. For
 * more information on the synchronization behavior, refer to the CallbackBase
 * class.
 *
 * This class inherits the basic blocking and signaling calls from
 * CallbackBase, and implements the HIDL notify call from
 * IExecutionCallback. This callback object is passed as an argument to
 * IPreparedModel::execute.
 */
class ExecutionCallback : public CallbackBase,  public IExecutionCallback {
 public:
    ExecutionCallback();
    ~ExecutionCallback() override;

    /**
     * IExecutionCallback::notify marks the callback object with the return
     * status of the asynchronous execution that held this callback and enables
     * all prior and future wait* calls on the ExecutionCallback object to
     * proceed. For more information on the synchronization behavior, refer to
     * the CallbackBase class.
     *
     * IExecutionCallback::notify must be called exactly once on a given
     * ExecutionCallback object.
     *
     * @param status Error status returned from asynchronously preparing the
     *               model; will be:
     *               - NONE if the asynchronous execution was successful
     *               - DEVICE_UNAVAILABLE if driver is offline or busy
     *               - GENERAL_FAILURE if there is an unspecified error
     *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
     *                 not large enough to store the resultant values
     *               - INVALID_ARGUMENT if the input request is invalid
     */
    Return<void> notify(ErrorStatus status) override;

    /**
     * Retrieves the error status returned from the asynchronous task launched
     * by IPreparedModel::execute. If IPreparedModel::execute has not finished
     * asynchronously executing, this call will block until the asynchronous task
     * notifies the object.
     *
     * @return status Error status returned from asynchronously preparing the
     *                model; will be:
     *                - NONE if the asynchronous execution was successful
     *                - DEVICE_UNAVAILABLE if driver is offline or busy
     *                - GENERAL_FAILURE if there is an unspecified error
     *                - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
     *                  not large enough to store the resultant values
     *                - INVALID_ARGUMENT if the input request is invalid
     */
    ErrorStatus getStatus();

 private:
    ErrorStatus mErrorStatus;
};


// template function implementation(s) below this point

template<class Rep, class Period>
std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
    std::unique_lock<std::mutex> lock(mMutex);
    std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
    if (status != std::cv_status::timeout) {
        join_thread_locked();
    }
    return status;
}

}  // namespace implementation
}  // namespace V1_0
}  // namespace neuralnetworks
}  // namespace hardware
}  // namespace android

#endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H