summaryrefslogtreecommitdiff
path: root/nn/tools/test_generator/test_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'nn/tools/test_generator/test_generator.py')
-rwxr-xr-xnn/tools/test_generator/test_generator.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/nn/tools/test_generator/test_generator.py b/nn/tools/test_generator/test_generator.py
index e1b10a7ad..92dfad756 100755
--- a/nn/tools/test_generator/test_generator.py
+++ b/nn/tools/test_generator/test_generator.py
@@ -295,6 +295,7 @@ class Operand(NamedVariable):
self.model_index = None
self.ins = []
self.outs = []
+ self.mayBeInternal = True
def SetValue(self, value):
self.value = value if type(value) is list or type(value) is tuple or value is None \
@@ -330,8 +331,15 @@ class Operand(NamedVariable):
extraParams=self.type.extraParams)
if not issubclass(DerivedClass, Internal):
newop.SetValue(self.value)
+ if not self.mayBeInternal:
+ assert not issubclass(DerivedClass, Internal)
+ newop.ShouldNeverBeInternal()
return newop
+ def ShouldNeverBeInternal(self):
+ self.mayBeInternal = False
+ return self
+
# Base class of user-defined input/output operand
class InOut(Operand):
@@ -1031,7 +1039,7 @@ class AllInputsAsInternalCoverter(ModelVariation):
raise SkipVariation
# Find all input tensors that can be an output of the ADD operation.
- modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i)]
+ modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal]
if not modelInputs:
raise SkipVariation