diff options
Diffstat (limited to 'nn/common/CpuExecutor.cpp')
-rw-r--r-- | nn/common/CpuExecutor.cpp | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/nn/common/CpuExecutor.cpp b/nn/common/CpuExecutor.cpp index 2673f2de6..d8582ed58 100644 --- a/nn/common/CpuExecutor.cpp +++ b/nn/common/CpuExecutor.cpp @@ -991,6 +991,9 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo } } break; case OperationType::EMBEDDING_LOOKUP: { + if (!allParametersPresent(2, 1)) { + return ANEURALNETWORKS_BAD_DATA; + } const RunTimeOperandInfo& values = operands[ins[EmbeddingLookup::kValueTensor]]; const RunTimeOperandInfo& lookups = operands[ins[EmbeddingLookup::kLookupTensor]]; RunTimeOperandInfo& output = operands[outs[EmbeddingLookup::kOutputTensor]]; @@ -1002,6 +1005,9 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo setInfoAndAllocateIfNeeded(&output, outputShape, &result) && lookup.Eval(); } break; case OperationType::HASHTABLE_LOOKUP: { + if (!allParametersPresent(3, 2)) { + return ANEURALNETWORKS_BAD_DATA; + } const RunTimeOperandInfo& lookups = operands[ins[HashtableLookup::kLookupTensor]]; const RunTimeOperandInfo& keys = operands[ins[HashtableLookup::kKeyTensor]]; const RunTimeOperandInfo& values = operands[ins[HashtableLookup::kValueTensor]]; @@ -1102,6 +1108,9 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo setInfoAndAllocateIfNeeded(&output, outputShape, &result) && lstm_cell.Eval(); } break; case OperationType::RANDOM_MULTINOMIAL: { + if (!allParametersPresent(3, 1)) { + return ANEURALNETWORKS_BAD_DATA; + } const RunTimeOperandInfo& lookups = operands[ins[HashtableLookup::kLookupTensor]]; const RunTimeOperandInfo& keys = operands[ins[HashtableLookup::kKeyTensor]]; const RunTimeOperandInfo& values = operands[ins[HashtableLookup::kValueTensor]]; @@ -1115,6 +1124,10 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo multinomial.Eval(); } break; case OperationType::RNN: { + if (!allParametersPresent(6, 2)) { + return ANEURALNETWORKS_BAD_DATA; + } + RunTimeOperandInfo& hiddenStateOut = operands[outs[RNN::kHiddenStateOutTensor]]; RunTimeOperandInfo& output = operands[outs[RNN::kOutputTensor]]; @@ -1409,8 +1422,8 @@ int CpuExecutor::executeOperation(const Operation& operation, RunTimeOperandInfo expand_dims::eval(input.buffer, input.shape(), axis, output.buffer, outShape); } break; case OperationType::SPLIT: { - if (ins.size() != 3) { - LOG(ERROR) << "Wrong input count"; + const size_t outCount = outs.size(); + if (!allParametersPresent(3, outCount)) { return ANEURALNETWORKS_BAD_DATA; } |