summaryrefslogtreecommitdiff
path: root/nn/common/operations/TransposeConv2D.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/TransposeConv2D.cpp')
-rw-r--r--nn/common/operations/TransposeConv2D.cpp15
1 files changed, 9 insertions, 6 deletions
diff --git a/nn/common/operations/TransposeConv2D.cpp b/nn/common/operations/TransposeConv2D.cpp
index d67a473e6..0ee5d044c 100644
--- a/nn/common/operations/TransposeConv2D.cpp
+++ b/nn/common/operations/TransposeConv2D.cpp
@@ -25,7 +25,6 @@
#include <vector>
#include "CpuOperationUtils.h"
-#include "HalInterfaces.h"
#include "OperationResolver.h"
#include "Tracing.h"
@@ -46,8 +45,6 @@ constexpr uint32_t kOutputTensor = 0;
namespace {
-using namespace hal;
-
// If possible we will use this static buffer for the tensor.
constexpr size_t kStaticBufferSize = 1605632;
char static_scratch_buffer[kStaticBufferSize];
@@ -452,7 +449,9 @@ bool validate(const IOperationValidationContext* context) {
filterType == inputType)
<< "Unsupported filter tensor type for operation " << kOperationName;
if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
- NN_RET_CHECK_EQ(context->getInputExtraParams(kFilterTensor).channelQuant().channelDim,
+ NN_RET_CHECK_EQ(std::get<Operand::SymmPerChannelQuantParams>(
+ context->getInputExtraParams(kFilterTensor))
+ .channelDim,
0)
<< "Unsupported filter tensor channel dimension for operation "
<< kOperationName;
@@ -570,7 +569,9 @@ bool execute(IOperationExecutionContext* context) {
context->getInputShape(kInputTensor),
context->getInputBuffer<int8_t>(kFilterTensor),
context->getInputShape(kFilterTensor),
- context->getInputExtraParams(kFilterTensor).channelQuant().scales.data(),
+ std::get<Operand::SymmPerChannelQuantParams>(
+ context->getInputExtraParams(kFilterTensor))
+ .scales.data(),
context->getInputBuffer<int32_t>(kBiasTensor),
context->getInputShape(kBiasTensor), param,
context->getOutputBuffer<uint8_t>(kOutputTensor),
@@ -595,7 +596,9 @@ bool execute(IOperationExecutionContext* context) {
context->getInputShape(kInputTensor),
context->getInputBuffer<int8_t>(kFilterTensor),
context->getInputShape(kFilterTensor),
- context->getInputExtraParams(kFilterTensor).channelQuant().scales.data(),
+ std::get<Operand::SymmPerChannelQuantParams>(
+ context->getInputExtraParams(kFilterTensor))
+ .scales.data(),
context->getInputBuffer<int32_t>(kBiasTensor),
context->getInputShape(kBiasTensor), param,
context->getOutputBuffer<int8_t>(kOutputTensor),