diff options
author | Jean-Luc Brouillet <jeanluc@google.com> | 2017-01-08 17:35:31 -0800 |
---|---|---|
committer | Jean-Luc Brouillet <jeanluc@google.com> | 2017-01-09 11:15:16 -0800 |
commit | 2a85b6b9d6f9cb8f1b20d573c1c5ceafe901b011 (patch) | |
tree | 038e6e8d5847568e0b560ff12ebee5f85d670542 /script_api/GenerateTestFiles.cpp | |
parent | 7f1125d183917d99123367eba3c8393bfde58a20 (diff) | |
download | rs-2a85b6b9d6f9cb8f1b20d573c1c5ceafe901b011.tar.gz |
Move scriptc to script_api/include.
Part 1 of the directory re-organization. We're renaming the "api" directory
to "script_api" directory to distinguish between our control api (java or c++)
and our script api.
We're also moving the scriptc directory under that newly renamed directory,
and change its name to the more appropriate "include".
Test: scriptc/generate.sh
Test: compiled ImageProcessing_jb
Change-Id: I00c3bbf5728b652d1541ebe4123717f6ab639f09
Diffstat (limited to 'script_api/GenerateTestFiles.cpp')
-rw-r--r-- | script_api/GenerateTestFiles.cpp | 1142 |
1 files changed, 1142 insertions, 0 deletions
diff --git a/script_api/GenerateTestFiles.cpp b/script_api/GenerateTestFiles.cpp new file mode 100644 index 00000000..3c288013 --- /dev/null +++ b/script_api/GenerateTestFiles.cpp @@ -0,0 +1,1142 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <iomanip> +#include <iostream> +#include <cmath> +#include <sstream> + +#include "Generator.h" +#include "Specification.h" +#include "Utilities.h" + +using namespace std; + +// Converts float2 to FLOAT_32 and 2, etc. +static void convertToRsType(const string& name, string* dataType, char* vectorSize) { + string s = name; + int last = s.size() - 1; + char lastChar = s[last]; + if (lastChar >= '1' && lastChar <= '4') { + s.erase(last); + *vectorSize = lastChar; + } else { + *vectorSize = '1'; + } + dataType->clear(); + for (int i = 0; i < NUM_TYPES; i++) { + if (s == TYPES[i].cType) { + *dataType = TYPES[i].rsDataType; + break; + } + } +} + +// Returns true if any permutation of the function have tests to b +static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) { + for (auto spec : function.getSpecifications()) { + if (spec->hasTests(versionOfTestFiles)) { + return true; + } + } + return false; +} + +/* One instance of this class is generated for each permutation of a function for which + * we are generating test code. This instance will generate both the script and the Java + * section of the test files for this permutation. The class is mostly used to keep track + * of the various names shared between script and Java files. + * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter + * should not exceed the lifetime of FunctionPermutation. + */ +class PermutationWriter { +private: + FunctionPermutation& mPermutation; + + string mRsKernelName; + string mJavaArgumentsClassName; + string mJavaArgumentsNClassName; + string mJavaVerifierComputeMethodName; + string mJavaVerifierVerifyMethodName; + string mJavaCheckMethodName; + string mJavaVerifyMethodName; + + // Pointer to the files we are generating. Handy to avoid always passing them in the calls. + GeneratedFile* mRs; + GeneratedFile* mJava; + + /* Shortcuts to the return parameter and the first input parameter of the function + * specification. + */ + const ParameterDefinition* mReturnParam; // Can be nullptr. NOT OWNED. + const ParameterDefinition* mFirstInputParam; // Can be nullptr. NOT OWNED. + + /* All the parameters plus the return param, if present. Collecting them together + * simplifies code generation. NOT OWNED. + */ + vector<const ParameterDefinition*> mAllInputsAndOutputs; + + /* We use a class to pass the arguments between the generated code and the CoreVerifier. This + * method generates this class. The set keeps track if we've generated this class already + * for this test file, as more than one permutation may use the same argument class. + */ + void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const; + + // Generate the Check* method that invokes the script and calls the verifier. + void writeJavaCheckMethod(bool generateCallToVerifier) const; + + // Generate code to define and randomly initialize the input allocation. + void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const; + + /* Generate code that instantiate an allocation of floats or integers and fills it with + * random data. This random data must be compatible with the specified type. This is + * used for the convert_* tests, as converting values that don't fit yield undefined results. + */ + void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed, + char vectorSize, + const NumericalType& compatibleType, + const NumericalType& generatedType) const; + void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed, + char vectorSize, + const NumericalType& compatibleType, + const NumericalType& generatedType) const; + + // Generate code that defines an output allocation. + void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const; + + /* Generate the code that verifies the results for RenderScript functions where each entry + * of a vector is evaluated independently. If verifierValidates is true, CoreMathVerifier + * does the actual validation instead of more commonly returning the range of acceptable values. + */ + void writeJavaVerifyScalarMethod(bool verifierValidates) const; + + /* Generate the code that verify the results for a RenderScript function where a vector + * is a point in n-dimensional space. + */ + void writeJavaVerifyVectorMethod() const; + + // Generate the line that creates the Target. + void writeJavaCreateTarget() const; + + // Generate the method header of the verify function. + void writeJavaVerifyMethodHeader() const; + + // Generate codes that copies the content of an allocation to an array. + void writeJavaArrayInitialization(const ParameterDefinition& p) const; + + // Generate code that tests one value returned from the script. + void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex, + const string& actualIndex) const; + void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex, + const string& actualIndex) const; + // For test:vector cases, generate code that compares returned vector vs. expected value. + void writeJavaVectorComparison(const ParameterDefinition& p) const; + + // Muliple functions that generates code to build the error message if an error is found. + void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex, + const string& actualIndex, bool verifierValidates) const; + void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const; + void writeJavaAppendNewLineToMessage() const; + void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const; + void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const; + + // Generate the set of instructions to call the script. + void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const; + + // Write an allocation definition if not already emitted in the .rs file. + void writeRsAllocationDefinition(const ParameterDefinition& param, + set<string>* rsAllocationsGenerated) const; + +public: + /* NOTE: We keep pointers to the permutation and the files. This object should not + * outlive the arguments. + */ + PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile, + GeneratedFile* javaFile); + string getJavaCheckMethodName() const { return mJavaCheckMethodName; } + + // Write the script test function for this permutation. + void writeRsSection(set<string>* rsAllocationsGenerated) const; + // Write the section of the Java code that calls the script and validates the results + void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const; +}; + +PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile, + GeneratedFile* javaFile) + : mPermutation(permutation), + mRs(rsFile), + mJava(javaFile), + mReturnParam(nullptr), + mFirstInputParam(nullptr) { + mRsKernelName = "test" + capitalize(permutation.getName()); + + mJavaArgumentsClassName = "Arguments"; + mJavaArgumentsNClassName = "Arguments"; + const string trunk = capitalize(permutation.getNameTrunk()); + mJavaCheckMethodName = "check" + trunk; + mJavaVerifyMethodName = "verifyResults" + trunk; + + for (auto p : permutation.getParams()) { + mAllInputsAndOutputs.push_back(p); + if (mFirstInputParam == nullptr && !p->isOutParameter) { + mFirstInputParam = p; + } + } + mReturnParam = permutation.getReturn(); + if (mReturnParam) { + mAllInputsAndOutputs.push_back(mReturnParam); + } + + for (auto p : mAllInputsAndOutputs) { + const string capitalizedRsType = capitalize(p->rsType); + const string capitalizedBaseType = capitalize(p->rsBaseType); + mRsKernelName += capitalizedRsType; + mJavaArgumentsClassName += capitalizedBaseType; + mJavaArgumentsNClassName += capitalizedBaseType; + if (p->mVectorSize != "1") { + mJavaArgumentsNClassName += "N"; + } + mJavaCheckMethodName += capitalizedRsType; + mJavaVerifyMethodName += capitalizedRsType; + } + mJavaVerifierComputeMethodName = "compute" + trunk; + mJavaVerifierVerifyMethodName = "verify" + trunk; +} + +void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const { + // By default, we test the results using item by item comparison. + const string test = mPermutation.getTest(); + if (test == "scalar" || test == "limited") { + writeJavaArgumentClass(true, javaGeneratedArgumentClasses); + writeJavaCheckMethod(true); + writeJavaVerifyScalarMethod(false); + } else if (test == "custom") { + writeJavaArgumentClass(true, javaGeneratedArgumentClasses); + writeJavaCheckMethod(true); + writeJavaVerifyScalarMethod(true); + } else if (test == "vector") { + writeJavaArgumentClass(false, javaGeneratedArgumentClasses); + writeJavaCheckMethod(true); + writeJavaVerifyVectorMethod(); + } else if (test == "noverify") { + writeJavaCheckMethod(false); + } +} + +void PermutationWriter::writeJavaArgumentClass(bool scalar, + set<string>* javaGeneratedArgumentClasses) const { + string name; + if (scalar) { + name = mJavaArgumentsClassName; + } else { + name = mJavaArgumentsNClassName; + } + + // Make sure we have not generated the argument class already. + if (!testAndSet(name, javaGeneratedArgumentClasses)) { + mJava->indent() << "public class " << name; + mJava->startBlock(); + + for (auto p : mAllInputsAndOutputs) { + bool isFieldArray = !scalar && p->mVectorSize != "1"; + bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom"; + + mJava->indent() << "public "; + if (isFloatyField) { + *mJava << "Target.Floaty"; + } else { + *mJava << p->javaBaseType; + } + if (isFieldArray) { + *mJava << "[]"; + } + *mJava << " " << p->variableName << ";\n"; + + // For Float16 parameters, add an extra 'double' field in the class + // to hold the Double value converted from the input. + if (p->isFloat16Parameter() && !isFloatyField) { + mJava->indent() << "public double"; + if (isFieldArray) { + *mJava << "[]"; + } + *mJava << " " + p->variableName << "Double;\n"; + } + } + mJava->endBlock(); + *mJava << "\n"; + } +} + +void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const { + mJava->indent() << "private void " << mJavaCheckMethodName << "()"; + mJava->startBlock(); + + // Generate the input allocations and initialization. + for (auto p : mAllInputsAndOutputs) { + if (!p->isOutParameter) { + writeJavaInputAllocationDefinition(*p); + } + } + // Generate code to enforce ordering between two allocations if needed. + for (auto p : mAllInputsAndOutputs) { + if (!p->isOutParameter && !p->smallerParameter.empty()) { + string smallerAlloc = "in" + capitalize(p->smallerParameter); + mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName + << ");\n"; + } + } + + // Generate code to check the full and relaxed scripts. + writeJavaCallToRs(false, generateCallToVerifier); + writeJavaCallToRs(true, generateCallToVerifier); + + mJava->endBlock(); + *mJava << "\n"; +} + +void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const { + string dataType; + char vectorSize; + convertToRsType(param.rsType, &dataType, &vectorSize); + + const string seed = hashString(mJavaCheckMethodName + param.javaAllocName); + mJava->indent() << "Allocation " << param.javaAllocName << " = "; + if (param.compatibleTypeIndex >= 0) { + if (TYPES[param.typeIndex].kind == FLOATING_POINT) { + writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize, + TYPES[param.compatibleTypeIndex], + TYPES[param.typeIndex]); + } else { + writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize, + TYPES[param.compatibleTypeIndex], + TYPES[param.typeIndex]); + } + } else if (!param.minValue.empty()) { + *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", " + << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue + << ")"; + } else { + /* TODO Instead of passing always false, check whether we are doing a limited test. + * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true") + */ + *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize + << ", " << seed << ", false)"; + } + *mJava << ";\n"; +} + +void PermutationWriter::writeJavaRandomCompatibleFloatAllocation( + const string& dataType, const string& seed, char vectorSize, + const NumericalType& compatibleType, const NumericalType& generatedType) const { + *mJava << "createRandomFloatAllocation" + << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", "; + double minValue = 0.0; + double maxValue = 0.0; + switch (compatibleType.kind) { + case FLOATING_POINT: { + // We're generating floating point values. We just worry about the exponent. + // Subtract 1 for the exponent sign. + int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1; + maxValue = ldexp(0.95, (1 << bits) - 1); + minValue = -maxValue; + break; + } + case UNSIGNED_INTEGER: + maxValue = maxDoubleForInteger(compatibleType.significantBits, + generatedType.significantBits); + minValue = 0.0; + break; + case SIGNED_INTEGER: + maxValue = maxDoubleForInteger(compatibleType.significantBits, + generatedType.significantBits); + minValue = -maxValue - 1.0; + break; + } + *mJava << scientific << std::setprecision(19); + *mJava << minValue << ", " << maxValue << ")"; + mJava->unsetf(ios_base::floatfield); +} + +void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation( + const string& dataType, const string& seed, char vectorSize, + const NumericalType& compatibleType, const NumericalType& generatedType) const { + *mJava << "createRandomIntegerAllocation" + << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", "; + + if (compatibleType.kind == FLOATING_POINT) { + // Currently, all floating points can take any number we generate. + bool isSigned = generatedType.kind == SIGNED_INTEGER; + *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits; + } else { + bool isSigned = + compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER; + *mJava << (isSigned ? "true" : "false") << ", " + << min(compatibleType.significantBits, generatedType.significantBits); + } + *mJava << ")"; +} + +void PermutationWriter::writeJavaOutputAllocationDefinition( + const ParameterDefinition& param) const { + string dataType; + char vectorSize; + convertToRsType(param.rsType, &dataType, &vectorSize); + mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, " + << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize + << "), INPUTSIZE);\n"; +} + +void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const { + writeJavaVerifyMethodHeader(); + mJava->startBlock(); + + string vectorSize = "1"; + for (auto p : mAllInputsAndOutputs) { + writeJavaArrayInitialization(*p); + if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) { + if (vectorSize == "1") { + vectorSize = p->mVectorSize; + } else { + cerr << "Error. Had vector " << vectorSize << " and " << p->mVectorSize << "\n"; + } + } + } + + mJava->indent() << "StringBuilder message = new StringBuilder();\n"; + mJava->indent() << "boolean errorFound = false;\n"; + mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)"; + mJava->startBlock(); + + mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)"; + mJava->startBlock(); + + mJava->indent() << "// Extract the inputs.\n"; + mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName + << "();\n"; + for (auto p : mAllInputsAndOutputs) { + if (!p->isOutParameter) { + mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i"; + if (p->vectorWidth != "1") { + *mJava << " * " << p->vectorWidth << " + j"; + } + *mJava << "];\n"; + + // Convert the Float16 parameter to double and store it in the appropriate field in the + // Arguments class. + if (p->isFloat16Parameter()) { + mJava->indent() << "args." << p->doubleVariableName + << " = Float16Utils.convertFloat16ToDouble(args." + << p->variableName << ");\n"; + } + } + } + const bool hasFloat = mPermutation.hasFloatAnswers(); + if (verifierValidates) { + mJava->indent() << "// Extract the outputs.\n"; + for (auto p : mAllInputsAndOutputs) { + if (p->isOutParameter) { + mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName + << "[i * " << p->vectorWidth << " + j];\n"; + if (p->isFloat16Parameter()) { + mJava->indent() << "args." << p->doubleVariableName + << " = Float16Utils.convertFloat16ToDouble(args." + << p->variableName << ");\n"; + } + } + } + mJava->indent() << "// Ask the CoreMathVerifier to validate.\n"; + if (hasFloat) { + writeJavaCreateTarget(); + } + mJava->indent() << "String errorMessage = CoreMathVerifier." + << mJavaVerifierVerifyMethodName << "(args"; + if (hasFloat) { + *mJava << ", target"; + } + *mJava << ");\n"; + mJava->indent() << "boolean valid = errorMessage == null;\n"; + } else { + mJava->indent() << "// Figure out what the outputs should have been.\n"; + if (hasFloat) { + writeJavaCreateTarget(); + } + mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args"; + if (hasFloat) { + *mJava << ", target"; + } + *mJava << ");\n"; + mJava->indent() << "// Validate the outputs.\n"; + mJava->indent() << "boolean valid = true;\n"; + for (auto p : mAllInputsAndOutputs) { + if (p->isOutParameter) { + writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]"); + } + } + } + + mJava->indent() << "if (!valid)"; + mJava->startBlock(); + mJava->indent() << "if (!errorFound)"; + mJava->startBlock(); + mJava->indent() << "errorFound = true;\n"; + + for (auto p : mAllInputsAndOutputs) { + if (p->isOutParameter) { + writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]", + verifierValidates); + } else { + writeJavaAppendInputToMessage(*p, "args." + p->variableName); + } + } + if (verifierValidates) { + mJava->indent() << "message.append(errorMessage);\n"; + } + mJava->indent() << "message.append(\"Errors at\");\n"; + mJava->endBlock(); + + mJava->indent() << "message.append(\" [\");\n"; + mJava->indent() << "message.append(Integer.toString(i));\n"; + mJava->indent() << "message.append(\", \");\n"; + mJava->indent() << "message.append(Integer.toString(j));\n"; + mJava->indent() << "message.append(\"]\");\n"; + + mJava->endBlock(); + mJava->endBlock(); + mJava->endBlock(); + + mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n"; + mJava->indentPlus() + << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n"; + + mJava->endBlock(); + *mJava << "\n"; +} + +void PermutationWriter::writeJavaVerifyVectorMethod() const { + writeJavaVerifyMethodHeader(); + mJava->startBlock(); + + for (auto p : mAllInputsAndOutputs) { + writeJavaArrayInitialization(*p); + } + mJava->indent() << "StringBuilder message = new StringBuilder();\n"; + mJava->indent() << "boolean errorFound = false;\n"; + mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)"; + mJava->startBlock(); + + mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName + << "();\n"; + + mJava->indent() << "// Create the appropriate sized arrays in args\n"; + for (auto p : mAllInputsAndOutputs) { + if (p->mVectorSize != "1") { + string type = p->javaBaseType; + if (p->isOutParameter && p->isFloatType) { + type = "Target.Floaty"; + } + mJava->indent() << "args." << p->variableName << " = new " << type << "[" + << p->mVectorSize << "];\n"; + if (p->isFloat16Parameter() && !p->isOutParameter) { + mJava->indent() << "args." << p->variableName << "Double = new double[" + << p->mVectorSize << "];\n"; + } + } + } + + mJava->indent() << "// Fill args with the input values\n"; + for (auto p : mAllInputsAndOutputs) { + if (!p->isOutParameter) { + if (p->mVectorSize == "1") { + mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]" + << ";\n"; + // Convert the Float16 parameter to double and store it in the appropriate field in + // the Arguments class. + if (p->isFloat16Parameter()) { + mJava->indent() << "args." << p->doubleVariableName << " = " + << "Float16Utils.convertFloat16ToDouble(args." + << p->variableName << ");\n"; + } + } else { + mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)"; + mJava->startBlock(); + mJava->indent() << "args." << p->variableName << "[j] = " + << p->javaArrayName << "[i * " << p->vectorWidth << " + j]" + << ";\n"; + + // Convert the Float16 parameter to double and store it in the appropriate field in + // the Arguments class. + if (p->isFloat16Parameter()) { + mJava->indent() << "args." << p->doubleVariableName << "[j] = " + << "Float16Utils.convertFloat16ToDouble(args." + << p->variableName << "[j]);\n"; + } + mJava->endBlock(); + } + } + } + writeJavaCreateTarget(); + mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName + << "(args, target);\n\n"; + + mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n"; + mJava->indent() << "boolean valid = true;\n"; + for (auto p : mAllInputsAndOutputs) { + if (p->isOutParameter) { + writeJavaVectorComparison(*p); + } + } + + mJava->indent() << "if (!valid)"; + mJava->startBlock(); + mJava->indent() << "if (!errorFound)"; + mJava->startBlock(); + mJava->indent() << "errorFound = true;\n"; + + for (auto p : mAllInputsAndOutputs) { + if (p->isOutParameter) { + writeJavaAppendVectorOutputToMessage(*p); + } else { + writeJavaAppendVectorInputToMessage(*p); + } + } + mJava->indent() << "message.append(\"Errors at\");\n"; + mJava->endBlock(); + + mJava->indent() << "message.append(\" [\");\n"; + mJava->indent() << "message.append(Integer.toString(i));\n"; + mJava->indent() << "message.append(\"]\");\n"; + + mJava->endBlock(); + mJava->endBlock(); + + mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n"; + mJava->indentPlus() + << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n"; + + mJava->endBlock(); + *mJava << "\n"; +} + + +void PermutationWriter::writeJavaCreateTarget() const { + string name = mPermutation.getName(); + + const char* functionType = "NORMAL"; + size_t end = name.find('_'); + if (end != string::npos) { + if (name.compare(0, end, "native") == 0) { + functionType = "NATIVE"; + } else if (name.compare(0, end, "half") == 0) { + functionType = "HALF"; + } else if (name.compare(0, end, "fast") == 0) { + functionType = "FAST"; + } + } + + string floatType = mReturnParam->specType; + const char* precisionStr = ""; + if (floatType.compare("f16") == 0) { + precisionStr = "HALF"; + } else if (floatType.compare("f32") == 0) { + precisionStr = "FLOAT"; + } else if (floatType.compare("f64") == 0) { + precisionStr = "DOUBLE"; + } else { + cerr << "Error. Unreachable. Return type is not floating point\n"; + } + + mJava->indent() << "Target target = new Target(Target.FunctionType." << + functionType << ", Target.ReturnType." << precisionStr << + ", relaxed);\n"; +} + +void PermutationWriter::writeJavaVerifyMethodHeader() const { + mJava->indent() << "private void " << mJavaVerifyMethodName << "("; + for (auto p : mAllInputsAndOutputs) { + *mJava << "Allocation " << p->javaAllocName << ", "; + } + *mJava << "boolean relaxed)"; +} + +void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const { + mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType + << "[INPUTSIZE * " << p.vectorWidth << "];\n"; + + /* For basic types, populate the array with values, to help understand failures. We have had + * bugs where the output buffer was all 0. We were not sure if there was a failed copy or + * the GPU driver was copying zeroes. + */ + if (p.typeIndex >= 0) { + mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType + << ") 42);\n"; + } + + mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n"; +} + +void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p, + const string& argsIndex, + const string& actualIndex) const { + writeJavaTestOneValue(p, argsIndex, actualIndex); + mJava->startBlock(); + mJava->indent() << "valid = false;\n"; + mJava->endBlock(); +} + +void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex, + const string& actualIndex) const { + string actualOut; + if (p.isFloat16Parameter()) { + // For Float16 values, the output needs to be converted to Double. + actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")"; + } else { + actualOut = p.javaArrayName + actualIndex; + } + + mJava->indent() << "if ("; + if (p.isFloatType) { + *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut; + const string s = mPermutation.getPrecisionLimit(); + if (!s.empty()) { + *mJava << ", " << s; + } + *mJava << ")"; + } else { + *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName + << actualIndex; + } + + if (p.undefinedIfOutIsNan && mReturnParam) { + *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()"; + } + *mJava << ")"; +} + +void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const { + if (p.mVectorSize == "1") { + writeJavaTestAndSetValid(p, "", "[i]"); + } else { + mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; + mJava->startBlock(); + writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]"); + mJava->endBlock(); + } +} + +void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p, + const string& argsIndex, + const string& actualIndex, + bool verifierValidates) const { + if (verifierValidates) { + mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n"; + mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex + << ");\n"; + writeJavaAppendNewLineToMessage(); + if (p.isFloat16Parameter()) { + writeJavaAppendNewLineToMessage(); + mJava->indent() << "message.append(\"Output " << p.variableName + << " (in double): \");\n"; + mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName + << ");\n"; + writeJavaAppendNewLineToMessage(); + } + } else { + mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n"; + mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex + << ");\n"; + writeJavaAppendNewLineToMessage(); + + mJava->indent() << "message.append(\"Actual output " << p.variableName << ": \");\n"; + mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex + << ");\n"; + + if (p.isFloat16Parameter()) { + writeJavaAppendNewLineToMessage(); + mJava->indent() << "message.append(\"Actual output " << p.variableName + << " (in double): \");\n"; + mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble(" + << p.javaArrayName << actualIndex << "));\n"; + } + + writeJavaTestOneValue(p, argsIndex, actualIndex); + mJava->startBlock(); + mJava->indent() << "message.append(\" FAIL\");\n"; + mJava->endBlock(); + writeJavaAppendNewLineToMessage(); + } +} + +void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p, + const string& actual) const { + mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n"; + mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n"; + writeJavaAppendNewLineToMessage(); +} + +void PermutationWriter::writeJavaAppendNewLineToMessage() const { + mJava->indent() << "message.append(\"\\n\");\n"; +} + +void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const { + if (p.mVectorSize == "1") { + writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]"); + } else { + mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; + mJava->startBlock(); + writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]"); + mJava->endBlock(); + } +} + +void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const { + if (p.mVectorSize == "1") { + writeJavaAppendOutputToMessage(p, "", "[i]", false); + } else { + mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; + mJava->startBlock(); + writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false); + mJava->endBlock(); + } +} + +void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const { + string script = "script"; + if (relaxed) { + script += "Relaxed"; + } + + mJava->indent() << "try"; + mJava->startBlock(); + + for (auto p : mAllInputsAndOutputs) { + if (p->isOutParameter) { + writeJavaOutputAllocationDefinition(*p); + } + } + + for (auto p : mPermutation.getParams()) { + if (p != mFirstInputParam) { + mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName + << ");\n"; + } + } + + mJava->indent() << script << ".forEach_" << mRsKernelName << "("; + bool needComma = false; + if (mFirstInputParam) { + *mJava << mFirstInputParam->javaAllocName; + needComma = true; + } + if (mReturnParam) { + if (needComma) { + *mJava << ", "; + } + *mJava << mReturnParam->variableName << ");\n"; + } + + if (generateCallToVerifier) { + mJava->indent() << mJavaVerifyMethodName << "("; + for (auto p : mAllInputsAndOutputs) { + *mJava << p->variableName << ", "; + } + + if (relaxed) { + *mJava << "true"; + } else { + *mJava << "false"; + } + *mJava << ");\n"; + } + mJava->decreaseIndent(); + mJava->indent() << "} catch (Exception e) {\n"; + mJava->increaseIndent(); + mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_" + << mRsKernelName << ": \" + e.toString());\n"; + mJava->endBlock(); +} + +/* Write the section of the .rs file for this permutation. + * + * We communicate the extra input and output parameters via global allocations. + * For example, if we have a function that takes three arguments, two for input + * and one for output: + * + * start: + * name: gamn + * ret: float3 + * arg: float3 a + * arg: int b + * arg: float3 *c + * end: + * + * We'll produce: + * + * rs_allocation gAllocInB; + * rs_allocation gAllocOutC; + * + * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) { + * int inB; + * float3 outC; + * float2 out; + * inB = rsGetElementAt_int(gAllocInB, x); + * out = gamn(a, in_b, &outC); + * rsSetElementAt_float4(gAllocOutC, &outC, x); + * return out; + * } + * + * We avoid re-using x and y from the definition because these have reserved + * meanings in a .rs file. + */ +void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const { + // Write the allocation declarations we'll need. + for (auto p : mPermutation.getParams()) { + // Don't need allocation for one input and one return value. + if (p != mFirstInputParam) { + writeRsAllocationDefinition(*p, rsAllocationsGenerated); + } + } + *mRs << "\n"; + + // Write the function header. + if (mReturnParam) { + *mRs << mReturnParam->rsType; + } else { + *mRs << "void"; + } + *mRs << " __attribute__((kernel)) " << mRsKernelName; + *mRs << "("; + bool needComma = false; + if (mFirstInputParam) { + *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName; + needComma = true; + } + if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) { + if (needComma) { + *mRs << ", "; + } + *mRs << "unsigned int x"; + } + *mRs << ")"; + mRs->startBlock(); + + // Write the local variable declarations and initializations. + for (auto p : mPermutation.getParams()) { + if (p == mFirstInputParam) { + continue; + } + mRs->indent() << p->rsType << " " << p->variableName; + if (p->isOutParameter) { + *mRs << " = 0;\n"; + } else { + *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n"; + } + } + + // Write the function call. + if (mReturnParam) { + if (mPermutation.getOutputCount() > 1) { + mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = "; + } else { + mRs->indent() << "return "; + } + } + *mRs << mPermutation.getName() << "("; + needComma = false; + for (auto p : mPermutation.getParams()) { + if (needComma) { + *mRs << ", "; + } + if (p->isOutParameter) { + *mRs << "&"; + } + *mRs << p->variableName; + needComma = true; + } + *mRs << ");\n"; + + if (mPermutation.getOutputCount() > 1) { + // Write setting the extra out parameters into the allocations. + for (auto p : mPermutation.getParams()) { + if (p->isOutParameter) { + mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", "; + // Check if we need to use '&' for this type of argument. + char lastChar = p->variableName.back(); + if (lastChar >= '0' && lastChar <= '9') { + *mRs << "&"; + } + *mRs << p->variableName << ", x);\n"; + } + } + if (mReturnParam) { + mRs->indent() << "return " << mReturnParam->variableName << ";\n"; + } + } + mRs->endBlock(); +} + +void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param, + set<string>* rsAllocationsGenerated) const { + if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) { + *mRs << "rs_allocation " << param.rsAllocName << ";\n"; + } +} + +// Open the mJavaFile and writes the header. +static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory, + const string& testName, const string& relaxedTestName) { + const string fileName = testName + ".java"; + if (!file->start(directory, fileName)) { + return false; + } + file->writeNotices(); + + *file << "package android.renderscript.cts;\n\n"; + + *file << "import android.renderscript.Allocation;\n"; + *file << "import android.renderscript.RSRuntimeException;\n"; + *file << "import android.renderscript.Element;\n"; + *file << "import android.renderscript.cts.Target;\n\n"; + *file << "import java.util.Arrays;\n\n"; + + *file << "public class " << testName << " extends RSBaseCompute"; + file->startBlock(); // The corresponding endBlock() is in finishJavaFile() + *file << "\n"; + + file->indent() << "private ScriptC_" << testName << " script;\n"; + file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n"; + + file->indent() << "@Override\n"; + file->indent() << "protected void setUp() throws Exception"; + file->startBlock(); + + file->indent() << "super.setUp();\n"; + file->indent() << "script = new ScriptC_" << testName << "(mRS);\n"; + file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n"; + + file->endBlock(); + *file << "\n"; + return true; +} + +// Write the test method that calls all the generated Check methods. +static void finishJavaFile(GeneratedFile* file, const Function& function, + const vector<string>& javaCheckMethods) { + file->indent() << "public void test" << function.getCapitalizedName() << "()"; + file->startBlock(); + for (auto m : javaCheckMethods) { + file->indent() << m << "();\n"; + } + file->endBlock(); + + file->endBlock(); +} + +// Open the script file and write its header. +static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory, + const string& testName) { + string fileName = testName + ".rs"; + if (!file->start(directory, fileName)) { + return false; + } + file->writeNotices(); + + *file << "#pragma version(1)\n"; + *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n"; + return true; +} + +// Write the entire *Relaxed.rs test file, as it only depends on the name. +static bool writeRelaxedRsFile(const Function& function, const string& directory, + const string& testName, const string& relaxedTestName) { + string name = relaxedTestName + ".rs"; + + GeneratedFile file; + if (!file.start(directory, name)) { + return false; + } + file.writeNotices(); + + file << "#include \"" << testName << ".rs\"\n"; + file << "#pragma rs_fp_relaxed\n"; + file.close(); + return true; +} + +/* Write the .java and the two .rs test files. versionOfTestFiles is used to restrict which API + * to test. + */ +static bool writeTestFilesForFunction(const Function& function, const string& directory, + unsigned int versionOfTestFiles) { + // Avoid creating empty files if we're not testing this function. + if (!needTestFiles(function, versionOfTestFiles)) { + return true; + } + + const string testName = "Test" + function.getCapitalizedName(); + const string relaxedTestName = testName + "Relaxed"; + + if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) { + return false; + } + + GeneratedFile rsFile; // The Renderscript test file we're generating. + GeneratedFile javaFile; // The Jave test file we're generating. + if (!startRsFile(&rsFile, function, directory, testName)) { + return false; + } + + if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) { + return false; + } + + /* We keep track of the allocations generated in the .rs file and the argument classes defined + * in the Java file, as we share these between the functions created for each specification. + */ + set<string> rsAllocationsGenerated; + set<string> javaGeneratedArgumentClasses; + // Lines of Java code to invoke the check methods. + vector<string> javaCheckMethods; + + for (auto spec : function.getSpecifications()) { + if (spec->hasTests(versionOfTestFiles)) { + for (auto permutation : spec->getPermutations()) { + PermutationWriter w(*permutation, &rsFile, &javaFile); + w.writeRsSection(&rsAllocationsGenerated); + w.writeJavaSection(&javaGeneratedArgumentClasses); + + // Store the check method to be called. + javaCheckMethods.push_back(w.getJavaCheckMethodName()); + } + } + } + + finishJavaFile(&javaFile, function, javaCheckMethods); + // There's no work to wrap-up in the .rs file. + + rsFile.close(); + javaFile.close(); + return true; +} + +bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) { + bool success = true; + for (auto f : systemSpecification.getFunctions()) { + if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) { + success = false; + } + } + return success; +} |