summaryrefslogtreecommitdiff
path: root/nn/common/include/OperationResolver.h
blob: ab70e4c3a6bbb337fb1938b319ab73c2a6475801 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H

#include "HalInterfaces.h"
#include "OperationsUtils.h"

namespace android {
namespace nn {

// Encapsulates an operation implementation.
struct OperationRegistration {
    hal::OperationType type;
    const char* name;

    // Validates operand types, shapes, and any values known during graph creation.
    std::function<bool(const IOperationValidationContext*)> validate;

    // prepare is called when the inputs this operation depends on have been
    // computed. Typically, prepare does any remaining validation and sets
    // output shapes via context->setOutputShape(...).
    std::function<bool(IOperationExecutionContext*)> prepare;

    // Executes the operation, reading from context->getInputBuffer(...)
    // and writing to context->getOutputBuffer(...).
    std::function<bool(IOperationExecutionContext*)> execute;

    struct Flag {
        // Whether the operation allows at least one operand to be omitted.
        bool allowOmittedOperand = false;
        // Whether the operation allows at least one input operand to be a zero-sized tensor.
        bool allowZeroSizedInput = false;
    } flags;

    OperationRegistration(hal::OperationType type, const char* name,
                          std::function<bool(const IOperationValidationContext*)> validate,
                          std::function<bool(IOperationExecutionContext*)> prepare,
                          std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
        : type(type),
          name(name),
          validate(validate),
          prepare(prepare),
          execute(execute),
          flags(flags) {}
};

// A registry of operation implementations.
class IOperationResolver {
   public:
    virtual const OperationRegistration* findOperation(hal::OperationType operationType) const = 0;
    virtual ~IOperationResolver() {}
};

// A registry of builtin operation implementations.
//
// Note that some operations bypass BuiltinOperationResolver (b/124041202).
//
// Usage:
//   const OperationRegistration* operationRegistration =
//           BuiltinOperationResolver::get()->findOperation(operationType);
//   NN_RET_CHECK(operationRegistration != nullptr);
//   NN_RET_CHECK(operationRegistration->validate != nullptr);
//   NN_RET_CHECK(operationRegistration->validate(&context));
//
class BuiltinOperationResolver : public IOperationResolver {
    DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);

   public:
    static const BuiltinOperationResolver* get() {
        static BuiltinOperationResolver instance;
        return &instance;
    }

    const OperationRegistration* findOperation(hal::OperationType operationType) const override;

   private:
    BuiltinOperationResolver();

    void registerOperation(const OperationRegistration* operationRegistration);

    const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
};

// NN_REGISTER_OPERATION creates OperationRegistration for consumption by
// OperationResolver.
//
// Usage:
// (check OperationRegistration::Flag for available fields and default values.)
//
// - With default flags.
//   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
//                         foo_op::prepare, foo_op::execute);
//
// - With a customized flag.
//   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
//                         foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
//
// - With multiple customized flags.
//   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
//                         foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
//                         .allowZeroSizedInput = true);
//
#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
#define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...)        \
    const OperationRegistration* register_##identifier() {                                       \
        static OperationRegistration registration(hal::OperationType::identifier, operationName, \
                                                  validate, prepare, execute, {__VA_ARGS__});    \
        return &registration;                                                                    \
    }
#else
// This version ignores CPU execution logic (prepare and execute).
// The compiler is supposed to omit that code so that only validation logic
// makes it into libneuralnetworks_utils.
#define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
                              ...)                                                                 \
    const OperationRegistration* register_##identifier() {                                         \
        static OperationRegistration registration(hal::OperationType::identifier, operationName,   \
                                                  validate, nullptr, nullptr, {__VA_ARGS__});      \
        return &registration;                                                                      \
    }
#endif

}  // namespace nn
}  // namespace android

#endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H