diff options
author | Michael Butler <butlermichael@google.com> | 2020-10-06 21:22:49 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2020-10-06 21:22:49 +0000 |
commit | e238d86505854623f9a39bf480b1aa5c78151f61 (patch) | |
tree | 3450bcddea95ab9d30968a372a07c05be7a4370e /nn/common | |
parent | 9a4f318c3827ad986bccbc8576d79f73f488ea9e (diff) | |
parent | bfef4a044a9176fdb0c1c5ac1fd90e67e2022b7a (diff) | |
download | ml-e238d86505854623f9a39bf480b1aa5c78151f61.tar.gz |
Merge changes from topic "nnapi-canonical-types"
* changes:
Introduce Result type into NNAPI
Implement canonical types in NNAPI
Diffstat (limited to 'nn/common')
-rw-r--r-- | nn/common/Android.bp | 21 | ||||
-rw-r--r-- | nn/common/SharedMemory.cpp | 109 | ||||
-rw-r--r-- | nn/common/SharedMemoryAndroid.cpp | 278 | ||||
-rw-r--r-- | nn/common/SharedMemoryHost.cpp | 161 | ||||
-rw-r--r-- | nn/common/TypeUtils.cpp | 849 | ||||
-rw-r--r-- | nn/common/Types.cpp | 8 | ||||
-rw-r--r-- | nn/common/Validation.cpp | 2664 | ||||
-rw-r--r-- | nn/common/include/nnapi/Result.h | 147 | ||||
-rw-r--r-- | nn/common/include/nnapi/SharedMemory.h | 93 | ||||
-rw-r--r-- | nn/common/include/nnapi/TypeUtils.h | 112 | ||||
-rw-r--r-- | nn/common/include/nnapi/Types.h | 5 | ||||
-rw-r--r-- | nn/common/include/nnapi/Validation.h | 97 |
12 files changed, 4537 insertions, 7 deletions
diff --git a/nn/common/Android.bp b/nn/common/Android.bp index 65d993247..202c69ed1 100644 --- a/nn/common/Android.bp +++ b/nn/common/Android.bp @@ -243,8 +243,27 @@ cc_library_static { name: "neuralnetworks_types", defaults: ["neuralnetworks_utils_defaults"], srcs: [ + "SharedMemory.cpp", + "TypeUtils.cpp", "Types.cpp", - ], + "Validation.cpp", + ], + target: { + android: { + srcs: ["SharedMemoryAndroid.cpp"], + shared_libs: [ + "android.hidl.allocator@1.0", + "android.hidl.memory@1.0", + "libhidlbase", + "libhidlmemory", + "libnativewindow", + ], + static_libs: ["libarect"], + }, + host: { + srcs: ["SharedMemoryHost.cpp"], + }, + }, local_include_dirs: ["include/nnapi"], export_include_dirs: ["include"], shared_libs: [ diff --git a/nn/common/SharedMemory.cpp b/nn/common/SharedMemory.cpp new file mode 100644 index 000000000..9186be20e --- /dev/null +++ b/nn/common/SharedMemory.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <android-base/logging.h> + +#include <limits> +#include <optional> +#include <utility> +#include <variant> +#include <vector> + +#include "Result.h" +#include "SharedMemory.h" +#include "Types.h" + +namespace android::nn { +namespace { + +constexpr size_t safeDivideRoundedUp(size_t numerator, size_t denominator) { + CHECK_NE(denominator, 0u); + CHECK_LE(numerator, std::numeric_limits<size_t>::max() - denominator); + return (numerator + denominator - 1) / denominator; +} + +constexpr size_t safeMultiply(size_t a, size_t b) { + if (b == 0) { + return 0; + } + CHECK_LE(a, std::numeric_limits<size_t>::max() / b); + return a * b; +} + +constexpr size_t extendToAlignment(size_t length) { + constexpr size_t kMaxAlignment = alignof(AlignedData); + return safeMultiply(safeDivideRoundedUp(length, kMaxAlignment), kMaxAlignment); +} + +} // namespace + +MutableMemoryBuilder::MutableMemoryBuilder(uint32_t poolIndex) : mPoolIndex(poolIndex) {} + +DataLocation MutableMemoryBuilder::append(size_t length) { + CHECK_GT(length, 0u); + const size_t offset = mSize; + mSize += extendToAlignment(length); + CHECK_LE(offset, std::numeric_limits<uint32_t>::max()); + CHECK_LE(length, std::numeric_limits<uint32_t>::max()); + return {.poolIndex = mPoolIndex, + .offset = static_cast<uint32_t>(offset), + .length = static_cast<uint32_t>(length)}; +} + +bool MutableMemoryBuilder::empty() const { + return mSize == 0; +} + +Result<Memory> MutableMemoryBuilder::finish() { + return createSharedMemory(mSize); +} + +ConstantMemoryBuilder::ConstantMemoryBuilder(uint32_t poolIndex) : mBuilder(poolIndex) {} + +DataLocation ConstantMemoryBuilder::append(const void* data, size_t length) { + const auto location = mBuilder.append(length); + CHECK_EQ(location.length, length); + mSlices.push_back({.data = data, .length = length, .offset = location.offset}); + return location; +} + +bool ConstantMemoryBuilder::empty() const { + return mBuilder.empty(); +} + +Result<Memory> ConstantMemoryBuilder::finish() { + // Allocate the memory. + auto memory = NN_TRY(mBuilder.finish()); + + // Map the memory. + const auto [pointer, size, context] = NN_TRY(map(memory);); + + // Get mutable pointer. + if (!std::holds_alternative<void*>(pointer)) { + return NN_ERROR() + << "MemoryBuilder::finish failed because the mapped pointer is not mutable"; + } + uint8_t* mutablePointer = static_cast<uint8_t*>(std::get<void*>(pointer)); + + // Copy data to the memory pool. + std::for_each(mSlices.begin(), mSlices.end(), [mutablePointer](const auto& slice) { + std::memcpy(mutablePointer + slice.offset, slice.data, slice.length); + }); + + return memory; +} + +} // namespace android::nn diff --git a/nn/common/SharedMemoryAndroid.cpp b/nn/common/SharedMemoryAndroid.cpp new file mode 100644 index 000000000..caa83a6e8 --- /dev/null +++ b/nn/common/SharedMemoryAndroid.cpp @@ -0,0 +1,278 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <android-base/logging.h> +#include <android-base/mapped_file.h> +#include <android-base/scopeguard.h> +#include <android/hardware_buffer.h> +#include <android/hidl/allocator/1.0/IAllocator.h> +#include <cutils/native_handle.h> +#include <hidl/HidlSupport.h> +#include <hidlmemory/mapping.h> +#include <sys/mman.h> +#include <vndk/hardware_buffer.h> + +#include <any> +#include <limits> +#include <memory> +#include <string> +#include <utility> +#include <variant> +#include <vector> + +#include "Result.h" +#include "SharedMemory.h" +#include "TypeUtils.h" +#include "Types.h" + +namespace android::nn { +namespace { + +using ::android::hardware::hidl_memory; +using ::android::hidl::allocator::V1_0::IAllocator; + +const char* const kAllocatorService = "ashmem"; + +Result<hidl_memory> allocateSharedMemory(size_t size) { + static const auto allocator = IAllocator::getService(kAllocatorService); + CHECK_GT(size, 0u); + + hidl_memory maybeMemory; + auto fn = [&maybeMemory](bool success, const hidl_memory& memory) { + if (success) { + maybeMemory = memory; + } + }; + allocator->allocate(size, fn); + + if (!maybeMemory.valid()) { + return NN_ERROR() << "IAllocator::allocate returned an invalid (empty) memory object"; + } + + return maybeMemory; +} + +Memory createMemory(const hidl_memory& memory) { + CHECK_LE(memory.size(), std::numeric_limits<uint32_t>::max()); + + auto* cloned = native_handle_clone(memory.handle()); + auto nativeHandle = ::android::NativeHandle::create(cloned, /*ownsHandle=*/true); + + return { + .handle = std::move(nativeHandle), + .size = static_cast<uint32_t>(memory.size()), + .name = memory.name(), + }; +} + +hidl_memory createHidlMemory(const Memory& memory) { + const auto hidlMemory = hidl_memory(memory.name, memory.handle->handle(), memory.size); + // Copy memory to force the native_handle_t to be copied. + auto copiedMemory = hidlMemory; + return copiedMemory; +} + +Result<Mapping> mapAshmem(const Memory& memory) { + const auto hidlMemory = createHidlMemory(memory); + const auto mapping = mapMemory(hidlMemory); + if (mapping == nullptr) { + return NN_ERROR() << "Failed to map memory"; + } + auto* const pointer = mapping->getPointer().withDefault(nullptr); + if (pointer == nullptr) { + return NN_ERROR() << "Failed to get the mapped pointer"; + } + const auto fullSize = mapping->getSize().withDefault(0); + if (fullSize == 0 || fullSize > std::numeric_limits<size_t>::max()) { + return NN_ERROR() << "Failed to get the mapped size"; + } + const size_t size = static_cast<size_t>(fullSize); + + return Mapping{ + .pointer = pointer, + .size = size, + .context = mapping, + }; +} + +struct MmapFdMappingContext { + int prot; + std::any context; +}; + +Result<Mapping> mapMemFd(const Memory& memory) { + const size_t size = memory.size; + const native_handle_t* handle = memory.handle->handle(); + const int fd = handle->data[0]; + const int prot = handle->data[1]; + const size_t offset = getOffsetFromInts(handle->data[2], handle->data[3]); + + std::shared_ptr<base::MappedFile> mapping = base::MappedFile::FromFd(fd, offset, size, prot); + if (mapping == nullptr) { + return NN_ERROR() << "Can't mmap the file descriptor."; + } + void* data = mapping->data(); + + auto context = MmapFdMappingContext{.prot = prot, .context = std::move(mapping)}; + return Mapping{.pointer = data, .size = size, .context = std::move(context)}; +} + +Result<Mapping> mapAhwbBlobMemory(const Memory& memory) { + const auto* handle = memory.handle->handle(); + const auto size = memory.size; + const auto format = AHARDWAREBUFFER_FORMAT_BLOB; + const auto usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + const uint32_t width = size; + const uint32_t height = 1; // height is always 1 for BLOB mode AHardwareBuffer. + const uint32_t layers = 1; // layers is always 1 for BLOB mode AHardwareBuffer. + const uint32_t stride = size; + + AHardwareBuffer_Desc desc{ + .width = width, + .height = height, + .layers = layers, + .format = format, + .usage = usage, + .stride = stride, + }; + + AHardwareBuffer* hardwareBuffer = nullptr; + status_t status = AHardwareBuffer_createFromHandle( + &desc, handle, AHARDWAREBUFFER_CREATE_FROM_HANDLE_METHOD_CLONE, &hardwareBuffer); + if (status != NO_ERROR) { + return NN_ERROR() << "Can't create AHardwareBuffer from handle. Error: " << status; + } + + void* data = nullptr; + status = AHardwareBuffer_lock(hardwareBuffer, usage, -1, nullptr, &data); + if (status != NO_ERROR) { + return NN_ERROR() << "Can't lock the AHardwareBuffer. Error: " << status; + // TODO(b/169166682): do we need to call AHardwareBuffer_release? + } + + // Create shared scoped object to munmap. + auto scoped = base::make_scope_guard([hardwareBuffer] { + AHardwareBuffer_unlock(hardwareBuffer, nullptr); + if (hardwareBuffer != nullptr) { + AHardwareBuffer_release(hardwareBuffer); + } + }); + auto sharedScoped = std::make_shared<decltype(scoped)>(std::move(scoped)); + + return Mapping{.pointer = data, .size = size, .context = std::move(sharedScoped)}; +} + +Result<Mapping> mapAhwbMemory(const Memory& /*memory*/) { + return NN_ERROR() << "Unable to map non-BLOB AHardwareBuffer memory"; +} + +} // namespace + +Result<Memory> createSharedMemory(size_t size) { + const auto memory = NN_TRY(allocateSharedMemory(size)); + return createSharedMemoryFromHidlMemory(memory); +} + +Result<Memory> createSharedMemoryFromFd(size_t size, int prot, int fd, size_t offset) { + if (size == 0 || fd < 0) { + return NN_ERROR() << "Invalid size or fd"; + } + + // Duplicate the file descriptor so the resultant Memory owns its own version. + int dupfd = dup(fd); + if (dupfd == -1) { + // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here? + return NN_ERROR() << "Failed to dup the fd"; + } + + // Create a temporary native handle to own the dupfd. + native_handle_t* nativeHandle = native_handle_create(1, 3); + if (nativeHandle == nullptr) { + close(dupfd); + // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here? + return NN_ERROR() << "Failed to create native_handle"; + } + + const auto [lowOffsetBits, highOffsetBits] = getIntsFromOffset(offset); + nativeHandle->data[0] = dupfd; + nativeHandle->data[1] = prot; + nativeHandle->data[2] = lowOffsetBits; + nativeHandle->data[3] = highOffsetBits; + + // Create a NativeHandle which owns the native handle and fd so that we don't have to manually + // clean either the native handle or the fd. + auto ownedHandle = ::android::NativeHandle::create(nativeHandle, /*ownsHandle=*/true); + + return Memory{.handle = std::move(ownedHandle), .size = size, .name = "mmap_fd"}; +} + +Result<Memory> createSharedMemoryFromHidlMemory(const hardware::hidl_memory& memory) { + return createMemory(memory); +} + +Result<Memory> createSharedMemoryFromAHWB(const AHardwareBuffer& ahwb) { + AHardwareBuffer_Desc bufferDesc; + AHardwareBuffer_describe(&ahwb, &bufferDesc); + const native_handle_t* handle = AHardwareBuffer_getNativeHandle(&ahwb); + + auto* cloned = native_handle_clone(handle); + auto nativeHandle = ::android::NativeHandle::create(cloned, /*ownsHandle=*/true); + + if (bufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) { + return Memory{ + .handle = std::move(nativeHandle), + .size = bufferDesc.width, + .name = "hardware_buffer_blob", + }; + } + + // memory size is not used for non-BLOB AHWB memory. + return Memory{.handle = std::move(nativeHandle), .size = 0, .name = "hardware_buffer"}; +} + +Result<Mapping> map(const Memory& memory) { + if (memory.name == "ashmem") { + return mapAshmem(memory); + } + if (memory.name == "mmap_fd") { + return mapMemFd(memory); + } + if (memory.name == "hardware_buffer_blob") { + return mapAhwbBlobMemory(memory); + } + if (memory.name == "hardware_buffer") { + return mapAhwbMemory(memory); + } + return NN_ERROR() << "Cannot map unknown memory " << memory.name; +} + +bool flush(const Mapping& mapping) { + if (const auto* mmapFdMapping = std::any_cast<MmapFdMappingContext>(&mapping.context)) { + if (!std::holds_alternative<void*>(mapping.pointer)) { + return true; + } + void* data = std::get<void*>(mapping.pointer); + const int prot = mmapFdMapping->prot; + if (prot & PROT_WRITE) { + const size_t size = mapping.size; + return msync(data, size, MS_SYNC) == 0; + } + } + // No-op for other types of memory. + return true; +} + +} // namespace android::nn diff --git a/nn/common/SharedMemoryHost.cpp b/nn/common/SharedMemoryHost.cpp new file mode 100644 index 000000000..bc29d1ff1 --- /dev/null +++ b/nn/common/SharedMemoryHost.cpp @@ -0,0 +1,161 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <android-base/logging.h> +#include <android-base/mapped_file.h> +#include <cutils/ashmem.h> +#include <cutils/native_handle.h> +#include <sys/mman.h> + +#include <limits> +#include <memory> +#include <utility> + +#include "Result.h" +#include "SharedMemory.h" +#include "TypeUtils.h" +#include "Types.h" + +namespace android::nn { +namespace { + +Result<Mapping> mapAshmem(const Memory& memory) { + CHECK_LE(memory.size, std::numeric_limits<uint32_t>::max()); + const auto size = memory.size; + + int fd = memory.handle->handle()->data[0]; + std::shared_ptr<base::MappedFile> mapping = + base::MappedFile::FromFd(fd, /*offset=*/0, size, PROT_READ | PROT_WRITE); + if (mapping == nullptr) { + return NN_ERROR() << "Can't mmap the file descriptor."; + } + void* data = mapping->data(); + + return Mapping{.pointer = data, .size = size, .context = std::move(mapping)}; +} + +struct MmapFdMappingContext { + int prot; + std::any context; +}; + +Result<Mapping> mapMemFd(const Memory& memory) { + const size_t size = memory.size; + const native_handle_t* handle = memory.handle->handle(); + const int fd = handle->data[0]; + const int prot = handle->data[1]; + const size_t offset = getOffsetFromInts(handle->data[2], handle->data[3]); + + std::shared_ptr<base::MappedFile> mapping = base::MappedFile::FromFd(fd, offset, size, prot); + if (mapping == nullptr) { + return NN_ERROR() << "Can't mmap the file descriptor."; + } + void* data = mapping->data(); + + auto context = MmapFdMappingContext{.prot = prot, .context = std::move(mapping)}; + return Mapping{.pointer = data, .size = size, .context = std::move(context)}; +} + +} // namespace + +Result<Memory> createSharedMemory(size_t size) { + int fd = ashmem_create_region("NnapiAshmem", size); + if (fd < 0) { + return NN_ERROR() << "ashmem_create_region(" << size << ") fails with " << fd; + } + + native_handle_t* handle = native_handle_create(1, 0); + if (handle == nullptr) { + // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here? + return NN_ERROR() << "Failed to create native_handle"; + } + handle->data[0] = fd; + + // Create a NativeHandle which owns the native handle and fd so that we don't have to manually + // clean either the native handle or the fd. + auto nativeHandle = ::android::NativeHandle::create(handle, /*ownsHandle=*/true); + + return Memory{.handle = std::move(nativeHandle), .size = size, .name = "ashmem"}; +} + +Result<Memory> createSharedMemoryFromFd(size_t size, int prot, int fd, size_t offset) { + if (size == 0 || fd < 0) { + return NN_ERROR() << "Invalid size or fd"; + } + + // Duplicate the file descriptor so the resultant Memory owns its own version. + int dupfd = dup(fd); + if (dupfd == -1) { + // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here? + return NN_ERROR() << "Failed to dup the fd"; + } + + // Create a temporary native handle to own the dupfd. + native_handle_t* nativeHandle = native_handle_create(1, 3); + if (nativeHandle == nullptr) { + close(dupfd); + // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return here? + return NN_ERROR() << "Failed to create native_handle"; + } + + const auto [lowOffsetBits, highOffsetBits] = getIntsFromOffset(offset); + nativeHandle->data[0] = dupfd; + nativeHandle->data[1] = prot; + nativeHandle->data[2] = lowOffsetBits; + nativeHandle->data[3] = highOffsetBits; + + // Create a NativeHandle which owns the native handle and fd so that we don't have to manually + // clean either the native handle or the fd. + auto ownedHandle = ::android::NativeHandle::create(nativeHandle, /*ownsHandle=*/true); + + return Memory{.handle = std::move(ownedHandle), .size = size, .name = "mmap_fd"}; +} + +Result<Memory> createSharedMemoryFromHidlMemory(const hardware::hidl_memory& /*memory*/) { + return NN_ERROR() << "hidl_memory not supported on host"; +} + +Result<Memory> createSharedMemoryFromAHWB(const AHardwareBuffer& /*ahwb*/) { + return NN_ERROR() << "AHardwareBuffer memory not supported on host"; +} + +Result<Mapping> map(const Memory& memory) { + if (memory.name == "ashmem") { + return mapAshmem(memory); + } + if (memory.name == "mmap_fd") { + return mapMemFd(memory); + } + return NN_ERROR() << "Cannot map unknown memory " << memory.name; +} + +bool flush(const Mapping& mapping) { + if (const auto* mmapFdMapping = std::any_cast<MmapFdMappingContext>(&mapping.context)) { + if (!std::holds_alternative<void*>(mapping.pointer)) { + return true; + } + void* data = std::get<void*>(mapping.pointer); + const int prot = mmapFdMapping->prot; + if (prot & PROT_WRITE) { + const size_t size = mapping.size; + return msync(data, size, MS_SYNC) == 0; + } + } + // No-op for other types of memory. + return true; +} + +} // namespace android::nn diff --git a/nn/common/TypeUtils.cpp b/nn/common/TypeUtils.cpp new file mode 100644 index 000000000..d997f65a2 --- /dev/null +++ b/nn/common/TypeUtils.cpp @@ -0,0 +1,849 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TypeUtils.h" + +#include <android-base/logging.h> + +#include <chrono> +#include <limits> +#include <memory> +#include <ostream> +#include <type_traits> +#include <utility> +#include <vector> + +#include "OperandTypes.h" +#include "OperationTypes.h" +#include "Result.h" +#include "Types.h" + +namespace android::nn { +namespace { + +template <typename Type> +constexpr std::underlying_type_t<Type> underlyingType(Type object) { + return static_cast<std::underlying_type_t<Type>>(object); +} + +uint16_t getExtensionPrefix(uint32_t type) { + return static_cast<uint16_t>(type >> kExtensionTypeBits); +} + +template <typename Type> +std::ostream& operator<<(std::ostream& os, const std::vector<Type>& vec) { + constexpr size_t kMaxVectorPrint = 20; + os << "["; + size_t count = 0; + for (const auto& element : vec) { + if (count > 0) { + os << ", "; + } + os << element; + count++; + if (count >= kMaxVectorPrint) { + return os << "...]"; + } + } + return os << "]"; +} + +} // namespace + +bool isExtension(OperandType type) { + return getExtensionPrefix(underlyingType(type)) != 0; +} + +bool isExtension(OperationType type) { + return getExtensionPrefix(underlyingType(type)) != 0; +} + +bool isNonExtensionScalar(OperandType operandType) { + CHECK(!isExtension(operandType)); + switch (operandType) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::BOOL: + case OperandType::FLOAT16: + case OperandType::SUBGRAPH: + case OperandType::OEM: + return true; + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::TENSOR_QUANT16_SYMM: + case OperandType::TENSOR_FLOAT16: + case OperandType::TENSOR_BOOL8: + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case OperandType::TENSOR_QUANT16_ASYMM: + case OperandType::TENSOR_QUANT8_SYMM: + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + case OperandType::TENSOR_OEM_BYTE: + return false; + } + return false; +} + +size_t getNonExtensionSize(OperandType operandType) { + CHECK(!isExtension(operandType)); + switch (operandType) { + case OperandType::SUBGRAPH: + case OperandType::OEM: + case OperandType::TENSOR_OEM_BYTE: + return 0; + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::BOOL: + case OperandType::TENSOR_BOOL8: + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case OperandType::TENSOR_QUANT8_SYMM: + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + return 1; + case OperandType::TENSOR_QUANT16_SYMM: + case OperandType::TENSOR_FLOAT16: + case OperandType::FLOAT16: + case OperandType::TENSOR_QUANT16_ASYMM: + return 2; + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + return 4; + } + return 0; +} + +std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions) { + CHECK(!isExtension(operandType)) << "Size of extension operand data is unknown"; + size_t size = getNonExtensionSize(operandType); + if (isNonExtensionScalar(operandType)) { + return size; + } else if (dimensions.empty()) { + return 0; + } + for (Dimension dimension : dimensions) { + if (dimension != 0 && size > std::numeric_limits<size_t>::max() / dimension) { + return std::nullopt; + } + size *= dimension; + } + return size; +} + +std::optional<size_t> getNonExtensionSize(const Operand& operand) { + return getNonExtensionSize(operand.type, operand.dimensions); +} + +size_t getOffsetFromInts(int lower, int higher) { + const int32_t lowBits = static_cast<int32_t>(lower); + const int32_t highBits = static_cast<int32_t>(higher); + const uint32_t lowOffsetBits = *reinterpret_cast<const uint32_t*>(&lowBits); + const uint32_t highOffsetBits = *reinterpret_cast<const uint32_t*>(&highBits); + const uint64_t offset = lowOffsetBits | (static_cast<uint64_t>(highOffsetBits) << 32); + return offset; +} + +std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset) { + const uint64_t bits = static_cast<uint64_t>(offset); + const uint32_t lowBits = static_cast<uint32_t>(bits & 0xffffffff); + const uint32_t highBits = static_cast<uint32_t>(bits >> 32); + const int32_t lowOffsetBits = *reinterpret_cast<const int32_t*>(&lowBits); + const int32_t highOffsetBits = *reinterpret_cast<const int32_t*>(&highBits); + return std::make_pair(lowOffsetBits, highOffsetBits); +} + +std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands, + const std::vector<nn::Operation>& operations) { + std::vector<uint32_t> numberOfConsumers(numberOfOperands, 0); + auto eachOperandIndex = [&numberOfConsumers](uint32_t operandIndex) { + numberOfConsumers.at(operandIndex)++; + }; + auto eachOperation = [&eachOperandIndex](const nn::Operation& operation) { + std::for_each(operation.inputs.begin(), operation.inputs.end(), eachOperandIndex); + }; + std::for_each(operations.begin(), operations.end(), eachOperation); + return numberOfConsumers; +} + +Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs) { + if (rhs.empty()) return lhs; + if (lhs.empty()) return rhs; + if (lhs.size() != rhs.size()) { + std::ostringstream os; + os << "Incompatible ranks: " << lhs << " and " << rhs; + return NN_ERROR() << os.str(); + } + Dimensions combined = lhs; + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] == 0) { + combined[i] = rhs[i]; + } else if (rhs[i] != 0 && lhs[i] != rhs[i]) { + std::ostringstream os; + os << "Incompatible dimensions: " << lhs << " and " << rhs; + return NN_ERROR() << os.str(); + } + } + return combined; +} + +std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus) { + switch (deviceStatus) { + case DeviceStatus::AVAILABLE: + return os << "AVAILABLE"; + case DeviceStatus::BUSY: + return os << "BUSY"; + case DeviceStatus::OFFLINE: + return os << "OFFLINE"; + case DeviceStatus::UNKNOWN: + return os << "UNKNOWN"; + } + return os << "DeviceStatus{" << underlyingType(deviceStatus) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference) { + switch (executionPreference) { + case ExecutionPreference::LOW_POWER: + return os << "LOW_POWER"; + case ExecutionPreference::FAST_SINGLE_ANSWER: + return os << "FAST_SINGLE_ANSWER"; + case ExecutionPreference::SUSTAINED_SPEED: + return os << "SUSTAINED_SPEED"; + } + return os << "ExecutionPreference{" << underlyingType(executionPreference) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType) { + switch (deviceType) { + case DeviceType::UNKNOWN: + return os << "UNKNOWN"; + case DeviceType::OTHER: + return os << "OTHER"; + case DeviceType::CPU: + return os << "CPU"; + case DeviceType::GPU: + return os << "GPU"; + case DeviceType::ACCELERATOR: + return os << "ACCELERATOR"; + } + return os << "DeviceType{" << underlyingType(deviceType) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming) { + switch (measureTiming) { + case MeasureTiming::NO: + return os << "NO"; + case MeasureTiming::YES: + return os << "YES"; + } + return os << "MeasureTiming{" << underlyingType(measureTiming) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const OperandType& operandType) { + switch (operandType) { + case OperandType::FLOAT32: + return os << "FLOAT32"; + case OperandType::INT32: + return os << "INT32"; + case OperandType::UINT32: + return os << "UINT32"; + case OperandType::TENSOR_FLOAT32: + return os << "TENSOR_FLOAT32"; + case OperandType::TENSOR_INT32: + return os << "TENSOR_INT32"; + case OperandType::TENSOR_QUANT8_ASYMM: + return os << "TENSOR_QUANT8_ASYMM"; + case OperandType::BOOL: + return os << "BOOL"; + case OperandType::TENSOR_QUANT16_SYMM: + return os << "TENSOR_QUANT16_SYMM"; + case OperandType::TENSOR_FLOAT16: + return os << "TENSOR_FLOAT16"; + case OperandType::TENSOR_BOOL8: + return os << "TENSOR_BOOL8"; + case OperandType::FLOAT16: + return os << "FLOAT16"; + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + return os << "TENSOR_QUANT8_SYMM_PER_CHANNEL"; + case OperandType::TENSOR_QUANT16_ASYMM: + return os << "TENSOR_QUANT16_ASYMM"; + case OperandType::TENSOR_QUANT8_SYMM: + return os << "TENSOR_QUANT8_SYMM"; + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + return os << "TENSOR_QUANT8_ASYMM_SIGNED"; + case OperandType::SUBGRAPH: + return os << "SUBGRAPH"; + case OperandType::OEM: + return os << "OEM"; + case OperandType::TENSOR_OEM_BYTE: + return os << "TENSOR_OEM_BYTE"; + } + if (isExtension(operandType)) { + return os << "Extension OperandType " << underlyingType(operandType); + } + return os << "OperandType{" << underlyingType(operandType) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime) { + switch (lifetime) { + case Operand::LifeTime::TEMPORARY_VARIABLE: + return os << "TEMPORARY_VARIABLE"; + case Operand::LifeTime::SUBGRAPH_INPUT: + return os << "SUBGRAPH_INPUT"; + case Operand::LifeTime::SUBGRAPH_OUTPUT: + return os << "SUBGRAPH_OUTPUT"; + case Operand::LifeTime::CONSTANT_COPY: + return os << "CONSTANT_COPY"; + case Operand::LifeTime::CONSTANT_REFERENCE: + return os << "CONSTANT_REFERENCE"; + case Operand::LifeTime::NO_VALUE: + return os << "NO_VALUE"; + case Operand::LifeTime::SUBGRAPH: + return os << "SUBGRAPH"; + case Operand::LifeTime::POINTER: + return os << "POINTER"; + } + return os << "Operand::LifeTime{" << underlyingType(lifetime) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const OperationType& operationType) { + switch (operationType) { + case OperationType::ADD: + return os << "ADD"; + case OperationType::AVERAGE_POOL_2D: + return os << "AVERAGE_POOL_2D"; + case OperationType::CONCATENATION: + return os << "CONCATENATION"; + case OperationType::CONV_2D: + return os << "CONV_2D"; + case OperationType::DEPTHWISE_CONV_2D: + return os << "DEPTHWISE_CONV_2D"; + case OperationType::DEPTH_TO_SPACE: + return os << "DEPTH_TO_SPACE"; + case OperationType::DEQUANTIZE: + return os << "DEQUANTIZE"; + case OperationType::EMBEDDING_LOOKUP: + return os << "EMBEDDING_LOOKUP"; + case OperationType::FLOOR: + return os << "FLOOR"; + case OperationType::FULLY_CONNECTED: + return os << "FULLY_CONNECTED"; + case OperationType::HASHTABLE_LOOKUP: + return os << "HASHTABLE_LOOKUP"; + case OperationType::L2_NORMALIZATION: + return os << "L2_NORMALIZATION"; + case OperationType::L2_POOL_2D: + return os << "L2_POOL_2D"; + case OperationType::LOCAL_RESPONSE_NORMALIZATION: + return os << "LOCAL_RESPONSE_NORMALIZATION"; + case OperationType::LOGISTIC: + return os << "LOGISTIC"; + case OperationType::LSH_PROJECTION: + return os << "LSH_PROJECTION"; + case OperationType::LSTM: + return os << "LSTM"; + case OperationType::MAX_POOL_2D: + return os << "MAX_POOL_2D"; + case OperationType::MUL: + return os << "MUL"; + case OperationType::RELU: + return os << "RELU"; + case OperationType::RELU1: + return os << "RELU1"; + case OperationType::RELU6: + return os << "RELU6"; + case OperationType::RESHAPE: + return os << "RESHAPE"; + case OperationType::RESIZE_BILINEAR: + return os << "RESIZE_BILINEAR"; + case OperationType::RNN: + return os << "RNN"; + case OperationType::SOFTMAX: + return os << "SOFTMAX"; + case OperationType::SPACE_TO_DEPTH: + return os << "SPACE_TO_DEPTH"; + case OperationType::SVDF: + return os << "SVDF"; + case OperationType::TANH: + return os << "TANH"; + case OperationType::BATCH_TO_SPACE_ND: + return os << "BATCH_TO_SPACE_ND"; + case OperationType::DIV: + return os << "DIV"; + case OperationType::MEAN: + return os << "MEAN"; + case OperationType::PAD: + return os << "PAD"; + case OperationType::SPACE_TO_BATCH_ND: + return os << "SPACE_TO_BATCH_ND"; + case OperationType::SQUEEZE: + return os << "SQUEEZE"; + case OperationType::STRIDED_SLICE: + return os << "STRIDED_SLICE"; + case OperationType::SUB: + return os << "SUB"; + case OperationType::TRANSPOSE: + return os << "TRANSPOSE"; + case OperationType::ABS: + return os << "ABS"; + case OperationType::ARGMAX: + return os << "ARGMAX"; + case OperationType::ARGMIN: + return os << "ARGMIN"; + case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM: + return os << "AXIS_ALIGNED_BBOX_TRANSFORM"; + case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: + return os << "BIDIRECTIONAL_SEQUENCE_LSTM"; + case OperationType::BIDIRECTIONAL_SEQUENCE_RNN: + return os << "BIDIRECTIONAL_SEQUENCE_RNN"; + case OperationType::BOX_WITH_NMS_LIMIT: + return os << "BOX_WITH_NMS_LIMIT"; + case OperationType::CAST: + return os << "CAST"; + case OperationType::CHANNEL_SHUFFLE: + return os << "CHANNEL_SHUFFLE"; + case OperationType::DETECTION_POSTPROCESSING: + return os << "DETECTION_POSTPROCESSING"; + case OperationType::EQUAL: + return os << "EQUAL"; + case OperationType::EXP: + return os << "EXP"; + case OperationType::EXPAND_DIMS: + return os << "EXPAND_DIMS"; + case OperationType::GATHER: + return os << "GATHER"; + case OperationType::GENERATE_PROPOSALS: + return os << "GENERATE_PROPOSALS"; + case OperationType::GREATER: + return os << "GREATER"; + case OperationType::GREATER_EQUAL: + return os << "GREATER_EQUAL"; + case OperationType::GROUPED_CONV_2D: + return os << "GROUPED_CONV_2D"; + case OperationType::HEATMAP_MAX_KEYPOINT: + return os << "HEATMAP_MAX_KEYPOINT"; + case OperationType::INSTANCE_NORMALIZATION: + return os << "INSTANCE_NORMALIZATION"; + case OperationType::LESS: + return os << "LESS"; + case OperationType::LESS_EQUAL: + return os << "LESS_EQUAL"; + case OperationType::LOG: + return os << "LOG"; + case OperationType::LOGICAL_AND: + return os << "LOGICAL_AND"; + case OperationType::LOGICAL_NOT: + return os << "LOGICAL_NOT"; + case OperationType::LOGICAL_OR: + return os << "LOGICAL_OR"; + case OperationType::LOG_SOFTMAX: + return os << "LOG_SOFTMAX"; + case OperationType::MAXIMUM: + return os << "MAXIMUM"; + case OperationType::MINIMUM: + return os << "MINIMUM"; + case OperationType::NEG: + return os << "NEG"; + case OperationType::NOT_EQUAL: + return os << "NOT_EQUAL"; + case OperationType::PAD_V2: + return os << "PAD_V2"; + case OperationType::POW: + return os << "POW"; + case OperationType::PRELU: + return os << "PRELU"; + case OperationType::QUANTIZE: + return os << "QUANTIZE"; + case OperationType::QUANTIZED_16BIT_LSTM: + return os << "QUANTIZED_16BIT_LSTM"; + case OperationType::RANDOM_MULTINOMIAL: + return os << "RANDOM_MULTINOMIAL"; + case OperationType::REDUCE_ALL: + return os << "REDUCE_ALL"; + case OperationType::REDUCE_ANY: + return os << "REDUCE_ANY"; + case OperationType::REDUCE_MAX: + return os << "REDUCE_MAX"; + case OperationType::REDUCE_MIN: + return os << "REDUCE_MIN"; + case OperationType::REDUCE_PROD: + return os << "REDUCE_PROD"; + case OperationType::REDUCE_SUM: + return os << "REDUCE_SUM"; + case OperationType::ROI_ALIGN: + return os << "ROI_ALIGN"; + case OperationType::ROI_POOLING: + return os << "ROI_POOLING"; + case OperationType::RSQRT: + return os << "RSQRT"; + case OperationType::SELECT: + return os << "SELECT"; + case OperationType::SIN: + return os << "SIN"; + case OperationType::SLICE: + return os << "SLICE"; + case OperationType::SPLIT: + return os << "SPLIT"; + case OperationType::SQRT: + return os << "SQRT"; + case OperationType::TILE: + return os << "TILE"; + case OperationType::TOPK_V2: + return os << "TOPK_V2"; + case OperationType::TRANSPOSE_CONV_2D: + return os << "TRANSPOSE_CONV_2D"; + case OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM: + return os << "UNIDIRECTIONAL_SEQUENCE_LSTM"; + case OperationType::UNIDIRECTIONAL_SEQUENCE_RNN: + return os << "UNIDIRECTIONAL_SEQUENCE_RNN"; + case OperationType::RESIZE_NEAREST_NEIGHBOR: + return os << "RESIZE_NEAREST_NEIGHBOR"; + case OperationType::QUANTIZED_LSTM: + return os << "QUANTIZED_LSTM"; + case OperationType::IF: + return os << "IF"; + case OperationType::WHILE: + return os << "WHILE"; + case OperationType::ELU: + return os << "ELU"; + case OperationType::HARD_SWISH: + return os << "HARD_SWISH"; + case OperationType::FILL: + return os << "FILL"; + case OperationType::RANK: + return os << "RANK"; + case OperationType::OEM_OPERATION: + return os << "OEM_OPERATION"; + } + if (isExtension(operationType)) { + return os << "Extension OperationType " << underlyingType(operationType); + } + return os << "OperationType{" << underlyingType(operationType) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime) { + switch (lifetime) { + case Request::Argument::LifeTime::POOL: + return os << "POOL"; + case Request::Argument::LifeTime::NO_VALUE: + return os << "NO_VALUE"; + case Request::Argument::LifeTime::POINTER: + return os << "POINTER"; + } + return os << "Request::Argument::LifeTime{" << underlyingType(lifetime) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Priority& priority) { + switch (priority) { + case Priority::LOW: + return os << "LOW"; + case Priority::MEDIUM: + return os << "MEDIUM"; + case Priority::HIGH: + return os << "HIGH"; + } + return os << "Priority{" << underlyingType(priority) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) { + switch (errorStatus) { + case ErrorStatus::NONE: + return os << "NONE"; + case ErrorStatus::DEVICE_UNAVAILABLE: + return os << "DEVICE_UNAVAILABLE"; + case ErrorStatus::GENERAL_FAILURE: + return os << "GENERAL_FAILURE"; + case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: + return os << "OUTPUT_INSUFFICIENT_SIZE"; + case ErrorStatus::INVALID_ARGUMENT: + return os << "INVALID_ARGUMENT"; + case ErrorStatus::MISSED_DEADLINE_TRANSIENT: + return os << "MISSED_DEADLINE_TRANSIENT"; + case ErrorStatus::MISSED_DEADLINE_PERSISTENT: + return os << "MISSED_DEADLINE_PERSISTENT"; + case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT: + return os << "RESOURCE_EXHAUSTED_TRANSIENT"; + case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT: + return os << "RESOURCE_EXHAUSTED_PERSISTENT"; + case ErrorStatus::DEAD_OBJECT: + return os << "DEAD_OBJECT"; + } + return os << "ErrorStatus{" << underlyingType(errorStatus) << "}"; +} + +std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) { + return os << "OutputShape{.dimensions=" << outputShape.dimensions + << ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Timing& timing) { + constexpr auto printTime = [](std::ostream& os, uint64_t nanoseconds) { + if (nanoseconds == std::numeric_limits<uint64_t>::max()) { + os << "<no time information provided>"; + } else { + os << nanoseconds << "ns"; + } + }; + os << "Timing{.timeOnDevice="; + printTime(os, timing.timeOnDevice); + os << ", .timeInDriver="; + printTime(os, timing.timeInDriver); + return os << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo) { + return os << "Capabilities::PerformanceInfo{.execTime=" << performanceInfo.execTime + << ", .powerUsage=" << performanceInfo.powerUsage << "}"; +} + +std::ostream& operator<<(std::ostream& os, + const Capabilities::OperandPerformance& operandPerformance) { + return os << "Capabilities::OperandPerformance{.type=" << operandPerformance.type + << ", .info=" << operandPerformance.info << "}"; +} + +std::ostream& operator<<(std::ostream& os, + const Capabilities::OperandPerformanceTable& operandPerformances) { + return os << operandPerformances.asVector(); +} + +std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities) { + return os << "Capabilities{.relaxedFloat32toFloat16PerformanceScalar=" + << capabilities.relaxedFloat32toFloat16PerformanceScalar + << ", .relaxedFloat32toFloat16PerformanceTensor=" + << capabilities.relaxedFloat32toFloat16PerformanceTensor + << ", .operandPerformance=" << capabilities.operandPerformance + << ", .ifPerformance=" << capabilities.ifPerformance + << ", .whilePerformance=" << capabilities.whilePerformance << "}"; +} + +std::ostream& operator<<(std::ostream& os, + const Extension::OperandTypeInformation& operandTypeInformation) { + return os << "Extension::OperandTypeInformation{.type=" << operandTypeInformation.type + << ", .isTensor=" << (operandTypeInformation.isTensor ? "true" : "false") + << ", .byteSize=" << operandTypeInformation.byteSize << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Extension& extension) { + return os << "Extension{.name=" << extension.name + << ", .operandTypes=" << extension.operandTypes << "}"; +} + +std::ostream& operator<<(std::ostream& os, const DataLocation& location) { + const auto printPointer = [&os](const std::variant<const void*, void*>& pointer) { + os << (std::holds_alternative<const void*>(pointer) ? "<constant " : "<mutable "); + os << std::visit( + [](const auto* ptr) { + return ptr == nullptr ? "null pointer>" : "non-null pointer>"; + }, + pointer); + }; + os << "DataLocation{.pointer="; + printPointer(location.pointer); + return os << ", .poolIndex=" << location.poolIndex << ", .offset=" << location.offset + << ", .length=" << location.length << "}"; +} + +std::ostream& operator<<(std::ostream& os, + const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) { + return os << "Operand::SymmPerChannelQuantParams{.scales=" << symmPerChannelQuantParams.scales + << ", .channelDim=" << symmPerChannelQuantParams.channelDim << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams) { + os << "Operand::ExtraParams{"; + if (std::holds_alternative<Operand::NoParams>(extraParams)) { + os << "<no params>"; + } else if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(extraParams)) { + os << std::get<Operand::SymmPerChannelQuantParams>(extraParams); + } else if (std::holds_alternative<Operand::ExtensionParams>(extraParams)) { + os << std::get<Operand::ExtensionParams>(extraParams); + } + return os << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Operand& operand) { + return os << "Operand{.type=" << operand.type << ", .dimensions=" << operand.dimensions + << ", .scale=" << operand.scale << ", .zeroPoint=" << operand.zeroPoint + << ", lifetime=" << operand.lifetime << ", .location=" << operand.location + << ", .extraParams=" << operand.extraParams << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Operation& operation) { + return os << "Operation{.type=" << operation.type << ", .inputs=" << operation.inputs + << ", .outputs=" << operation.outputs << "}"; +} + +std::ostream& operator<<(std::ostream& os, const NativeHandle& handle) { + return os << (handle != nullptr ? "<non-empty handle>" : "<empty handle>"); +} + +std::ostream& operator<<(std::ostream& os, const Memory& memory) { + return os << "Memory{.handle=" << memory.handle << ", .size=" << memory.size + << ", .name=" << memory.name << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph) { + std::vector<Operand> operands; + std::vector<Operation> operations; + std::vector<uint32_t> inputIndexes; + std::vector<uint32_t> outputIndexes; + return os << "Model::Subgraph{.operands=" << subgraph.operands + << ", .operations=" << subgraph.operations + << ", .inputIndexes=" << subgraph.inputIndexes + << ", .outputIndexes=" << subgraph.outputIndexes << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues) { + return os << "Model::OperandValues{<" << operandValues.size() << "bytes>}"; +} + +std::ostream& operator<<(std::ostream& os, + const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { + return os << "Model::ExtensionNameAndPrefix{.name=" << extensionNameAndPrefix.name + << ", .prefix=" << extensionNameAndPrefix.prefix << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Model& model) { + return os << "Model{.main=" << model.main << ", .referenced=" << model.referenced + << ", .operandValues=" << model.operandValues << ", .pools=" << model.pools + << ", .relaxComputationFloat32toFloat16=" + << (model.relaxComputationFloat32toFloat16 ? "true" : "false") + << ", extensionNameToPrefix=" << model.extensionNameToPrefix << "}"; +} + +std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc) { + return os << "BufferDesc{.dimensions=" << bufferDesc.dimensions << "}"; +} + +std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole) { + return os << "BufferRole{.modelIndex=" << bufferRole.modelIndex + << ", .ioIndex=" << bufferRole.ioIndex << ", .frequency=" << bufferRole.frequency + << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument) { + return os << "Request::Argument{.lifetime=" << requestArgument.lifetime + << ", .location=" << requestArgument.location + << ", .dimensions=" << requestArgument.dimensions << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool) { + os << "Request::MemoryPool{"; + if (std::holds_alternative<Memory>(memoryPool)) { + os << std::get<Memory>(memoryPool); + } else if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) { + const auto& token = std::get<Request::MemoryDomainToken>(memoryPool); + if (token == Request::MemoryDomainToken{}) { + os << "<invalid MemoryDomainToken>"; + } else { + os << "MemoryDomainToken=" << underlyingType(token); + } + } else if (std::holds_alternative<std::shared_ptr<const IBuffer>>(memoryPool)) { + const auto& buffer = std::get<std::shared_ptr<const IBuffer>>(memoryPool); + os << (buffer != nullptr ? "<non-null IBuffer>" : "<null IBuffer>"); + } + return os << "}"; +} + +std::ostream& operator<<(std::ostream& os, const Request& request) { + return os << "Request{.inputs=" << request.inputs << ", .outputs=" << request.outputs + << ", .pools=" << request.pools << "}"; +} + +std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint) { + return os << timePoint.time_since_epoch().count() << "ns since epoch"; +} + +std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint) { + if (!optionalTimePoint.has_value()) { + return os << "<no time point>"; + } + return os << optionalTimePoint.value(); +} + +std::ostream& operator<<(std::ostream& os, const TimeoutDuration& timeoutDuration) { + return os << timeoutDuration.count() << "ns"; +} + +std::ostream& operator<<(std::ostream& os, const OptionalTimeoutDuration& optionalTimeoutDuration) { + if (!optionalTimeoutDuration.has_value()) { + return os << "<no timeout duration>"; + } + return os << optionalTimeoutDuration.value(); +} + +std::ostream& operator<<(std::ostream& os, const Version& version) { + switch (version) { + case Version::ANDROID_OC_MR1: + return os << "ANDROID_OC_MR1"; + case Version::ANDROID_P: + return os << "ANDROID_P"; + case Version::ANDROID_Q: + return os << "ANDROID_Q"; + case Version::ANDROID_R: + return os << "ANDROID_R"; + case Version::CURRENT_RUNTIME: + return os << "CURRENT_RUNTIME"; + } + return os << "Version{" << underlyingType(version) << "}"; +} + +bool operator==(const Timing& a, const Timing& b) { + return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver; +} + +bool operator!=(const Timing& a, const Timing& b) { + return !(a == b); +} + +bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) { + return a.execTime == b.execTime && a.powerUsage == b.powerUsage; +} + +bool operator==(const Capabilities::OperandPerformance& a, + const Capabilities::OperandPerformance& b) { + return a.type == b.type && a.info == b.info; +} + +bool operator==(const Capabilities& a, const Capabilities& b) { + return a.relaxedFloat32toFloat16PerformanceScalar == + b.relaxedFloat32toFloat16PerformanceScalar && + a.relaxedFloat32toFloat16PerformanceTensor == + b.relaxedFloat32toFloat16PerformanceTensor && + a.operandPerformance.asVector() == b.operandPerformance.asVector() && + a.ifPerformance == b.ifPerformance && a.whilePerformance == b.whilePerformance; +} + +bool operator==(const Extension::OperandTypeInformation& a, + const Extension::OperandTypeInformation& b) { + return a.type == b.type && a.isTensor == b.isTensor && a.byteSize == b.byteSize; +} + +bool operator==(const Extension& a, const Extension& b) { + return a.name == b.name && a.operandTypes == b.operandTypes; +} + +bool operator==(const Operand::SymmPerChannelQuantParams& a, + const Operand::SymmPerChannelQuantParams& b) { + return a.scales == b.scales && a.channelDim == b.channelDim; +} +bool operator!=(const Operand::SymmPerChannelQuantParams& a, + const Operand::SymmPerChannelQuantParams& b) { + return !(a == b); +} + +} // namespace android::nn diff --git a/nn/common/Types.cpp b/nn/common/Types.cpp index a49a8dc35..761ffa72c 100644 --- a/nn/common/Types.cpp +++ b/nn/common/Types.cpp @@ -28,6 +28,7 @@ #include "OperandTypes.h" #include "OperationTypes.h" +#include "Result.h" namespace android::nn { namespace { @@ -87,15 +88,14 @@ Capabilities::OperandPerformanceTable::OperandPerformanceTable( std::vector<OperandPerformance> operandPerformances) : mSorted(std::move(operandPerformances)) {} -std::optional<Capabilities::OperandPerformanceTable> Capabilities::OperandPerformanceTable::create( +Result<Capabilities::OperandPerformanceTable> Capabilities::OperandPerformanceTable::create( std::vector<OperandPerformance> operandPerformances) { const auto notUnique = [](const auto& lhs, const auto& rhs) { return !(lhs.type < rhs.type); }; const bool isUnique = std::adjacent_find(operandPerformances.begin(), operandPerformances.end(), notUnique) == operandPerformances.end(); if (!isUnique) { - LOG(ERROR) << "Failed to create OperandPerformanceTable: Input must be sorted by key (in " - "ascending order), and there must be no duplicate keys"; - return std::nullopt; + return NN_ERROR() << "Failed to create OperandPerformanceTable: Input must be sorted by " + "key (in ascending order), and there must be no duplicate keys"; } return Capabilities::OperandPerformanceTable(std::move(operandPerformances)); diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp new file mode 100644 index 000000000..d18ba3443 --- /dev/null +++ b/nn/common/Validation.cpp @@ -0,0 +1,2664 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Validation.h" + +#include <android-base/logging.h> + +#include <algorithm> +#include <cctype> +#include <functional> +#include <limits> +#include <memory> +#include <numeric> +#include <set> +#include <sstream> +#include <string> +#include <string_view> +#include <tuple> +#include <utility> +#include <variant> +#include <vector> + +#include "ControlFlow.h" +#include "OperandTypes.h" +#include "OperationTypes.h" +#include "Result.h" +#include "TypeUtils.h" +#include "Types.h" + +// The NN_VALIDATE family of macros defined below is similar to the CHECK family defined in +// system/core/base/include/android-base/logging.h +// +// The difference is that NN_VALIDATE macros use LOG(ERROR) instead of LOG(FATAL) +// and return false instead of aborting. + +// Logs an error and returns false or INVALID. Append context using << after. For example: +// +// NN_VALIDATE_FAIL() << "Something went wrong"; +// +// The containing function must return a bool or Version. +#define NN_VALIDATE_FAIL() \ + return NN_ERROR() << "NN_VALIDATE failed (" << __FILE__ << ":" << __LINE__ << "): " + +// Logs an error and returns false or Version::INVALID if condition is false. Extra logging can be +// appended using << after. For example: +// +// NN_VALIDATE(false) << "Something went wrong"; +// +// The containing function must return a bool. +#define NN_VALIDATE(condition) \ + while (UNLIKELY(!(condition))) NN_VALIDATE_FAIL() << #condition << " " + +// Helper for NN_VALIDATE_xx(x, y) macros. +#define NN_VALIDATE_OP(LHS, RHS, OP) \ + for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \ + UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \ + /* empty */) \ + NN_VALIDATE_FAIL() \ + << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \ + << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \ + << ", " << #RHS << " = " \ + << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \ + << ") " + +// Logs an error and returns false or Version::INVALID if a condition between x and y does not hold. +// Extra logging can be appended using << after. For example: +// +// NN_VALIDATE_EQ(a, b) << "Something went wrong"; +// +// The values must implement the appropriate comparison operator as well as +// `operator<<(std::ostream&, ...)`. +// The containing function must return a bool or Version. +#define NN_VALIDATE_EQ(x, y) NN_VALIDATE_OP(x, y, ==) +#define NN_VALIDATE_NE(x, y) NN_VALIDATE_OP(x, y, !=) +#define NN_VALIDATE_LE(x, y) NN_VALIDATE_OP(x, y, <=) +#define NN_VALIDATE_LT(x, y) NN_VALIDATE_OP(x, y, <) +#define NN_VALIDATE_GE(x, y) NN_VALIDATE_OP(x, y, >=) +#define NN_VALIDATE_GT(x, y) NN_VALIDATE_OP(x, y, >) + +namespace android::nn { +namespace { + +constexpr auto kNullptrVariant = std::variant<const void*, void*>{}; +constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{}; + +template <typename Type, typename ValidationFunction> +Result<Version> validateVector(const std::vector<Type>& objects, + const ValidationFunction& validationFunction) { + auto version = Version::ANDROID_OC_MR1; + for (const auto& object : objects) { + version = combineVersions(version, NN_TRY(validationFunction(object))); + } + return version; +} + +bool isValidExtensionName(const std::string& name) { + constexpr auto validSymbol = [](char symbol) { + return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_'; + }; + const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol); + const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end(); + return hasOnlyValidSymbols && hasAtLeastOnePeriod; +} + +Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) { + switch (deviceStatus) { + case DeviceStatus::AVAILABLE: + case DeviceStatus::BUSY: + case DeviceStatus::OFFLINE: + case DeviceStatus::UNKNOWN: + return Version::ANDROID_OC_MR1; + } + NN_VALIDATE_FAIL() << "Invalid DeviceStatus " << deviceStatus; +} + +Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) { + switch (executionPreference) { + case ExecutionPreference::LOW_POWER: + case ExecutionPreference::FAST_SINGLE_ANSWER: + case ExecutionPreference::SUSTAINED_SPEED: + return Version::ANDROID_P; + } + NN_VALIDATE_FAIL() << "Invalid ExecutionPreference " << executionPreference; +} + +Result<Version> validateDeviceType(const DeviceType& deviceType) { + switch (deviceType) { + case DeviceType::UNKNOWN: + // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when + // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a + // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a + // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for + // DeviceType::UNKNOWN. + return Version::ANDROID_OC_MR1; + case DeviceType::OTHER: + case DeviceType::CPU: + case DeviceType::GPU: + case DeviceType::ACCELERATOR: + return Version::ANDROID_Q; + } + NN_VALIDATE_FAIL() << "Invalid DeviceType " << deviceType; +} + +Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) { + switch (measureTiming) { + case MeasureTiming::NO: + case MeasureTiming::YES: + return Version::ANDROID_Q; + } + NN_VALIDATE_FAIL() << "Invalid MeasureTiming " << measureTiming; +} + +Result<Version> validateOperandType(const OperandType& operandType) { + switch (operandType) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::OEM: + case OperandType::TENSOR_OEM_BYTE: + return Version::ANDROID_OC_MR1; + case OperandType::BOOL: + case OperandType::TENSOR_QUANT16_SYMM: + case OperandType::TENSOR_FLOAT16: + case OperandType::TENSOR_BOOL8: + case OperandType::FLOAT16: + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case OperandType::TENSOR_QUANT16_ASYMM: + case OperandType::TENSOR_QUANT8_SYMM: + return Version::ANDROID_Q; + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + case OperandType::SUBGRAPH: + return Version::ANDROID_R; + } + if (isExtension(operandType)) { + return Version::ANDROID_Q; + } + NN_VALIDATE_FAIL() << "Invalid OperandType " << operandType; +} + +Result<Version> validateOperandLifeTime(const Operand& operand) { + // Make sure SUBGRAPH operand type and lifetime always go together. + NN_VALIDATE_EQ((operand.type == OperandType::SUBGRAPH), + (operand.lifetime == Operand::LifeTime::SUBGRAPH)) + << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime; + + switch (operand.lifetime) { + case Operand::LifeTime::TEMPORARY_VARIABLE: + case Operand::LifeTime::SUBGRAPH_INPUT: + case Operand::LifeTime::SUBGRAPH_OUTPUT: + case Operand::LifeTime::CONSTANT_COPY: + case Operand::LifeTime::CONSTANT_REFERENCE: + case Operand::LifeTime::NO_VALUE: + case Operand::LifeTime::POINTER: + return Version::ANDROID_OC_MR1; + case Operand::LifeTime::SUBGRAPH: + return Version::ANDROID_R; + } + NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime; +} + +Result<Version> validatePriority(const Priority& priority) { + switch (priority) { + case Priority::LOW: + case Priority::MEDIUM: + case Priority::HIGH: + return Version::ANDROID_R; + } + NN_VALIDATE_FAIL() << "Invalid Priority " << priority; +} + +Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) { + // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih + // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE. + switch (errorStatus) { + case ErrorStatus::NONE: + case ErrorStatus::DEVICE_UNAVAILABLE: + case ErrorStatus::GENERAL_FAILURE: + case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: + case ErrorStatus::INVALID_ARGUMENT: + case ErrorStatus::MISSED_DEADLINE_TRANSIENT: + case ErrorStatus::MISSED_DEADLINE_PERSISTENT: + case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT: + case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT: + case ErrorStatus::DEAD_OBJECT: + return Version::ANDROID_OC_MR1; + } + NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus; +} + +Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) { + return Version::ANDROID_Q; +} + +Result<Version> validateTiming(const Timing& timing) { + if (timing.timeInDriver != kNoTiming && timing.timeOnDevice != kNoTiming) { + NN_VALIDATE_LE(timing.timeOnDevice, timing.timeInDriver); + } + return Version::ANDROID_Q; +} + +Result<Version> validateCapabilitiesPerformanceInfo( + const Capabilities::PerformanceInfo& performanceInfo) { + NN_VALIDATE_GT(performanceInfo.execTime, 0.0f); + NN_VALIDATE_GT(performanceInfo.powerUsage, 0.0f); + return Version::ANDROID_OC_MR1; +} + +Result<Version> validateCapabilitiesOperandPerformance( + const Capabilities::OperandPerformance& operandPerformance) { + auto version = NN_TRY(validateOperandType(operandPerformance.type)); + return combineVersions(version, + NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info))); +} + +Result<Version> validateCapabilitiesOperandPerformanceTable( + const Capabilities::OperandPerformanceTable& operandPerformances) { + // OperandPerformanceTable's order was validated when it was created, and it is castable to any + // version. If an OperandType does not exist in the lower version being converted to, that + // OperandPerformance will be dropped. + NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance)); + return Version::ANDROID_OC_MR1; +} + +Result<Version> validateCapabilities(const Capabilities& capabilities) { + auto version = + NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance)); + + version = combineVersions(version, + NN_TRY(validateCapabilitiesPerformanceInfo( + capabilities.relaxedFloat32toFloat16PerformanceScalar))); + version = combineVersions(version, + NN_TRY(validateCapabilitiesPerformanceInfo( + capabilities.relaxedFloat32toFloat16PerformanceTensor))); + version = combineVersions( + version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance))); + version = combineVersions( + version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance))); + + return version; +} + +Result<Version> validateExtensionOperandTypeInformation( + const Extension::OperandTypeInformation& operandTypeInformation) { + NN_VALIDATE_GT(operandTypeInformation.byteSize, 0u); + return Version::ANDROID_Q; +} + +Result<Version> validateExtension(const Extension& extension) { + NN_VALIDATE(isValidExtensionName(extension.name)); + + // Verify all OperandTypeInformations have unique types. + std::vector<uint16_t> types; + types.reserve(extension.operandTypes.size()); + std::transform(extension.operandTypes.begin(), extension.operandTypes.end(), + std::back_inserter(types), + [](const Extension::OperandTypeInformation& operandTypeInformation) { + return operandTypeInformation.type; + }); + std::sort(types.begin(), types.end()); + const auto iter = std::adjacent_find(types.begin(), types.end()); + NN_VALIDATE(iter == types.end()) << "Extension has duplicate type " << *iter; + + return combineVersions(Version::ANDROID_Q, + NN_TRY(validateVector(extension.operandTypes, + validateExtensionOperandTypeInformation))); +} + +Result<Version> validateExtensions(const std::vector<Extension>& extensions) { + const auto version = NN_TRY(validateVector(extensions, validateExtension)); + + // Verify all extensions have unique names. + std::vector<std::reference_wrapper<const std::string>> names; + names.reserve(extensions.size()); + std::transform(extensions.begin(), extensions.end(), std::back_inserter(names), + [](const Extension& extension) { return std::cref(extension.name); }); + std::sort(names.begin(), names.end(), std::less<std::string>{}); + const auto nameIter = + std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{}); + NN_VALIDATE(nameIter == names.end()) + << "Two or more extensions have the duplicate name " << nameIter->get(); + + return version; +} + +Result<Version> validateOperandDataLocation(const Operand& operand, size_t operandValuesSize, + const std::vector<size_t>& poolSizes, + size_t subgraphCount) { + const DataLocation& location = operand.location; + switch (operand.lifetime) { + case Operand::LifeTime::CONSTANT_COPY: + NN_VALIDATE(location.pointer == kNullptrVariant) + << "CONSTANT_COPY with a non-null pointer"; + NN_VALIDATE_EQ(location.poolIndex, 0u) + << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex; + // Do the addition using uint64_t to avoid potential wrap-around problems. + NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length, + operandValuesSize) + << "OperandValue location out of range. Starts at " << location.offset + << ", length " << location.length << ", max " << operandValuesSize; + return Version::ANDROID_OC_MR1; + case Operand::LifeTime::CONSTANT_REFERENCE: + NN_VALIDATE_LT(location.poolIndex, poolSizes.size()); + // Do the addition using uint64_t to avoid potential wrap-around problems. + NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length, + poolSizes[location.poolIndex]) + << "OperandValue location out of range. Starts at " << location.offset + << ", length " << location.length << ", max " << poolSizes[location.poolIndex]; + return Version::ANDROID_OC_MR1; + case Operand::LifeTime::TEMPORARY_VARIABLE: + case Operand::LifeTime::SUBGRAPH_INPUT: + case Operand::LifeTime::SUBGRAPH_OUTPUT: + case Operand::LifeTime::NO_VALUE: + NN_VALIDATE(location.pointer == kNullptrVariant) + << "Unexpected pointer value for operand of lifetime " << operand.lifetime; + NN_VALIDATE_EQ(location.poolIndex, 0u) + << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime " + << operand.lifetime; + NN_VALIDATE_EQ(location.offset, 0u) << "Unexpected offset " << location.offset + << " for operand of lifetime " << operand.lifetime; + NN_VALIDATE_EQ(location.length, 0u) << "Unexpected length " << location.length + << " for operand of lifetime " << operand.lifetime; + return Version::ANDROID_OC_MR1; + case Operand::LifeTime::SUBGRAPH: + NN_VALIDATE(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer"; + NN_VALIDATE_EQ(location.poolIndex, 0u) + << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex; + NN_VALIDATE_LT(location.offset, subgraphCount) + << "Subgraph index out of range: " << location.offset + << " >= " << subgraphCount; + NN_VALIDATE_EQ(location.length, 0u) + << "SUBGRAPH with a non-zero length " << location.length; + return Version::ANDROID_R; + case Operand::LifeTime::POINTER: { + const bool nonNull = + std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer); + NN_VALIDATE(nonNull) << "POINTER with a null pointer"; + NN_VALIDATE_EQ(location.poolIndex, 0u) + << "POINTER with a non-zero poolIndex " << location.poolIndex; + NN_VALIDATE_EQ(location.offset, 0u) + << "POINTER with a non-zero offset " << location.offset; + return Version::ANDROID_OC_MR1; + } + } + NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime; +} + +Result<Version> validateOperandDimensions(const Operand& operand) { + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::BOOL: + case OperandType::FLOAT16: + case OperandType::SUBGRAPH: + case OperandType::OEM: + NN_VALIDATE(operand.dimensions.empty()) + << "Scalar data has dimensions of rank " << operand.dimensions.size(); + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::TENSOR_QUANT16_SYMM: + case OperandType::TENSOR_FLOAT16: + case OperandType::TENSOR_BOOL8: + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case OperandType::TENSOR_QUANT16_ASYMM: + case OperandType::TENSOR_QUANT8_SYMM: + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + case OperandType::TENSOR_OEM_BYTE: { + if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY || + operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE || + operand.lifetime == Operand::LifeTime::POINTER) { + NN_VALIDATE(!operand.dimensions.empty()) + << "Tensor has lifetime of " << operand.lifetime + << " but dimensions of rank 0"; + const auto size = getNonExtensionSize(operand); + NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow"; + NN_VALIDATE_EQ(size.value(), 0u) << "Tensor has at least one unknown dimension"; + } + // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before + // Android Q? + if (operand.dimensions.empty()) { + // Unspecified rank was added in Android Q. + return Version::ANDROID_Q; + } + return Version::ANDROID_OC_MR1; + } + } + if (isExtension(operand.type)) { + // Extension types were added in Android Q. + return Version::ANDROID_Q; + } + NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; +} + +Result<Version> validateOperandScale(const Operand& operand) { + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + case OperandType::BOOL: + case OperandType::TENSOR_FLOAT16: + case OperandType::TENSOR_BOOL8: + case OperandType::FLOAT16: + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case OperandType::SUBGRAPH: + NN_VALIDATE_EQ(operand.scale, 0.0f) + << "Operand of type " << operand.type << " with a non-zero scale (" + << operand.scale << ")"; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_INT32: + // TENSOR_INT32 may be used with or without scale, depending on the operation. + // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale. + NN_VALIDATE_GE(operand.scale, 0.0f) + << "Operand of type " << operand.type << " with a negative scale"; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::TENSOR_QUANT16_SYMM: + case OperandType::TENSOR_QUANT16_ASYMM: + case OperandType::TENSOR_QUANT8_SYMM: + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + NN_VALIDATE_GT(operand.scale, 0.0f) + << "Operand of type " << operand.type << " with a non-positive scale"; + return Version::ANDROID_OC_MR1; + case OperandType::OEM: + case OperandType::TENSOR_OEM_BYTE: + // No validation for OEM types. + return Version::ANDROID_OC_MR1; + } + if (isExtension(operand.type)) { + NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type + << " with a non-zero scale (" << operand.scale << ")"; + return Version::ANDROID_Q; + } + NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; +} + +Result<Version> validateOperandZeroPoint(const Operand& operand) { + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + case OperandType::BOOL: + case OperandType::TENSOR_FLOAT16: + case OperandType::TENSOR_BOOL8: + case OperandType::FLOAT16: + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case OperandType::TENSOR_QUANT8_SYMM: + case OperandType::SUBGRAPH: + NN_VALIDATE_EQ(operand.zeroPoint, 0) + << "Operand of type " << operand.type << " with a non-zero zeroPoint " + << operand.zeroPoint; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_QUANT8_ASYMM: + NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 255) + << "Operand of type " << operand.type << " with an invalid zeroPoint " + << operand.zeroPoint << ", must be in range [0, 255]"; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + NN_VALIDATE(operand.zeroPoint >= -128 && operand.zeroPoint <= 127) + << "Operand of type " << operand.type << " with an invalid zeroPoint " + << operand.zeroPoint << ", must be in range [-128, 127]"; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_QUANT16_ASYMM: + NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535) + << "Operand of type " << operand.type << " with an invalid zeroPoint " + << operand.zeroPoint << ", must be in range [0, 65535]"; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_QUANT16_SYMM: + NN_VALIDATE_EQ(operand.zeroPoint, 0) + << "Operand of type " << operand.type << " with a non-zero zeroPoint " + << operand.zeroPoint; + return Version::ANDROID_OC_MR1; + case OperandType::OEM: + case OperandType::TENSOR_OEM_BYTE: + // No validation for OEM types. + return Version::ANDROID_OC_MR1; + } + if (isExtension(operand.type)) { + NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type + << " with a non-zero zeroPoint " << operand.zeroPoint; + return Version::ANDROID_Q; + } + NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; +} + +Result<Version> validateOperandExtraParams(const Operand& operand) { + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::BOOL: + case OperandType::TENSOR_QUANT16_SYMM: + case OperandType::TENSOR_FLOAT16: + case OperandType::TENSOR_BOOL8: + case OperandType::FLOAT16: + case OperandType::TENSOR_QUANT16_ASYMM: + case OperandType::TENSOR_QUANT8_SYMM: + case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: + case OperandType::SUBGRAPH: + NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams)) + << "Operand of type " << operand.type + << " has extraParams when there must be none"; + return Version::ANDROID_OC_MR1; + case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: { + NN_VALIDATE( + std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams)) + << "Operand of type " << operand.type + << " without a Channel Quantization params"; + const auto& channelQuant = + std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams); + + const size_t count = operand.dimensions.size(); + NN_VALIDATE_LT(channelQuant.channelDim, count) + << "Operand of type " << operand.type + << " with an invalid channelQuant.channelDim " << channelQuant.channelDim + << ", must be valid dimension index in range [0, " << count << ")"; + const uint32_t expected = operand.dimensions[channelQuant.channelDim]; + NN_VALIDATE_EQ(channelQuant.scales.size(), expected) + << "Operand of type " << operand.type << " with a wrong-sized scales, expected " + << expected << " was " << channelQuant.scales.size(); + NN_VALIDATE_NE(expected, 0u) + << "Operand of type " << operand.type << " channel dimension " + << channelQuant.channelDim << " is underspecified (can't be 0)"; + for (uint32_t i = 0; i < expected; ++i) { + NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) + << "Operand of type " << operand.type + << " with a non-positive value in scales[" << i + << "]=" << channelQuant.scales[i]; + } + return Version::ANDROID_Q; + } + case OperandType::OEM: + case OperandType::TENSOR_OEM_BYTE: + // No validation for OEM types. + return Version::ANDROID_OC_MR1; + } + if (isExtension(operand.type)) { + NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams) || + std::holds_alternative<Operand::ExtensionParams>(operand.extraParams)) + << "Extension operand of type " << operand.type + << " must not have SymmPerChannelQuant extraParams"; + return Version::ANDROID_OC_MR1; + } + NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; +} + +Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize, + const std::vector<size_t>& poolSizes, size_t subgraphCount) { + auto version = NN_TRY(validateOperandType(operand.type)); + version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand))); + version = combineVersions(version, NN_TRY(validateOperandDimensions(operand))); + version = combineVersions(version, NN_TRY(validateOperandScale(operand))); + version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand))); + version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand))); + version = + combineVersions(version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, + poolSizes, subgraphCount))); + + // For constants, validate that the length is as expected. The other lifetimes + // expect the length to be 0. Don't validate for OEM types. + if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE || + operand.lifetime == Operand::LifeTime::CONSTANT_COPY || + operand.lifetime == Operand::LifeTime::POINTER) { + if (!isExtension(operand.type) && operand.type != OperandType::OEM && + operand.type != OperandType::TENSOR_OEM_BYTE) { + const auto expectedLength = getNonExtensionSize(operand).value(); + NN_VALIDATE_EQ(operand.location.length, expectedLength) + << "For operand " << operand.type << " expected a size of " << expectedLength + << " but got " << operand.location.length; + } + } + + return version; +} + +Result<Version> validateOperands(const std::vector<Operand>& operands, size_t operandValuesSize, + const std::vector<size_t>& poolSizes, size_t subgraphCount) { + auto version = Version::ANDROID_OC_MR1; + for (size_t i = 0; i < operands.size(); ++i) { + auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphCount); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << " for operand " << i; + } + version = combineVersions(version, result.value()); + } + return version; +} + +Result<Version> validateOperationImpl(const Operation& operation, + const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs); + +Result<Version> validateOperations(const std::vector<Operation>& operations, + const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs) { + auto version = Version::ANDROID_OC_MR1; + for (size_t i = 0; i < operations.size(); ++i) { + auto result = validateOperationImpl(operations[i], operands, subgraphs); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << " for operation " << i; + } + version = combineVersions(version, result.value()); + } + return version; +} + +Result<Version> validateNativeHandle(const NativeHandle& handle) { + NN_VALIDATE(handle != nullptr); + return Version::ANDROID_OC_MR1; +} + +Result<Version> validateMemory(const Memory& memory) { + NN_TRY(validateNativeHandle(memory.handle)); + + if (memory.name == "ashmem") { + NN_VALIDATE_NE(memory.size, 0u); + return Version::ANDROID_OC_MR1; + } + if (memory.name == "mmap_fd") { + NN_VALIDATE_NE(memory.size, 0u); + return Version::ANDROID_OC_MR1; + } + if (memory.name == "hardware_buffer_blob") { + NN_VALIDATE_NE(memory.size, 0u); + return Version::ANDROID_Q; + } + if (memory.name == "hardware_buffer") { + // For hardware_buffer memory, all size information is bound to the AHardwareBuffer, so + // memory.size must be 0. + NN_VALIDATE_EQ(memory.size, 0u); + return Version::ANDROID_Q; + } + + NN_VALIDATE_FAIL() << "Unknown memory type " << memory.name; +} + +Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes, + const std::vector<Operand>& operands, + Operand::LifeTime lifetime) { + const size_t operandCount = operands.size(); + for (uint32_t i : indexes) { + NN_VALIDATE_LT(i, operandCount) + << "Model " << lifetime << " input or output index out of range: " << i << "/" + << operandCount; + const Operand& operand = operands[i]; + NN_VALIDATE_EQ(operand.lifetime, lifetime) + << "Model " << lifetime << " operand " << i << " has lifetime of " + << operand.lifetime << " instead of the expected " << lifetime; + } + + std::vector<uint32_t> sortedIndexes = indexes; + std::sort(sortedIndexes.begin(), sortedIndexes.end()); + const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end()); + NN_VALIDATE(iter == sortedIndexes.end()) + << "Model input or output occurs multiple times: " << *iter; + + for (size_t i = 0; i < operands.size(); ++i) { + if (operands[i].lifetime == lifetime) { + const auto containsIndex = [&sortedIndexes](size_t index) { + return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index); + }; + NN_VALIDATE(containsIndex(i)) + << "Operand " << i << " marked as " << lifetime + << " but is not included in Model input or output indexes"; + } + } + + return {}; +} + +Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) { + // Either the operand has a known value before model execution begins, or we've seen a writer + // for this operand while walking operands in execution order. Initialize to known operands. + std::vector<bool> operandValueKnown; + operandValueKnown.reserve(subgraph.operands.size()); + std::transform(subgraph.operands.begin(), subgraph.operands.end(), + std::back_inserter(operandValueKnown), [](const Operand& operand) { + return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE && + operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT; + }); + + // Validate that operations are sorted into execution order. + // + // If there is a cycle in the graph, the operations will not + // appear to be sorted into execution order: Some operation will + // have an input for which operandValueKnown[] is false. + for (size_t i = 0; i < subgraph.operations.size(); ++i) { + const auto& operation = subgraph.operations[i]; + + for (size_t j = 0; j < operation.inputs.size(); ++j) { + const uint32_t k = operation.inputs[j]; + NN_VALIDATE(operandValueKnown[k]) << "Operation " << i << " input " << j << " (operand " + << k << ") is read before it is written"; + } + + for (size_t j = 0; j < operation.outputs.size(); ++j) { + const uint32_t k = operation.outputs[j]; + // Assuming validateOperations() has returned true, we know that this output is + // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be + // true is if we've already seen a writer for this operand. + NN_VALIDATE(!operandValueKnown[k]) << "Operation " << i << " output " << j + << " (operand " << k << ") has already been written"; + operandValueKnown[k] = true; + } + } + + // Verify all operands are written. + for (size_t i = 0; i < subgraph.operands.size(); ++i) { + NN_VALIDATE(operandValueKnown[i]) << "Operand " << i << " is never written"; + } + + // TODO(b/77871786): verify that every operation has at least one output operand that is read? + + return {}; +} + +Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph, size_t operandValuesSize, + const std::vector<size_t>& poolSizes, + const std::vector<Model::Subgraph>& referenced) { + NN_VALIDATE(!subgraph.operands.empty()); + NN_VALIDATE(!subgraph.operations.empty()); + // TODO(b/134529942#comment7): should we verify !subgraph.inputIndexes.empty()? + NN_VALIDATE(!subgraph.inputIndexes.empty()); + NN_VALIDATE(!subgraph.outputIndexes.empty()); + + auto version = NN_TRY( + validateOperands(subgraph.operands, operandValuesSize, poolSizes, referenced.size())); + version = combineVersions(version, NN_TRY(validateOperations(subgraph.operations, + subgraph.operands, referenced))); + + NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands, + Operand::LifeTime::SUBGRAPH_INPUT)); + NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands, + Operand::LifeTime::SUBGRAPH_OUTPUT)); + + NN_TRY(validateExecutionOrder(subgraph)); + + return version; +} + +Result<Version> validateModelExtensionNamesAndPrefixes( + const std::vector<Model::ExtensionNameAndPrefix>& extensionNamesAndPrefixes) { + for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) { + NN_VALIDATE(isValidExtensionName(extensionNameAndPrefix.name)); + } + + std::vector<std::reference_wrapper<const std::string>> names; + names.reserve(extensionNamesAndPrefixes.size()); + std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(), + std::back_inserter(names), + [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { + return std::cref(extensionNameAndPrefix.name); + }); + std::sort(names.begin(), names.end(), std::less<std::string>{}); + const auto nameIter = + std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{}); + NN_VALIDATE(nameIter == names.end()) + << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get(); + + std::vector<uint16_t> types; + types.reserve(extensionNamesAndPrefixes.size()); + std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(), + std::back_inserter(types), + [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { + return extensionNameAndPrefix.prefix; + }); + std::sort(types.begin(), types.end()); + const auto typeIter = std::adjacent_find(types.begin(), types.end()); + NN_VALIDATE(typeIter == types.end()) + << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter; + + const bool hasExtensions = !extensionNamesAndPrefixes.empty(); + return hasExtensions ? Version::ANDROID_Q : Version::ANDROID_OC_MR1; +} + +// Makes sure the model does not contain subgraph reference cycles. +Result<void> checkNoReferenceCycles(const Model& model, const Model::Subgraph& subgraph, + std::set<const Model::Subgraph*>* path) { + const auto [_, isNew] = path->insert(&subgraph); + NN_VALIDATE(isNew) << "Model contains a circular subgraph reference"; + // TODO(b/165154824): It looks like this functions is doing a lot of redundant work. + for (const Operand& operand : subgraph.operands) { + if (operand.lifetime == Operand::LifeTime::SUBGRAPH) { + const uint32_t refSubgraphIndex = operand.location.offset; + NN_TRY(checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path)); + } + } + path->erase(&subgraph); + return {}; +} + +Result<void> checkNoReferenceCycles(const Model& model) { + std::set<const Model::Subgraph*> path; + return checkNoReferenceCycles(model, model.main, &path); +} + +Result<Version> validateModel(const Model& model) { + auto version = NN_TRY(validateVector(model.pools, validateMemory)); + version = combineVersions( + version, NN_TRY(validateModelExtensionNamesAndPrefixes(model.extensionNameToPrefix))); + + // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the + // execution stricter. + + // Referenced models were introduced in Android R. + const bool hasReferencedModels = !model.referenced.empty(); + const auto referenceModelVersion = + hasReferencedModels ? Version::ANDROID_R : Version::ANDROID_OC_MR1; + version = combineVersions(version, referenceModelVersion); + + // Get memory sizes. + std::vector<size_t> poolSizes; + poolSizes.reserve(model.pools.size()); + std::transform(model.pools.begin(), model.pools.end(), std::back_inserter(poolSizes), + [](const Memory& memory) { return memory.size; }); + const size_t operandValuesSize = model.operandValues.size(); + + // Validate main subgraph. + auto result = validateModelSubgraph(model.main, operandValuesSize, poolSizes, model.referenced); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << " for main subgraph"; + } + version = combineVersions(version, result.value()); + + // Validate referenced subgraphs. + for (size_t i = 0; i < model.referenced.size(); ++i) { + const auto& subgraph = model.referenced[i]; + auto result = + validateModelSubgraph(subgraph, operandValuesSize, poolSizes, model.referenced); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << " for referenced subgraph" << i; + } + version = combineVersions(version, result.value()); + } + + // Ensure that there are no references formed by calling the subgraphs. + NN_TRY(checkNoReferenceCycles(model)); + + return version; +} + +Result<Version> validateBufferDesc(const BufferDesc& /*bufferDesc*/) { + return Version::ANDROID_R; +} + +Result<Version> validateBufferRole(const BufferRole& bufferRole) { + NN_VALIDATE_GT(bufferRole.frequency, 0.0f); + NN_VALIDATE_LE(bufferRole.frequency, 1.0f); + return Version::ANDROID_R; +} + +Result<Version> validateRequestArgument(const Request::Argument& requestArgument, + const std::vector<size_t>& memorySizes, bool isOutput) { + const auto lifetime = requestArgument.lifetime; + const auto& location = requestArgument.location; + const auto& dimensions = requestArgument.dimensions; + + switch (lifetime) { + case Request::Argument::LifeTime::POOL: { + NN_VALIDATE(location.pointer == kNullptrVariant); + NN_VALIDATE_LT(location.poolIndex, memorySizes.size()); + // Do the addition using uint64_t to avoid potential wrap-around problems. + const auto lastPosition = static_cast<uint64_t>(location.offset) + location.length; + NN_VALIDATE_LE(lastPosition, memorySizes[location.poolIndex]); + return Version::ANDROID_OC_MR1; + } + case Request::Argument::LifeTime::NO_VALUE: + NN_VALIDATE(location.pointer == kNullptrVariant); + NN_VALIDATE_EQ(location.poolIndex, 0u); + NN_VALIDATE_EQ(location.offset, 0u); + NN_VALIDATE_EQ(location.length, 0u); + NN_VALIDATE(dimensions.empty()); + return Version::ANDROID_OC_MR1; + case Request::Argument::LifeTime::POINTER: { + const bool isNullptr = + std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer); + NN_VALIDATE(!isNullptr); + NN_VALIDATE_EQ(location.poolIndex, 0u); + NN_VALIDATE_EQ(location.offset, 0u); + NN_VALIDATE_NE(location.length, 0u); + if (isOutput) { + NN_VALIDATE(std::holds_alternative<void*>(location.pointer)); + } + return Version::ANDROID_OC_MR1; + } + } + NN_VALIDATE_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime; +} + +Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) { + if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) { + NN_VALIDATE(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken); + return Version::ANDROID_R; + } + if (std::holds_alternative<std::shared_ptr<const IBuffer>>(memoryPool)) { + NN_VALIDATE(std::get<std::shared_ptr<const IBuffer>>(memoryPool) != nullptr); + return Version::ANDROID_R; + } + return validateMemory(std::get<Memory>(memoryPool)); +} + +Result<Version> validateRequest(const Request& request) { + auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool)); + + // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0. + std::vector<size_t> memorySizes; + memorySizes.reserve(request.pools.size()); + std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes), + [](const Request::MemoryPool& memoryPool) { + const auto* memory = std::get_if<Memory>(&memoryPool); + return memory != nullptr ? memory->size : 0; + }); + + for (size_t i = 0; i < request.inputs.size(); ++i) { + const auto& input = request.inputs[i]; + auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << " for input RequestArgument " << i; + } + version = combineVersions(version, result.value()); + } + for (size_t i = 0; i < request.outputs.size(); ++i) { + const auto& output = request.outputs[i]; + auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << " for output RequestArgument " << i; + } + version = combineVersions(version, result.value()); + } + + return version; +} + +Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) { + if (optionalTimePoint.has_value()) { + NN_VALIDATE_GE(optionalTimePoint.value().time_since_epoch().count(), 0); + } + return Version::ANDROID_R; +} + +Result<Version> validateOptionalTimeoutDuration( + const OptionalTimeoutDuration& optionalTimeoutDuration) { + if (optionalTimeoutDuration.has_value()) { + NN_VALIDATE_GE(optionalTimeoutDuration.value().count(), 0); + } + return Version::ANDROID_R; +} + +Result<Version> validateRequestArgumentsForModel( + const std::vector<Request::Argument>& requestArguments, + const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands, + bool isOutput) { + auto version = Version::ANDROID_OC_MR1; + // The request should specify as many arguments as were described in the model. + const std::string_view type = isOutput ? "output" : "input"; + const size_t requestArgumentCount = requestArguments.size(); + NN_VALIDATE_EQ(requestArgumentCount, operandIndexes.size()) + << "Request specifies " << requestArgumentCount << " " << type << "s but the model has " + << operandIndexes.size(); + for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount; + requestArgumentIndex++) { + const Request::Argument& requestArgument = requestArguments[requestArgumentIndex]; + // Get the operand index for this argument. We extract it from the list + // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs. + // We assume in this function that the model has been validated already. + const uint32_t operandIndex = operandIndexes[requestArgumentIndex]; + const Operand& operand = operands[operandIndex]; + if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) { + // If the argument specified a dimension, validate it. + uint32_t modelRank = operand.dimensions.size(); + uint32_t requestRank = requestArgument.dimensions.size(); + if (requestRank == 0) { + // NOTE: validateRequestArguments cannot validate unknown tensor rank with + // extension operand type. + if (!isExtension(operand.type) && !isNonExtensionScalar(operand.type)) { + if (modelRank <= 0) { + NN_VALIDATE(isOutput) + << "Model has unknown input rank but the request does not " + "specify the rank."; + // Unspecified output dimensions introduced in Android Q. + version = combineVersions(version, Version::ANDROID_Q); + } + } + // Validate that all the dimensions are specified in the model. + for (size_t i = 0; i < modelRank; i++) { + if (operand.dimensions[i] == 0) { + NN_VALIDATE(isOutput) + << "Model has dimension " << i + << " set to 0 but the request does not specify the dimension."; + // Unspecified output dimensions introduced in Android Q. + version = combineVersions(version, Version::ANDROID_Q); + } + } + } else { + NN_VALIDATE(modelRank == 0 || requestRank == modelRank) + << "Request " << type << " " << requestArgumentIndex + << " has number of dimensions (" << requestRank + << ") different than the model's (" << modelRank << ")"; + for (size_t i = 0; i < requestRank; i++) { + NN_VALIDATE(modelRank == 0 || operand.dimensions[i] == 0 || + requestArgument.dimensions[i] == operand.dimensions[i]) + << "Request " << type << " " << requestArgumentIndex + << " has dimension " << i << " of " << requestArgument.dimensions[i] + << " different than the model's " << operand.dimensions[i]; + if (requestArgument.dimensions[i] == 0) { + NN_VALIDATE(isOutput) << "Request " << type << " " << requestArgumentIndex + << " has dimension " << i << " of zero"; + // Unspecified output dimensions introduced in Android Q. + version = combineVersions(version, Version::ANDROID_Q); + } + } + } + } + } + return version; +} + +Result<Version> validateRequestForModelImpl(const Request& request, const Model& model) { + auto version = NN_TRY(validateRequest(request)); + version = combineVersions(version, NN_TRY(validateModel(model))); + version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel( + request.inputs, model.main.inputIndexes, + model.main.operands, /*isOutput=*/false))); + version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel( + request.inputs, model.main.inputIndexes, + model.main.operands, /*isOutput=*/true))); + return version; +} + +Result<Version> validateMemoryDescImpl( + const BufferDesc& desc, + const std::vector<std::shared_ptr<const IPreparedModel>>& preparedModels, + const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles, + const std::function<const Model*(const std::shared_ptr<const IPreparedModel>&)>& getModel, + std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) { + NN_VALIDATE(!preparedModels.empty()); + NN_VALIDATE(!inputRoles.empty() || !outputRoles.empty()); + + std::set<PreparedModelRole> roles; + std::vector<nn::Operand> operands; + operands.reserve(inputRoles.size() + outputRoles.size()); + for (const auto& role : inputRoles) { + NN_VALIDATE_LT(role.modelIndex, preparedModels.size()); + const auto& preparedModel = preparedModels[role.modelIndex]; + NN_VALIDATE(preparedModel != nullptr); + const auto* model = getModel(preparedModel); + NN_VALIDATE(model != nullptr); + const auto& inputIndexes = model->main.inputIndexes; + NN_VALIDATE_LT(role.ioIndex, inputIndexes.size()); + NN_VALIDATE_GT(role.frequency, 0.0f); + NN_VALIDATE_LE(role.frequency, 1.0f); + const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex); + NN_VALIDATE(success); + operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]); + } + for (const auto& role : outputRoles) { + NN_VALIDATE_LT(role.modelIndex, preparedModels.size()); + const auto& preparedModel = preparedModels[role.modelIndex]; + NN_VALIDATE(preparedModel != nullptr); + const auto* model = getModel(preparedModel); + NN_VALIDATE(model != nullptr); + const auto& outputIndexes = model->main.outputIndexes; + NN_VALIDATE_LT(role.ioIndex, outputIndexes.size()); + NN_VALIDATE_GT(role.frequency, 0.0f); + NN_VALIDATE_LE(role.frequency, 1.0f); + const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex); + NN_VALIDATE(success); + operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]); + } + + CHECK(!operands.empty()); + const auto opType = operands.front().type; + + Dimensions dimensions = desc.dimensions; + for (const auto& operand : operands) { + NN_VALIDATE_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type; + NN_VALIDATE_EQ(operand.scale, operands.front().scale); + NN_VALIDATE_EQ(operand.zeroPoint, operands.front().zeroPoint); + // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type. + if (!isExtension(opType)) { + NN_VALIDATE_EQ(operand.extraParams, operands.front().extraParams) + << operand.extraParams << " vs " << operands.front().extraParams; + } + dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions)); + } + + // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type. + if (!isExtension(opType)) { + NN_VALIDATE(!isNonExtensionScalar(opType) || dimensions.empty()) + << "invalid dimensions with scalar operand type."; + } + + if (preparedModelRoles != nullptr) { + *preparedModelRoles = std::move(roles); + } + if (combinedOperand != nullptr) { + *combinedOperand = operands.front(); + combinedOperand->dimensions = dimensions; + } + return Version::ANDROID_R; +} + +// TODO: Enable this block of code once canonical types are integrated in the rest of the NNAPI +// codebase. +#if 0 +class OperationValidationContext : public IOperationValidationContext { + DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext); + + public: + OperationValidationContext(const char* operationName, const std::vector<uint32_t>& + inputIndexes, + const std::vector<uint32_t>& outputIndexes, + const std::vector<Operand>& operands, Version version) + : operationName(operationName), + inputIndexes(inputIndexes), + outputIndexes(outputIndexes), + operands(operands), + version(version) {} + + const char* getOperationName() const override; + Version getVersion() const override; + + uint32_t getNumInputs() const override; + OperandType getInputType(uint32_t index) const override; + Shape getInputShape(uint32_t index) const override; + const Operand::ExtraParams getInputExtraParams(uint32_t index) const override; + + uint32_t getNumOutputs() const override; + OperandType getOutputType(uint32_t index) const override; + Shape getOutputShape(uint32_t index) const override; + + private: + const Operand* getInputOperand(uint32_t index) const; + const Operand* getOutputOperand(uint32_t index) const; + + const char* operationName; + const std::vector<uint32_t>& inputIndexes; + const std::vector<uint32_t>& outputIndexes; + const std::vector<Operand>& operands; + Version version; +}; + +const char* OperationValidationContext::getOperationName() const { + return operationName; +} + +Version OperationValidationContext::getVersion() const { + return version; +} + +const Operand* OperationValidationContext::getInputOperand(uint32_t index) const { + return &operands.at(inputIndexes.at(index)); +} + +const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const { + return &operands.at(outputIndexes.at(index)); +} + +uint32_t OperationValidationContext::getNumInputs() const { + auto count = inputIndexes.size(); + CHECK_LE(count, std::numeric_limits<uint32_t>::max()); + return static_cast<uint32_t>(count); +} + +uint32_t OperationValidationContext::getNumOutputs() const { + auto count = outputIndexes.size(); + CHECK_LE(count, std::numeric_limits<uint32_t>::max()); + return static_cast<uint32_t>(count); +} + +OperandType OperationValidationContext::getInputType(uint32_t index) const { + return getInputOperand(index)->type; +} + +Shape OperationValidationContext::getInputShape(uint32_t index) const { + const Operand* operand = getInputOperand(index); + return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, + operand->extraParams}; +} + +const Operand::ExtraParams OperationValidationContext::getInputExtraParams(uint32_t index) const +{ + return getInputOperand(index)->extraParams; +} + +OperandType OperationValidationContext::getOutputType(uint32_t index) const { + return getOutputOperand(index)->type; +} + +Shape OperationValidationContext::getOutputShape(uint32_t index) const { + const Operand* operand = getOutputOperand(index); + return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, + operand->extraParams}; +} +#endif + +// TODO(b/169345292): reduce the duplicate validation here + +Result<void> validateOperandSymmPerChannelQuantParamsImpl( + const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, + const char* tag) { + if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { + NN_VALIDATE_FAIL(); + } + + NN_VALIDATE_LT(channelQuant.channelDim, operand.dimensions.size()) << tag; + NN_VALIDATE(!channelQuant.scales.empty()) << tag; + NN_VALIDATE_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag; + NN_VALIDATE_NE(operand.dimensions[channelQuant.channelDim], 0u) + << tag << " channel dimension " << channelQuant.channelDim << " is underspecified"; + for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) { + NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]"; + } + return {}; +} + +Result<void> validateScalarDimensions(const Operand& type, const char* tag) { + NN_VALIDATE(type.dimensions.empty()) << tag << " invalid dimensions for scalar type"; + return {}; +} + +Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) { + NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 255) + << tag << " invalid zeroPoint: " << type.zeroPoint; + NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; + return {}; +} + +Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) { + NN_VALIDATE(-128 <= type.zeroPoint && type.zeroPoint <= 127) + << tag << " invalid zeroPoint: " << type.zeroPoint; + NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; + return {}; +} + +Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) { + NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint; + NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; + return {}; +} + +Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) { + NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 65535) + << tag << " invalid zeroPoint: " << type.zeroPoint; + NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; + return {}; +} + +Result<void> validateQuantSymmParams(const Operand& type, const char* tag) { + NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero"; + NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; + return {}; +} + +Result<void> validateNoQuantParams(const Operand& type, const char* tag) { + NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero"; + NN_VALIDATE_EQ(type.scale, 0.0f) << tag << " scale is not zero"; + return {}; +} + +Result<void> validateTensorDimensions( + const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo, + const char* tag, bool allowPartial) { + if (!allowPartial) { + NN_VALIDATE(!type.dimensions.empty()) << tag << " invalid operand dimensions"; + } + uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize + : getNonExtensionSize(type.type); + constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max(); + for (size_t i = 0; i < type.dimensions.size(); i++) { + if (!allowPartial) { + NN_VALIDATE_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions"; + } + if (type.dimensions[i] != 0) { + size *= type.dimensions[i]; + NN_VALIDATE_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize; + } + } + return {}; +} + +Result<void> validateOperandTypeImpl( + const Operand& type, + const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag, + bool allowPartial) { + if (isExtension(type.type)) { + NN_VALIDATE(extensionOperandTypeInfo != nullptr); + if (extensionOperandTypeInfo->isTensor) { + NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial)); + } else { + NN_TRY(validateScalarDimensions(type, tag)); + } + return validateNoQuantParams(type, tag); + } + + NN_VALIDATE(extensionOperandTypeInfo == nullptr); + NN_TRY(validateOperandType(type.type)); + + if (isNonExtensionScalar(type.type)) { + NN_TRY(validateScalarDimensions(type, tag)); + if (type.type != OperandType::OEM) { // Historically, we have allowed OEM types + // to use quantization parameters. + NN_TRY(validateNoQuantParams(type, tag)); + } + } else { + NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial)); + if (type.type == OperandType::TENSOR_QUANT8_ASYMM) { + NN_TRY(validateQuant8AsymmParams(type, tag)); + } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + NN_TRY(validateQuant8AsymmSignedParams(type, tag)); + } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) { + NN_TRY(validateQuant8SymmParams(type, tag)); + } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) { + NN_TRY(validateQuant16AsymmParams(type, tag)); + } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) { + NN_TRY(validateQuantSymmParams(type, tag)); + } else if (type.type == OperandType::TENSOR_INT32 || + type.type == OperandType::TENSOR_OEM_BYTE) { + // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters. + // Historically, we have allowed OEM types to use quantization parameters. + } else { + NN_TRY(validateNoQuantParams(type, tag)); + } + } + + return {}; +} + +Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount, + const char* tag) { + for (size_t i = 0; i < list.size(); i++) { + NN_VALIDATE_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = " + << list[i] << ", operandCount " << operandCount; + } + return {}; +} + +Result<void> validateOperationOperandTypes(const std::vector<Operand>& operands, + const std::vector<uint32_t>& inputIndexes, + const std::vector<OperandType>& inExpectedTypes, + const std::vector<uint32_t>& outputIndexes, + const std::vector<OperandType>& outExpectedInTypes) { + NN_VALIDATE_EQ(inputIndexes.size(), inExpectedTypes.size()) + << "Wrong operand count: expected " << inputIndexes.size() << " inputs, got " + << inputIndexes.size() << " inputs"; + NN_VALIDATE_EQ(outputIndexes.size(), outExpectedInTypes.size()) + << "Wrong operand count: expected " << outputIndexes.size() << " outputs, got " + << outputIndexes.size() << " outputs"; + for (size_t i = 0; i < inputIndexes.size(); i++) { + NN_VALIDATE_EQ(operands[inputIndexes[i]].type, inExpectedTypes[i]) + << "Invalid input tensor type " << operands[inputIndexes[i]].type << " for input " + << i << ", expected " << inExpectedTypes[i]; + } + for (size_t i = 0; i < outputIndexes.size(); i++) { + NN_VALIDATE_EQ(operands[outputIndexes[i]].type, outExpectedInTypes[i]) + << "Invalid output tensor type " << operands[outputIndexes[i]].type << " for input " + << i << ", expected " << outExpectedInTypes[i]; + } + + return {}; +} + +Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs, + const Operand& modelOperand) { + NN_VALIDATE_EQ(modelOperand.type, OperandType::SUBGRAPH) + << "Unexpected operand type: " << modelOperand.type; + NN_VALIDATE_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference"; + return {}; +} +const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs, + const Operand& modelOperand) { + return subgraphs.at(modelOperand.location.offset); +} +uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) { + return getSubgraph(subgraphs, modelOperand).inputIndexes.size(); +} +uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs, + const Operand& modelOperand) { + return getSubgraph(subgraphs, modelOperand).outputIndexes.size(); +} +const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs, + const Operand& modelOperand, uint32_t index) { + const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand); + return subgraph.operands.at(subgraph.inputIndexes.at(index)); +} +const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs, + const Operand& modelOperand, uint32_t index) { + const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand); + return subgraph.operands.at(subgraph.outputIndexes.at(index)); +} + +// Checks if two operands have the same types, ranks (if specified), dimensions +// (if specified), scales, zeroPoints, and extraParams. +Result<void> compatible(const Operand& a, const Operand& b) { + NN_VALIDATE_EQ(a.type, b.type) << a.type << " != " << b.type; + if (!a.dimensions.empty() && !b.dimensions.empty()) { + NN_VALIDATE_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions"; + for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) { + if (a.dimensions[i] != 0 && b.dimensions[i] != 0) { + NN_VALIDATE_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions"; + } + } + } + NN_VALIDATE_EQ(a.scale, b.scale); + NN_VALIDATE_EQ(a.zeroPoint, b.zeroPoint); + NN_VALIDATE_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams; + return {}; +} + +Result<void> validateConditionOperand(const Operand& operand) { + NN_VALIDATE_EQ(operand.type, OperandType::TENSOR_BOOL8) + << "Unexpected condition operand type: " << operand.type; + NN_VALIDATE_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton"; + NN_VALIDATE_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton"; + return {}; +} + +Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs, + const std::vector<uint32_t>& outputs, + const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs) { + namespace op = operation_if; + NN_VALIDATE_GE(inputs.size(), 3u) << "IF must have at least 3 inputs"; + NN_VALIDATE_GE(outputs.size(), 1u) << "IF must have at least 1 output"; + auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> { + auto result = validateSubgraphReference(subgraphs, branchModelOperand); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() + << "Operand is not a valid subgraph reference"; + } + const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand); + const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand); + NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + branchModelInputCount); + NN_VALIDATE_EQ(outputs.size(), branchModelOutputCount); + for (uint32_t i = 0; i < branchModelInputCount; ++i) { + const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i); + const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; + NN_TRY(compatible(innerOperand, outerOperand)); + } + for (uint32_t i = 0; i < branchModelOutputCount; ++i) { + const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i); + const Operand& outerOperand = operands[outputs[i]]; + NN_TRY(compatible(innerOperand, outerOperand)); + } + return {}; + }; + auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() + << "Validation failed for IF condition operand"; + } + result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << "Validation failed for IF then model"; + } + result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << "Validation failed for IF else model"; + } + return Version::ANDROID_R; +} + +Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) { + if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) { + // 1.3 HAL (corresponding to Version::ANDROID_R) does not support CF operations with + // operands of unknown size. See http://b/132458982#comment63. + return Version::CURRENT_RUNTIME; + } + return Version::ANDROID_R; +} + +Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs, + const std::vector<uint32_t>& outputs, + const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs) { + // Let the loop have + // - m >= 1 input-output operands, + // - k >= 0 state-only operands, and + // - n >= 0 input-only operands. + // Then + // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs. + // - the condition model has (m + k + n) inputs and 1 output. + // - the body model has (m + k + n) inputs and (m + k) outputs. + namespace op = operation_while; + NN_VALIDATE_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs"; + NN_VALIDATE_GE(outputs.size(), 1u) << "WHILE must have at least 1 output"; + auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> { + Version version = Version::ANDROID_R; + auto result = validateSubgraphReference(subgraphs, condModelOperand); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() + << "Operand is not a valid subgraph reference"; + } + const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand); + const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand); + NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + condModelInputCount); + NN_VALIDATE_EQ(condModelOutputCount, 1u); + for (uint32_t i = 0; i < condModelInputCount; ++i) { + const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i); + const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; + NN_TRY(compatible(innerOperand, outerOperand)); + version = combineVersions(version, + NN_TRY(validateControlFlowOperandUnknownSize(innerOperand))); + version = combineVersions(version, + NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); + } + NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0))); + return version; + }; + auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> { + Version version = Version::ANDROID_R; + auto result = validateSubgraphReference(subgraphs, bodyModelOperand); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() + << "Operand is not a valid subgraph reference"; + } + const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand); + const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand); + NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount); + NN_VALIDATE_GE(bodyModelOutputCount, outputs.size()); + NN_VALIDATE_GE(bodyModelInputCount, bodyModelOutputCount); + const uint32_t inputOutputCount = outputs.size(); + const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount; + const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount; + for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) { + const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i); + const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; + NN_TRY(compatible(innerOperand, outerOperand)); + version = combineVersions(version, + NN_TRY(validateControlFlowOperandUnknownSize(innerOperand))); + version = combineVersions(version, + NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); + } + for (uint32_t i = 0; i < inputOutputCount; ++i) { + const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i); + const Operand& outerOperand = operands[outputs[i]]; + NN_TRY(compatible(innerOperand, outerOperand)); + version = combineVersions(version, + NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); + } + for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) { + const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i); + const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i); + NN_TRY(compatible(inputOperand, outputOperand)); + version = combineVersions(version, + NN_TRY(validateControlFlowOperandUnknownSize(outputOperand))); + } + return version; + }; + auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() + << "Validation failed for WHILE condition model"; + } + auto version = result.value(); + result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]); + if (!result.has_value()) { + return NN_ERROR() << std::move(result).error() << "Validation failed for WHILE body model"; + } + version = combineVersions(version, result.value()); + return version; +} + +Result<Version> validateOperationImpl(const Operation& operation, + const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs) { + const auto opType = operation.type; + const auto& inputIndexes = operation.inputs; + const auto& outputIndexes = operation.outputs; + + NN_TRY(validateOperandListImpl(inputIndexes, operands.size(), + "ANeuralNetworksModel_addOperation inputs")); + NN_TRY(validateOperandListImpl(outputIndexes, operands.size(), + "ANeuralNetworksModel_addOperation outputs")); + + if (isExtension(opType)) { + // There is no other validation we can do for an extension operation. + return Version::ANDROID_Q; + } + + auto invalidInOutNumberMessage = [opType, &inputIndexes, &outputIndexes](int expIn, + int expOut) { + std::ostringstream os; + os << "Invalid number of input operands (" << inputIndexes.size() << ", expected " << expIn + << ") or output operands (" << outputIndexes.size() << ", expected " << expOut + << ") for operation " << opType; + return os.str(); + }; + + switch (opType) { + case OperationType::OEM_OPERATION: { + return Version::ANDROID_OC_MR1; + } + case OperationType::RESHAPE: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32}; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, + OperandType::TENSOR_INT32}; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + const auto inputRank = operands[inputIndexes[0]].dimensions.size(); + NN_VALIDATE_LE(inputRank, 4u) + << "Unsupported input tensor rank for operation " << opType; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::DEPTH_TO_SPACE: { + NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && + outputIndexes.size() == 1) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 3 or 2) or output operands (" << outputIndexes.size() + << ", expected 1) for operation " << opType; + auto inputType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + if (inputIndexes.size() == 3) { + inExpectedTypes.push_back(OperandType::BOOL); + version = combineVersions(version, Version::ANDROID_Q); + } else { + version = combineVersions(version, Version::ANDROID_OC_MR1); + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::SPACE_TO_DEPTH: { + NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && + outputIndexes.size() == 1) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 3 or 2) or output operands (" << outputIndexes.size() + << ", expected 1) for operation " << opType; + auto inputType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + if (inputIndexes.size() == 3) { + inExpectedTypes.push_back(OperandType::BOOL); + version = combineVersions(version, Version::ANDROID_Q); + } else { + version = combineVersions(version, Version::ANDROID_OC_MR1); + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::EMBEDDING_LOOKUP: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[1]].type; + NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) + << "Unsupported input tensor type for operation " << opType; + Version version; + std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType}; + std::vector<OperandType> outExpectedTypes = {inputType}; + if (inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else if (inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM) { + version = Version::ANDROID_Q; + } else { + version = Version::ANDROID_OC_MR1; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::HASHTABLE_LOOKUP: { + NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 2) + << invalidInOutNumberMessage(3, 2); + auto inputType = operands[inputIndexes[2]].type; + NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM) + << "Unsupported input tensor type for operation " << opType; + std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, inputType}; + std::vector<OperandType> outExpectedTypes = {inputType, + OperandType::TENSOR_QUANT8_ASYMM}; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return Version::ANDROID_OC_MR1; + } + case OperationType::LSH_PROJECTION: { + NN_VALIDATE(inputIndexes.size() == 4 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(4, 1); + auto inputType = operands[inputIndexes[1]].type; + NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM) + << "Unsupported input tensor type for operation " << opType; + auto hashType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + if (hashType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT16, + inputType, + OperandType::TENSOR_FLOAT16, + OperandType::INT32, + }; + } else if (hashType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = { + OperandType::TENSOR_FLOAT32, + inputType, + OperandType::TENSOR_FLOAT32, + OperandType::INT32, + }; + } else { + NN_VALIDATE_FAIL() << "Unsupported hash tensor type for operation " << opType; + } + std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32}; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: { + const uint32_t kNumOutputs = 2; + const uint32_t kNumOutputsMerged = 1; + const uint32_t kNumOutputsWithState = 6; + const uint32_t kNumOutputsMergedWithState = 5; + NN_VALIDATE(inputIndexes.size() == 61 && + (outputIndexes.size() == kNumOutputs || + outputIndexes.size() == kNumOutputsMerged || + outputIndexes.size() == kNumOutputsWithState || + outputIndexes.size() == kNumOutputsMergedWithState)) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 61) or output operands (" << outputIndexes.size() + << ", expected 1, 2, 5 or 6) for operation " << opType; + + std::vector<OperandType> inExpectedTypes; + auto inputType = operands[inputIndexes[0]].type; + NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_FLOAT16) + << "Unsupported input tensor type for operation " << opType; + + inExpectedTypes = {}; + for (int i = 0; i < 48; ++i) { + inExpectedTypes.push_back(inputType); + } + inExpectedTypes.push_back(OperandType::INT32); + inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32 + ? OperandType::FLOAT32 + : OperandType::FLOAT16); + inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32 + ? OperandType::FLOAT32 + : OperandType::FLOAT16); + inExpectedTypes.push_back(OperandType::BOOL); + inExpectedTypes.push_back(OperandType::BOOL); + for (int i = 0; i < 8; ++i) { + inExpectedTypes.push_back(inputType); + } + + Version version = Version::ANDROID_Q; + if (outputIndexes.size() == kNumOutputsWithState || + outputIndexes.size() == kNumOutputsMergedWithState) { + version = Version::ANDROID_R; + } + std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType); + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::LSTM: { + NN_VALIDATE((inputIndexes.size() == 23 || inputIndexes.size() == 27) && + outputIndexes.size() == 4) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 23 or 27) or output operands (" << outputIndexes.size() + << ", expected 4) for operation " << opType; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + auto inputType = operands[inputIndexes[0]].type; + NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_FLOAT16) + << "Unsupported input tensor type for operation " << opType; + + Version version = Version::ANDROID_OC_MR1; + inExpectedTypes = {inputType, inputType, inputType, inputType, inputType, + inputType, inputType, inputType, inputType, inputType, + inputType, inputType, inputType, inputType, inputType, + inputType, inputType, inputType, inputType, inputType, + OperandType::INT32}; + if (inputType == OperandType::TENSOR_FLOAT32) { + inExpectedTypes.push_back(OperandType::FLOAT32); + inExpectedTypes.push_back(OperandType::FLOAT32); + } else { + version = Version::ANDROID_Q; + inExpectedTypes.push_back(OperandType::FLOAT16); + inExpectedTypes.push_back(OperandType::FLOAT16); + } + + outExpectedTypes = {inputType, inputType, inputType, inputType}; + if (inputIndexes.size() == 23) { + version = combineVersions(version, Version::ANDROID_OC_MR1); + } else { + version = combineVersions(version, Version::ANDROID_Q); + for (int i = 0; i < 4; ++i) { + inExpectedTypes.push_back(inputType); + } + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::QUANTIZED_16BIT_LSTM: { + NN_VALIDATE(inputIndexes.size() == 15 && outputIndexes.size() == 2) + << invalidInOutNumberMessage(15, 2); + std::vector<OperandType> inExpectedTypes = { + OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, + OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, + OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, + OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, + OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_SYMM, + OperandType::TENSOR_QUANT8_ASYMM}; + std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM, + OperandType::TENSOR_QUANT8_ASYMM}; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return Version::ANDROID_Q; + } + case OperationType::RANDOM_MULTINOMIAL: { + NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(3, 1); + OperandType inputType = operands[inputIndexes[0]].type; + std::vector<OperandType> inExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_FLOAT16) { + inExpectedTypes = {inputType, OperandType::INT32, OperandType::TENSOR_INT32}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32}; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return Version::ANDROID_Q; + } + case OperationType::RNN: { + NN_VALIDATE(inputIndexes.size() == 6 && outputIndexes.size() == 2) + << invalidInOutNumberMessage(6, 2); + OperandType inputType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_OC_MR1; + inExpectedTypes = { + OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_FLOAT32, OperandType::INT32, + }; + outExpectedTypes = { + OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_FLOAT32, + }; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_FLOAT16, OperandType::INT32, + }; + outExpectedTypes = { + OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_FLOAT16, + }; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::SVDF: { + NN_VALIDATE(inputIndexes.size() == 7 && outputIndexes.size() == 2) + << invalidInOutNumberMessage(7, 2); + Version version; + OperandType inputType = operands[inputIndexes[0]].type; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_OC_MR1; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + std::vector<OperandType> inExpectedTypes = { + inputType, inputType, inputType, inputType, + inputType, OperandType::INT32, OperandType::INT32, + }; + std::vector<OperandType> outExpectedTypes = {inputType, inputType}; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::BATCH_TO_SPACE_ND: { + NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && + outputIndexes.size() == 1) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 3 or 2) or output operands (" << outputIndexes.size() + << ", expected 1) for operation " << opType; + auto inputType = operands[inputIndexes[0]].type; + Version version = Version::ANDROID_OC_MR1; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + inExpectedTypes = { + OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { + inExpectedTypes = { + OperandType::TENSOR_QUANT8_ASYMM, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + inExpectedTypes = { + OperandType::TENSOR_QUANT8_ASYMM_SIGNED, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + if (inputIndexes.size() == 3) { + inExpectedTypes.push_back(OperandType::BOOL); + version = combineVersions(version, Version::ANDROID_Q); + } else { + version = combineVersions(version, Version::ANDROID_P); + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::SPACE_TO_BATCH_ND: { + NN_VALIDATE((inputIndexes.size() == 4 || inputIndexes.size() == 3) && + outputIndexes.size() == 1) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 4 or 3) or output operands (" << outputIndexes.size() + << ", expected 1) for operation " << opType; + auto inputType = operands[inputIndexes[0]].type; + Version version = Version::ANDROID_OC_MR1; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + inExpectedTypes = { + OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { + if (operands[inputIndexes[0]].zeroPoint != 0) { + version = Version::ANDROID_Q; + } + inExpectedTypes = { + OperandType::TENSOR_QUANT8_ASYMM, + OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + inExpectedTypes = { + OperandType::TENSOR_QUANT8_ASYMM_SIGNED, + OperandType::TENSOR_INT32, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + if (inputIndexes.size() == 4) { + inExpectedTypes.push_back(OperandType::BOOL); + version = combineVersions(version, Version::ANDROID_Q); + } else { + version = combineVersions(version, Version::ANDROID_P); + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::PAD: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_P; + inExpectedTypes = { + OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + if (operands[inputIndexes[0]].zeroPoint == 0) { + version = Version::ANDROID_P; + } else { + version = Version::ANDROID_Q; + } + } + inExpectedTypes = { + inputType, + OperandType::TENSOR_INT32, + }; + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + const auto inputRank = operands[inputIndexes[0]].dimensions.size(); + NN_VALIDATE_LE(inputRank, 4u) + << "Unsupported input tensor rank for operation " << opType; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::PAD_V2: { + NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(3, 1); + auto inputType = operands[inputIndexes[0]].type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_INT32, + OperandType::FLOAT32, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + inExpectedTypes = { + OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_INT32, + OperandType::FLOAT16, + }; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + inExpectedTypes = { + inputType, + OperandType::TENSOR_INT32, + OperandType::INT32, + }; // TODO(b/116699425): Make it UINT8. + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + const auto inputRank = operands[inputIndexes[0]].dimensions.size(); + NN_VALIDATE_LE(inputRank, 4u) + << "Unsupported input tensor rank for operation " << opType; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::CAST: { + NN_VALIDATE(inputIndexes.size() == 1 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(1, 1); + auto inputOperand = operands[inputIndexes[0]]; + auto outputOperand = operands[outputIndexes[0]]; + auto inputType = inputOperand.type; + auto outputType = outputOperand.type; + Version version; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if ((inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM) && + (outputType == OperandType::TENSOR_FLOAT16 || + outputType == OperandType::TENSOR_FLOAT32 || + outputType == OperandType::TENSOR_INT32 || + outputType == OperandType::TENSOR_QUANT8_ASYMM)) { + version = Version::ANDROID_Q; + inExpectedTypes = {inputType}; + outExpectedTypes = {outputType}; + } else if (inputType == OperandType::TENSOR_BOOL8 || + inputType == OperandType::TENSOR_QUANT16_ASYMM || + inputType == OperandType::TENSOR_QUANT16_SYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED || + inputType == OperandType::TENSOR_QUANT8_SYMM) { + version = Version::ANDROID_R; + inExpectedTypes = {inputType}; + outExpectedTypes = {inputType}; // Only identity CAST is supported. + } else { + NN_VALIDATE_FAIL() << "Unsupported data type for operation " << opType; + } + // Validate that output shape is equal to input shape if dimensions + // are already known. + auto getNumberOfElements = [](const std::vector<uint32_t>& dims) { + if (dims.empty()) { + return 0; + } + return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + }; + NN_VALIDATE(inputOperand.dimensions.empty() || outputOperand.dimensions.empty() || + getNumberOfElements(outputOperand.dimensions) == 0 || + inputOperand.dimensions == outputOperand.dimensions); + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::MEAN: { + NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(3, 1); + const auto inputRank = operands[inputIndexes[0]].dimensions.size(); + NN_VALIDATE_LE(inputRank, 4u) + << "Unsupported input tensor rank for operation " << opType; + auto inputType = operands[inputIndexes[0]].type; + Version version; + if (inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM) { + version = Version::ANDROID_P; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + version = Version::ANDROID_Q; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + std::vector<OperandType> inExpectedTypes = {inputType, OperandType::TENSOR_INT32, + OperandType::INT32}; + std::vector<OperandType> outExpectedTypes = {inputType}; + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::ARGMAX: + case OperationType::ARGMIN: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[0]].type; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + inExpectedTypes = {inputType, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_INT32}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return Version::ANDROID_Q; + } + case OperationType::EXPAND_DIMS: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[0]].type; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + inExpectedTypes = {inputType, OperandType::INT32}; + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + Version version; + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::SPLIT: { + NN_VALIDATE_EQ(inputIndexes.size(), 3u) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 3)" << opType; + auto inputType = operands[inputIndexes[0]].type; + NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) + << "Unsupported input tensor type for operation " << opType; + Version version; + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32, + OperandType::INT32}; + std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType); + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::MAXIMUM: + case OperationType::MINIMUM: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + OperandType inputType = operands[inputIndexes[0]].type; + if (inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + inExpectedTypes = {inputType, inputType}; + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + Version version; + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::GROUPED_CONV_2D: { + NN_VALIDATE((inputIndexes.size() == 12 || inputIndexes.size() == 9) && + outputIndexes.size() == 1) + << "Invalid number of input operands (" << inputIndexes.size() + << ", expected 12 or 9) or output operands (" << outputIndexes.size() + << ", expected 1) for operation " << opType; + auto inputType = operands[inputIndexes[0]].type; + auto filterType = operands[inputIndexes[1]].type; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT32) { + inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, + OperandType::TENSOR_FLOAT32, OperandType::INT32, + OperandType::INT32, OperandType::INT32, + OperandType::INT32, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT32}; + } else if (inputType == OperandType::TENSOR_FLOAT16) { + inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, + OperandType::TENSOR_FLOAT16, OperandType::INT32, + OperandType::INT32, OperandType::INT32, + OperandType::INT32, OperandType::INT32}; + outExpectedTypes = {OperandType::TENSOR_FLOAT16}; + } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + NN_VALIDATE(filterType == inputType || + filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) + << "Unsupported filter tensor type for operation " << opType; + + NN_VALIDATE(filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || + std::get<Operand::SymmPerChannelQuantParams>( + operands[inputIndexes[1]].extraParams) + .channelDim == 0) + << "Unsupported filter tensor channel dimension for operation " << opType; + + inExpectedTypes = { + inputType, filterType, OperandType::TENSOR_INT32, + OperandType::INT32, OperandType::INT32, OperandType::INT32, + OperandType::INT32, OperandType::INT32}; + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + + if (inputIndexes.size() == 12) { + std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32); + inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(), + explicitScalarTypes.end()); + } + inExpectedTypes.push_back(OperandType::BOOL); + Version version; + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::TILE: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[0]].type; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32 || + inputType == OperandType::TENSOR_INT32 || + inputType == OperandType::TENSOR_QUANT8_ASYMM || + inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + inExpectedTypes = {inputType, OperandType::TENSOR_INT32}; + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + Version version; + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::POW: { + NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) + << invalidInOutNumberMessage(2, 1); + auto inputType = operands[inputIndexes[0]].type; + std::vector<OperandType> inExpectedTypes; + std::vector<OperandType> outExpectedTypes; + if (inputType == OperandType::TENSOR_FLOAT16 || + inputType == OperandType::TENSOR_FLOAT32) { + inExpectedTypes = {inputType, inputType}; + outExpectedTypes = {inputType}; + } else { + NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; + } + Version version; + if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { + version = Version::ANDROID_R; + } else { + version = Version::ANDROID_Q; + } + NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, + outputIndexes, outExpectedTypes)); + return version; + } + case OperationType::IF: { + return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs); + } + case OperationType::WHILE: { + return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs); + } + default: { + // TODO: Enable this block of code once canonical types are integrated in the rest of + // the NNAPI codebase. +#if 0 + const OperationRegistration* operationRegistration = + BuiltinOperationResolver::get()->findOperation( + static_cast<OperationType>(opType)); + // TODO: return ErrorStatus::UNEXPECTED_NULL + NN_VALIDATE(operationRegistration != nullptr) << opType << " not registered"; + // TODO: return ErrorStatus::UNEXPECTED_NULL + NN_VALIDATE(operationRegistration->validate != nullptr) + << "Incomplete operation registration: " << opType; + OperationValidationContext context(operationRegistration->name, inputIndexes, + outputIndexes, operands); + auto result = operationRegistration->validate(&context); + if (!result.has_value()) { + return NN_ERROR() << "Validation failed for operation " << opType << ": " + << std::move(result).error(); + } + return result; +#endif + NN_VALIDATE_FAIL() << "Validation for " << opType << " is not yet implemented"; + } + } +} + +} // anonymous namespace + +Version combineVersions(Version lhs, Version rhs) { + return std::max<Version>(lhs, rhs); +} + +Result<Version> validate(const DeviceStatus& deviceStatus) { + return validateDeviceStatus(deviceStatus); +} + +Result<Version> validate(const ExecutionPreference& executionPreference) { + return validateExecutionPreference(executionPreference); +} + +Result<Version> validate(const DeviceType& deviceType) { + return validateDeviceType(deviceType); +} + +Result<Version> validate(const MeasureTiming& measureTiming) { + return validateMeasureTiming(measureTiming); +} + +Result<Version> validate(const Priority& priority) { + return validatePriority(priority); +} + +Result<Version> validate(const ErrorStatus& errorStatus) { + return validateErrorStatus(errorStatus); +} + +Result<Version> validate(const OutputShape& outputShape) { + return validateOutputShape(outputShape); +} + +Result<Version> validate(const Timing& timing) { + return validateTiming(timing); +} + +Result<Version> validate(const Capabilities& capabilities) { + return validateCapabilities(capabilities); +} + +Result<Version> validate(const Extension& extension) { + return validateExtension(extension); +} + +Result<Version> validate(const NativeHandle& handle) { + return validateNativeHandle(handle); +} + +Result<Version> validate(const Memory& memory) { + return validateMemory(memory); +} + +Result<Version> validate(const Model& model) { + return validateModel(model); +} + +Result<Version> validate(const BufferDesc& bufferDesc) { + return validateBufferDesc(bufferDesc); +} + +Result<Version> validate(const BufferRole& bufferRole) { + return validateBufferRole(bufferRole); +} + +Result<Version> validate(const Request& request) { + return validateRequest(request); +} + +Result<Version> validate(const OptionalTimePoint& optionalTimePoint) { + return validateOptionalTimePoint(optionalTimePoint); +} + +Result<Version> validate(const OptionalTimeoutDuration& optionalTimeoutDuration) { + return validateOptionalTimeoutDuration(optionalTimeoutDuration); +} + +Result<Version> validate(const std::vector<OutputShape>& outputShapes) { + return validateVector(outputShapes, validateOutputShape); +} + +Result<Version> validate(const std::vector<Extension>& extensions) { + return validateExtensions(extensions); +} + +Result<Version> validate(const std::vector<NativeHandle>& handles) { + return validateVector(handles, validateNativeHandle); +} + +Result<Version> validate(const std::vector<BufferRole>& bufferRoles) { + return validateVector(bufferRoles, validateBufferRole); +} + +Result<Version> validateRequestForModel(const Request& request, const Model& model) { + return validateRequestForModelImpl(request, model); +} + +Result<Version> validateMemoryDesc( + const BufferDesc& desc, + const std::vector<std::shared_ptr<const IPreparedModel>>& preparedModels, + const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles, + const std::function<const Model*(const std::shared_ptr<const IPreparedModel>&)>& getModel, + std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) { + return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel, + preparedModelRoles, combinedOperand); +} + +Result<void> validateOperandSymmPerChannelQuantParams( + const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, + const char* tag) { + return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag); +} + +Result<void> validateOperandType(const Operand& type, + const Extension::OperandTypeInformation* extensionOperandTypeInfo, + const char* tag, bool allowPartial) { + return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial); +} + +Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount, + const char* tag) { + return validateOperandListImpl(list, operandCount, tag); +} + +Result<Version> validateOperation(const Operation& operation, const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs) { + return validateOperationImpl(operation, operands, subgraphs); +} + +} // namespace android::nn diff --git a/nn/common/include/nnapi/Result.h b/nn/common/include/nnapi/Result.h new file mode 100644 index 000000000..e232cef3f --- /dev/null +++ b/nn/common/include/nnapi/Result.h @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_RESULT_H +#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_RESULT_H + +#include <android-base/expected.h> + +#include <optional> +#include <sstream> +#include <string> +#include <utility> + +namespace android::nn { + +/** + * Type alias for `::android::base::expected` where the unexpected state is represented by a + * std::string describing the error. + * + * See the following file for more information on ::android::base::expected: + * system/libbase/include/android-base/expected.h + */ +template <typename Type> +using Result = base::expected<Type, std::string>; + +namespace detail { + +class ErrorBuilder { + public: + template <typename T, typename E> + operator base::expected<T, E>() const /* NOLINT(google-explicit-constructor) */ { + return base::unexpected<E>(std::move(mStream).str()); + } + + template <typename T> + ErrorBuilder operator<<(const T& t) { + mStream << t; + return std::move(*this); + } + + private: + std::ostringstream mStream; +}; + +} // namespace detail + +/** + * Creates an error builder for the case where no arguments are provided. + */ +inline detail::ErrorBuilder error() { + return detail::ErrorBuilder(); +} + +/** + * Helper macro that will create an error builder already populated with the file name and line + * number. + * + * This macro uses the following customization points: + * * `::android::nn::error` is a set of functions that can be customized to return a specialized + * error builder object. Customization is based on the types of arguments passed and the number + * of arguments passed to `error`. + * + * Usage at error site: + * if (errorDetected) { + * return NN_ERROR() << "<error_message>"; + * } + * return <regular_return_value>; + */ +#define NN_ERROR(...) \ + [] { \ + using ::android::nn::error; \ + return error(__VA_ARGS__) << __FILE__ << ":" << __LINE__ << ": "; \ + }() + +template <typename T, typename E> +bool nnTryHasValue(const base::expected<T, E>& o) { + return o.has_value(); +} + +template <typename T, typename E> +T nnTryGetValue(base::expected<T, E> o) { + return std::move(o).value(); +} + +template <typename T, typename E> +base::unexpected<E> nnTryGetError(base::expected<T, E> o) { + return base::unexpected(std::move(o).error()); +} + +template <typename T> +bool nnTryHasValue(const std::optional<T>& o) { + return o.has_value(); +} + +template <typename T> +T nnTryGetValue(std::optional<T> o) { + return std::move(o).value(); +} + +template <typename T> +std::nullopt_t nnTryGetError(std::optional<T> /*o*/) { + return std::nullopt; +} + +/** + * A macro that will exit from the current function if `expr` is unexpected or return the expected + * value from the macro if `expr` is expected. + * + * This macro can currently be used on `::android::nn::Result`, `::android::base::expected`, or + * `std::optional` values. To enable this macro to be used with other values, implement the + * following functions for the type: + * * `::android::nn::nnTryHasValue` returns `true` if the `expr` holds a successful value, false if + * the `expr` value holds an error + * * `::android::nn::nnTryGetError` returns the successful value of `expr` or crashes + * * `::android::nn::nnTryGetValue` returns the error value of `expr` or crashes + * + * Usage at call site: + * const auto [a, b, c] = NN_TRY(failableFunction(args)); + */ +#define NN_TRY(expr) \ + ({ \ + using ::android::nn::nnTryHasValue; \ + using ::android::nn::nnTryGetValue; \ + using ::android::nn::nnTryGetError; \ + auto nnTryTemporaryResult = expr; \ + if (!nnTryHasValue(nnTryTemporaryResult)) { \ + return nnTryGetError(std::move(nnTryTemporaryResult)); \ + } \ + nnTryGetValue(std::move(nnTryTemporaryResult)); \ + }) + +} // namespace android::nn + +#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_RESULT_H diff --git a/nn/common/include/nnapi/SharedMemory.h b/nn/common/include/nnapi/SharedMemory.h new file mode 100644 index 000000000..ee8330f45 --- /dev/null +++ b/nn/common/include/nnapi/SharedMemory.h @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_SHARED_MEMORY_H +#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_SHARED_MEMORY_H + +#include <any> +#include <optional> +#include <string> +#include <variant> +#include <vector> + +#include "nnapi/Result.h" +#include "nnapi/Types.h" + +// Forward declare AHardwareBuffer +extern "C" typedef struct AHardwareBuffer AHardwareBuffer; + +// Forward declare hidl_memory +namespace android::hardware { +struct hidl_memory; +} // namespace android::hardware + +namespace android::nn { + +class MutableMemoryBuilder { + public: + explicit MutableMemoryBuilder(uint32_t poolIndex); + + DataLocation append(size_t length); + bool empty() const; + + Result<Memory> finish(); + + private: + uint32_t mPoolIndex; + size_t mSize = 0; +}; + +class ConstantMemoryBuilder { + public: + explicit ConstantMemoryBuilder(uint32_t poolIndex); + + DataLocation append(const void* data, size_t length); + bool empty() const; + + Result<Memory> finish(); + + private: + struct LazyCopy { + const void* data; + size_t length; + size_t offset; + }; + + MutableMemoryBuilder mBuilder; + std::vector<LazyCopy> mSlices; +}; + +Result<Memory> createSharedMemory(size_t size); + +Result<Memory> createSharedMemoryFromFd(size_t size, int prot, int fd, size_t offset); + +Result<Memory> createSharedMemoryFromHidlMemory(const hardware::hidl_memory& memory); + +Result<Memory> createSharedMemoryFromAHWB(const AHardwareBuffer& ahwb); + +struct Mapping { + std::variant<void*, const void*> pointer; + size_t size; + std::any context; +}; + +Result<Mapping> map(const Memory& memory); + +bool flush(const Mapping& mapping); + +} // namespace android::nn + +#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_SHARED_MEMORY_H diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h new file mode 100644 index 000000000..58021a780 --- /dev/null +++ b/nn/common/include/nnapi/TypeUtils.h @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H +#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H + +#include <ostream> +#include <utility> +#include <vector> + +#include "nnapi/OperandTypes.h" +#include "nnapi/OperationTypes.h" +#include "nnapi/Result.h" +#include "nnapi/Types.h" + +namespace android::nn { + +bool isExtension(OperandType type); +bool isExtension(OperationType type); + +bool isNonExtensionScalar(OperandType operandType); + +size_t getNonExtensionSize(OperandType operandType); + +std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions); +std::optional<size_t> getNonExtensionSize(const Operand& operand); + +size_t getOffsetFromInts(int lower, int higher); +std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset); + +std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands, + const std::vector<nn::Operation>& operations); + +// Combine two tensor dimensions, both may have unspecified dimensions or rank. +Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs); + +// Set of output utility functions. +std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus); +std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference); +std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType); +std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming); +std::ostream& operator<<(std::ostream& os, const OperandType& operandType); +std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime); +std::ostream& operator<<(std::ostream& os, const OperationType& operationType); +std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime); +std::ostream& operator<<(std::ostream& os, const Priority& priority); +std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus); +std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape); +std::ostream& operator<<(std::ostream& os, const Timing& timing); +std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo); +std::ostream& operator<<(std::ostream& os, + const Capabilities::OperandPerformance& operandPerformance); +std::ostream& operator<<(std::ostream& os, + const Capabilities::OperandPerformanceTable& operandPerformances); +std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities); +std::ostream& operator<<(std::ostream& os, + const Extension::OperandTypeInformation& operandTypeInformation); +std::ostream& operator<<(std::ostream& os, const Extension& extension); +std::ostream& operator<<(std::ostream& os, const DataLocation& location); +std::ostream& operator<<(std::ostream& os, + const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams); +std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams); +std::ostream& operator<<(std::ostream& os, const Operand& operand); +std::ostream& operator<<(std::ostream& os, const Operation& operation); +std::ostream& operator<<(std::ostream& os, const NativeHandle& handle); +std::ostream& operator<<(std::ostream& os, const Memory& memory); +std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph); +std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues); +std::ostream& operator<<(std::ostream& os, + const Model::ExtensionNameAndPrefix& extensionNameAndPrefix); +std::ostream& operator<<(std::ostream& os, const Model& model); +std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc); +std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole); +std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument); +std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool); +std::ostream& operator<<(std::ostream& os, const Request& request); +std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint); +std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint); +std::ostream& operator<<(std::ostream& os, const TimeoutDuration& timeoutDuration); +std::ostream& operator<<(std::ostream& os, const OptionalTimeoutDuration& optionalTimeoutDuration); +std::ostream& operator<<(std::ostream& os, const Version& version); + +bool operator==(const Timing& a, const Timing& b); +bool operator!=(const Timing& a, const Timing& b); +bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b); +bool operator==(const Capabilities::OperandPerformance& a, + const Capabilities::OperandPerformance& b); +bool operator==(const Capabilities& a, const Capabilities& b); +bool operator==(const Extension::OperandTypeInformation& a, + const Extension::OperandTypeInformation& b); +bool operator==(const Extension& a, const Extension& b); +bool operator==(const Operand::SymmPerChannelQuantParams& a, + const Operand::SymmPerChannelQuantParams& b); +bool operator!=(const Operand::SymmPerChannelQuantParams& a, + const Operand::SymmPerChannelQuantParams& b); + +} // namespace android::nn + +#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H diff --git a/nn/common/include/nnapi/Types.h b/nn/common/include/nnapi/Types.h index 0536a2e2b..c0718dae1 100644 --- a/nn/common/include/nnapi/Types.h +++ b/nn/common/include/nnapi/Types.h @@ -32,6 +32,7 @@ #include "nnapi/OperandTypes.h" #include "nnapi/OperationTypes.h" +#include "nnapi/Result.h" namespace android::nn { @@ -141,7 +142,7 @@ struct Capabilities { }; class OperandPerformanceTable { public: - static std::optional<OperandPerformanceTable> create( + static Result<OperandPerformanceTable> create( std::vector<OperandPerformance> operandPerformances); PerformanceInfo lookup(OperandType type) const; @@ -290,7 +291,7 @@ using OptionalTimePoint = std::optional<TimePoint>; using TimeoutDuration = std::chrono::nanoseconds; using OptionalTimeoutDuration = std::optional<TimeoutDuration>; -enum class Version { ANDROID_OC_MR1, ANDROID_P, ANDROID_Q, ANDROID_R, CURRENT_RUNTIME, INVALID }; +enum class Version { ANDROID_OC_MR1, ANDROID_P, ANDROID_Q, ANDROID_R, CURRENT_RUNTIME }; } // namespace android::nn diff --git a/nn/common/include/nnapi/Validation.h b/nn/common/include/nnapi/Validation.h new file mode 100644 index 000000000..61a2a273c --- /dev/null +++ b/nn/common/include/nnapi/Validation.h @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H +#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H + +#include <memory> +#include <set> +#include <tuple> +#include <vector> + +#include "nnapi/Result.h" +#include "nnapi/Types.h" + +namespace android::nn { + +// Utility functions + +Version combineVersions(Version lhs, Version rhs); + +Result<Version> validate(const DeviceStatus& deviceStatus); +Result<Version> validate(const ExecutionPreference& executionPreference); +Result<Version> validate(const DeviceType& deviceType); +Result<Version> validate(const MeasureTiming& measureTiming); +Result<Version> validate(const Priority& priority); +Result<Version> validate(const ErrorStatus& errorStatus); +Result<Version> validate(const OutputShape& outputShape); +Result<Version> validate(const Timing& timing); +Result<Version> validate(const Capabilities& capabilities); +Result<Version> validate(const Extension& extension); +Result<Version> validate(const NativeHandle& handle); +Result<Version> validate(const Memory& memory); +Result<Version> validate(const Model& model); +Result<Version> validate(const BufferDesc& bufferDesc); +Result<Version> validate(const BufferRole& bufferRole); +Result<Version> validate(const Request& request); +Result<Version> validate(const OptionalTimePoint& optionalTimePoint); +Result<Version> validate(const OptionalTimeoutDuration& optionalTimeoutDuration); + +Result<Version> validate(const std::vector<OutputShape>& outputShapes); +Result<Version> validate(const std::vector<Extension>& extensions); +Result<Version> validate(const std::vector<NativeHandle>& handles); +Result<Version> validate(const std::vector<BufferRole>& bufferRoles); + +// Validate request applied to model. +Result<Version> validateRequestForModel(const Request& request, const Model& model); + +// Validate memory descriptor. +enum class IOType { INPUT, OUTPUT }; +using PreparedModelRole = std::tuple<const IPreparedModel*, IOType, uint32_t>; + +// Verifies that the input arguments to IDevice::allocate are valid. +// Optionally, this function can return a flattened prepared model roles and a combined operand. +// Pass nullptr if either value is not needed. +// IMPORTANT: This function cannot validate dimensions and extraParams with extension operand type. +// Each driver should do their own validation of extension type dimensions and extraParams. +Result<Version> validateMemoryDesc( + const BufferDesc& desc, + const std::vector<std::shared_ptr<const IPreparedModel>>& preparedModels, + const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles, + const std::function<const Model*(const std::shared_ptr<const IPreparedModel>&)>& getModel, + std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand); + +Result<void> validateOperandSymmPerChannelQuantParams( + const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, + const char* tag); + +// Validates an operand type. +// +// extensionOperandTypeInfo must be nullptr iff the type is not an extension type. +// +// If allowPartial is true, the dimensions may be underspecified. +Result<void> validateOperandType(const Operand& type, + const Extension::OperandTypeInformation* extensionOperandTypeInfo, + const char* tag, bool allowPartial); +Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount, + const char* tag); + +Result<Version> validateOperation(const Operation& operation, const std::vector<Operand>& operands, + const std::vector<Model::Subgraph>& subgraphs); + +} // namespace android::nn + +#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H |