diff options
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r-- | nn/common/Utils.cpp | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index 188436f11..1694c9c23 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -762,6 +762,15 @@ static bool validateIfOperation(uint32_t inputCount, const uint32_t* inputs, uin return true; } +static bool validateControlFlowOperandUnknownSize(const SubgraphValidationHelper& helper, + const Operand& operand) { + if (!helper.allowControlFlowOperationWithOperandOfUnknownSize && + !isExtensionOperandType(operand.type)) { + NN_RET_CHECK_NE(nonExtensionOperandSizeOfData(operand.type, operand.dimensions), 0u); + } + return true; +} + static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs, const std::vector<Operand>& operands, @@ -789,6 +798,8 @@ static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs, const Operand& innerOperand = *helper.getSubgraphInputOperand(condModelOperand, i); const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; NN_RET_CHECK(compatible(innerOperand, outerOperand)); + NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand)); + NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand)); } NN_RET_CHECK( validateConditionOperand(*helper.getSubgraphOutputOperand(condModelOperand, 0))); @@ -809,16 +820,20 @@ static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs, const Operand& innerOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i); const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; NN_RET_CHECK(compatible(innerOperand, outerOperand)); + NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand)); + NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand)); } for (uint32_t i = 0; i < inputOutputCount; ++i) { const Operand& innerOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i); const Operand& outerOperand = operands[outputs[i]]; NN_RET_CHECK(compatible(innerOperand, outerOperand)); + NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand)); } for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) { const Operand& inputOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i); const Operand& outputOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i); NN_RET_CHECK(compatible(inputOperand, outputOperand)); + NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outputOperand)); } return true; }; |