summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestGenerated.cpp
diff options
context:
space:
mode:
authorXusong Wang <xusongw@google.com>2019-07-17 15:57:13 -0700
committerXusong Wang <xusongw@google.com>2019-07-19 10:52:43 -0700
commitf1919e85ed072baf0a57535198374982cd477dc8 (patch)
tree2bb2615a8cd132fd228bef2897ed0f5e2ed5326d /nn/runtime/test/TestGenerated.cpp
parent57b3051fef611dbab1fee34e7a0af2c486dabf8f (diff)
downloadml-f1919e85ed072baf0a57535198374982cd477dc8.tar.gz
Enable generated validation tests.
Also move strided_slice_invalid_output_dims to spec directory. Tests tagged by ExpectFailure are expected to fail gracefully during compilation or execution. Updated the documentation. Bug: 132155416 Test: NNT_static Test: 1.2 VTS Change-Id: I1f747cf289e78fdd6e04c7d8e51b71468d72237c
Diffstat (limited to 'nn/runtime/test/TestGenerated.cpp')
-rw-r--r--nn/runtime/test/TestGenerated.cpp51
1 files changed, 36 insertions, 15 deletions
diff --git a/nn/runtime/test/TestGenerated.cpp b/nn/runtime/test/TestGenerated.cpp
index d259bb692..5f663e3b0 100644
--- a/nn/runtime/test/TestGenerated.cpp
+++ b/nn/runtime/test/TestGenerated.cpp
@@ -83,21 +83,29 @@ void printAll(std::ostream& os, const MixedTyped& test) {
}
} // namespace
-Compilation GeneratedTests::compileModel(const Model* model) {
+std::optional<Compilation> GeneratedTests::compileModel(const Model* model) {
NNTRACE_APP(NNTRACE_PHASE_COMPILATION, "compileModel");
if (mTestCompilationCaching) {
// Compile the model twice with the same token, so that compilation caching will be
// exercised if supported by the driver.
+ // No invalid model will be passed to this branch.
+ EXPECT_FALSE(mExpectFailure);
Compilation compilation1(model);
- compilation1.setCaching(mCacheDir, mToken);
- compilation1.finish();
+ EXPECT_EQ(compilation1.setCaching(mCacheDir, mToken), Result::NO_ERROR);
+ EXPECT_EQ(compilation1.finish(), Result::NO_ERROR);
Compilation compilation2(model);
- compilation2.setCaching(mCacheDir, mToken);
- compilation2.finish();
+ EXPECT_EQ(compilation2.setCaching(mCacheDir, mToken), Result::NO_ERROR);
+ EXPECT_EQ(compilation2.finish(), Result::NO_ERROR);
return compilation2;
} else {
Compilation compilation(model);
- compilation.finish();
+ Result result = compilation.finish();
+
+ // For valid model, we check the compilation result == NO_ERROR.
+ // For invalid model, the driver may fail at compilation or execution, so any result code is
+ // permitted at this point.
+ if (mExpectFailure && result != Result::NO_ERROR) return std::nullopt;
+ EXPECT_EQ(result, Result::NO_ERROR);
return compilation;
}
}
@@ -151,8 +159,14 @@ void GeneratedTests::executeWithCompilation(const Model* model, Compilation* com
});
}
- Result r = execution.compute();
- ASSERT_EQ(Result::NO_ERROR, r);
+ Result result = execution.compute();
+ if (mExpectFailure) {
+ ASSERT_NE(result, Result::NO_ERROR);
+ continue;
+ } else {
+ ASSERT_EQ(result, Result::NO_ERROR);
+ }
+
{
NNTRACE_APP(NNTRACE_PHASE_RESULTS, "executeWithCompilation example");
@@ -188,8 +202,10 @@ void GeneratedTests::executeWithCompilation(const Model* model, Compilation* com
void GeneratedTests::executeOnce(const Model* model, std::function<bool(int)> isIgnored,
std::vector<MixedTypedExample>& examples, std::string dumpFile) {
NNTRACE_APP(NNTRACE_PHASE_OVERALL, "executeOnce");
- Compilation compilation = compileModel(model);
- executeWithCompilation(model, &compilation, isIgnored, examples, dumpFile);
+ std::optional<Compilation> compilation = compileModel(model);
+ // Early return if compilation fails. The compilation result code is checked in compileModel.
+ if (!compilation) return;
+ executeWithCompilation(model, &compilation.value(), isIgnored, examples, dumpFile);
}
void GeneratedTests::executeMultithreadedOwnCompilation(const Model* model,
@@ -209,11 +225,14 @@ void GeneratedTests::executeMultithreadedSharedCompilation(
std::vector<MixedTypedExample>& examples) {
NNTRACE_APP(NNTRACE_PHASE_OVERALL, "executeMultithreadedSharedCompilation");
SCOPED_TRACE("MultithreadedSharedCompilation");
- Compilation compilation = compileModel(model);
+ std::optional<Compilation> compilation = compileModel(model);
+ // Early return if compilation fails. The ompilation result code is checked in compileModel.
+ if (!compilation) return;
std::vector<std::thread> threads;
for (int i = 0; i < 10; i++) {
- threads.push_back(std::thread(
- [&]() { executeWithCompilation(model, &compilation, isIgnored, examples, ""); }));
+ threads.push_back(std::thread([&]() {
+ executeWithCompilation(model, &compilation.value(), isIgnored, examples, "");
+ }));
}
std::for_each(threads.begin(), threads.end(), [](std::thread& t) { t.join(); });
}
@@ -239,8 +258,10 @@ void GeneratedTests::execute(std::function<void(Model*)> createModel,
};
mTestCompilationCaching = false;
executeInternal(dumpFile);
- mTestCompilationCaching = true;
- executeInternal("");
+ if (!mExpectFailure) {
+ mTestCompilationCaching = true;
+ executeInternal("");
+ }
}
void GeneratedTests::SetUp() {