diff options
Diffstat (limited to 'nn/tools/test_generator/test_generator.py')
-rwxr-xr-x | nn/tools/test_generator/test_generator.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/nn/tools/test_generator/test_generator.py b/nn/tools/test_generator/test_generator.py index e1b10a7ad..92dfad756 100755 --- a/nn/tools/test_generator/test_generator.py +++ b/nn/tools/test_generator/test_generator.py @@ -295,6 +295,7 @@ class Operand(NamedVariable): self.model_index = None self.ins = [] self.outs = [] + self.mayBeInternal = True def SetValue(self, value): self.value = value if type(value) is list or type(value) is tuple or value is None \ @@ -330,8 +331,15 @@ class Operand(NamedVariable): extraParams=self.type.extraParams) if not issubclass(DerivedClass, Internal): newop.SetValue(self.value) + if not self.mayBeInternal: + assert not issubclass(DerivedClass, Internal) + newop.ShouldNeverBeInternal() return newop + def ShouldNeverBeInternal(self): + self.mayBeInternal = False + return self + # Base class of user-defined input/output operand class InOut(Operand): @@ -1031,7 +1039,7 @@ class AllInputsAsInternalCoverter(ModelVariation): raise SkipVariation # Find all input tensors that can be an output of the ADD operation. - modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i)] + modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal] if not modelInputs: raise SkipVariation |