summaryrefslogtreecommitdiff
path: root/nn/common
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2020-10-06 21:22:49 +0000
committerGerrit Code Review <noreply-gerritcodereview@google.com>2020-10-06 21:22:49 +0000
commite238d86505854623f9a39bf480b1aa5c78151f61 (patch)
tree3450bcddea95ab9d30968a372a07c05be7a4370e /nn/common
parent9a4f318c3827ad986bccbc8576d79f73f488ea9e (diff)
parentbfef4a044a9176fdb0c1c5ac1fd90e67e2022b7a (diff)
downloadml-e238d86505854623f9a39bf480b1aa5c78151f61.tar.gz
Merge changes from topic "nnapi-canonical-types"
* changes: Introduce Result type into NNAPI Implement canonical types in NNAPI
Diffstat (limited to 'nn/common')
-rw-r--r--nn/common/Android.bp21
-rw-r--r--nn/common/SharedMemory.cpp109
-rw-r--r--nn/common/SharedMemoryAndroid.cpp278
-rw-r--r--nn/common/SharedMemoryHost.cpp161
-rw-r--r--nn/common/TypeUtils.cpp849
-rw-r--r--nn/common/Types.cpp8
-rw-r--r--nn/common/Validation.cpp2664
-rw-r--r--nn/common/include/nnapi/Result.h147
-rw-r--r--nn/common/include/nnapi/SharedMemory.h93
-rw-r--r--nn/common/include/nnapi/TypeUtils.h112
-rw-r--r--nn/common/include/nnapi/Types.h5
-rw-r--r--nn/common/include/nnapi/Validation.h97
12 files changed, 4537 insertions, 7 deletions
diff --git a/nn/common/Android.bp b/nn/common/Android.bp
index 65d993247..202c69ed1 100644
--- a/nn/common/Android.bp
+++ b/nn/common/Android.bp
@@ -243,8 +243,27 @@ cc_library_static {
name: "neuralnetworks_types",
defaults: ["neuralnetworks_utils_defaults"],
srcs: [
+ "SharedMemory.cpp",
+ "TypeUtils.cpp",
"Types.cpp",
- ],
+ "Validation.cpp",
+ ],
+ target: {
+ android: {
+ srcs: ["SharedMemoryAndroid.cpp"],
+ shared_libs: [
+ "android.hidl.allocator@1.0",
+ "android.hidl.memory@1.0",
+ "libhidlbase",
+ "libhidlmemory",
+ "libnativewindow",
+ ],
+ static_libs: ["libarect"],
+ },
+ host: {
+ srcs: ["SharedMemoryHost.cpp"],
+ },
+ },
local_include_dirs: ["include/nnapi"],
export_include_dirs: ["include"],
shared_libs: [
diff --git a/nn/common/SharedMemory.cpp b/nn/common/SharedMemory.cpp
new file mode 100644
index 000000000..9186be20e
--- /dev/null
+++ b/nn/common/SharedMemory.cpp
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <android-base/logging.h>
+
+#include <limits>
+#include <optional>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "Result.h"
+#include "SharedMemory.h"
+#include "Types.h"
+
+namespace android::nn {
+namespace {
+
+constexpr size_t safeDivideRoundedUp(size_t numerator, size_t denominator) {
+ CHECK_NE(denominator, 0u);
+ CHECK_LE(numerator, std::numeric_limits<size_t>::max() - denominator);
+ return (numerator + denominator - 1) / denominator;
+}
+
+constexpr size_t safeMultiply(size_t a, size_t b) {
+ if (b == 0) {
+ return 0;
+ }
+ CHECK_LE(a, std::numeric_limits<size_t>::max() / b);
+ return a * b;
+}
+
+constexpr size_t extendToAlignment(size_t length) {
+ constexpr size_t kMaxAlignment = alignof(AlignedData);
+ return safeMultiply(safeDivideRoundedUp(length, kMaxAlignment), kMaxAlignment);
+}
+
+} // namespace
+
+MutableMemoryBuilder::MutableMemoryBuilder(uint32_t poolIndex) : mPoolIndex(poolIndex) {}
+
+DataLocation MutableMemoryBuilder::append(size_t length) {
+ CHECK_GT(length, 0u);
+ const size_t offset = mSize;
+ mSize += extendToAlignment(length);
+ CHECK_LE(offset, std::numeric_limits<uint32_t>::max());
+ CHECK_LE(length, std::numeric_limits<uint32_t>::max());
+ return {.poolIndex = mPoolIndex,
+ .offset = static_cast<uint32_t>(offset),
+ .length = static_cast<uint32_t>(length)};
+}
+
+bool MutableMemoryBuilder::empty() const {
+ return mSize == 0;
+}
+
+Result<Memory> MutableMemoryBuilder::finish() {
+ return createSharedMemory(mSize);
+}
+
+ConstantMemoryBuilder::ConstantMemoryBuilder(uint32_t poolIndex) : mBuilder(poolIndex) {}
+
+DataLocation ConstantMemoryBuilder::append(const void* data, size_t length) {
+ const auto location = mBuilder.append(length);
+ CHECK_EQ(location.length, length);
+ mSlices.push_back({.data = data, .length = length, .offset = location.offset});
+ return location;
+}
+
+bool ConstantMemoryBuilder::empty() const {
+ return mBuilder.empty();
+}
+
+Result<Memory> ConstantMemoryBuilder::finish() {
+ // Allocate the memory.
+ auto memory = NN_TRY(mBuilder.finish());
+
+ // Map the memory.
+ const auto [pointer, size, context] = NN_TRY(map(memory););
+
+ // Get mutable pointer.
+ if (!std::holds_alternative<void*>(pointer)) {
+ return NN_ERROR()
+ << "MemoryBuilder::finish failed because the mapped pointer is not mutable";
+ }
+ uint8_t* mutablePointer = static_cast<uint8_t*>(std::get<void*>(pointer));
+
+ // Copy data to the memory pool.
+ std::for_each(mSlices.begin(), mSlices.end(), [mutablePointer](const auto& slice) {
+ std::memcpy(mutablePointer + slice.offset, slice.data, slice.length);
+ });
+
+ return memory;
+}
+
+} // namespace android::nn
diff --git a/nn/common/SharedMemoryAndroid.cpp b/nn/common/SharedMemoryAndroid.cpp
new file mode 100644
index 000000000..caa83a6e8
--- /dev/null
+++ b/nn/common/SharedMemoryAndroid.cpp
@@ -0,0 +1,278 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <android-base/logging.h>
+#include <android-base/mapped_file.h>
+#include <android-base/scopeguard.h>
+#include <android/hardware_buffer.h>
+#include <android/hidl/allocator/1.0/IAllocator.h>
+#include <cutils/native_handle.h>
+#include <hidl/HidlSupport.h>
+#include <hidlmemory/mapping.h>
+#include <sys/mman.h>
+#include <vndk/hardware_buffer.h>
+
+#include <any>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "Result.h"
+#include "SharedMemory.h"
+#include "TypeUtils.h"
+#include "Types.h"
+
+namespace android::nn {
+namespace {
+
+using ::android::hardware::hidl_memory;
+using ::android::hidl::allocator::V1_0::IAllocator;
+
+const char* const kAllocatorService = "ashmem";
+
+Result<hidl_memory> allocateSharedMemory(size_t size) {
+ static const auto allocator = IAllocator::getService(kAllocatorService);
+ CHECK_GT(size, 0u);
+
+ hidl_memory maybeMemory;
+ auto fn = [&maybeMemory](bool success, const hidl_memory& memory) {
+ if (success) {
+ maybeMemory = memory;
+ }
+ };
+ allocator->allocate(size, fn);
+
+ if (!maybeMemory.valid()) {
+ return NN_ERROR() << "IAllocator::allocate returned an invalid (empty) memory object";
+ }
+
+ return maybeMemory;
+}
+
+Memory createMemory(const hidl_memory& memory) {
+ CHECK_LE(memory.size(), std::numeric_limits<uint32_t>::max());
+
+ auto* cloned = native_handle_clone(memory.handle());
+ auto nativeHandle = ::android::NativeHandle::create(cloned, /*ownsHandle=*/true);
+
+ return {
+ .handle = std::move(nativeHandle),
+ .size = static_cast<uint32_t>(memory.size()),
+ .name = memory.name(),
+ };
+}
+
+hidl_memory createHidlMemory(const Memory& memory) {
+ const auto hidlMemory = hidl_memory(memory.name, memory.handle->handle(), memory.size);
+ // Copy memory to force the native_handle_t to be copied.
+ auto copiedMemory = hidlMemory;
+ return copiedMemory;
+}
+
+Result<Mapping> mapAshmem(const Memory& memory) {
+ const auto hidlMemory = createHidlMemory(memory);
+ const auto mapping = mapMemory(hidlMemory);
+ if (mapping == nullptr) {
+ return NN_ERROR() << "Failed to map memory";
+ }
+ auto* const pointer = mapping->getPointer().withDefault(nullptr);
+ if (pointer == nullptr) {
+ return NN_ERROR() << "Failed to get the mapped pointer";
+ }
+ const auto fullSize = mapping->getSize().withDefault(0);
+ if (fullSize == 0 || fullSize > std::numeric_limits<size_t>::max()) {
+ return NN_ERROR() << "Failed to get the mapped size";
+ }
+ const size_t size = static_cast<size_t>(fullSize);
+
+ return Mapping{
+ .pointer = pointer,
+ .size = size,
+ .context = mapping,
+ };
+}
+
+struct MmapFdMappingContext {
+ int prot;
+ std::any context;
+};
+
+Result<Mapping> mapMemFd(const Memory& memory) {
+ const size_t size = memory.size;
+ const native_handle_t* handle = memory.handle->handle();
+ const int fd = handle->data[0];
+ const int prot = handle->data[1];
+ const size_t offset = getOffsetFromInts(handle->data[2], handle->data[3]);
+
+ std::shared_ptr<base::MappedFile> mapping = base::MappedFile::FromFd(fd, offset, size, prot);
+ if (mapping == nullptr) {
+ return NN_ERROR() << "Can't mmap the file descriptor.";
+ }
+ void* data = mapping->data();
+
+ auto context = MmapFdMappingContext{.prot = prot, .context = std::move(mapping)};
+ return Mapping{.pointer = data, .size = size, .context = std::move(context)};
+}
+
+Result<Mapping> mapAhwbBlobMemory(const Memory& memory) {
+ const auto* handle = memory.handle->handle();
+ const auto size = memory.size;
+ const auto format = AHARDWAREBUFFER_FORMAT_BLOB;
+ const auto usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
+ const uint32_t width = size;
+ const uint32_t height = 1; // height is always 1 for BLOB mode AHardwareBuffer.
+ const uint32_t layers = 1; // layers is always 1 for BLOB mode AHardwareBuffer.
+ const uint32_t stride = size;
+
+ AHardwareBuffer_Desc desc{
+ .width = width,
+ .height = height,
+ .layers = layers,
+ .format = format,
+ .usage = usage,
+ .stride = stride,
+ };
+
+ AHardwareBuffer* hardwareBuffer = nullptr;
+ status_t status = AHardwareBuffer_createFromHandle(
+ &desc, handle, AHARDWAREBUFFER_CREATE_FROM_HANDLE_METHOD_CLONE, &hardwareBuffer);
+ if (status != NO_ERROR) {
+ return NN_ERROR() << "Can't create AHardwareBuffer from handle. Error: " << status;
+ }
+
+ void* data = nullptr;
+ status = AHardwareBuffer_lock(hardwareBuffer, usage, -1, nullptr, &data);
+ if (status != NO_ERROR) {
+ return NN_ERROR() << "Can't lock the AHardwareBuffer. Error: " << status;
+ // TODO(b/169166682): do we need to call AHardwareBuffer_release?
+ }
+
+ // Create shared scoped object to munmap.
+ auto scoped = base::make_scope_guard([hardwareBuffer] {
+ AHardwareBuffer_unlock(hardwareBuffer, nullptr);
+ if (hardwareBuffer != nullptr) {
+ AHardwareBuffer_release(hardwareBuffer);
+ }
+ });
+ auto sharedScoped = std::make_shared<decltype(scoped)>(std::move(scoped));
+
+ return Mapping{.pointer = data, .size = size, .context = std::move(sharedScoped)};
+}
+
+Result<Mapping> mapAhwbMemory(const Memory& /*memory*/) {
+ return NN_ERROR() << "Unable to map non-BLOB AHardwareBuffer memory";
+}
+
+} // namespace
+
+Result<Memory> createSharedMemory(size_t size) {
+ const auto memory = NN_TRY(allocateSharedMemory(size));
+ return createSharedMemoryFromHidlMemory(memory);
+}
+
+Result<Memory> createSharedMemoryFromFd(size_t size, int prot, int fd, size_t offset) {
+ if (size == 0 || fd < 0) {
+ return NN_ERROR() << "Invalid size or fd";
+ }
+
+ // Duplicate the file descriptor so the resultant Memory owns its own version.
+ int dupfd = dup(fd);
+ if (dupfd == -1) {
+ // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here?
+ return NN_ERROR() << "Failed to dup the fd";
+ }
+
+ // Create a temporary native handle to own the dupfd.
+ native_handle_t* nativeHandle = native_handle_create(1, 3);
+ if (nativeHandle == nullptr) {
+ close(dupfd);
+ // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here?
+ return NN_ERROR() << "Failed to create native_handle";
+ }
+
+ const auto [lowOffsetBits, highOffsetBits] = getIntsFromOffset(offset);
+ nativeHandle->data[0] = dupfd;
+ nativeHandle->data[1] = prot;
+ nativeHandle->data[2] = lowOffsetBits;
+ nativeHandle->data[3] = highOffsetBits;
+
+ // Create a NativeHandle which owns the native handle and fd so that we don't have to manually
+ // clean either the native handle or the fd.
+ auto ownedHandle = ::android::NativeHandle::create(nativeHandle, /*ownsHandle=*/true);
+
+ return Memory{.handle = std::move(ownedHandle), .size = size, .name = "mmap_fd"};
+}
+
+Result<Memory> createSharedMemoryFromHidlMemory(const hardware::hidl_memory& memory) {
+ return createMemory(memory);
+}
+
+Result<Memory> createSharedMemoryFromAHWB(const AHardwareBuffer& ahwb) {
+ AHardwareBuffer_Desc bufferDesc;
+ AHardwareBuffer_describe(&ahwb, &bufferDesc);
+ const native_handle_t* handle = AHardwareBuffer_getNativeHandle(&ahwb);
+
+ auto* cloned = native_handle_clone(handle);
+ auto nativeHandle = ::android::NativeHandle::create(cloned, /*ownsHandle=*/true);
+
+ if (bufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) {
+ return Memory{
+ .handle = std::move(nativeHandle),
+ .size = bufferDesc.width,
+ .name = "hardware_buffer_blob",
+ };
+ }
+
+ // memory size is not used for non-BLOB AHWB memory.
+ return Memory{.handle = std::move(nativeHandle), .size = 0, .name = "hardware_buffer"};
+}
+
+Result<Mapping> map(const Memory& memory) {
+ if (memory.name == "ashmem") {
+ return mapAshmem(memory);
+ }
+ if (memory.name == "mmap_fd") {
+ return mapMemFd(memory);
+ }
+ if (memory.name == "hardware_buffer_blob") {
+ return mapAhwbBlobMemory(memory);
+ }
+ if (memory.name == "hardware_buffer") {
+ return mapAhwbMemory(memory);
+ }
+ return NN_ERROR() << "Cannot map unknown memory " << memory.name;
+}
+
+bool flush(const Mapping& mapping) {
+ if (const auto* mmapFdMapping = std::any_cast<MmapFdMappingContext>(&mapping.context)) {
+ if (!std::holds_alternative<void*>(mapping.pointer)) {
+ return true;
+ }
+ void* data = std::get<void*>(mapping.pointer);
+ const int prot = mmapFdMapping->prot;
+ if (prot & PROT_WRITE) {
+ const size_t size = mapping.size;
+ return msync(data, size, MS_SYNC) == 0;
+ }
+ }
+ // No-op for other types of memory.
+ return true;
+}
+
+} // namespace android::nn
diff --git a/nn/common/SharedMemoryHost.cpp b/nn/common/SharedMemoryHost.cpp
new file mode 100644
index 000000000..bc29d1ff1
--- /dev/null
+++ b/nn/common/SharedMemoryHost.cpp
@@ -0,0 +1,161 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <android-base/logging.h>
+#include <android-base/mapped_file.h>
+#include <cutils/ashmem.h>
+#include <cutils/native_handle.h>
+#include <sys/mman.h>
+
+#include <limits>
+#include <memory>
+#include <utility>
+
+#include "Result.h"
+#include "SharedMemory.h"
+#include "TypeUtils.h"
+#include "Types.h"
+
+namespace android::nn {
+namespace {
+
+Result<Mapping> mapAshmem(const Memory& memory) {
+ CHECK_LE(memory.size, std::numeric_limits<uint32_t>::max());
+ const auto size = memory.size;
+
+ int fd = memory.handle->handle()->data[0];
+ std::shared_ptr<base::MappedFile> mapping =
+ base::MappedFile::FromFd(fd, /*offset=*/0, size, PROT_READ | PROT_WRITE);
+ if (mapping == nullptr) {
+ return NN_ERROR() << "Can't mmap the file descriptor.";
+ }
+ void* data = mapping->data();
+
+ return Mapping{.pointer = data, .size = size, .context = std::move(mapping)};
+}
+
+struct MmapFdMappingContext {
+ int prot;
+ std::any context;
+};
+
+Result<Mapping> mapMemFd(const Memory& memory) {
+ const size_t size = memory.size;
+ const native_handle_t* handle = memory.handle->handle();
+ const int fd = handle->data[0];
+ const int prot = handle->data[1];
+ const size_t offset = getOffsetFromInts(handle->data[2], handle->data[3]);
+
+ std::shared_ptr<base::MappedFile> mapping = base::MappedFile::FromFd(fd, offset, size, prot);
+ if (mapping == nullptr) {
+ return NN_ERROR() << "Can't mmap the file descriptor.";
+ }
+ void* data = mapping->data();
+
+ auto context = MmapFdMappingContext{.prot = prot, .context = std::move(mapping)};
+ return Mapping{.pointer = data, .size = size, .context = std::move(context)};
+}
+
+} // namespace
+
+Result<Memory> createSharedMemory(size_t size) {
+ int fd = ashmem_create_region("NnapiAshmem", size);
+ if (fd < 0) {
+ return NN_ERROR() << "ashmem_create_region(" << size << ") fails with " << fd;
+ }
+
+ native_handle_t* handle = native_handle_create(1, 0);
+ if (handle == nullptr) {
+ // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here?
+ return NN_ERROR() << "Failed to create native_handle";
+ }
+ handle->data[0] = fd;
+
+ // Create a NativeHandle which owns the native handle and fd so that we don't have to manually
+ // clean either the native handle or the fd.
+ auto nativeHandle = ::android::NativeHandle::create(handle, /*ownsHandle=*/true);
+
+ return Memory{.handle = std::move(nativeHandle), .size = size, .name = "ashmem"};
+}
+
+Result<Memory> createSharedMemoryFromFd(size_t size, int prot, int fd, size_t offset) {
+ if (size == 0 || fd < 0) {
+ return NN_ERROR() << "Invalid size or fd";
+ }
+
+ // Duplicate the file descriptor so the resultant Memory owns its own version.
+ int dupfd = dup(fd);
+ if (dupfd == -1) {
+ // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here?
+ return NN_ERROR() << "Failed to dup the fd";
+ }
+
+ // Create a temporary native handle to own the dupfd.
+ native_handle_t* nativeHandle = native_handle_create(1, 3);
+ if (nativeHandle == nullptr) {
+ close(dupfd);
+ // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here?
+ return NN_ERROR() << "Failed to create native_handle";
+ }
+
+ const auto [lowOffsetBits, highOffsetBits] = getIntsFromOffset(offset);
+ nativeHandle->data[0] = dupfd;
+ nativeHandle->data[1] = prot;
+ nativeHandle->data[2] = lowOffsetBits;
+ nativeHandle->data[3] = highOffsetBits;
+
+ // Create a NativeHandle which owns the native handle and fd so that we don't have to manually
+ // clean either the native handle or the fd.
+ auto ownedHandle = ::android::NativeHandle::create(nativeHandle, /*ownsHandle=*/true);
+
+ return Memory{.handle = std::move(ownedHandle), .size = size, .name = "mmap_fd"};
+}
+
+Result<Memory> createSharedMemoryFromHidlMemory(const hardware::hidl_memory& /*memory*/) {
+ return NN_ERROR() << "hidl_memory not supported on host";
+}
+
+Result<Memory> createSharedMemoryFromAHWB(const AHardwareBuffer& /*ahwb*/) {
+ return NN_ERROR() << "AHardwareBuffer memory not supported on host";
+}
+
+Result<Mapping> map(const Memory& memory) {
+ if (memory.name == "ashmem") {
+ return mapAshmem(memory);
+ }
+ if (memory.name == "mmap_fd") {
+ return mapMemFd(memory);
+ }
+ return NN_ERROR() << "Cannot map unknown memory " << memory.name;
+}
+
+bool flush(const Mapping& mapping) {
+ if (const auto* mmapFdMapping = std::any_cast<MmapFdMappingContext>(&mapping.context)) {
+ if (!std::holds_alternative<void*>(mapping.pointer)) {
+ return true;
+ }
+ void* data = std::get<void*>(mapping.pointer);
+ const int prot = mmapFdMapping->prot;
+ if (prot & PROT_WRITE) {
+ const size_t size = mapping.size;
+ return msync(data, size, MS_SYNC) == 0;
+ }
+ }
+ // No-op for other types of memory.
+ return true;
+}
+
+} // namespace android::nn
diff --git a/nn/common/TypeUtils.cpp b/nn/common/TypeUtils.cpp
new file mode 100644
index 000000000..d997f65a2
--- /dev/null
+++ b/nn/common/TypeUtils.cpp
@@ -0,0 +1,849 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TypeUtils.h"
+
+#include <android-base/logging.h>
+
+#include <chrono>
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "OperandTypes.h"
+#include "OperationTypes.h"
+#include "Result.h"
+#include "Types.h"
+
+namespace android::nn {
+namespace {
+
+template <typename Type>
+constexpr std::underlying_type_t<Type> underlyingType(Type object) {
+ return static_cast<std::underlying_type_t<Type>>(object);
+}
+
+uint16_t getExtensionPrefix(uint32_t type) {
+ return static_cast<uint16_t>(type >> kExtensionTypeBits);
+}
+
+template <typename Type>
+std::ostream& operator<<(std::ostream& os, const std::vector<Type>& vec) {
+ constexpr size_t kMaxVectorPrint = 20;
+ os << "[";
+ size_t count = 0;
+ for (const auto& element : vec) {
+ if (count > 0) {
+ os << ", ";
+ }
+ os << element;
+ count++;
+ if (count >= kMaxVectorPrint) {
+ return os << "...]";
+ }
+ }
+ return os << "]";
+}
+
+} // namespace
+
+bool isExtension(OperandType type) {
+ return getExtensionPrefix(underlyingType(type)) != 0;
+}
+
+bool isExtension(OperationType type) {
+ return getExtensionPrefix(underlyingType(type)) != 0;
+}
+
+bool isNonExtensionScalar(OperandType operandType) {
+ CHECK(!isExtension(operandType));
+ switch (operandType) {
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::BOOL:
+ case OperandType::FLOAT16:
+ case OperandType::SUBGRAPH:
+ case OperandType::OEM:
+ return true;
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ case OperandType::TENSOR_OEM_BYTE:
+ return false;
+ }
+ return false;
+}
+
+size_t getNonExtensionSize(OperandType operandType) {
+ CHECK(!isExtension(operandType));
+ switch (operandType) {
+ case OperandType::SUBGRAPH:
+ case OperandType::OEM:
+ case OperandType::TENSOR_OEM_BYTE:
+ return 0;
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::BOOL:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ return 1;
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::FLOAT16:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ return 2;
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ return 4;
+ }
+ return 0;
+}
+
+std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions) {
+ CHECK(!isExtension(operandType)) << "Size of extension operand data is unknown";
+ size_t size = getNonExtensionSize(operandType);
+ if (isNonExtensionScalar(operandType)) {
+ return size;
+ } else if (dimensions.empty()) {
+ return 0;
+ }
+ for (Dimension dimension : dimensions) {
+ if (dimension != 0 && size > std::numeric_limits<size_t>::max() / dimension) {
+ return std::nullopt;
+ }
+ size *= dimension;
+ }
+ return size;
+}
+
+std::optional<size_t> getNonExtensionSize(const Operand& operand) {
+ return getNonExtensionSize(operand.type, operand.dimensions);
+}
+
+size_t getOffsetFromInts(int lower, int higher) {
+ const int32_t lowBits = static_cast<int32_t>(lower);
+ const int32_t highBits = static_cast<int32_t>(higher);
+ const uint32_t lowOffsetBits = *reinterpret_cast<const uint32_t*>(&lowBits);
+ const uint32_t highOffsetBits = *reinterpret_cast<const uint32_t*>(&highBits);
+ const uint64_t offset = lowOffsetBits | (static_cast<uint64_t>(highOffsetBits) << 32);
+ return offset;
+}
+
+std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset) {
+ const uint64_t bits = static_cast<uint64_t>(offset);
+ const uint32_t lowBits = static_cast<uint32_t>(bits & 0xffffffff);
+ const uint32_t highBits = static_cast<uint32_t>(bits >> 32);
+ const int32_t lowOffsetBits = *reinterpret_cast<const int32_t*>(&lowBits);
+ const int32_t highOffsetBits = *reinterpret_cast<const int32_t*>(&highBits);
+ return std::make_pair(lowOffsetBits, highOffsetBits);
+}
+
+std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands,
+ const std::vector<nn::Operation>& operations) {
+ std::vector<uint32_t> numberOfConsumers(numberOfOperands, 0);
+ auto eachOperandIndex = [&numberOfConsumers](uint32_t operandIndex) {
+ numberOfConsumers.at(operandIndex)++;
+ };
+ auto eachOperation = [&eachOperandIndex](const nn::Operation& operation) {
+ std::for_each(operation.inputs.begin(), operation.inputs.end(), eachOperandIndex);
+ };
+ std::for_each(operations.begin(), operations.end(), eachOperation);
+ return numberOfConsumers;
+}
+
+Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs) {
+ if (rhs.empty()) return lhs;
+ if (lhs.empty()) return rhs;
+ if (lhs.size() != rhs.size()) {
+ std::ostringstream os;
+ os << "Incompatible ranks: " << lhs << " and " << rhs;
+ return NN_ERROR() << os.str();
+ }
+ Dimensions combined = lhs;
+ for (size_t i = 0; i < lhs.size(); i++) {
+ if (lhs[i] == 0) {
+ combined[i] = rhs[i];
+ } else if (rhs[i] != 0 && lhs[i] != rhs[i]) {
+ std::ostringstream os;
+ os << "Incompatible dimensions: " << lhs << " and " << rhs;
+ return NN_ERROR() << os.str();
+ }
+ }
+ return combined;
+}
+
+std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus) {
+ switch (deviceStatus) {
+ case DeviceStatus::AVAILABLE:
+ return os << "AVAILABLE";
+ case DeviceStatus::BUSY:
+ return os << "BUSY";
+ case DeviceStatus::OFFLINE:
+ return os << "OFFLINE";
+ case DeviceStatus::UNKNOWN:
+ return os << "UNKNOWN";
+ }
+ return os << "DeviceStatus{" << underlyingType(deviceStatus) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference) {
+ switch (executionPreference) {
+ case ExecutionPreference::LOW_POWER:
+ return os << "LOW_POWER";
+ case ExecutionPreference::FAST_SINGLE_ANSWER:
+ return os << "FAST_SINGLE_ANSWER";
+ case ExecutionPreference::SUSTAINED_SPEED:
+ return os << "SUSTAINED_SPEED";
+ }
+ return os << "ExecutionPreference{" << underlyingType(executionPreference) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType) {
+ switch (deviceType) {
+ case DeviceType::UNKNOWN:
+ return os << "UNKNOWN";
+ case DeviceType::OTHER:
+ return os << "OTHER";
+ case DeviceType::CPU:
+ return os << "CPU";
+ case DeviceType::GPU:
+ return os << "GPU";
+ case DeviceType::ACCELERATOR:
+ return os << "ACCELERATOR";
+ }
+ return os << "DeviceType{" << underlyingType(deviceType) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming) {
+ switch (measureTiming) {
+ case MeasureTiming::NO:
+ return os << "NO";
+ case MeasureTiming::YES:
+ return os << "YES";
+ }
+ return os << "MeasureTiming{" << underlyingType(measureTiming) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const OperandType& operandType) {
+ switch (operandType) {
+ case OperandType::FLOAT32:
+ return os << "FLOAT32";
+ case OperandType::INT32:
+ return os << "INT32";
+ case OperandType::UINT32:
+ return os << "UINT32";
+ case OperandType::TENSOR_FLOAT32:
+ return os << "TENSOR_FLOAT32";
+ case OperandType::TENSOR_INT32:
+ return os << "TENSOR_INT32";
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ return os << "TENSOR_QUANT8_ASYMM";
+ case OperandType::BOOL:
+ return os << "BOOL";
+ case OperandType::TENSOR_QUANT16_SYMM:
+ return os << "TENSOR_QUANT16_SYMM";
+ case OperandType::TENSOR_FLOAT16:
+ return os << "TENSOR_FLOAT16";
+ case OperandType::TENSOR_BOOL8:
+ return os << "TENSOR_BOOL8";
+ case OperandType::FLOAT16:
+ return os << "FLOAT16";
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ return os << "TENSOR_QUANT8_SYMM_PER_CHANNEL";
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ return os << "TENSOR_QUANT16_ASYMM";
+ case OperandType::TENSOR_QUANT8_SYMM:
+ return os << "TENSOR_QUANT8_SYMM";
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ return os << "TENSOR_QUANT8_ASYMM_SIGNED";
+ case OperandType::SUBGRAPH:
+ return os << "SUBGRAPH";
+ case OperandType::OEM:
+ return os << "OEM";
+ case OperandType::TENSOR_OEM_BYTE:
+ return os << "TENSOR_OEM_BYTE";
+ }
+ if (isExtension(operandType)) {
+ return os << "Extension OperandType " << underlyingType(operandType);
+ }
+ return os << "OperandType{" << underlyingType(operandType) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime) {
+ switch (lifetime) {
+ case Operand::LifeTime::TEMPORARY_VARIABLE:
+ return os << "TEMPORARY_VARIABLE";
+ case Operand::LifeTime::SUBGRAPH_INPUT:
+ return os << "SUBGRAPH_INPUT";
+ case Operand::LifeTime::SUBGRAPH_OUTPUT:
+ return os << "SUBGRAPH_OUTPUT";
+ case Operand::LifeTime::CONSTANT_COPY:
+ return os << "CONSTANT_COPY";
+ case Operand::LifeTime::CONSTANT_REFERENCE:
+ return os << "CONSTANT_REFERENCE";
+ case Operand::LifeTime::NO_VALUE:
+ return os << "NO_VALUE";
+ case Operand::LifeTime::SUBGRAPH:
+ return os << "SUBGRAPH";
+ case Operand::LifeTime::POINTER:
+ return os << "POINTER";
+ }
+ return os << "Operand::LifeTime{" << underlyingType(lifetime) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const OperationType& operationType) {
+ switch (operationType) {
+ case OperationType::ADD:
+ return os << "ADD";
+ case OperationType::AVERAGE_POOL_2D:
+ return os << "AVERAGE_POOL_2D";
+ case OperationType::CONCATENATION:
+ return os << "CONCATENATION";
+ case OperationType::CONV_2D:
+ return os << "CONV_2D";
+ case OperationType::DEPTHWISE_CONV_2D:
+ return os << "DEPTHWISE_CONV_2D";
+ case OperationType::DEPTH_TO_SPACE:
+ return os << "DEPTH_TO_SPACE";
+ case OperationType::DEQUANTIZE:
+ return os << "DEQUANTIZE";
+ case OperationType::EMBEDDING_LOOKUP:
+ return os << "EMBEDDING_LOOKUP";
+ case OperationType::FLOOR:
+ return os << "FLOOR";
+ case OperationType::FULLY_CONNECTED:
+ return os << "FULLY_CONNECTED";
+ case OperationType::HASHTABLE_LOOKUP:
+ return os << "HASHTABLE_LOOKUP";
+ case OperationType::L2_NORMALIZATION:
+ return os << "L2_NORMALIZATION";
+ case OperationType::L2_POOL_2D:
+ return os << "L2_POOL_2D";
+ case OperationType::LOCAL_RESPONSE_NORMALIZATION:
+ return os << "LOCAL_RESPONSE_NORMALIZATION";
+ case OperationType::LOGISTIC:
+ return os << "LOGISTIC";
+ case OperationType::LSH_PROJECTION:
+ return os << "LSH_PROJECTION";
+ case OperationType::LSTM:
+ return os << "LSTM";
+ case OperationType::MAX_POOL_2D:
+ return os << "MAX_POOL_2D";
+ case OperationType::MUL:
+ return os << "MUL";
+ case OperationType::RELU:
+ return os << "RELU";
+ case OperationType::RELU1:
+ return os << "RELU1";
+ case OperationType::RELU6:
+ return os << "RELU6";
+ case OperationType::RESHAPE:
+ return os << "RESHAPE";
+ case OperationType::RESIZE_BILINEAR:
+ return os << "RESIZE_BILINEAR";
+ case OperationType::RNN:
+ return os << "RNN";
+ case OperationType::SOFTMAX:
+ return os << "SOFTMAX";
+ case OperationType::SPACE_TO_DEPTH:
+ return os << "SPACE_TO_DEPTH";
+ case OperationType::SVDF:
+ return os << "SVDF";
+ case OperationType::TANH:
+ return os << "TANH";
+ case OperationType::BATCH_TO_SPACE_ND:
+ return os << "BATCH_TO_SPACE_ND";
+ case OperationType::DIV:
+ return os << "DIV";
+ case OperationType::MEAN:
+ return os << "MEAN";
+ case OperationType::PAD:
+ return os << "PAD";
+ case OperationType::SPACE_TO_BATCH_ND:
+ return os << "SPACE_TO_BATCH_ND";
+ case OperationType::SQUEEZE:
+ return os << "SQUEEZE";
+ case OperationType::STRIDED_SLICE:
+ return os << "STRIDED_SLICE";
+ case OperationType::SUB:
+ return os << "SUB";
+ case OperationType::TRANSPOSE:
+ return os << "TRANSPOSE";
+ case OperationType::ABS:
+ return os << "ABS";
+ case OperationType::ARGMAX:
+ return os << "ARGMAX";
+ case OperationType::ARGMIN:
+ return os << "ARGMIN";
+ case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM:
+ return os << "AXIS_ALIGNED_BBOX_TRANSFORM";
+ case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM:
+ return os << "BIDIRECTIONAL_SEQUENCE_LSTM";
+ case OperationType::BIDIRECTIONAL_SEQUENCE_RNN:
+ return os << "BIDIRECTIONAL_SEQUENCE_RNN";
+ case OperationType::BOX_WITH_NMS_LIMIT:
+ return os << "BOX_WITH_NMS_LIMIT";
+ case OperationType::CAST:
+ return os << "CAST";
+ case OperationType::CHANNEL_SHUFFLE:
+ return os << "CHANNEL_SHUFFLE";
+ case OperationType::DETECTION_POSTPROCESSING:
+ return os << "DETECTION_POSTPROCESSING";
+ case OperationType::EQUAL:
+ return os << "EQUAL";
+ case OperationType::EXP:
+ return os << "EXP";
+ case OperationType::EXPAND_DIMS:
+ return os << "EXPAND_DIMS";
+ case OperationType::GATHER:
+ return os << "GATHER";
+ case OperationType::GENERATE_PROPOSALS:
+ return os << "GENERATE_PROPOSALS";
+ case OperationType::GREATER:
+ return os << "GREATER";
+ case OperationType::GREATER_EQUAL:
+ return os << "GREATER_EQUAL";
+ case OperationType::GROUPED_CONV_2D:
+ return os << "GROUPED_CONV_2D";
+ case OperationType::HEATMAP_MAX_KEYPOINT:
+ return os << "HEATMAP_MAX_KEYPOINT";
+ case OperationType::INSTANCE_NORMALIZATION:
+ return os << "INSTANCE_NORMALIZATION";
+ case OperationType::LESS:
+ return os << "LESS";
+ case OperationType::LESS_EQUAL:
+ return os << "LESS_EQUAL";
+ case OperationType::LOG:
+ return os << "LOG";
+ case OperationType::LOGICAL_AND:
+ return os << "LOGICAL_AND";
+ case OperationType::LOGICAL_NOT:
+ return os << "LOGICAL_NOT";
+ case OperationType::LOGICAL_OR:
+ return os << "LOGICAL_OR";
+ case OperationType::LOG_SOFTMAX:
+ return os << "LOG_SOFTMAX";
+ case OperationType::MAXIMUM:
+ return os << "MAXIMUM";
+ case OperationType::MINIMUM:
+ return os << "MINIMUM";
+ case OperationType::NEG:
+ return os << "NEG";
+ case OperationType::NOT_EQUAL:
+ return os << "NOT_EQUAL";
+ case OperationType::PAD_V2:
+ return os << "PAD_V2";
+ case OperationType::POW:
+ return os << "POW";
+ case OperationType::PRELU:
+ return os << "PRELU";
+ case OperationType::QUANTIZE:
+ return os << "QUANTIZE";
+ case OperationType::QUANTIZED_16BIT_LSTM:
+ return os << "QUANTIZED_16BIT_LSTM";
+ case OperationType::RANDOM_MULTINOMIAL:
+ return os << "RANDOM_MULTINOMIAL";
+ case OperationType::REDUCE_ALL:
+ return os << "REDUCE_ALL";
+ case OperationType::REDUCE_ANY:
+ return os << "REDUCE_ANY";
+ case OperationType::REDUCE_MAX:
+ return os << "REDUCE_MAX";
+ case OperationType::REDUCE_MIN:
+ return os << "REDUCE_MIN";
+ case OperationType::REDUCE_PROD:
+ return os << "REDUCE_PROD";
+ case OperationType::REDUCE_SUM:
+ return os << "REDUCE_SUM";
+ case OperationType::ROI_ALIGN:
+ return os << "ROI_ALIGN";
+ case OperationType::ROI_POOLING:
+ return os << "ROI_POOLING";
+ case OperationType::RSQRT:
+ return os << "RSQRT";
+ case OperationType::SELECT:
+ return os << "SELECT";
+ case OperationType::SIN:
+ return os << "SIN";
+ case OperationType::SLICE:
+ return os << "SLICE";
+ case OperationType::SPLIT:
+ return os << "SPLIT";
+ case OperationType::SQRT:
+ return os << "SQRT";
+ case OperationType::TILE:
+ return os << "TILE";
+ case OperationType::TOPK_V2:
+ return os << "TOPK_V2";
+ case OperationType::TRANSPOSE_CONV_2D:
+ return os << "TRANSPOSE_CONV_2D";
+ case OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM:
+ return os << "UNIDIRECTIONAL_SEQUENCE_LSTM";
+ case OperationType::UNIDIRECTIONAL_SEQUENCE_RNN:
+ return os << "UNIDIRECTIONAL_SEQUENCE_RNN";
+ case OperationType::RESIZE_NEAREST_NEIGHBOR:
+ return os << "RESIZE_NEAREST_NEIGHBOR";
+ case OperationType::QUANTIZED_LSTM:
+ return os << "QUANTIZED_LSTM";
+ case OperationType::IF:
+ return os << "IF";
+ case OperationType::WHILE:
+ return os << "WHILE";
+ case OperationType::ELU:
+ return os << "ELU";
+ case OperationType::HARD_SWISH:
+ return os << "HARD_SWISH";
+ case OperationType::FILL:
+ return os << "FILL";
+ case OperationType::RANK:
+ return os << "RANK";
+ case OperationType::OEM_OPERATION:
+ return os << "OEM_OPERATION";
+ }
+ if (isExtension(operationType)) {
+ return os << "Extension OperationType " << underlyingType(operationType);
+ }
+ return os << "OperationType{" << underlyingType(operationType) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime) {
+ switch (lifetime) {
+ case Request::Argument::LifeTime::POOL:
+ return os << "POOL";
+ case Request::Argument::LifeTime::NO_VALUE:
+ return os << "NO_VALUE";
+ case Request::Argument::LifeTime::POINTER:
+ return os << "POINTER";
+ }
+ return os << "Request::Argument::LifeTime{" << underlyingType(lifetime) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Priority& priority) {
+ switch (priority) {
+ case Priority::LOW:
+ return os << "LOW";
+ case Priority::MEDIUM:
+ return os << "MEDIUM";
+ case Priority::HIGH:
+ return os << "HIGH";
+ }
+ return os << "Priority{" << underlyingType(priority) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) {
+ switch (errorStatus) {
+ case ErrorStatus::NONE:
+ return os << "NONE";
+ case ErrorStatus::DEVICE_UNAVAILABLE:
+ return os << "DEVICE_UNAVAILABLE";
+ case ErrorStatus::GENERAL_FAILURE:
+ return os << "GENERAL_FAILURE";
+ case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
+ return os << "OUTPUT_INSUFFICIENT_SIZE";
+ case ErrorStatus::INVALID_ARGUMENT:
+ return os << "INVALID_ARGUMENT";
+ case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
+ return os << "MISSED_DEADLINE_TRANSIENT";
+ case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
+ return os << "MISSED_DEADLINE_PERSISTENT";
+ case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
+ return os << "RESOURCE_EXHAUSTED_TRANSIENT";
+ case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
+ return os << "RESOURCE_EXHAUSTED_PERSISTENT";
+ case ErrorStatus::DEAD_OBJECT:
+ return os << "DEAD_OBJECT";
+ }
+ return os << "ErrorStatus{" << underlyingType(errorStatus) << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) {
+ return os << "OutputShape{.dimensions=" << outputShape.dimensions
+ << ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Timing& timing) {
+ constexpr auto printTime = [](std::ostream& os, uint64_t nanoseconds) {
+ if (nanoseconds == std::numeric_limits<uint64_t>::max()) {
+ os << "<no time information provided>";
+ } else {
+ os << nanoseconds << "ns";
+ }
+ };
+ os << "Timing{.timeOnDevice=";
+ printTime(os, timing.timeOnDevice);
+ os << ", .timeInDriver=";
+ printTime(os, timing.timeInDriver);
+ return os << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo) {
+ return os << "Capabilities::PerformanceInfo{.execTime=" << performanceInfo.execTime
+ << ", .powerUsage=" << performanceInfo.powerUsage << "}";
+}
+
+std::ostream& operator<<(std::ostream& os,
+ const Capabilities::OperandPerformance& operandPerformance) {
+ return os << "Capabilities::OperandPerformance{.type=" << operandPerformance.type
+ << ", .info=" << operandPerformance.info << "}";
+}
+
+std::ostream& operator<<(std::ostream& os,
+ const Capabilities::OperandPerformanceTable& operandPerformances) {
+ return os << operandPerformances.asVector();
+}
+
+std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities) {
+ return os << "Capabilities{.relaxedFloat32toFloat16PerformanceScalar="
+ << capabilities.relaxedFloat32toFloat16PerformanceScalar
+ << ", .relaxedFloat32toFloat16PerformanceTensor="
+ << capabilities.relaxedFloat32toFloat16PerformanceTensor
+ << ", .operandPerformance=" << capabilities.operandPerformance
+ << ", .ifPerformance=" << capabilities.ifPerformance
+ << ", .whilePerformance=" << capabilities.whilePerformance << "}";
+}
+
+std::ostream& operator<<(std::ostream& os,
+ const Extension::OperandTypeInformation& operandTypeInformation) {
+ return os << "Extension::OperandTypeInformation{.type=" << operandTypeInformation.type
+ << ", .isTensor=" << (operandTypeInformation.isTensor ? "true" : "false")
+ << ", .byteSize=" << operandTypeInformation.byteSize << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Extension& extension) {
+ return os << "Extension{.name=" << extension.name
+ << ", .operandTypes=" << extension.operandTypes << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const DataLocation& location) {
+ const auto printPointer = [&os](const std::variant<const void*, void*>& pointer) {
+ os << (std::holds_alternative<const void*>(pointer) ? "<constant " : "<mutable ");
+ os << std::visit(
+ [](const auto* ptr) {
+ return ptr == nullptr ? "null pointer>" : "non-null pointer>";
+ },
+ pointer);
+ };
+ os << "DataLocation{.pointer=";
+ printPointer(location.pointer);
+ return os << ", .poolIndex=" << location.poolIndex << ", .offset=" << location.offset
+ << ", .length=" << location.length << "}";
+}
+
+std::ostream& operator<<(std::ostream& os,
+ const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
+ return os << "Operand::SymmPerChannelQuantParams{.scales=" << symmPerChannelQuantParams.scales
+ << ", .channelDim=" << symmPerChannelQuantParams.channelDim << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams) {
+ os << "Operand::ExtraParams{";
+ if (std::holds_alternative<Operand::NoParams>(extraParams)) {
+ os << "<no params>";
+ } else if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(extraParams)) {
+ os << std::get<Operand::SymmPerChannelQuantParams>(extraParams);
+ } else if (std::holds_alternative<Operand::ExtensionParams>(extraParams)) {
+ os << std::get<Operand::ExtensionParams>(extraParams);
+ }
+ return os << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Operand& operand) {
+ return os << "Operand{.type=" << operand.type << ", .dimensions=" << operand.dimensions
+ << ", .scale=" << operand.scale << ", .zeroPoint=" << operand.zeroPoint
+ << ", lifetime=" << operand.lifetime << ", .location=" << operand.location
+ << ", .extraParams=" << operand.extraParams << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Operation& operation) {
+ return os << "Operation{.type=" << operation.type << ", .inputs=" << operation.inputs
+ << ", .outputs=" << operation.outputs << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const NativeHandle& handle) {
+ return os << (handle != nullptr ? "<non-empty handle>" : "<empty handle>");
+}
+
+std::ostream& operator<<(std::ostream& os, const Memory& memory) {
+ return os << "Memory{.handle=" << memory.handle << ", .size=" << memory.size
+ << ", .name=" << memory.name << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph) {
+ std::vector<Operand> operands;
+ std::vector<Operation> operations;
+ std::vector<uint32_t> inputIndexes;
+ std::vector<uint32_t> outputIndexes;
+ return os << "Model::Subgraph{.operands=" << subgraph.operands
+ << ", .operations=" << subgraph.operations
+ << ", .inputIndexes=" << subgraph.inputIndexes
+ << ", .outputIndexes=" << subgraph.outputIndexes << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues) {
+ return os << "Model::OperandValues{<" << operandValues.size() << "bytes>}";
+}
+
+std::ostream& operator<<(std::ostream& os,
+ const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
+ return os << "Model::ExtensionNameAndPrefix{.name=" << extensionNameAndPrefix.name
+ << ", .prefix=" << extensionNameAndPrefix.prefix << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Model& model) {
+ return os << "Model{.main=" << model.main << ", .referenced=" << model.referenced
+ << ", .operandValues=" << model.operandValues << ", .pools=" << model.pools
+ << ", .relaxComputationFloat32toFloat16="
+ << (model.relaxComputationFloat32toFloat16 ? "true" : "false")
+ << ", extensionNameToPrefix=" << model.extensionNameToPrefix << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc) {
+ return os << "BufferDesc{.dimensions=" << bufferDesc.dimensions << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole) {
+ return os << "BufferRole{.modelIndex=" << bufferRole.modelIndex
+ << ", .ioIndex=" << bufferRole.ioIndex << ", .frequency=" << bufferRole.frequency
+ << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument) {
+ return os << "Request::Argument{.lifetime=" << requestArgument.lifetime
+ << ", .location=" << requestArgument.location
+ << ", .dimensions=" << requestArgument.dimensions << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool) {
+ os << "Request::MemoryPool{";
+ if (std::holds_alternative<Memory>(memoryPool)) {
+ os << std::get<Memory>(memoryPool);
+ } else if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
+ const auto& token = std::get<Request::MemoryDomainToken>(memoryPool);
+ if (token == Request::MemoryDomainToken{}) {
+ os << "<invalid MemoryDomainToken>";
+ } else {
+ os << "MemoryDomainToken=" << underlyingType(token);
+ }
+ } else if (std::holds_alternative<std::shared_ptr<const IBuffer>>(memoryPool)) {
+ const auto& buffer = std::get<std::shared_ptr<const IBuffer>>(memoryPool);
+ os << (buffer != nullptr ? "<non-null IBuffer>" : "<null IBuffer>");
+ }
+ return os << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const Request& request) {
+ return os << "Request{.inputs=" << request.inputs << ", .outputs=" << request.outputs
+ << ", .pools=" << request.pools << "}";
+}
+
+std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint) {
+ return os << timePoint.time_since_epoch().count() << "ns since epoch";
+}
+
+std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint) {
+ if (!optionalTimePoint.has_value()) {
+ return os << "<no time point>";
+ }
+ return os << optionalTimePoint.value();
+}
+
+std::ostream& operator<<(std::ostream& os, const TimeoutDuration& timeoutDuration) {
+ return os << timeoutDuration.count() << "ns";
+}
+
+std::ostream& operator<<(std::ostream& os, const OptionalTimeoutDuration& optionalTimeoutDuration) {
+ if (!optionalTimeoutDuration.has_value()) {
+ return os << "<no timeout duration>";
+ }
+ return os << optionalTimeoutDuration.value();
+}
+
+std::ostream& operator<<(std::ostream& os, const Version& version) {
+ switch (version) {
+ case Version::ANDROID_OC_MR1:
+ return os << "ANDROID_OC_MR1";
+ case Version::ANDROID_P:
+ return os << "ANDROID_P";
+ case Version::ANDROID_Q:
+ return os << "ANDROID_Q";
+ case Version::ANDROID_R:
+ return os << "ANDROID_R";
+ case Version::CURRENT_RUNTIME:
+ return os << "CURRENT_RUNTIME";
+ }
+ return os << "Version{" << underlyingType(version) << "}";
+}
+
+bool operator==(const Timing& a, const Timing& b) {
+ return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver;
+}
+
+bool operator!=(const Timing& a, const Timing& b) {
+ return !(a == b);
+}
+
+bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
+ return a.execTime == b.execTime && a.powerUsage == b.powerUsage;
+}
+
+bool operator==(const Capabilities::OperandPerformance& a,
+ const Capabilities::OperandPerformance& b) {
+ return a.type == b.type && a.info == b.info;
+}
+
+bool operator==(const Capabilities& a, const Capabilities& b) {
+ return a.relaxedFloat32toFloat16PerformanceScalar ==
+ b.relaxedFloat32toFloat16PerformanceScalar &&
+ a.relaxedFloat32toFloat16PerformanceTensor ==
+ b.relaxedFloat32toFloat16PerformanceTensor &&
+ a.operandPerformance.asVector() == b.operandPerformance.asVector() &&
+ a.ifPerformance == b.ifPerformance && a.whilePerformance == b.whilePerformance;
+}
+
+bool operator==(const Extension::OperandTypeInformation& a,
+ const Extension::OperandTypeInformation& b) {
+ return a.type == b.type && a.isTensor == b.isTensor && a.byteSize == b.byteSize;
+}
+
+bool operator==(const Extension& a, const Extension& b) {
+ return a.name == b.name && a.operandTypes == b.operandTypes;
+}
+
+bool operator==(const Operand::SymmPerChannelQuantParams& a,
+ const Operand::SymmPerChannelQuantParams& b) {
+ return a.scales == b.scales && a.channelDim == b.channelDim;
+}
+bool operator!=(const Operand::SymmPerChannelQuantParams& a,
+ const Operand::SymmPerChannelQuantParams& b) {
+ return !(a == b);
+}
+
+} // namespace android::nn
diff --git a/nn/common/Types.cpp b/nn/common/Types.cpp
index a49a8dc35..761ffa72c 100644
--- a/nn/common/Types.cpp
+++ b/nn/common/Types.cpp
@@ -28,6 +28,7 @@
#include "OperandTypes.h"
#include "OperationTypes.h"
+#include "Result.h"
namespace android::nn {
namespace {
@@ -87,15 +88,14 @@ Capabilities::OperandPerformanceTable::OperandPerformanceTable(
std::vector<OperandPerformance> operandPerformances)
: mSorted(std::move(operandPerformances)) {}
-std::optional<Capabilities::OperandPerformanceTable> Capabilities::OperandPerformanceTable::create(
+Result<Capabilities::OperandPerformanceTable> Capabilities::OperandPerformanceTable::create(
std::vector<OperandPerformance> operandPerformances) {
const auto notUnique = [](const auto& lhs, const auto& rhs) { return !(lhs.type < rhs.type); };
const bool isUnique = std::adjacent_find(operandPerformances.begin(), operandPerformances.end(),
notUnique) == operandPerformances.end();
if (!isUnique) {
- LOG(ERROR) << "Failed to create OperandPerformanceTable: Input must be sorted by key (in "
- "ascending order), and there must be no duplicate keys";
- return std::nullopt;
+ return NN_ERROR() << "Failed to create OperandPerformanceTable: Input must be sorted by "
+ "key (in ascending order), and there must be no duplicate keys";
}
return Capabilities::OperandPerformanceTable(std::move(operandPerformances));
diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp
new file mode 100644
index 000000000..d18ba3443
--- /dev/null
+++ b/nn/common/Validation.cpp
@@ -0,0 +1,2664 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Validation.h"
+
+#include <android-base/logging.h>
+
+#include <algorithm>
+#include <cctype>
+#include <functional>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <set>
+#include <sstream>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "ControlFlow.h"
+#include "OperandTypes.h"
+#include "OperationTypes.h"
+#include "Result.h"
+#include "TypeUtils.h"
+#include "Types.h"
+
+// The NN_VALIDATE family of macros defined below is similar to the CHECK family defined in
+// system/core/base/include/android-base/logging.h
+//
+// The difference is that NN_VALIDATE macros use LOG(ERROR) instead of LOG(FATAL)
+// and return false instead of aborting.
+
+// Logs an error and returns false or INVALID. Append context using << after. For example:
+//
+// NN_VALIDATE_FAIL() << "Something went wrong";
+//
+// The containing function must return a bool or Version.
+#define NN_VALIDATE_FAIL() \
+ return NN_ERROR() << "NN_VALIDATE failed (" << __FILE__ << ":" << __LINE__ << "): "
+
+// Logs an error and returns false or Version::INVALID if condition is false. Extra logging can be
+// appended using << after. For example:
+//
+// NN_VALIDATE(false) << "Something went wrong";
+//
+// The containing function must return a bool.
+#define NN_VALIDATE(condition) \
+ while (UNLIKELY(!(condition))) NN_VALIDATE_FAIL() << #condition << " "
+
+// Helper for NN_VALIDATE_xx(x, y) macros.
+#define NN_VALIDATE_OP(LHS, RHS, OP) \
+ for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \
+ UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \
+ /* empty */) \
+ NN_VALIDATE_FAIL() \
+ << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \
+ << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
+ << ", " << #RHS << " = " \
+ << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
+ << ") "
+
+// Logs an error and returns false or Version::INVALID if a condition between x and y does not hold.
+// Extra logging can be appended using << after. For example:
+//
+// NN_VALIDATE_EQ(a, b) << "Something went wrong";
+//
+// The values must implement the appropriate comparison operator as well as
+// `operator<<(std::ostream&, ...)`.
+// The containing function must return a bool or Version.
+#define NN_VALIDATE_EQ(x, y) NN_VALIDATE_OP(x, y, ==)
+#define NN_VALIDATE_NE(x, y) NN_VALIDATE_OP(x, y, !=)
+#define NN_VALIDATE_LE(x, y) NN_VALIDATE_OP(x, y, <=)
+#define NN_VALIDATE_LT(x, y) NN_VALIDATE_OP(x, y, <)
+#define NN_VALIDATE_GE(x, y) NN_VALIDATE_OP(x, y, >=)
+#define NN_VALIDATE_GT(x, y) NN_VALIDATE_OP(x, y, >)
+
+namespace android::nn {
+namespace {
+
+constexpr auto kNullptrVariant = std::variant<const void*, void*>{};
+constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{};
+
+template <typename Type, typename ValidationFunction>
+Result<Version> validateVector(const std::vector<Type>& objects,
+ const ValidationFunction& validationFunction) {
+ auto version = Version::ANDROID_OC_MR1;
+ for (const auto& object : objects) {
+ version = combineVersions(version, NN_TRY(validationFunction(object)));
+ }
+ return version;
+}
+
+bool isValidExtensionName(const std::string& name) {
+ constexpr auto validSymbol = [](char symbol) {
+ return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_';
+ };
+ const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol);
+ const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end();
+ return hasOnlyValidSymbols && hasAtLeastOnePeriod;
+}
+
+Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) {
+ switch (deviceStatus) {
+ case DeviceStatus::AVAILABLE:
+ case DeviceStatus::BUSY:
+ case DeviceStatus::OFFLINE:
+ case DeviceStatus::UNKNOWN:
+ return Version::ANDROID_OC_MR1;
+ }
+ NN_VALIDATE_FAIL() << "Invalid DeviceStatus " << deviceStatus;
+}
+
+Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) {
+ switch (executionPreference) {
+ case ExecutionPreference::LOW_POWER:
+ case ExecutionPreference::FAST_SINGLE_ANSWER:
+ case ExecutionPreference::SUSTAINED_SPEED:
+ return Version::ANDROID_P;
+ }
+ NN_VALIDATE_FAIL() << "Invalid ExecutionPreference " << executionPreference;
+}
+
+Result<Version> validateDeviceType(const DeviceType& deviceType) {
+ switch (deviceType) {
+ case DeviceType::UNKNOWN:
+ // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when
+ // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a
+ // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a
+ // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for
+ // DeviceType::UNKNOWN.
+ return Version::ANDROID_OC_MR1;
+ case DeviceType::OTHER:
+ case DeviceType::CPU:
+ case DeviceType::GPU:
+ case DeviceType::ACCELERATOR:
+ return Version::ANDROID_Q;
+ }
+ NN_VALIDATE_FAIL() << "Invalid DeviceType " << deviceType;
+}
+
+Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) {
+ switch (measureTiming) {
+ case MeasureTiming::NO:
+ case MeasureTiming::YES:
+ return Version::ANDROID_Q;
+ }
+ NN_VALIDATE_FAIL() << "Invalid MeasureTiming " << measureTiming;
+}
+
+Result<Version> validateOperandType(const OperandType& operandType) {
+ switch (operandType) {
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::OEM:
+ case OperandType::TENSOR_OEM_BYTE:
+ return Version::ANDROID_OC_MR1;
+ case OperandType::BOOL:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::FLOAT16:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ return Version::ANDROID_Q;
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ case OperandType::SUBGRAPH:
+ return Version::ANDROID_R;
+ }
+ if (isExtension(operandType)) {
+ return Version::ANDROID_Q;
+ }
+ NN_VALIDATE_FAIL() << "Invalid OperandType " << operandType;
+}
+
+Result<Version> validateOperandLifeTime(const Operand& operand) {
+ // Make sure SUBGRAPH operand type and lifetime always go together.
+ NN_VALIDATE_EQ((operand.type == OperandType::SUBGRAPH),
+ (operand.lifetime == Operand::LifeTime::SUBGRAPH))
+ << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime;
+
+ switch (operand.lifetime) {
+ case Operand::LifeTime::TEMPORARY_VARIABLE:
+ case Operand::LifeTime::SUBGRAPH_INPUT:
+ case Operand::LifeTime::SUBGRAPH_OUTPUT:
+ case Operand::LifeTime::CONSTANT_COPY:
+ case Operand::LifeTime::CONSTANT_REFERENCE:
+ case Operand::LifeTime::NO_VALUE:
+ case Operand::LifeTime::POINTER:
+ return Version::ANDROID_OC_MR1;
+ case Operand::LifeTime::SUBGRAPH:
+ return Version::ANDROID_R;
+ }
+ NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
+}
+
+Result<Version> validatePriority(const Priority& priority) {
+ switch (priority) {
+ case Priority::LOW:
+ case Priority::MEDIUM:
+ case Priority::HIGH:
+ return Version::ANDROID_R;
+ }
+ NN_VALIDATE_FAIL() << "Invalid Priority " << priority;
+}
+
+Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) {
+ // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih
+ // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE.
+ switch (errorStatus) {
+ case ErrorStatus::NONE:
+ case ErrorStatus::DEVICE_UNAVAILABLE:
+ case ErrorStatus::GENERAL_FAILURE:
+ case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
+ case ErrorStatus::INVALID_ARGUMENT:
+ case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
+ case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
+ case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
+ case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
+ case ErrorStatus::DEAD_OBJECT:
+ return Version::ANDROID_OC_MR1;
+ }
+ NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus;
+}
+
+Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) {
+ return Version::ANDROID_Q;
+}
+
+Result<Version> validateTiming(const Timing& timing) {
+ if (timing.timeInDriver != kNoTiming && timing.timeOnDevice != kNoTiming) {
+ NN_VALIDATE_LE(timing.timeOnDevice, timing.timeInDriver);
+ }
+ return Version::ANDROID_Q;
+}
+
+Result<Version> validateCapabilitiesPerformanceInfo(
+ const Capabilities::PerformanceInfo& performanceInfo) {
+ NN_VALIDATE_GT(performanceInfo.execTime, 0.0f);
+ NN_VALIDATE_GT(performanceInfo.powerUsage, 0.0f);
+ return Version::ANDROID_OC_MR1;
+}
+
+Result<Version> validateCapabilitiesOperandPerformance(
+ const Capabilities::OperandPerformance& operandPerformance) {
+ auto version = NN_TRY(validateOperandType(operandPerformance.type));
+ return combineVersions(version,
+ NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info)));
+}
+
+Result<Version> validateCapabilitiesOperandPerformanceTable(
+ const Capabilities::OperandPerformanceTable& operandPerformances) {
+ // OperandPerformanceTable's order was validated when it was created, and it is castable to any
+ // version. If an OperandType does not exist in the lower version being converted to, that
+ // OperandPerformance will be dropped.
+ NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance));
+ return Version::ANDROID_OC_MR1;
+}
+
+Result<Version> validateCapabilities(const Capabilities& capabilities) {
+ auto version =
+ NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance));
+
+ version = combineVersions(version,
+ NN_TRY(validateCapabilitiesPerformanceInfo(
+ capabilities.relaxedFloat32toFloat16PerformanceScalar)));
+ version = combineVersions(version,
+ NN_TRY(validateCapabilitiesPerformanceInfo(
+ capabilities.relaxedFloat32toFloat16PerformanceTensor)));
+ version = combineVersions(
+ version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance)));
+ version = combineVersions(
+ version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance)));
+
+ return version;
+}
+
+Result<Version> validateExtensionOperandTypeInformation(
+ const Extension::OperandTypeInformation& operandTypeInformation) {
+ NN_VALIDATE_GT(operandTypeInformation.byteSize, 0u);
+ return Version::ANDROID_Q;
+}
+
+Result<Version> validateExtension(const Extension& extension) {
+ NN_VALIDATE(isValidExtensionName(extension.name));
+
+ // Verify all OperandTypeInformations have unique types.
+ std::vector<uint16_t> types;
+ types.reserve(extension.operandTypes.size());
+ std::transform(extension.operandTypes.begin(), extension.operandTypes.end(),
+ std::back_inserter(types),
+ [](const Extension::OperandTypeInformation& operandTypeInformation) {
+ return operandTypeInformation.type;
+ });
+ std::sort(types.begin(), types.end());
+ const auto iter = std::adjacent_find(types.begin(), types.end());
+ NN_VALIDATE(iter == types.end()) << "Extension has duplicate type " << *iter;
+
+ return combineVersions(Version::ANDROID_Q,
+ NN_TRY(validateVector(extension.operandTypes,
+ validateExtensionOperandTypeInformation)));
+}
+
+Result<Version> validateExtensions(const std::vector<Extension>& extensions) {
+ const auto version = NN_TRY(validateVector(extensions, validateExtension));
+
+ // Verify all extensions have unique names.
+ std::vector<std::reference_wrapper<const std::string>> names;
+ names.reserve(extensions.size());
+ std::transform(extensions.begin(), extensions.end(), std::back_inserter(names),
+ [](const Extension& extension) { return std::cref(extension.name); });
+ std::sort(names.begin(), names.end(), std::less<std::string>{});
+ const auto nameIter =
+ std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
+ NN_VALIDATE(nameIter == names.end())
+ << "Two or more extensions have the duplicate name " << nameIter->get();
+
+ return version;
+}
+
+Result<Version> validateOperandDataLocation(const Operand& operand, size_t operandValuesSize,
+ const std::vector<size_t>& poolSizes,
+ size_t subgraphCount) {
+ const DataLocation& location = operand.location;
+ switch (operand.lifetime) {
+ case Operand::LifeTime::CONSTANT_COPY:
+ NN_VALIDATE(location.pointer == kNullptrVariant)
+ << "CONSTANT_COPY with a non-null pointer";
+ NN_VALIDATE_EQ(location.poolIndex, 0u)
+ << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex;
+ // Do the addition using uint64_t to avoid potential wrap-around problems.
+ NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length,
+ operandValuesSize)
+ << "OperandValue location out of range. Starts at " << location.offset
+ << ", length " << location.length << ", max " << operandValuesSize;
+ return Version::ANDROID_OC_MR1;
+ case Operand::LifeTime::CONSTANT_REFERENCE:
+ NN_VALIDATE_LT(location.poolIndex, poolSizes.size());
+ // Do the addition using uint64_t to avoid potential wrap-around problems.
+ NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length,
+ poolSizes[location.poolIndex])
+ << "OperandValue location out of range. Starts at " << location.offset
+ << ", length " << location.length << ", max " << poolSizes[location.poolIndex];
+ return Version::ANDROID_OC_MR1;
+ case Operand::LifeTime::TEMPORARY_VARIABLE:
+ case Operand::LifeTime::SUBGRAPH_INPUT:
+ case Operand::LifeTime::SUBGRAPH_OUTPUT:
+ case Operand::LifeTime::NO_VALUE:
+ NN_VALIDATE(location.pointer == kNullptrVariant)
+ << "Unexpected pointer value for operand of lifetime " << operand.lifetime;
+ NN_VALIDATE_EQ(location.poolIndex, 0u)
+ << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime "
+ << operand.lifetime;
+ NN_VALIDATE_EQ(location.offset, 0u) << "Unexpected offset " << location.offset
+ << " for operand of lifetime " << operand.lifetime;
+ NN_VALIDATE_EQ(location.length, 0u) << "Unexpected length " << location.length
+ << " for operand of lifetime " << operand.lifetime;
+ return Version::ANDROID_OC_MR1;
+ case Operand::LifeTime::SUBGRAPH:
+ NN_VALIDATE(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer";
+ NN_VALIDATE_EQ(location.poolIndex, 0u)
+ << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex;
+ NN_VALIDATE_LT(location.offset, subgraphCount)
+ << "Subgraph index out of range: " << location.offset
+ << " >= " << subgraphCount;
+ NN_VALIDATE_EQ(location.length, 0u)
+ << "SUBGRAPH with a non-zero length " << location.length;
+ return Version::ANDROID_R;
+ case Operand::LifeTime::POINTER: {
+ const bool nonNull =
+ std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer);
+ NN_VALIDATE(nonNull) << "POINTER with a null pointer";
+ NN_VALIDATE_EQ(location.poolIndex, 0u)
+ << "POINTER with a non-zero poolIndex " << location.poolIndex;
+ NN_VALIDATE_EQ(location.offset, 0u)
+ << "POINTER with a non-zero offset " << location.offset;
+ return Version::ANDROID_OC_MR1;
+ }
+ }
+ NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
+}
+
+Result<Version> validateOperandDimensions(const Operand& operand) {
+ switch (operand.type) {
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::BOOL:
+ case OperandType::FLOAT16:
+ case OperandType::SUBGRAPH:
+ case OperandType::OEM:
+ NN_VALIDATE(operand.dimensions.empty())
+ << "Scalar data has dimensions of rank " << operand.dimensions.size();
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ case OperandType::TENSOR_OEM_BYTE: {
+ if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
+ operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
+ operand.lifetime == Operand::LifeTime::POINTER) {
+ NN_VALIDATE(!operand.dimensions.empty())
+ << "Tensor has lifetime of " << operand.lifetime
+ << " but dimensions of rank 0";
+ const auto size = getNonExtensionSize(operand);
+ NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow";
+ NN_VALIDATE_EQ(size.value(), 0u) << "Tensor has at least one unknown dimension";
+ }
+ // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before
+ // Android Q?
+ if (operand.dimensions.empty()) {
+ // Unspecified rank was added in Android Q.
+ return Version::ANDROID_Q;
+ }
+ return Version::ANDROID_OC_MR1;
+ }
+ }
+ if (isExtension(operand.type)) {
+ // Extension types were added in Android Q.
+ return Version::ANDROID_Q;
+ }
+ NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
+}
+
+Result<Version> validateOperandScale(const Operand& operand) {
+ switch (operand.type) {
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::BOOL:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::FLOAT16:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::SUBGRAPH:
+ NN_VALIDATE_EQ(operand.scale, 0.0f)
+ << "Operand of type " << operand.type << " with a non-zero scale ("
+ << operand.scale << ")";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_INT32:
+ // TENSOR_INT32 may be used with or without scale, depending on the operation.
+ // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
+ NN_VALIDATE_GE(operand.scale, 0.0f)
+ << "Operand of type " << operand.type << " with a negative scale";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ NN_VALIDATE_GT(operand.scale, 0.0f)
+ << "Operand of type " << operand.type << " with a non-positive scale";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::OEM:
+ case OperandType::TENSOR_OEM_BYTE:
+ // No validation for OEM types.
+ return Version::ANDROID_OC_MR1;
+ }
+ if (isExtension(operand.type)) {
+ NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type
+ << " with a non-zero scale (" << operand.scale << ")";
+ return Version::ANDROID_Q;
+ }
+ NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
+}
+
+Result<Version> validateOperandZeroPoint(const Operand& operand) {
+ switch (operand.type) {
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::BOOL:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::FLOAT16:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::SUBGRAPH:
+ NN_VALIDATE_EQ(operand.zeroPoint, 0)
+ << "Operand of type " << operand.type << " with a non-zero zeroPoint "
+ << operand.zeroPoint;
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 255)
+ << "Operand of type " << operand.type << " with an invalid zeroPoint "
+ << operand.zeroPoint << ", must be in range [0, 255]";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ NN_VALIDATE(operand.zeroPoint >= -128 && operand.zeroPoint <= 127)
+ << "Operand of type " << operand.type << " with an invalid zeroPoint "
+ << operand.zeroPoint << ", must be in range [-128, 127]";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535)
+ << "Operand of type " << operand.type << " with an invalid zeroPoint "
+ << operand.zeroPoint << ", must be in range [0, 65535]";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_QUANT16_SYMM:
+ NN_VALIDATE_EQ(operand.zeroPoint, 0)
+ << "Operand of type " << operand.type << " with a non-zero zeroPoint "
+ << operand.zeroPoint;
+ return Version::ANDROID_OC_MR1;
+ case OperandType::OEM:
+ case OperandType::TENSOR_OEM_BYTE:
+ // No validation for OEM types.
+ return Version::ANDROID_OC_MR1;
+ }
+ if (isExtension(operand.type)) {
+ NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type
+ << " with a non-zero zeroPoint " << operand.zeroPoint;
+ return Version::ANDROID_Q;
+ }
+ NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
+}
+
+Result<Version> validateOperandExtraParams(const Operand& operand) {
+ switch (operand.type) {
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::BOOL:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::FLOAT16:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ case OperandType::SUBGRAPH:
+ NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams))
+ << "Operand of type " << operand.type
+ << " has extraParams when there must be none";
+ return Version::ANDROID_OC_MR1;
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
+ NN_VALIDATE(
+ std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams))
+ << "Operand of type " << operand.type
+ << " without a Channel Quantization params";
+ const auto& channelQuant =
+ std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams);
+
+ const size_t count = operand.dimensions.size();
+ NN_VALIDATE_LT(channelQuant.channelDim, count)
+ << "Operand of type " << operand.type
+ << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
+ << ", must be valid dimension index in range [0, " << count << ")";
+ const uint32_t expected = operand.dimensions[channelQuant.channelDim];
+ NN_VALIDATE_EQ(channelQuant.scales.size(), expected)
+ << "Operand of type " << operand.type << " with a wrong-sized scales, expected "
+ << expected << " was " << channelQuant.scales.size();
+ NN_VALIDATE_NE(expected, 0u)
+ << "Operand of type " << operand.type << " channel dimension "
+ << channelQuant.channelDim << " is underspecified (can't be 0)";
+ for (uint32_t i = 0; i < expected; ++i) {
+ NN_VALIDATE_GT(channelQuant.scales[i], 0.0f)
+ << "Operand of type " << operand.type
+ << " with a non-positive value in scales[" << i
+ << "]=" << channelQuant.scales[i];
+ }
+ return Version::ANDROID_Q;
+ }
+ case OperandType::OEM:
+ case OperandType::TENSOR_OEM_BYTE:
+ // No validation for OEM types.
+ return Version::ANDROID_OC_MR1;
+ }
+ if (isExtension(operand.type)) {
+ NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams) ||
+ std::holds_alternative<Operand::ExtensionParams>(operand.extraParams))
+ << "Extension operand of type " << operand.type
+ << " must not have SymmPerChannelQuant extraParams";
+ return Version::ANDROID_OC_MR1;
+ }
+ NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
+}
+
+Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize,
+ const std::vector<size_t>& poolSizes, size_t subgraphCount) {
+ auto version = NN_TRY(validateOperandType(operand.type));
+ version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand)));
+ version = combineVersions(version, NN_TRY(validateOperandDimensions(operand)));
+ version = combineVersions(version, NN_TRY(validateOperandScale(operand)));
+ version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand)));
+ version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand)));
+ version =
+ combineVersions(version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize,
+ poolSizes, subgraphCount)));
+
+ // For constants, validate that the length is as expected. The other lifetimes
+ // expect the length to be 0. Don't validate for OEM types.
+ if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
+ operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
+ operand.lifetime == Operand::LifeTime::POINTER) {
+ if (!isExtension(operand.type) && operand.type != OperandType::OEM &&
+ operand.type != OperandType::TENSOR_OEM_BYTE) {
+ const auto expectedLength = getNonExtensionSize(operand).value();
+ NN_VALIDATE_EQ(operand.location.length, expectedLength)
+ << "For operand " << operand.type << " expected a size of " << expectedLength
+ << " but got " << operand.location.length;
+ }
+ }
+
+ return version;
+}
+
+Result<Version> validateOperands(const std::vector<Operand>& operands, size_t operandValuesSize,
+ const std::vector<size_t>& poolSizes, size_t subgraphCount) {
+ auto version = Version::ANDROID_OC_MR1;
+ for (size_t i = 0; i < operands.size(); ++i) {
+ auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphCount);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << " for operand " << i;
+ }
+ version = combineVersions(version, result.value());
+ }
+ return version;
+}
+
+Result<Version> validateOperationImpl(const Operation& operation,
+ const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs);
+
+Result<Version> validateOperations(const std::vector<Operation>& operations,
+ const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs) {
+ auto version = Version::ANDROID_OC_MR1;
+ for (size_t i = 0; i < operations.size(); ++i) {
+ auto result = validateOperationImpl(operations[i], operands, subgraphs);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << " for operation " << i;
+ }
+ version = combineVersions(version, result.value());
+ }
+ return version;
+}
+
+Result<Version> validateNativeHandle(const NativeHandle& handle) {
+ NN_VALIDATE(handle != nullptr);
+ return Version::ANDROID_OC_MR1;
+}
+
+Result<Version> validateMemory(const Memory& memory) {
+ NN_TRY(validateNativeHandle(memory.handle));
+
+ if (memory.name == "ashmem") {
+ NN_VALIDATE_NE(memory.size, 0u);
+ return Version::ANDROID_OC_MR1;
+ }
+ if (memory.name == "mmap_fd") {
+ NN_VALIDATE_NE(memory.size, 0u);
+ return Version::ANDROID_OC_MR1;
+ }
+ if (memory.name == "hardware_buffer_blob") {
+ NN_VALIDATE_NE(memory.size, 0u);
+ return Version::ANDROID_Q;
+ }
+ if (memory.name == "hardware_buffer") {
+ // For hardware_buffer memory, all size information is bound to the AHardwareBuffer, so
+ // memory.size must be 0.
+ NN_VALIDATE_EQ(memory.size, 0u);
+ return Version::ANDROID_Q;
+ }
+
+ NN_VALIDATE_FAIL() << "Unknown memory type " << memory.name;
+}
+
+Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes,
+ const std::vector<Operand>& operands,
+ Operand::LifeTime lifetime) {
+ const size_t operandCount = operands.size();
+ for (uint32_t i : indexes) {
+ NN_VALIDATE_LT(i, operandCount)
+ << "Model " << lifetime << " input or output index out of range: " << i << "/"
+ << operandCount;
+ const Operand& operand = operands[i];
+ NN_VALIDATE_EQ(operand.lifetime, lifetime)
+ << "Model " << lifetime << " operand " << i << " has lifetime of "
+ << operand.lifetime << " instead of the expected " << lifetime;
+ }
+
+ std::vector<uint32_t> sortedIndexes = indexes;
+ std::sort(sortedIndexes.begin(), sortedIndexes.end());
+ const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
+ NN_VALIDATE(iter == sortedIndexes.end())
+ << "Model input or output occurs multiple times: " << *iter;
+
+ for (size_t i = 0; i < operands.size(); ++i) {
+ if (operands[i].lifetime == lifetime) {
+ const auto containsIndex = [&sortedIndexes](size_t index) {
+ return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index);
+ };
+ NN_VALIDATE(containsIndex(i))
+ << "Operand " << i << " marked as " << lifetime
+ << " but is not included in Model input or output indexes";
+ }
+ }
+
+ return {};
+}
+
+Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) {
+ // Either the operand has a known value before model execution begins, or we've seen a writer
+ // for this operand while walking operands in execution order. Initialize to known operands.
+ std::vector<bool> operandValueKnown;
+ operandValueKnown.reserve(subgraph.operands.size());
+ std::transform(subgraph.operands.begin(), subgraph.operands.end(),
+ std::back_inserter(operandValueKnown), [](const Operand& operand) {
+ return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE &&
+ operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT;
+ });
+
+ // Validate that operations are sorted into execution order.
+ //
+ // If there is a cycle in the graph, the operations will not
+ // appear to be sorted into execution order: Some operation will
+ // have an input for which operandValueKnown[] is false.
+ for (size_t i = 0; i < subgraph.operations.size(); ++i) {
+ const auto& operation = subgraph.operations[i];
+
+ for (size_t j = 0; j < operation.inputs.size(); ++j) {
+ const uint32_t k = operation.inputs[j];
+ NN_VALIDATE(operandValueKnown[k]) << "Operation " << i << " input " << j << " (operand "
+ << k << ") is read before it is written";
+ }
+
+ for (size_t j = 0; j < operation.outputs.size(); ++j) {
+ const uint32_t k = operation.outputs[j];
+ // Assuming validateOperations() has returned true, we know that this output is
+ // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be
+ // true is if we've already seen a writer for this operand.
+ NN_VALIDATE(!operandValueKnown[k]) << "Operation " << i << " output " << j
+ << " (operand " << k << ") has already been written";
+ operandValueKnown[k] = true;
+ }
+ }
+
+ // Verify all operands are written.
+ for (size_t i = 0; i < subgraph.operands.size(); ++i) {
+ NN_VALIDATE(operandValueKnown[i]) << "Operand " << i << " is never written";
+ }
+
+ // TODO(b/77871786): verify that every operation has at least one output operand that is read?
+
+ return {};
+}
+
+Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph, size_t operandValuesSize,
+ const std::vector<size_t>& poolSizes,
+ const std::vector<Model::Subgraph>& referenced) {
+ NN_VALIDATE(!subgraph.operands.empty());
+ NN_VALIDATE(!subgraph.operations.empty());
+ // TODO(b/134529942#comment7): should we verify !subgraph.inputIndexes.empty()?
+ NN_VALIDATE(!subgraph.inputIndexes.empty());
+ NN_VALIDATE(!subgraph.outputIndexes.empty());
+
+ auto version = NN_TRY(
+ validateOperands(subgraph.operands, operandValuesSize, poolSizes, referenced.size()));
+ version = combineVersions(version, NN_TRY(validateOperations(subgraph.operations,
+ subgraph.operands, referenced)));
+
+ NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands,
+ Operand::LifeTime::SUBGRAPH_INPUT));
+ NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands,
+ Operand::LifeTime::SUBGRAPH_OUTPUT));
+
+ NN_TRY(validateExecutionOrder(subgraph));
+
+ return version;
+}
+
+Result<Version> validateModelExtensionNamesAndPrefixes(
+ const std::vector<Model::ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
+ for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) {
+ NN_VALIDATE(isValidExtensionName(extensionNameAndPrefix.name));
+ }
+
+ std::vector<std::reference_wrapper<const std::string>> names;
+ names.reserve(extensionNamesAndPrefixes.size());
+ std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
+ std::back_inserter(names),
+ [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
+ return std::cref(extensionNameAndPrefix.name);
+ });
+ std::sort(names.begin(), names.end(), std::less<std::string>{});
+ const auto nameIter =
+ std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
+ NN_VALIDATE(nameIter == names.end())
+ << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get();
+
+ std::vector<uint16_t> types;
+ types.reserve(extensionNamesAndPrefixes.size());
+ std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
+ std::back_inserter(types),
+ [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
+ return extensionNameAndPrefix.prefix;
+ });
+ std::sort(types.begin(), types.end());
+ const auto typeIter = std::adjacent_find(types.begin(), types.end());
+ NN_VALIDATE(typeIter == types.end())
+ << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter;
+
+ const bool hasExtensions = !extensionNamesAndPrefixes.empty();
+ return hasExtensions ? Version::ANDROID_Q : Version::ANDROID_OC_MR1;
+}
+
+// Makes sure the model does not contain subgraph reference cycles.
+Result<void> checkNoReferenceCycles(const Model& model, const Model::Subgraph& subgraph,
+ std::set<const Model::Subgraph*>* path) {
+ const auto [_, isNew] = path->insert(&subgraph);
+ NN_VALIDATE(isNew) << "Model contains a circular subgraph reference";
+ // TODO(b/165154824): It looks like this functions is doing a lot of redundant work.
+ for (const Operand& operand : subgraph.operands) {
+ if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
+ const uint32_t refSubgraphIndex = operand.location.offset;
+ NN_TRY(checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path));
+ }
+ }
+ path->erase(&subgraph);
+ return {};
+}
+
+Result<void> checkNoReferenceCycles(const Model& model) {
+ std::set<const Model::Subgraph*> path;
+ return checkNoReferenceCycles(model, model.main, &path);
+}
+
+Result<Version> validateModel(const Model& model) {
+ auto version = NN_TRY(validateVector(model.pools, validateMemory));
+ version = combineVersions(
+ version, NN_TRY(validateModelExtensionNamesAndPrefixes(model.extensionNameToPrefix)));
+
+ // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the
+ // execution stricter.
+
+ // Referenced models were introduced in Android R.
+ const bool hasReferencedModels = !model.referenced.empty();
+ const auto referenceModelVersion =
+ hasReferencedModels ? Version::ANDROID_R : Version::ANDROID_OC_MR1;
+ version = combineVersions(version, referenceModelVersion);
+
+ // Get memory sizes.
+ std::vector<size_t> poolSizes;
+ poolSizes.reserve(model.pools.size());
+ std::transform(model.pools.begin(), model.pools.end(), std::back_inserter(poolSizes),
+ [](const Memory& memory) { return memory.size; });
+ const size_t operandValuesSize = model.operandValues.size();
+
+ // Validate main subgraph.
+ auto result = validateModelSubgraph(model.main, operandValuesSize, poolSizes, model.referenced);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << " for main subgraph";
+ }
+ version = combineVersions(version, result.value());
+
+ // Validate referenced subgraphs.
+ for (size_t i = 0; i < model.referenced.size(); ++i) {
+ const auto& subgraph = model.referenced[i];
+ auto result =
+ validateModelSubgraph(subgraph, operandValuesSize, poolSizes, model.referenced);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << " for referenced subgraph" << i;
+ }
+ version = combineVersions(version, result.value());
+ }
+
+ // Ensure that there are no references formed by calling the subgraphs.
+ NN_TRY(checkNoReferenceCycles(model));
+
+ return version;
+}
+
+Result<Version> validateBufferDesc(const BufferDesc& /*bufferDesc*/) {
+ return Version::ANDROID_R;
+}
+
+Result<Version> validateBufferRole(const BufferRole& bufferRole) {
+ NN_VALIDATE_GT(bufferRole.frequency, 0.0f);
+ NN_VALIDATE_LE(bufferRole.frequency, 1.0f);
+ return Version::ANDROID_R;
+}
+
+Result<Version> validateRequestArgument(const Request::Argument& requestArgument,
+ const std::vector<size_t>& memorySizes, bool isOutput) {
+ const auto lifetime = requestArgument.lifetime;
+ const auto& location = requestArgument.location;
+ const auto& dimensions = requestArgument.dimensions;
+
+ switch (lifetime) {
+ case Request::Argument::LifeTime::POOL: {
+ NN_VALIDATE(location.pointer == kNullptrVariant);
+ NN_VALIDATE_LT(location.poolIndex, memorySizes.size());
+ // Do the addition using uint64_t to avoid potential wrap-around problems.
+ const auto lastPosition = static_cast<uint64_t>(location.offset) + location.length;
+ NN_VALIDATE_LE(lastPosition, memorySizes[location.poolIndex]);
+ return Version::ANDROID_OC_MR1;
+ }
+ case Request::Argument::LifeTime::NO_VALUE:
+ NN_VALIDATE(location.pointer == kNullptrVariant);
+ NN_VALIDATE_EQ(location.poolIndex, 0u);
+ NN_VALIDATE_EQ(location.offset, 0u);
+ NN_VALIDATE_EQ(location.length, 0u);
+ NN_VALIDATE(dimensions.empty());
+ return Version::ANDROID_OC_MR1;
+ case Request::Argument::LifeTime::POINTER: {
+ const bool isNullptr =
+ std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
+ NN_VALIDATE(!isNullptr);
+ NN_VALIDATE_EQ(location.poolIndex, 0u);
+ NN_VALIDATE_EQ(location.offset, 0u);
+ NN_VALIDATE_NE(location.length, 0u);
+ if (isOutput) {
+ NN_VALIDATE(std::holds_alternative<void*>(location.pointer));
+ }
+ return Version::ANDROID_OC_MR1;
+ }
+ }
+ NN_VALIDATE_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime;
+}
+
+Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) {
+ if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
+ NN_VALIDATE(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken);
+ return Version::ANDROID_R;
+ }
+ if (std::holds_alternative<std::shared_ptr<const IBuffer>>(memoryPool)) {
+ NN_VALIDATE(std::get<std::shared_ptr<const IBuffer>>(memoryPool) != nullptr);
+ return Version::ANDROID_R;
+ }
+ return validateMemory(std::get<Memory>(memoryPool));
+}
+
+Result<Version> validateRequest(const Request& request) {
+ auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool));
+
+ // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0.
+ std::vector<size_t> memorySizes;
+ memorySizes.reserve(request.pools.size());
+ std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes),
+ [](const Request::MemoryPool& memoryPool) {
+ const auto* memory = std::get_if<Memory>(&memoryPool);
+ return memory != nullptr ? memory->size : 0;
+ });
+
+ for (size_t i = 0; i < request.inputs.size(); ++i) {
+ const auto& input = request.inputs[i];
+ auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << " for input RequestArgument " << i;
+ }
+ version = combineVersions(version, result.value());
+ }
+ for (size_t i = 0; i < request.outputs.size(); ++i) {
+ const auto& output = request.outputs[i];
+ auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << " for output RequestArgument " << i;
+ }
+ version = combineVersions(version, result.value());
+ }
+
+ return version;
+}
+
+Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) {
+ if (optionalTimePoint.has_value()) {
+ NN_VALIDATE_GE(optionalTimePoint.value().time_since_epoch().count(), 0);
+ }
+ return Version::ANDROID_R;
+}
+
+Result<Version> validateOptionalTimeoutDuration(
+ const OptionalTimeoutDuration& optionalTimeoutDuration) {
+ if (optionalTimeoutDuration.has_value()) {
+ NN_VALIDATE_GE(optionalTimeoutDuration.value().count(), 0);
+ }
+ return Version::ANDROID_R;
+}
+
+Result<Version> validateRequestArgumentsForModel(
+ const std::vector<Request::Argument>& requestArguments,
+ const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands,
+ bool isOutput) {
+ auto version = Version::ANDROID_OC_MR1;
+ // The request should specify as many arguments as were described in the model.
+ const std::string_view type = isOutput ? "output" : "input";
+ const size_t requestArgumentCount = requestArguments.size();
+ NN_VALIDATE_EQ(requestArgumentCount, operandIndexes.size())
+ << "Request specifies " << requestArgumentCount << " " << type << "s but the model has "
+ << operandIndexes.size();
+ for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
+ requestArgumentIndex++) {
+ const Request::Argument& requestArgument = requestArguments[requestArgumentIndex];
+ // Get the operand index for this argument. We extract it from the list
+ // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
+ // We assume in this function that the model has been validated already.
+ const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
+ const Operand& operand = operands[operandIndex];
+ if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) {
+ // If the argument specified a dimension, validate it.
+ uint32_t modelRank = operand.dimensions.size();
+ uint32_t requestRank = requestArgument.dimensions.size();
+ if (requestRank == 0) {
+ // NOTE: validateRequestArguments cannot validate unknown tensor rank with
+ // extension operand type.
+ if (!isExtension(operand.type) && !isNonExtensionScalar(operand.type)) {
+ if (modelRank <= 0) {
+ NN_VALIDATE(isOutput)
+ << "Model has unknown input rank but the request does not "
+ "specify the rank.";
+ // Unspecified output dimensions introduced in Android Q.
+ version = combineVersions(version, Version::ANDROID_Q);
+ }
+ }
+ // Validate that all the dimensions are specified in the model.
+ for (size_t i = 0; i < modelRank; i++) {
+ if (operand.dimensions[i] == 0) {
+ NN_VALIDATE(isOutput)
+ << "Model has dimension " << i
+ << " set to 0 but the request does not specify the dimension.";
+ // Unspecified output dimensions introduced in Android Q.
+ version = combineVersions(version, Version::ANDROID_Q);
+ }
+ }
+ } else {
+ NN_VALIDATE(modelRank == 0 || requestRank == modelRank)
+ << "Request " << type << " " << requestArgumentIndex
+ << " has number of dimensions (" << requestRank
+ << ") different than the model's (" << modelRank << ")";
+ for (size_t i = 0; i < requestRank; i++) {
+ NN_VALIDATE(modelRank == 0 || operand.dimensions[i] == 0 ||
+ requestArgument.dimensions[i] == operand.dimensions[i])
+ << "Request " << type << " " << requestArgumentIndex
+ << " has dimension " << i << " of " << requestArgument.dimensions[i]
+ << " different than the model's " << operand.dimensions[i];
+ if (requestArgument.dimensions[i] == 0) {
+ NN_VALIDATE(isOutput) << "Request " << type << " " << requestArgumentIndex
+ << " has dimension " << i << " of zero";
+ // Unspecified output dimensions introduced in Android Q.
+ version = combineVersions(version, Version::ANDROID_Q);
+ }
+ }
+ }
+ }
+ }
+ return version;
+}
+
+Result<Version> validateRequestForModelImpl(const Request& request, const Model& model) {
+ auto version = NN_TRY(validateRequest(request));
+ version = combineVersions(version, NN_TRY(validateModel(model)));
+ version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel(
+ request.inputs, model.main.inputIndexes,
+ model.main.operands, /*isOutput=*/false)));
+ version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel(
+ request.inputs, model.main.inputIndexes,
+ model.main.operands, /*isOutput=*/true)));
+ return version;
+}
+
+Result<Version> validateMemoryDescImpl(
+ const BufferDesc& desc,
+ const std::vector<std::shared_ptr<const IPreparedModel>>& preparedModels,
+ const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
+ const std::function<const Model*(const std::shared_ptr<const IPreparedModel>&)>& getModel,
+ std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
+ NN_VALIDATE(!preparedModels.empty());
+ NN_VALIDATE(!inputRoles.empty() || !outputRoles.empty());
+
+ std::set<PreparedModelRole> roles;
+ std::vector<nn::Operand> operands;
+ operands.reserve(inputRoles.size() + outputRoles.size());
+ for (const auto& role : inputRoles) {
+ NN_VALIDATE_LT(role.modelIndex, preparedModels.size());
+ const auto& preparedModel = preparedModels[role.modelIndex];
+ NN_VALIDATE(preparedModel != nullptr);
+ const auto* model = getModel(preparedModel);
+ NN_VALIDATE(model != nullptr);
+ const auto& inputIndexes = model->main.inputIndexes;
+ NN_VALIDATE_LT(role.ioIndex, inputIndexes.size());
+ NN_VALIDATE_GT(role.frequency, 0.0f);
+ NN_VALIDATE_LE(role.frequency, 1.0f);
+ const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
+ NN_VALIDATE(success);
+ operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
+ }
+ for (const auto& role : outputRoles) {
+ NN_VALIDATE_LT(role.modelIndex, preparedModels.size());
+ const auto& preparedModel = preparedModels[role.modelIndex];
+ NN_VALIDATE(preparedModel != nullptr);
+ const auto* model = getModel(preparedModel);
+ NN_VALIDATE(model != nullptr);
+ const auto& outputIndexes = model->main.outputIndexes;
+ NN_VALIDATE_LT(role.ioIndex, outputIndexes.size());
+ NN_VALIDATE_GT(role.frequency, 0.0f);
+ NN_VALIDATE_LE(role.frequency, 1.0f);
+ const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
+ NN_VALIDATE(success);
+ operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
+ }
+
+ CHECK(!operands.empty());
+ const auto opType = operands.front().type;
+
+ Dimensions dimensions = desc.dimensions;
+ for (const auto& operand : operands) {
+ NN_VALIDATE_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type;
+ NN_VALIDATE_EQ(operand.scale, operands.front().scale);
+ NN_VALIDATE_EQ(operand.zeroPoint, operands.front().zeroPoint);
+ // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
+ if (!isExtension(opType)) {
+ NN_VALIDATE_EQ(operand.extraParams, operands.front().extraParams)
+ << operand.extraParams << " vs " << operands.front().extraParams;
+ }
+ dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions));
+ }
+
+ // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
+ if (!isExtension(opType)) {
+ NN_VALIDATE(!isNonExtensionScalar(opType) || dimensions.empty())
+ << "invalid dimensions with scalar operand type.";
+ }
+
+ if (preparedModelRoles != nullptr) {
+ *preparedModelRoles = std::move(roles);
+ }
+ if (combinedOperand != nullptr) {
+ *combinedOperand = operands.front();
+ combinedOperand->dimensions = dimensions;
+ }
+ return Version::ANDROID_R;
+}
+
+// TODO: Enable this block of code once canonical types are integrated in the rest of the NNAPI
+// codebase.
+#if 0
+class OperationValidationContext : public IOperationValidationContext {
+ DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
+
+ public:
+ OperationValidationContext(const char* operationName, const std::vector<uint32_t>&
+ inputIndexes,
+ const std::vector<uint32_t>& outputIndexes,
+ const std::vector<Operand>& operands, Version version)
+ : operationName(operationName),
+ inputIndexes(inputIndexes),
+ outputIndexes(outputIndexes),
+ operands(operands),
+ version(version) {}
+
+ const char* getOperationName() const override;
+ Version getVersion() const override;
+
+ uint32_t getNumInputs() const override;
+ OperandType getInputType(uint32_t index) const override;
+ Shape getInputShape(uint32_t index) const override;
+ const Operand::ExtraParams getInputExtraParams(uint32_t index) const override;
+
+ uint32_t getNumOutputs() const override;
+ OperandType getOutputType(uint32_t index) const override;
+ Shape getOutputShape(uint32_t index) const override;
+
+ private:
+ const Operand* getInputOperand(uint32_t index) const;
+ const Operand* getOutputOperand(uint32_t index) const;
+
+ const char* operationName;
+ const std::vector<uint32_t>& inputIndexes;
+ const std::vector<uint32_t>& outputIndexes;
+ const std::vector<Operand>& operands;
+ Version version;
+};
+
+const char* OperationValidationContext::getOperationName() const {
+ return operationName;
+}
+
+Version OperationValidationContext::getVersion() const {
+ return version;
+}
+
+const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
+ return &operands.at(inputIndexes.at(index));
+}
+
+const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
+ return &operands.at(outputIndexes.at(index));
+}
+
+uint32_t OperationValidationContext::getNumInputs() const {
+ auto count = inputIndexes.size();
+ CHECK_LE(count, std::numeric_limits<uint32_t>::max());
+ return static_cast<uint32_t>(count);
+}
+
+uint32_t OperationValidationContext::getNumOutputs() const {
+ auto count = outputIndexes.size();
+ CHECK_LE(count, std::numeric_limits<uint32_t>::max());
+ return static_cast<uint32_t>(count);
+}
+
+OperandType OperationValidationContext::getInputType(uint32_t index) const {
+ return getInputOperand(index)->type;
+}
+
+Shape OperationValidationContext::getInputShape(uint32_t index) const {
+ const Operand* operand = getInputOperand(index);
+ return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
+ operand->extraParams};
+}
+
+const Operand::ExtraParams OperationValidationContext::getInputExtraParams(uint32_t index) const
+{
+ return getInputOperand(index)->extraParams;
+}
+
+OperandType OperationValidationContext::getOutputType(uint32_t index) const {
+ return getOutputOperand(index)->type;
+}
+
+Shape OperationValidationContext::getOutputShape(uint32_t index) const {
+ const Operand* operand = getOutputOperand(index);
+ return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
+ operand->extraParams};
+}
+#endif
+
+// TODO(b/169345292): reduce the duplicate validation here
+
+Result<void> validateOperandSymmPerChannelQuantParamsImpl(
+ const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
+ const char* tag) {
+ if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
+ NN_VALIDATE_FAIL();
+ }
+
+ NN_VALIDATE_LT(channelQuant.channelDim, operand.dimensions.size()) << tag;
+ NN_VALIDATE(!channelQuant.scales.empty()) << tag;
+ NN_VALIDATE_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag;
+ NN_VALIDATE_NE(operand.dimensions[channelQuant.channelDim], 0u)
+ << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
+ for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) {
+ NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
+ }
+ return {};
+}
+
+Result<void> validateScalarDimensions(const Operand& type, const char* tag) {
+ NN_VALIDATE(type.dimensions.empty()) << tag << " invalid dimensions for scalar type";
+ return {};
+}
+
+Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) {
+ NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 255)
+ << tag << " invalid zeroPoint: " << type.zeroPoint;
+ NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
+ return {};
+}
+
+Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) {
+ NN_VALIDATE(-128 <= type.zeroPoint && type.zeroPoint <= 127)
+ << tag << " invalid zeroPoint: " << type.zeroPoint;
+ NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
+ return {};
+}
+
+Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) {
+ NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
+ NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
+ return {};
+}
+
+Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) {
+ NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 65535)
+ << tag << " invalid zeroPoint: " << type.zeroPoint;
+ NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
+ return {};
+}
+
+Result<void> validateQuantSymmParams(const Operand& type, const char* tag) {
+ NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
+ NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
+ return {};
+}
+
+Result<void> validateNoQuantParams(const Operand& type, const char* tag) {
+ NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
+ NN_VALIDATE_EQ(type.scale, 0.0f) << tag << " scale is not zero";
+ return {};
+}
+
+Result<void> validateTensorDimensions(
+ const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo,
+ const char* tag, bool allowPartial) {
+ if (!allowPartial) {
+ NN_VALIDATE(!type.dimensions.empty()) << tag << " invalid operand dimensions";
+ }
+ uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize
+ : getNonExtensionSize(type.type);
+ constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
+ for (size_t i = 0; i < type.dimensions.size(); i++) {
+ if (!allowPartial) {
+ NN_VALIDATE_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
+ }
+ if (type.dimensions[i] != 0) {
+ size *= type.dimensions[i];
+ NN_VALIDATE_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
+ }
+ }
+ return {};
+}
+
+Result<void> validateOperandTypeImpl(
+ const Operand& type,
+ const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
+ bool allowPartial) {
+ if (isExtension(type.type)) {
+ NN_VALIDATE(extensionOperandTypeInfo != nullptr);
+ if (extensionOperandTypeInfo->isTensor) {
+ NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
+ } else {
+ NN_TRY(validateScalarDimensions(type, tag));
+ }
+ return validateNoQuantParams(type, tag);
+ }
+
+ NN_VALIDATE(extensionOperandTypeInfo == nullptr);
+ NN_TRY(validateOperandType(type.type));
+
+ if (isNonExtensionScalar(type.type)) {
+ NN_TRY(validateScalarDimensions(type, tag));
+ if (type.type != OperandType::OEM) { // Historically, we have allowed OEM types
+ // to use quantization parameters.
+ NN_TRY(validateNoQuantParams(type, tag));
+ }
+ } else {
+ NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
+ if (type.type == OperandType::TENSOR_QUANT8_ASYMM) {
+ NN_TRY(validateQuant8AsymmParams(type, tag));
+ } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ NN_TRY(validateQuant8AsymmSignedParams(type, tag));
+ } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) {
+ NN_TRY(validateQuant8SymmParams(type, tag));
+ } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) {
+ NN_TRY(validateQuant16AsymmParams(type, tag));
+ } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) {
+ NN_TRY(validateQuantSymmParams(type, tag));
+ } else if (type.type == OperandType::TENSOR_INT32 ||
+ type.type == OperandType::TENSOR_OEM_BYTE) {
+ // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
+ // Historically, we have allowed OEM types to use quantization parameters.
+ } else {
+ NN_TRY(validateNoQuantParams(type, tag));
+ }
+ }
+
+ return {};
+}
+
+Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount,
+ const char* tag) {
+ for (size_t i = 0; i < list.size(); i++) {
+ NN_VALIDATE_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = "
+ << list[i] << ", operandCount " << operandCount;
+ }
+ return {};
+}
+
+Result<void> validateOperationOperandTypes(const std::vector<Operand>& operands,
+ const std::vector<uint32_t>& inputIndexes,
+ const std::vector<OperandType>& inExpectedTypes,
+ const std::vector<uint32_t>& outputIndexes,
+ const std::vector<OperandType>& outExpectedInTypes) {
+ NN_VALIDATE_EQ(inputIndexes.size(), inExpectedTypes.size())
+ << "Wrong operand count: expected " << inputIndexes.size() << " inputs, got "
+ << inputIndexes.size() << " inputs";
+ NN_VALIDATE_EQ(outputIndexes.size(), outExpectedInTypes.size())
+ << "Wrong operand count: expected " << outputIndexes.size() << " outputs, got "
+ << outputIndexes.size() << " outputs";
+ for (size_t i = 0; i < inputIndexes.size(); i++) {
+ NN_VALIDATE_EQ(operands[inputIndexes[i]].type, inExpectedTypes[i])
+ << "Invalid input tensor type " << operands[inputIndexes[i]].type << " for input "
+ << i << ", expected " << inExpectedTypes[i];
+ }
+ for (size_t i = 0; i < outputIndexes.size(); i++) {
+ NN_VALIDATE_EQ(operands[outputIndexes[i]].type, outExpectedInTypes[i])
+ << "Invalid output tensor type " << operands[outputIndexes[i]].type << " for input "
+ << i << ", expected " << outExpectedInTypes[i];
+ }
+
+ return {};
+}
+
+Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs,
+ const Operand& modelOperand) {
+ NN_VALIDATE_EQ(modelOperand.type, OperandType::SUBGRAPH)
+ << "Unexpected operand type: " << modelOperand.type;
+ NN_VALIDATE_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference";
+ return {};
+}
+const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs,
+ const Operand& modelOperand) {
+ return subgraphs.at(modelOperand.location.offset);
+}
+uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) {
+ return getSubgraph(subgraphs, modelOperand).inputIndexes.size();
+}
+uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs,
+ const Operand& modelOperand) {
+ return getSubgraph(subgraphs, modelOperand).outputIndexes.size();
+}
+const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs,
+ const Operand& modelOperand, uint32_t index) {
+ const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
+ return subgraph.operands.at(subgraph.inputIndexes.at(index));
+}
+const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs,
+ const Operand& modelOperand, uint32_t index) {
+ const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
+ return subgraph.operands.at(subgraph.outputIndexes.at(index));
+}
+
+// Checks if two operands have the same types, ranks (if specified), dimensions
+// (if specified), scales, zeroPoints, and extraParams.
+Result<void> compatible(const Operand& a, const Operand& b) {
+ NN_VALIDATE_EQ(a.type, b.type) << a.type << " != " << b.type;
+ if (!a.dimensions.empty() && !b.dimensions.empty()) {
+ NN_VALIDATE_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
+ for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
+ if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
+ NN_VALIDATE_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
+ }
+ }
+ }
+ NN_VALIDATE_EQ(a.scale, b.scale);
+ NN_VALIDATE_EQ(a.zeroPoint, b.zeroPoint);
+ NN_VALIDATE_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams;
+ return {};
+}
+
+Result<void> validateConditionOperand(const Operand& operand) {
+ NN_VALIDATE_EQ(operand.type, OperandType::TENSOR_BOOL8)
+ << "Unexpected condition operand type: " << operand.type;
+ NN_VALIDATE_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
+ NN_VALIDATE_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
+ return {};
+}
+
+Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs,
+ const std::vector<uint32_t>& outputs,
+ const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs) {
+ namespace op = operation_if;
+ NN_VALIDATE_GE(inputs.size(), 3u) << "IF must have at least 3 inputs";
+ NN_VALIDATE_GE(outputs.size(), 1u) << "IF must have at least 1 output";
+ auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> {
+ auto result = validateSubgraphReference(subgraphs, branchModelOperand);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error()
+ << "Operand is not a valid subgraph reference";
+ }
+ const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand);
+ const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand);
+ NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + branchModelInputCount);
+ NN_VALIDATE_EQ(outputs.size(), branchModelOutputCount);
+ for (uint32_t i = 0; i < branchModelInputCount; ++i) {
+ const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i);
+ const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+ NN_TRY(compatible(innerOperand, outerOperand));
+ }
+ for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
+ const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i);
+ const Operand& outerOperand = operands[outputs[i]];
+ NN_TRY(compatible(innerOperand, outerOperand));
+ }
+ return {};
+ };
+ auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error()
+ << "Validation failed for IF condition operand";
+ }
+ result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << "Validation failed for IF then model";
+ }
+ result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << "Validation failed for IF else model";
+ }
+ return Version::ANDROID_R;
+}
+
+Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) {
+ if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) {
+ // 1.3 HAL (corresponding to Version::ANDROID_R) does not support CF operations with
+ // operands of unknown size. See http://b/132458982#comment63.
+ return Version::CURRENT_RUNTIME;
+ }
+ return Version::ANDROID_R;
+}
+
+Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs,
+ const std::vector<uint32_t>& outputs,
+ const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs) {
+ // Let the loop have
+ // - m >= 1 input-output operands,
+ // - k >= 0 state-only operands, and
+ // - n >= 0 input-only operands.
+ // Then
+ // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
+ // - the condition model has (m + k + n) inputs and 1 output.
+ // - the body model has (m + k + n) inputs and (m + k) outputs.
+ namespace op = operation_while;
+ NN_VALIDATE_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs";
+ NN_VALIDATE_GE(outputs.size(), 1u) << "WHILE must have at least 1 output";
+ auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> {
+ Version version = Version::ANDROID_R;
+ auto result = validateSubgraphReference(subgraphs, condModelOperand);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error()
+ << "Operand is not a valid subgraph reference";
+ }
+ const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand);
+ const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand);
+ NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + condModelInputCount);
+ NN_VALIDATE_EQ(condModelOutputCount, 1u);
+ for (uint32_t i = 0; i < condModelInputCount; ++i) {
+ const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i);
+ const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+ NN_TRY(compatible(innerOperand, outerOperand));
+ version = combineVersions(version,
+ NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
+ version = combineVersions(version,
+ NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
+ }
+ NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0)));
+ return version;
+ };
+ auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> {
+ Version version = Version::ANDROID_R;
+ auto result = validateSubgraphReference(subgraphs, bodyModelOperand);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error()
+ << "Operand is not a valid subgraph reference";
+ }
+ const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand);
+ const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand);
+ NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount);
+ NN_VALIDATE_GE(bodyModelOutputCount, outputs.size());
+ NN_VALIDATE_GE(bodyModelInputCount, bodyModelOutputCount);
+ const uint32_t inputOutputCount = outputs.size();
+ const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
+ const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
+ for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
+ const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i);
+ const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+ NN_TRY(compatible(innerOperand, outerOperand));
+ version = combineVersions(version,
+ NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
+ version = combineVersions(version,
+ NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
+ }
+ for (uint32_t i = 0; i < inputOutputCount; ++i) {
+ const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
+ const Operand& outerOperand = operands[outputs[i]];
+ NN_TRY(compatible(innerOperand, outerOperand));
+ version = combineVersions(version,
+ NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
+ }
+ for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
+ const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i);
+ const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
+ NN_TRY(compatible(inputOperand, outputOperand));
+ version = combineVersions(version,
+ NN_TRY(validateControlFlowOperandUnknownSize(outputOperand)));
+ }
+ return version;
+ };
+ auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error()
+ << "Validation failed for WHILE condition model";
+ }
+ auto version = result.value();
+ result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]);
+ if (!result.has_value()) {
+ return NN_ERROR() << std::move(result).error() << "Validation failed for WHILE body model";
+ }
+ version = combineVersions(version, result.value());
+ return version;
+}
+
+Result<Version> validateOperationImpl(const Operation& operation,
+ const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs) {
+ const auto opType = operation.type;
+ const auto& inputIndexes = operation.inputs;
+ const auto& outputIndexes = operation.outputs;
+
+ NN_TRY(validateOperandListImpl(inputIndexes, operands.size(),
+ "ANeuralNetworksModel_addOperation inputs"));
+ NN_TRY(validateOperandListImpl(outputIndexes, operands.size(),
+ "ANeuralNetworksModel_addOperation outputs"));
+
+ if (isExtension(opType)) {
+ // There is no other validation we can do for an extension operation.
+ return Version::ANDROID_Q;
+ }
+
+ auto invalidInOutNumberMessage = [opType, &inputIndexes, &outputIndexes](int expIn,
+ int expOut) {
+ std::ostringstream os;
+ os << "Invalid number of input operands (" << inputIndexes.size() << ", expected " << expIn
+ << ") or output operands (" << outputIndexes.size() << ", expected " << expOut
+ << ") for operation " << opType;
+ return os.str();
+ };
+
+ switch (opType) {
+ case OperationType::OEM_OPERATION: {
+ return Version::ANDROID_OC_MR1;
+ }
+ case OperationType::RESHAPE: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ OperandType::TENSOR_INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ NN_VALIDATE_LE(inputRank, 4u)
+ << "Unsupported input tensor rank for operation " << opType;
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::DEPTH_TO_SPACE: {
+ NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
+ outputIndexes.size() == 1)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 3 or 2) or output operands (" << outputIndexes.size()
+ << ", expected 1) for operation " << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ if (inputIndexes.size() == 3) {
+ inExpectedTypes.push_back(OperandType::BOOL);
+ version = combineVersions(version, Version::ANDROID_Q);
+ } else {
+ version = combineVersions(version, Version::ANDROID_OC_MR1);
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::SPACE_TO_DEPTH: {
+ NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
+ outputIndexes.size() == 1)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 3 or 2) or output operands (" << outputIndexes.size()
+ << ", expected 1) for operation " << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ if (inputIndexes.size() == 3) {
+ inExpectedTypes.push_back(OperandType::BOOL);
+ version = combineVersions(version, Version::ANDROID_Q);
+ } else {
+ version = combineVersions(version, Version::ANDROID_OC_MR1);
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::EMBEDDING_LOOKUP: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[1]].type;
+ NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
+ << "Unsupported input tensor type for operation " << opType;
+ Version version;
+ std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType};
+ std::vector<OperandType> outExpectedTypes = {inputType};
+ if (inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else if (inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ version = Version::ANDROID_Q;
+ } else {
+ version = Version::ANDROID_OC_MR1;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::HASHTABLE_LOOKUP: {
+ NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 2)
+ << invalidInOutNumberMessage(3, 2);
+ auto inputType = operands[inputIndexes[2]].type;
+ NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM)
+ << "Unsupported input tensor type for operation " << opType;
+ std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32, inputType};
+ std::vector<OperandType> outExpectedTypes = {inputType,
+ OperandType::TENSOR_QUANT8_ASYMM};
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return Version::ANDROID_OC_MR1;
+ }
+ case OperationType::LSH_PROJECTION: {
+ NN_VALIDATE(inputIndexes.size() == 4 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(4, 1);
+ auto inputType = operands[inputIndexes[1]].type;
+ NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM)
+ << "Unsupported input tensor type for operation " << opType;
+ auto hashType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ if (hashType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ inputType,
+ OperandType::TENSOR_FLOAT16,
+ OperandType::INT32,
+ };
+ } else if (hashType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ inputType,
+ OperandType::TENSOR_FLOAT32,
+ OperandType::INT32,
+ };
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported hash tensor type for operation " << opType;
+ }
+ std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: {
+ const uint32_t kNumOutputs = 2;
+ const uint32_t kNumOutputsMerged = 1;
+ const uint32_t kNumOutputsWithState = 6;
+ const uint32_t kNumOutputsMergedWithState = 5;
+ NN_VALIDATE(inputIndexes.size() == 61 &&
+ (outputIndexes.size() == kNumOutputs ||
+ outputIndexes.size() == kNumOutputsMerged ||
+ outputIndexes.size() == kNumOutputsWithState ||
+ outputIndexes.size() == kNumOutputsMergedWithState))
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 61) or output operands (" << outputIndexes.size()
+ << ", expected 1, 2, 5 or 6) for operation " << opType;
+
+ std::vector<OperandType> inExpectedTypes;
+ auto inputType = operands[inputIndexes[0]].type;
+ NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_FLOAT16)
+ << "Unsupported input tensor type for operation " << opType;
+
+ inExpectedTypes = {};
+ for (int i = 0; i < 48; ++i) {
+ inExpectedTypes.push_back(inputType);
+ }
+ inExpectedTypes.push_back(OperandType::INT32);
+ inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
+ ? OperandType::FLOAT32
+ : OperandType::FLOAT16);
+ inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
+ ? OperandType::FLOAT32
+ : OperandType::FLOAT16);
+ inExpectedTypes.push_back(OperandType::BOOL);
+ inExpectedTypes.push_back(OperandType::BOOL);
+ for (int i = 0; i < 8; ++i) {
+ inExpectedTypes.push_back(inputType);
+ }
+
+ Version version = Version::ANDROID_Q;
+ if (outputIndexes.size() == kNumOutputsWithState ||
+ outputIndexes.size() == kNumOutputsMergedWithState) {
+ version = Version::ANDROID_R;
+ }
+ std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType);
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::LSTM: {
+ NN_VALIDATE((inputIndexes.size() == 23 || inputIndexes.size() == 27) &&
+ outputIndexes.size() == 4)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 23 or 27) or output operands (" << outputIndexes.size()
+ << ", expected 4) for operation " << opType;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ auto inputType = operands[inputIndexes[0]].type;
+ NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_FLOAT16)
+ << "Unsupported input tensor type for operation " << opType;
+
+ Version version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {inputType, inputType, inputType, inputType, inputType,
+ inputType, inputType, inputType, inputType, inputType,
+ inputType, inputType, inputType, inputType, inputType,
+ inputType, inputType, inputType, inputType, inputType,
+ OperandType::INT32};
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ inExpectedTypes.push_back(OperandType::FLOAT32);
+ inExpectedTypes.push_back(OperandType::FLOAT32);
+ } else {
+ version = Version::ANDROID_Q;
+ inExpectedTypes.push_back(OperandType::FLOAT16);
+ inExpectedTypes.push_back(OperandType::FLOAT16);
+ }
+
+ outExpectedTypes = {inputType, inputType, inputType, inputType};
+ if (inputIndexes.size() == 23) {
+ version = combineVersions(version, Version::ANDROID_OC_MR1);
+ } else {
+ version = combineVersions(version, Version::ANDROID_Q);
+ for (int i = 0; i < 4; ++i) {
+ inExpectedTypes.push_back(inputType);
+ }
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::QUANTIZED_16BIT_LSTM: {
+ NN_VALIDATE(inputIndexes.size() == 15 && outputIndexes.size() == 2)
+ << invalidInOutNumberMessage(15, 2);
+ std::vector<OperandType> inExpectedTypes = {
+ OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
+ OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
+ OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
+ OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
+ OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32, OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_SYMM,
+ OperandType::TENSOR_QUANT8_ASYMM};
+ std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM,
+ OperandType::TENSOR_QUANT8_ASYMM};
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return Version::ANDROID_Q;
+ }
+ case OperationType::RANDOM_MULTINOMIAL: {
+ NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(3, 1);
+ OperandType inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_FLOAT16) {
+ inExpectedTypes = {inputType, OperandType::INT32, OperandType::TENSOR_INT32};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return Version::ANDROID_Q;
+ }
+ case OperationType::RNN: {
+ NN_VALIDATE(inputIndexes.size() == 6 && outputIndexes.size() == 2)
+ << invalidInOutNumberMessage(6, 2);
+ OperandType inputType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_OC_MR1;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32, OperandType::INT32,
+ };
+ outExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32,
+ };
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16, OperandType::INT32,
+ };
+ outExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16,
+ };
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::SVDF: {
+ NN_VALIDATE(inputIndexes.size() == 7 && outputIndexes.size() == 2)
+ << invalidInOutNumberMessage(7, 2);
+ Version version;
+ OperandType inputType = operands[inputIndexes[0]].type;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_OC_MR1;
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ std::vector<OperandType> inExpectedTypes = {
+ inputType, inputType, inputType, inputType,
+ inputType, OperandType::INT32, OperandType::INT32,
+ };
+ std::vector<OperandType> outExpectedTypes = {inputType, inputType};
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::BATCH_TO_SPACE_ND: {
+ NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
+ outputIndexes.size() == 1)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 3 or 2) or output operands (" << outputIndexes.size()
+ << ", expected 1) for operation " << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version = Version::ANDROID_OC_MR1;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ inExpectedTypes = {
+ OperandType::TENSOR_QUANT8_ASYMM,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ inExpectedTypes = {
+ OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ if (inputIndexes.size() == 3) {
+ inExpectedTypes.push_back(OperandType::BOOL);
+ version = combineVersions(version, Version::ANDROID_Q);
+ } else {
+ version = combineVersions(version, Version::ANDROID_P);
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::SPACE_TO_BATCH_ND: {
+ NN_VALIDATE((inputIndexes.size() == 4 || inputIndexes.size() == 3) &&
+ outputIndexes.size() == 1)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 4 or 3) or output operands (" << outputIndexes.size()
+ << ", expected 1) for operation " << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version = Version::ANDROID_OC_MR1;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ if (operands[inputIndexes[0]].zeroPoint != 0) {
+ version = Version::ANDROID_Q;
+ }
+ inExpectedTypes = {
+ OperandType::TENSOR_QUANT8_ASYMM,
+ OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ inExpectedTypes = {
+ OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
+ OperandType::TENSOR_INT32,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ if (inputIndexes.size() == 4) {
+ inExpectedTypes.push_back(OperandType::BOOL);
+ version = combineVersions(version, Version::ANDROID_Q);
+ } else {
+ version = combineVersions(version, Version::ANDROID_P);
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::PAD: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_P;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ if (operands[inputIndexes[0]].zeroPoint == 0) {
+ version = Version::ANDROID_P;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ }
+ inExpectedTypes = {
+ inputType,
+ OperandType::TENSOR_INT32,
+ };
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ NN_VALIDATE_LE(inputRank, 4u)
+ << "Unsupported input tensor rank for operation " << opType;
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::PAD_V2: {
+ NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(3, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_INT32,
+ OperandType::FLOAT32,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_INT32,
+ OperandType::FLOAT16,
+ };
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ inExpectedTypes = {
+ inputType,
+ OperandType::TENSOR_INT32,
+ OperandType::INT32,
+ }; // TODO(b/116699425): Make it UINT8.
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ NN_VALIDATE_LE(inputRank, 4u)
+ << "Unsupported input tensor rank for operation " << opType;
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::CAST: {
+ NN_VALIDATE(inputIndexes.size() == 1 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(1, 1);
+ auto inputOperand = operands[inputIndexes[0]];
+ auto outputOperand = operands[outputIndexes[0]];
+ auto inputType = inputOperand.type;
+ auto outputType = outputOperand.type;
+ Version version;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if ((inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM) &&
+ (outputType == OperandType::TENSOR_FLOAT16 ||
+ outputType == OperandType::TENSOR_FLOAT32 ||
+ outputType == OperandType::TENSOR_INT32 ||
+ outputType == OperandType::TENSOR_QUANT8_ASYMM)) {
+ version = Version::ANDROID_Q;
+ inExpectedTypes = {inputType};
+ outExpectedTypes = {outputType};
+ } else if (inputType == OperandType::TENSOR_BOOL8 ||
+ inputType == OperandType::TENSOR_QUANT16_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT16_SYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
+ inputType == OperandType::TENSOR_QUANT8_SYMM) {
+ version = Version::ANDROID_R;
+ inExpectedTypes = {inputType};
+ outExpectedTypes = {inputType}; // Only identity CAST is supported.
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported data type for operation " << opType;
+ }
+ // Validate that output shape is equal to input shape if dimensions
+ // are already known.
+ auto getNumberOfElements = [](const std::vector<uint32_t>& dims) {
+ if (dims.empty()) {
+ return 0;
+ }
+ return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
+ };
+ NN_VALIDATE(inputOperand.dimensions.empty() || outputOperand.dimensions.empty() ||
+ getNumberOfElements(outputOperand.dimensions) == 0 ||
+ inputOperand.dimensions == outputOperand.dimensions);
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::MEAN: {
+ NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(3, 1);
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ NN_VALIDATE_LE(inputRank, 4u)
+ << "Unsupported input tensor rank for operation " << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ Version version;
+ if (inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ version = Version::ANDROID_P;
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ version = Version::ANDROID_Q;
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ std::vector<OperandType> inExpectedTypes = {inputType, OperandType::TENSOR_INT32,
+ OperandType::INT32};
+ std::vector<OperandType> outExpectedTypes = {inputType};
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::ARGMAX:
+ case OperationType::ARGMIN: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ inExpectedTypes = {inputType, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_INT32};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return Version::ANDROID_Q;
+ }
+ case OperationType::EXPAND_DIMS: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ inExpectedTypes = {inputType, OperandType::INT32};
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ Version version;
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::SPLIT: {
+ NN_VALIDATE_EQ(inputIndexes.size(), 3u)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 3)" << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
+ << "Unsupported input tensor type for operation " << opType;
+ Version version;
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32,
+ OperandType::INT32};
+ std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType);
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::MAXIMUM:
+ case OperationType::MINIMUM: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ OperandType inputType = operands[inputIndexes[0]].type;
+ if (inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ inExpectedTypes = {inputType, inputType};
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ Version version;
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::GROUPED_CONV_2D: {
+ NN_VALIDATE((inputIndexes.size() == 12 || inputIndexes.size() == 9) &&
+ outputIndexes.size() == 1)
+ << "Invalid number of input operands (" << inputIndexes.size()
+ << ", expected 12 or 9) or output operands (" << outputIndexes.size()
+ << ", expected 1) for operation " << opType;
+ auto inputType = operands[inputIndexes[0]].type;
+ auto filterType = operands[inputIndexes[1]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32, OperandType::INT32,
+ OperandType::INT32, OperandType::INT32,
+ OperandType::INT32, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16, OperandType::INT32,
+ OperandType::INT32, OperandType::INT32,
+ OperandType::INT32, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT16};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ NN_VALIDATE(filterType == inputType ||
+ filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)
+ << "Unsupported filter tensor type for operation " << opType;
+
+ NN_VALIDATE(filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
+ std::get<Operand::SymmPerChannelQuantParams>(
+ operands[inputIndexes[1]].extraParams)
+ .channelDim == 0)
+ << "Unsupported filter tensor channel dimension for operation " << opType;
+
+ inExpectedTypes = {
+ inputType, filterType, OperandType::TENSOR_INT32,
+ OperandType::INT32, OperandType::INT32, OperandType::INT32,
+ OperandType::INT32, OperandType::INT32};
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+
+ if (inputIndexes.size() == 12) {
+ std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
+ inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
+ explicitScalarTypes.end());
+ }
+ inExpectedTypes.push_back(OperandType::BOOL);
+ Version version;
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::TILE: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32 ||
+ inputType == OperandType::TENSOR_INT32 ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ inExpectedTypes = {inputType, OperandType::TENSOR_INT32};
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ Version version;
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::POW: {
+ NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
+ << invalidInOutNumberMessage(2, 1);
+ auto inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT16 ||
+ inputType == OperandType::TENSOR_FLOAT32) {
+ inExpectedTypes = {inputType, inputType};
+ outExpectedTypes = {inputType};
+ } else {
+ NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
+ }
+ Version version;
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ version = Version::ANDROID_R;
+ } else {
+ version = Version::ANDROID_Q;
+ }
+ NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
+ outputIndexes, outExpectedTypes));
+ return version;
+ }
+ case OperationType::IF: {
+ return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs);
+ }
+ case OperationType::WHILE: {
+ return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs);
+ }
+ default: {
+ // TODO: Enable this block of code once canonical types are integrated in the rest of
+ // the NNAPI codebase.
+#if 0
+ const OperationRegistration* operationRegistration =
+ BuiltinOperationResolver::get()->findOperation(
+ static_cast<OperationType>(opType));
+ // TODO: return ErrorStatus::UNEXPECTED_NULL
+ NN_VALIDATE(operationRegistration != nullptr) << opType << " not registered";
+ // TODO: return ErrorStatus::UNEXPECTED_NULL
+ NN_VALIDATE(operationRegistration->validate != nullptr)
+ << "Incomplete operation registration: " << opType;
+ OperationValidationContext context(operationRegistration->name, inputIndexes,
+ outputIndexes, operands);
+ auto result = operationRegistration->validate(&context);
+ if (!result.has_value()) {
+ return NN_ERROR() << "Validation failed for operation " << opType << ": "
+ << std::move(result).error();
+ }
+ return result;
+#endif
+ NN_VALIDATE_FAIL() << "Validation for " << opType << " is not yet implemented";
+ }
+ }
+}
+
+} // anonymous namespace
+
+Version combineVersions(Version lhs, Version rhs) {
+ return std::max<Version>(lhs, rhs);
+}
+
+Result<Version> validate(const DeviceStatus& deviceStatus) {
+ return validateDeviceStatus(deviceStatus);
+}
+
+Result<Version> validate(const ExecutionPreference& executionPreference) {
+ return validateExecutionPreference(executionPreference);
+}
+
+Result<Version> validate(const DeviceType& deviceType) {
+ return validateDeviceType(deviceType);
+}
+
+Result<Version> validate(const MeasureTiming& measureTiming) {
+ return validateMeasureTiming(measureTiming);
+}
+
+Result<Version> validate(const Priority& priority) {
+ return validatePriority(priority);
+}
+
+Result<Version> validate(const ErrorStatus& errorStatus) {
+ return validateErrorStatus(errorStatus);
+}
+
+Result<Version> validate(const OutputShape& outputShape) {
+ return validateOutputShape(outputShape);
+}
+
+Result<Version> validate(const Timing& timing) {
+ return validateTiming(timing);
+}
+
+Result<Version> validate(const Capabilities& capabilities) {
+ return validateCapabilities(capabilities);
+}
+
+Result<Version> validate(const Extension& extension) {
+ return validateExtension(extension);
+}
+
+Result<Version> validate(const NativeHandle& handle) {
+ return validateNativeHandle(handle);
+}
+
+Result<Version> validate(const Memory& memory) {
+ return validateMemory(memory);
+}
+
+Result<Version> validate(const Model& model) {
+ return validateModel(model);
+}
+
+Result<Version> validate(const BufferDesc& bufferDesc) {
+ return validateBufferDesc(bufferDesc);
+}
+
+Result<Version> validate(const BufferRole& bufferRole) {
+ return validateBufferRole(bufferRole);
+}
+
+Result<Version> validate(const Request& request) {
+ return validateRequest(request);
+}
+
+Result<Version> validate(const OptionalTimePoint& optionalTimePoint) {
+ return validateOptionalTimePoint(optionalTimePoint);
+}
+
+Result<Version> validate(const OptionalTimeoutDuration& optionalTimeoutDuration) {
+ return validateOptionalTimeoutDuration(optionalTimeoutDuration);
+}
+
+Result<Version> validate(const std::vector<OutputShape>& outputShapes) {
+ return validateVector(outputShapes, validateOutputShape);
+}
+
+Result<Version> validate(const std::vector<Extension>& extensions) {
+ return validateExtensions(extensions);
+}
+
+Result<Version> validate(const std::vector<NativeHandle>& handles) {
+ return validateVector(handles, validateNativeHandle);
+}
+
+Result<Version> validate(const std::vector<BufferRole>& bufferRoles) {
+ return validateVector(bufferRoles, validateBufferRole);
+}
+
+Result<Version> validateRequestForModel(const Request& request, const Model& model) {
+ return validateRequestForModelImpl(request, model);
+}
+
+Result<Version> validateMemoryDesc(
+ const BufferDesc& desc,
+ const std::vector<std::shared_ptr<const IPreparedModel>>& preparedModels,
+ const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
+ const std::function<const Model*(const std::shared_ptr<const IPreparedModel>&)>& getModel,
+ std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
+ return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel,
+ preparedModelRoles, combinedOperand);
+}
+
+Result<void> validateOperandSymmPerChannelQuantParams(
+ const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
+ const char* tag) {
+ return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag);
+}
+
+Result<void> validateOperandType(const Operand& type,
+ const Extension::OperandTypeInformation* extensionOperandTypeInfo,
+ const char* tag, bool allowPartial) {
+ return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial);
+}
+
+Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount,
+ const char* tag) {
+ return validateOperandListImpl(list, operandCount, tag);
+}
+
+Result<Version> validateOperation(const Operation& operation, const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs) {
+ return validateOperationImpl(operation, operands, subgraphs);
+}
+
+} // namespace android::nn
diff --git a/nn/common/include/nnapi/Result.h b/nn/common/include/nnapi/Result.h
new file mode 100644
index 000000000..e232cef3f
--- /dev/null
+++ b/nn/common/include/nnapi/Result.h
@@ -0,0 +1,147 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_RESULT_H
+#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_RESULT_H
+
+#include <android-base/expected.h>
+
+#include <optional>
+#include <sstream>
+#include <string>
+#include <utility>
+
+namespace android::nn {
+
+/**
+ * Type alias for `::android::base::expected` where the unexpected state is represented by a
+ * std::string describing the error.
+ *
+ * See the following file for more information on ::android::base::expected:
+ * system/libbase/include/android-base/expected.h
+ */
+template <typename Type>
+using Result = base::expected<Type, std::string>;
+
+namespace detail {
+
+class ErrorBuilder {
+ public:
+ template <typename T, typename E>
+ operator base::expected<T, E>() const /* NOLINT(google-explicit-constructor) */ {
+ return base::unexpected<E>(std::move(mStream).str());
+ }
+
+ template <typename T>
+ ErrorBuilder operator<<(const T& t) {
+ mStream << t;
+ return std::move(*this);
+ }
+
+ private:
+ std::ostringstream mStream;
+};
+
+} // namespace detail
+
+/**
+ * Creates an error builder for the case where no arguments are provided.
+ */
+inline detail::ErrorBuilder error() {
+ return detail::ErrorBuilder();
+}
+
+/**
+ * Helper macro that will create an error builder already populated with the file name and line
+ * number.
+ *
+ * This macro uses the following customization points:
+ * * `::android::nn::error` is a set of functions that can be customized to return a specialized
+ * error builder object. Customization is based on the types of arguments passed and the number
+ * of arguments passed to `error`.
+ *
+ * Usage at error site:
+ * if (errorDetected) {
+ * return NN_ERROR() << "<error_message>";
+ * }
+ * return <regular_return_value>;
+ */
+#define NN_ERROR(...) \
+ [] { \
+ using ::android::nn::error; \
+ return error(__VA_ARGS__) << __FILE__ << ":" << __LINE__ << ": "; \
+ }()
+
+template <typename T, typename E>
+bool nnTryHasValue(const base::expected<T, E>& o) {
+ return o.has_value();
+}
+
+template <typename T, typename E>
+T nnTryGetValue(base::expected<T, E> o) {
+ return std::move(o).value();
+}
+
+template <typename T, typename E>
+base::unexpected<E> nnTryGetError(base::expected<T, E> o) {
+ return base::unexpected(std::move(o).error());
+}
+
+template <typename T>
+bool nnTryHasValue(const std::optional<T>& o) {
+ return o.has_value();
+}
+
+template <typename T>
+T nnTryGetValue(std::optional<T> o) {
+ return std::move(o).value();
+}
+
+template <typename T>
+std::nullopt_t nnTryGetError(std::optional<T> /*o*/) {
+ return std::nullopt;
+}
+
+/**
+ * A macro that will exit from the current function if `expr` is unexpected or return the expected
+ * value from the macro if `expr` is expected.
+ *
+ * This macro can currently be used on `::android::nn::Result`, `::android::base::expected`, or
+ * `std::optional` values. To enable this macro to be used with other values, implement the
+ * following functions for the type:
+ * * `::android::nn::nnTryHasValue` returns `true` if the `expr` holds a successful value, false if
+ * the `expr` value holds an error
+ * * `::android::nn::nnTryGetError` returns the successful value of `expr` or crashes
+ * * `::android::nn::nnTryGetValue` returns the error value of `expr` or crashes
+ *
+ * Usage at call site:
+ * const auto [a, b, c] = NN_TRY(failableFunction(args));
+ */
+#define NN_TRY(expr) \
+ ({ \
+ using ::android::nn::nnTryHasValue; \
+ using ::android::nn::nnTryGetValue; \
+ using ::android::nn::nnTryGetError; \
+ auto nnTryTemporaryResult = expr; \
+ if (!nnTryHasValue(nnTryTemporaryResult)) { \
+ return nnTryGetError(std::move(nnTryTemporaryResult)); \
+ } \
+ nnTryGetValue(std::move(nnTryTemporaryResult)); \
+ })
+
+} // namespace android::nn
+
+#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_RESULT_H
diff --git a/nn/common/include/nnapi/SharedMemory.h b/nn/common/include/nnapi/SharedMemory.h
new file mode 100644
index 000000000..ee8330f45
--- /dev/null
+++ b/nn/common/include/nnapi/SharedMemory.h
@@ -0,0 +1,93 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_SHARED_MEMORY_H
+#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_SHARED_MEMORY_H
+
+#include <any>
+#include <optional>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "nnapi/Result.h"
+#include "nnapi/Types.h"
+
+// Forward declare AHardwareBuffer
+extern "C" typedef struct AHardwareBuffer AHardwareBuffer;
+
+// Forward declare hidl_memory
+namespace android::hardware {
+struct hidl_memory;
+} // namespace android::hardware
+
+namespace android::nn {
+
+class MutableMemoryBuilder {
+ public:
+ explicit MutableMemoryBuilder(uint32_t poolIndex);
+
+ DataLocation append(size_t length);
+ bool empty() const;
+
+ Result<Memory> finish();
+
+ private:
+ uint32_t mPoolIndex;
+ size_t mSize = 0;
+};
+
+class ConstantMemoryBuilder {
+ public:
+ explicit ConstantMemoryBuilder(uint32_t poolIndex);
+
+ DataLocation append(const void* data, size_t length);
+ bool empty() const;
+
+ Result<Memory> finish();
+
+ private:
+ struct LazyCopy {
+ const void* data;
+ size_t length;
+ size_t offset;
+ };
+
+ MutableMemoryBuilder mBuilder;
+ std::vector<LazyCopy> mSlices;
+};
+
+Result<Memory> createSharedMemory(size_t size);
+
+Result<Memory> createSharedMemoryFromFd(size_t size, int prot, int fd, size_t offset);
+
+Result<Memory> createSharedMemoryFromHidlMemory(const hardware::hidl_memory& memory);
+
+Result<Memory> createSharedMemoryFromAHWB(const AHardwareBuffer& ahwb);
+
+struct Mapping {
+ std::variant<void*, const void*> pointer;
+ size_t size;
+ std::any context;
+};
+
+Result<Mapping> map(const Memory& memory);
+
+bool flush(const Mapping& mapping);
+
+} // namespace android::nn
+
+#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_SHARED_MEMORY_H
diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h
new file mode 100644
index 000000000..58021a780
--- /dev/null
+++ b/nn/common/include/nnapi/TypeUtils.h
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
+#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
+
+#include <ostream>
+#include <utility>
+#include <vector>
+
+#include "nnapi/OperandTypes.h"
+#include "nnapi/OperationTypes.h"
+#include "nnapi/Result.h"
+#include "nnapi/Types.h"
+
+namespace android::nn {
+
+bool isExtension(OperandType type);
+bool isExtension(OperationType type);
+
+bool isNonExtensionScalar(OperandType operandType);
+
+size_t getNonExtensionSize(OperandType operandType);
+
+std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions);
+std::optional<size_t> getNonExtensionSize(const Operand& operand);
+
+size_t getOffsetFromInts(int lower, int higher);
+std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset);
+
+std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands,
+ const std::vector<nn::Operation>& operations);
+
+// Combine two tensor dimensions, both may have unspecified dimensions or rank.
+Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs);
+
+// Set of output utility functions.
+std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus);
+std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference);
+std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType);
+std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming);
+std::ostream& operator<<(std::ostream& os, const OperandType& operandType);
+std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime);
+std::ostream& operator<<(std::ostream& os, const OperationType& operationType);
+std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime);
+std::ostream& operator<<(std::ostream& os, const Priority& priority);
+std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus);
+std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape);
+std::ostream& operator<<(std::ostream& os, const Timing& timing);
+std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo);
+std::ostream& operator<<(std::ostream& os,
+ const Capabilities::OperandPerformance& operandPerformance);
+std::ostream& operator<<(std::ostream& os,
+ const Capabilities::OperandPerformanceTable& operandPerformances);
+std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities);
+std::ostream& operator<<(std::ostream& os,
+ const Extension::OperandTypeInformation& operandTypeInformation);
+std::ostream& operator<<(std::ostream& os, const Extension& extension);
+std::ostream& operator<<(std::ostream& os, const DataLocation& location);
+std::ostream& operator<<(std::ostream& os,
+ const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams);
+std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams);
+std::ostream& operator<<(std::ostream& os, const Operand& operand);
+std::ostream& operator<<(std::ostream& os, const Operation& operation);
+std::ostream& operator<<(std::ostream& os, const NativeHandle& handle);
+std::ostream& operator<<(std::ostream& os, const Memory& memory);
+std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph);
+std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues);
+std::ostream& operator<<(std::ostream& os,
+ const Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
+std::ostream& operator<<(std::ostream& os, const Model& model);
+std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc);
+std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole);
+std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument);
+std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool);
+std::ostream& operator<<(std::ostream& os, const Request& request);
+std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint);
+std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint);
+std::ostream& operator<<(std::ostream& os, const TimeoutDuration& timeoutDuration);
+std::ostream& operator<<(std::ostream& os, const OptionalTimeoutDuration& optionalTimeoutDuration);
+std::ostream& operator<<(std::ostream& os, const Version& version);
+
+bool operator==(const Timing& a, const Timing& b);
+bool operator!=(const Timing& a, const Timing& b);
+bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
+bool operator==(const Capabilities::OperandPerformance& a,
+ const Capabilities::OperandPerformance& b);
+bool operator==(const Capabilities& a, const Capabilities& b);
+bool operator==(const Extension::OperandTypeInformation& a,
+ const Extension::OperandTypeInformation& b);
+bool operator==(const Extension& a, const Extension& b);
+bool operator==(const Operand::SymmPerChannelQuantParams& a,
+ const Operand::SymmPerChannelQuantParams& b);
+bool operator!=(const Operand::SymmPerChannelQuantParams& a,
+ const Operand::SymmPerChannelQuantParams& b);
+
+} // namespace android::nn
+
+#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
diff --git a/nn/common/include/nnapi/Types.h b/nn/common/include/nnapi/Types.h
index 0536a2e2b..c0718dae1 100644
--- a/nn/common/include/nnapi/Types.h
+++ b/nn/common/include/nnapi/Types.h
@@ -32,6 +32,7 @@
#include "nnapi/OperandTypes.h"
#include "nnapi/OperationTypes.h"
+#include "nnapi/Result.h"
namespace android::nn {
@@ -141,7 +142,7 @@ struct Capabilities {
};
class OperandPerformanceTable {
public:
- static std::optional<OperandPerformanceTable> create(
+ static Result<OperandPerformanceTable> create(
std::vector<OperandPerformance> operandPerformances);
PerformanceInfo lookup(OperandType type) const;
@@ -290,7 +291,7 @@ using OptionalTimePoint = std::optional<TimePoint>;
using TimeoutDuration = std::chrono::nanoseconds;
using OptionalTimeoutDuration = std::optional<TimeoutDuration>;
-enum class Version { ANDROID_OC_MR1, ANDROID_P, ANDROID_Q, ANDROID_R, CURRENT_RUNTIME, INVALID };
+enum class Version { ANDROID_OC_MR1, ANDROID_P, ANDROID_Q, ANDROID_R, CURRENT_RUNTIME };
} // namespace android::nn
diff --git a/nn/common/include/nnapi/Validation.h b/nn/common/include/nnapi/Validation.h
new file mode 100644
index 000000000..61a2a273c
--- /dev/null
+++ b/nn/common/include/nnapi/Validation.h
@@ -0,0 +1,97 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H
+#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H
+
+#include <memory>
+#include <set>
+#include <tuple>
+#include <vector>
+
+#include "nnapi/Result.h"
+#include "nnapi/Types.h"
+
+namespace android::nn {
+
+// Utility functions
+
+Version combineVersions(Version lhs, Version rhs);
+
+Result<Version> validate(const DeviceStatus& deviceStatus);
+Result<Version> validate(const ExecutionPreference& executionPreference);
+Result<Version> validate(const DeviceType& deviceType);
+Result<Version> validate(const MeasureTiming& measureTiming);
+Result<Version> validate(const Priority& priority);
+Result<Version> validate(const ErrorStatus& errorStatus);
+Result<Version> validate(const OutputShape& outputShape);
+Result<Version> validate(const Timing& timing);
+Result<Version> validate(const Capabilities& capabilities);
+Result<Version> validate(const Extension& extension);
+Result<Version> validate(const NativeHandle& handle);
+Result<Version> validate(const Memory& memory);
+Result<Version> validate(const Model& model);
+Result<Version> validate(const BufferDesc& bufferDesc);
+Result<Version> validate(const BufferRole& bufferRole);
+Result<Version> validate(const Request& request);
+Result<Version> validate(const OptionalTimePoint& optionalTimePoint);
+Result<Version> validate(const OptionalTimeoutDuration& optionalTimeoutDuration);
+
+Result<Version> validate(const std::vector<OutputShape>& outputShapes);
+Result<Version> validate(const std::vector<Extension>& extensions);
+Result<Version> validate(const std::vector<NativeHandle>& handles);
+Result<Version> validate(const std::vector<BufferRole>& bufferRoles);
+
+// Validate request applied to model.
+Result<Version> validateRequestForModel(const Request& request, const Model& model);
+
+// Validate memory descriptor.
+enum class IOType { INPUT, OUTPUT };
+using PreparedModelRole = std::tuple<const IPreparedModel*, IOType, uint32_t>;
+
+// Verifies that the input arguments to IDevice::allocate are valid.
+// Optionally, this function can return a flattened prepared model roles and a combined operand.
+// Pass nullptr if either value is not needed.
+// IMPORTANT: This function cannot validate dimensions and extraParams with extension operand type.
+// Each driver should do their own validation of extension type dimensions and extraParams.
+Result<Version> validateMemoryDesc(
+ const BufferDesc& desc,
+ const std::vector<std::shared_ptr<const IPreparedModel>>& preparedModels,
+ const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
+ const std::function<const Model*(const std::shared_ptr<const IPreparedModel>&)>& getModel,
+ std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand);
+
+Result<void> validateOperandSymmPerChannelQuantParams(
+ const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
+ const char* tag);
+
+// Validates an operand type.
+//
+// extensionOperandTypeInfo must be nullptr iff the type is not an extension type.
+//
+// If allowPartial is true, the dimensions may be underspecified.
+Result<void> validateOperandType(const Operand& type,
+ const Extension::OperandTypeInformation* extensionOperandTypeInfo,
+ const char* tag, bool allowPartial);
+Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount,
+ const char* tag);
+
+Result<Version> validateOperation(const Operation& operation, const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs);
+
+} // namespace android::nn
+
+#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H