summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestUnknownDimensions.cpp
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2018-11-02 12:34:31 +0000
committerLev Proleev <levp@google.com>2018-11-02 14:02:07 +0000
commitba304056176c602a8b6272a7bb6931ef7ae7fee6 (patch)
tree0badc16af05661e22715b6358feadc163ad48d5c /nn/runtime/test/TestUnknownDimensions.cpp
parent293cbbc7a6b500d80bae583cb01d2ec66ce4dc4a (diff)
downloadml-ba304056176c602a8b6272a7bb6931ef7ae7fee6.tar.gz
Add static asserts to functions using MixedTyped
This makes it easier to find all the places that need to be changed after adding a new type to MixedTyped. Also make CompareResults function generic so that it doesn't cause errors after adding new type to MixedTyped. Test: NeuralNetworksTest_static Test: VtsHalNeuralnetworksV1_2TargetTest Change-Id: Iebd89b703415b22ade56c93361926ceed0611f7d
Diffstat (limited to 'nn/runtime/test/TestUnknownDimensions.cpp')
-rw-r--r--nn/runtime/test/TestUnknownDimensions.cpp38
1 files changed, 18 insertions, 20 deletions
diff --git a/nn/runtime/test/TestUnknownDimensions.cpp b/nn/runtime/test/TestUnknownDimensions.cpp
index ef44c2802..513aa9f72 100644
--- a/nn/runtime/test/TestUnknownDimensions.cpp
+++ b/nn/runtime/test/TestUnknownDimensions.cpp
@@ -77,29 +77,27 @@ auto ioValues = Combine(ioDimensionValues, ioDimensionValues);
auto constantValues = Combine(constantDimensionValues, constantDimensionValues);
class UnknownDimensionsTest : public ::testing::TestWithParam<OperandParams> {
-protected:
- template<class T, Type TensorType> void TestOne(
- const OperandParams& paramsForInput0, const OperandParams& paramsForInput1,
- const OperandParams& paramsForConst, const OperandParams& paramsForOutput);
- template<class T, Type TensorType> void TestAll();
- void CompareResults(std::map<int, std::vector<float>>& expected,
- std::map<int, std::vector<float>>& actual);
- void CompareResults(std::map<int, std::vector<uint8_t>>& expected,
- std::map<int, std::vector<uint8_t>>& actual);
+ protected:
+ template <class T, Type TensorType>
+ void TestOne(const OperandParams& paramsForInput0, const OperandParams& paramsForInput1,
+ const OperandParams& paramsForConst, const OperandParams& paramsForOutput);
+ template <class T, Type TensorType>
+ void TestAll();
+
+ template <typename T>
+ void CompareResults(std::map<int, std::vector<T>>& expected,
+ std::map<int, std::vector<T>>& actual);
};
-void UnknownDimensionsTest::CompareResults(
- std::map<int, std::vector<uint8_t>>& expected,
- std::map<int, std::vector<uint8_t>>& actual) {
+template <typename T>
+void UnknownDimensionsTest::CompareResults(std::map<int, std::vector<T>>& expected,
+ std::map<int, std::vector<T>>& actual) {
// Uint8_t operands last in MixedType
- compare(MixedTyped{ {}, {}, expected }, MixedTyped{ {}, {}, actual });
-}
-
-void UnknownDimensionsTest::CompareResults(
- std::map<int, std::vector<float>>& expected,
- std::map<int, std::vector<float>>& actual) {
- // Float operands first in MixedType
- compare(MixedTyped{ expected, {}, {} }, MixedTyped{ actual, {}, {} });
+ MixedTyped expectedMixedTyped;
+ std::get<MixedTypedIndex<T>::index>(expectedMixedTyped) = expected;
+ MixedTyped actualMixedTyped;
+ std::get<MixedTypedIndex<T>::index>(actualMixedTyped) = actual;
+ compare(expectedMixedTyped, actualMixedTyped);
}
template<class T, Type TensorType> void UnknownDimensionsTest::TestOne(