diff options
Diffstat (limited to 'nn/common/operations/LSHProjection.cpp')
-rw-r--r-- | nn/common/operations/LSHProjection.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/nn/common/operations/LSHProjection.cpp b/nn/common/operations/LSHProjection.cpp index 9ca8be492..bdb106e18 100644 --- a/nn/common/operations/LSHProjection.cpp +++ b/nn/common/operations/LSHProjection.cpp @@ -44,8 +44,12 @@ LSHProjection::LSHProjection(const Operation& operation, RunTimeOperandInfo* ope bool LSHProjection::Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* outputShape) { - const int num_inputs = NumInputsWithValues(operation, operands); - NN_CHECK(num_inputs == 3 || num_inputs == 4); + // Check that none of the required inputs are omitted. + constexpr int requiredInputs[] = {kHashTensor, kInputTensor, kTypeParam}; + for (const int requiredInput : requiredInputs) { + NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput))) + << "required input " << requiredInput << " is omitted"; + } NN_CHECK_EQ(NumOutputs(operation), 1); const RunTimeOperandInfo* hash = GetInput(operation, operands, kHashTensor); @@ -56,8 +60,9 @@ bool LSHProjection::Prepare(const Operation& operation, RunTimeOperandInfo* oper const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor); NN_CHECK(NumDimensions(input) >= 1); - auto type = static_cast<LSHProjectionType>( - getScalarData<int32_t>(operands[operation.inputs[kTypeParam]])); + const auto& typeOperand = operands[operation.inputs[kTypeParam]]; + NN_RET_CHECK(typeOperand.length >= sizeof(int32_t)); + auto type = static_cast<LSHProjectionType>(getScalarData<int32_t>(typeOperand)); switch (type) { case LSHProjectionType_SPARSE: case LSHProjectionType_SPARSE_DEPRECATED: |