diff options
Diffstat (limited to 'nn/tools')
-rw-r--r-- | nn/tools/test_generator/README.md | 8 | ||||
-rwxr-xr-x | nn/tools/test_generator/test_generator.py | 10 |
2 files changed, 17 insertions, 1 deletions
diff --git a/nn/tools/test_generator/README.md b/nn/tools/test_generator/README.md index 29c1155d2..e781771ce 100644 --- a/nn/tools/test_generator/README.md +++ b/nn/tools/test_generator/README.md @@ -244,6 +244,14 @@ example.DisableLifeTimeVariation() example.DisableDynamicOutputShapeVariation() ``` +You may also specify a certain operand to be input/const-only that `AllInputsAsInternalCoverter` will skip converting this operand. + +```Python +# "hash" will be converted to a model input when applying AllTensorsAsInputsConverter, +# but will be skipped when further applying AllInputsAsInternalCoverter. +hash = Parameter("hash", "TENSOR_FLOAT32", "{1, 1}", [0.123]).ShouldNeverBeInternal() +``` + #### Some helper functions The test generator provides several helper functions or shorthands to add commonly used group of variations. diff --git a/nn/tools/test_generator/test_generator.py b/nn/tools/test_generator/test_generator.py index e1b10a7ad..92dfad756 100755 --- a/nn/tools/test_generator/test_generator.py +++ b/nn/tools/test_generator/test_generator.py @@ -295,6 +295,7 @@ class Operand(NamedVariable): self.model_index = None self.ins = [] self.outs = [] + self.mayBeInternal = True def SetValue(self, value): self.value = value if type(value) is list or type(value) is tuple or value is None \ @@ -330,8 +331,15 @@ class Operand(NamedVariable): extraParams=self.type.extraParams) if not issubclass(DerivedClass, Internal): newop.SetValue(self.value) + if not self.mayBeInternal: + assert not issubclass(DerivedClass, Internal) + newop.ShouldNeverBeInternal() return newop + def ShouldNeverBeInternal(self): + self.mayBeInternal = False + return self + # Base class of user-defined input/output operand class InOut(Operand): @@ -1031,7 +1039,7 @@ class AllInputsAsInternalCoverter(ModelVariation): raise SkipVariation # Find all input tensors that can be an output of the ADD operation. - modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i)] + modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal] if not modelInputs: raise SkipVariation |