summaryrefslogtreecommitdiff
path: root/nn/common/operations/LSHProjection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/LSHProjection.cpp')
-rw-r--r--nn/common/operations/LSHProjection.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/nn/common/operations/LSHProjection.cpp b/nn/common/operations/LSHProjection.cpp
index 9ca8be492..bdb106e18 100644
--- a/nn/common/operations/LSHProjection.cpp
+++ b/nn/common/operations/LSHProjection.cpp
@@ -44,8 +44,12 @@ LSHProjection::LSHProjection(const Operation& operation, RunTimeOperandInfo* ope
bool LSHProjection::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
Shape* outputShape) {
- const int num_inputs = NumInputsWithValues(operation, operands);
- NN_CHECK(num_inputs == 3 || num_inputs == 4);
+ // Check that none of the required inputs are omitted.
+ constexpr int requiredInputs[] = {kHashTensor, kInputTensor, kTypeParam};
+ for (const int requiredInput : requiredInputs) {
+ NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
+ << "required input " << requiredInput << " is omitted";
+ }
NN_CHECK_EQ(NumOutputs(operation), 1);
const RunTimeOperandInfo* hash = GetInput(operation, operands, kHashTensor);
@@ -56,8 +60,9 @@ bool LSHProjection::Prepare(const Operation& operation, RunTimeOperandInfo* oper
const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
NN_CHECK(NumDimensions(input) >= 1);
- auto type = static_cast<LSHProjectionType>(
- getScalarData<int32_t>(operands[operation.inputs[kTypeParam]]));
+ const auto& typeOperand = operands[operation.inputs[kTypeParam]];
+ NN_RET_CHECK(typeOperand.length >= sizeof(int32_t));
+ auto type = static_cast<LSHProjectionType>(getScalarData<int32_t>(typeOperand));
switch (type) {
case LSHProjectionType_SPARSE:
case LSHProjectionType_SPARSE_DEPRECATED: