diff options
Diffstat (limited to 'api/gen_runtime.cpp')
-rw-r--r-- | api/gen_runtime.cpp | 71 |
1 files changed, 51 insertions, 20 deletions
diff --git a/api/gen_runtime.cpp b/api/gen_runtime.cpp index bfc3d6d8..4570d680 100644 --- a/api/gen_runtime.cpp +++ b/api/gen_runtime.cpp @@ -302,6 +302,8 @@ private: // The number of input and output parameters. int mInputCount; int mOutputCount; + // Whether one of the output parameters is a float. + bool mHasFloatAnswers; string mRsKernelName; string mJavaArgumentsClassName; @@ -348,8 +350,8 @@ private: void writeJavaAppendNewLineToMessage(ofstream& file, int indent) const; void writeJavaAppendVariableToMessage(ofstream& file, int indent, const ParameterDefinition& p, const string& value) const; - void writeJavaAppendFloatyVariableToMessage(ofstream& file, int indent, - const string& value) const; + void writeJavaAppendFloatVariableToMessage(ofstream& file, int indent, const string& value, + bool regularFloat) const; void writeJavaVectorComparison(ofstream& file, int indent, const ParameterDefinition& p) const; void writeJavaAppendVectorInputToMessage(ofstream& file, int indent, const ParameterDefinition& p) const; @@ -1118,12 +1120,16 @@ Permutation::Permutation(Function* func, Specification* spec, int i1, int i2, in vector<string> paramDefinitions; spec->getParams(i1, i2, i3, i4, ¶mDefinitions); + mHasFloatAnswers = false; for (size_t i = 0; i < paramDefinitions.size(); i++) { ParameterDefinition* def = new ParameterDefinition(); def->parseParameterDefinition(paramDefinitions[i], false, &mInputCount, &mOutputCount); if (!def->isOutParameter && mFirstInputIndex < 0) { mFirstInputIndex = mParams.size(); } + if (def->isOutParameter && def->isFloatType) { + mHasFloatAnswers = true; + } mParams.push_back(def); } @@ -1132,6 +1138,9 @@ Permutation::Permutation(Function* func, Specification* spec, int i1, int i2, in ParameterDefinition* def = new ParameterDefinition(); // Adding "*" tells the parse method it's an output. def->parseParameterDefinition(s, true, &mInputCount, &mOutputCount); + if (def->isOutParameter && def->isFloatType) { + mHasFloatAnswers = true; + } mReturnIndex = mParams.size(); mParams.push_back(def); } @@ -1410,8 +1419,8 @@ void Permutation::writeJavaArgumentClass(ofstream& file, bool scalar) const { for (size_t i = 0; i < mParams.size(); i++) { const ParameterDefinition& p = *mParams[i]; s += tab(2) + "public "; - if (p.isOutParameter && p.isFloatType) { - s += "Floaty"; + if (p.isOutParameter && p.isFloatType && mTest != "custom") { + s += "Target.Floaty"; } else { s += p.javaBaseType; } @@ -1476,7 +1485,9 @@ void Permutation::writeJavaInputAllocationDefinition(ofstream& file, const strin } } else { file << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize - << ", " << seed << ", false)"; // TODO set to false only for native + // TODO set to false only for native, i.e. + // << ", " << seed << ", " << (mTest == "limited" ? "false" : "true") << ")"; + << ", " << seed << ", false)"; } file << ";\n"; } @@ -1606,14 +1617,26 @@ void Permutation::writeJavaVerifyScalarMethod(ofstream& file, bool verifierValid } } file << tab(4) << "// Ask the CoreMathVerifier to validate.\n"; - file << tab(4) << "Floaty.setRelaxed(relaxed);\n"; + if (mHasFloatAnswers) { + file << tab(4) << "Target target = new Target(relaxed);\n"; + } file << tab(4) << "String errorMessage = CoreMathVerifier." << mJavaVerifierVerifyMethodName - << "(args, relaxed);\n"; + << "(args"; + if (mHasFloatAnswers) { + file << ", target"; + } + file << ");\n"; file << tab(4) << "boolean valid = errorMessage == null;\n"; } else { file << tab(4) << "// Figure out what the outputs should have been.\n"; - file << tab(4) << "Floaty.setRelaxed(relaxed);\n"; - file << tab(4) << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args);\n"; + if (mHasFloatAnswers) { + file << tab(4) << "Target target = new Target(relaxed);\n"; + } + file << tab(4) << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args"; + if (mHasFloatAnswers) { + file << ", target"; + } + file << ");\n"; file << tab(4) << "// Validate the outputs.\n"; file << tab(4) << "boolean valid = true;\n"; for (size_t i = 0; i < mParams.size(); i++) { @@ -1678,7 +1701,7 @@ void Permutation::writeJavaTestOneValue(ofstream& file, int indent, const Parame file << "args." << p.variableName << argsIndex << " != " << p.javaArrayName << actualIndex; } if (p.undefinedIfOutIsNan && mReturnIndex >= 0) { - file << " && args." << mParams[mReturnIndex]->variableName << argsIndex << ".isNaN()"; + file << " && !args." << mParams[mReturnIndex]->variableName << argsIndex << ".isNaN()"; } file << ") {\n"; } @@ -1691,7 +1714,7 @@ void Permutation::writeJavaAppendOutputToMessage(ofstream& file, int indent, const string actual = "args." + p.variableName + argsIndex; file << tab(indent) << "message.append(\"Output " + p.variableName + ": \");\n"; if (p.isFloatType) { - writeJavaAppendFloatyVariableToMessage(file, indent, actual); + writeJavaAppendFloatVariableToMessage(file, indent, actual, true); } else { writeJavaAppendVariableToMessage(file, indent, p, actual); } @@ -1701,7 +1724,7 @@ void Permutation::writeJavaAppendOutputToMessage(ofstream& file, int indent, const string actual = p.javaArrayName + actualIndex; file << tab(indent) << "message.append(\"Expected output " + p.variableName + ": \");\n"; if (p.isFloatType) { - writeJavaAppendFloatyVariableToMessage(file, indent, expected); + writeJavaAppendFloatVariableToMessage(file, indent, expected, false); } else { writeJavaAppendVariableToMessage(file, indent, p, expected); } @@ -1732,11 +1755,11 @@ void Permutation::writeJavaAppendVariableToMessage(ofstream& file, int indent, const ParameterDefinition& p, const string& value) const { if (p.specType == "f16" || p.specType == "f32") { - file << tab(indent) << "message.append(String.format(\"%14.8g %8x %15a\",\n"; + file << tab(indent) << "message.append(String.format(\"%14.8g {%8x} %15a\",\n"; file << tab(indent + 2) << value << ", " << "Float.floatToRawIntBits(" << value << "), " << value << "));\n"; } else if (p.specType == "f64") { - file << tab(indent) << "message.append(String.format(\"%24.8g %16x %31a\",\n"; + file << tab(indent) << "message.append(String.format(\"%24.8g {%16x} %31a\",\n"; file << tab(indent + 2) << value << ", " << "Double.doubleToRawLongBits(" << value << "), " << value << "));\n"; } else if (p.specType[0] == 'u') { @@ -1746,9 +1769,16 @@ void Permutation::writeJavaAppendVariableToMessage(ofstream& file, int indent, } } -void Permutation::writeJavaAppendFloatyVariableToMessage(ofstream& file, int indent, - const string& value) const { - file << tab(indent) << "message.append(" << value << ".toString());\n"; +void Permutation::writeJavaAppendFloatVariableToMessage(ofstream& file, int indent, + const string& value, + bool regularFloat) const { + file << tab(indent) << "message.append("; + if (regularFloat) { + file << "Float.toString(" << value << ")"; + } else { + file << value << ".toString()"; + } + file << ");\n"; } void Permutation::writeJavaVectorComparison(ofstream& file, int indent, @@ -1804,7 +1834,7 @@ void Permutation::writeJavaVerifyVectorMethod(ofstream& file) const { if (p.mVectorSize != "1") { string type = p.javaBaseType; if (p.isOutParameter && p.isFloatType) { - type = "Floaty"; + type = "Target.Floaty"; } file << tab(3) << "args." << p.variableName << " = new " << type << "[" << p.mVectorSize << "];\n"; @@ -1827,8 +1857,9 @@ void Permutation::writeJavaVerifyVectorMethod(ofstream& file) const { } } } - file << tab(3) << "Floaty.setRelaxed(relaxed);\n"; - file << tab(3) << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args);\n\n"; + file << tab(3) << "Target target = new Target(relaxed);\n"; + file << tab(3) << "CoreMathVerifier." << mJavaVerifierComputeMethodName + << "(args, target);\n\n"; file << tab(3) << "// Compare the expected outputs to the actual values returned by RS.\n"; file << tab(3) << "boolean valid = true;\n"; |