summaryrefslogtreecommitdiff
path: root/nn/common/CpuExecutor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/CpuExecutor.cpp')
-rw-r--r--nn/common/CpuExecutor.cpp17
1 files changed, 15 insertions, 2 deletions
diff --git a/nn/common/CpuExecutor.cpp b/nn/common/CpuExecutor.cpp
index 2673f2de6..d8582ed58 100644
--- a/nn/common/CpuExecutor.cpp
+++ b/nn/common/CpuExecutor.cpp
@@ -991,6 +991,9 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo
}
} break;
case OperationType::EMBEDDING_LOOKUP: {
+ if (!allParametersPresent(2, 1)) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
const RunTimeOperandInfo& values = operands[ins[EmbeddingLookup::kValueTensor]];
const RunTimeOperandInfo& lookups = operands[ins[EmbeddingLookup::kLookupTensor]];
RunTimeOperandInfo& output = operands[outs[EmbeddingLookup::kOutputTensor]];
@@ -1002,6 +1005,9 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo
setInfoAndAllocateIfNeeded(&output, outputShape, &result) && lookup.Eval();
} break;
case OperationType::HASHTABLE_LOOKUP: {
+ if (!allParametersPresent(3, 2)) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
const RunTimeOperandInfo& lookups = operands[ins[HashtableLookup::kLookupTensor]];
const RunTimeOperandInfo& keys = operands[ins[HashtableLookup::kKeyTensor]];
const RunTimeOperandInfo& values = operands[ins[HashtableLookup::kValueTensor]];
@@ -1102,6 +1108,9 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo
setInfoAndAllocateIfNeeded(&output, outputShape, &result) && lstm_cell.Eval();
} break;
case OperationType::RANDOM_MULTINOMIAL: {
+ if (!allParametersPresent(3, 1)) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
const RunTimeOperandInfo& lookups = operands[ins[HashtableLookup::kLookupTensor]];
const RunTimeOperandInfo& keys = operands[ins[HashtableLookup::kKeyTensor]];
const RunTimeOperandInfo& values = operands[ins[HashtableLookup::kValueTensor]];
@@ -1115,6 +1124,10 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo
multinomial.Eval();
} break;
case OperationType::RNN: {
+ if (!allParametersPresent(6, 2)) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+
RunTimeOperandInfo& hiddenStateOut = operands[outs[RNN::kHiddenStateOutTensor]];
RunTimeOperandInfo& output = operands[outs[RNN::kOutputTensor]];
@@ -1409,8 +1422,8 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo
expand_dims::eval(input.buffer, input.shape(), axis, output.buffer, outShape);
} break;
case OperationType::SPLIT: {
- if (ins.size() != 3) {
- LOG(ERROR) << "Wrong input count";
+ const size_t outCount = outs.size();
+ if (!allParametersPresent(3, outCount)) {
return ANEURALNETWORKS_BAD_DATA;
}