summaryrefslogtreecommitdiff
path: root/nn/common/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r--nn/common/Utils.cpp15
1 files changed, 15 insertions, 0 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 188436f11..1694c9c23 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -762,6 +762,15 @@ static bool validateIfOperation(uint32_t inputCount, const uint32_t* inputs, uin
return true;
}
+static bool validateControlFlowOperandUnknownSize(const SubgraphValidationHelper& helper,
+ const Operand& operand) {
+ if (!helper.allowControlFlowOperationWithOperandOfUnknownSize &&
+ !isExtensionOperandType(operand.type)) {
+ NN_RET_CHECK_NE(nonExtensionOperandSizeOfData(operand.type, operand.dimensions), 0u);
+ }
+ return true;
+}
+
static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
uint32_t outputCount, const uint32_t* outputs,
const std::vector<Operand>& operands,
@@ -789,6 +798,8 @@ static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
const Operand& innerOperand = *helper.getSubgraphInputOperand(condModelOperand, i);
const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand));
+ NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
}
NN_RET_CHECK(
validateConditionOperand(*helper.getSubgraphOutputOperand(condModelOperand, 0)));
@@ -809,16 +820,20 @@ static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
const Operand& innerOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i);
const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand));
+ NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
}
for (uint32_t i = 0; i < inputOutputCount; ++i) {
const Operand& innerOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i);
const Operand& outerOperand = operands[outputs[i]];
NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
}
for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
const Operand& inputOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i);
const Operand& outputOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i);
NN_RET_CHECK(compatible(inputOperand, outputOperand));
+ NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outputOperand));
}
return true;
};