summaryrefslogtreecommitdiff
path: root/api/gen_runtime.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'api/gen_runtime.cpp')
-rw-r--r--api/gen_runtime.cpp71
1 files changed, 51 insertions, 20 deletions
diff --git a/api/gen_runtime.cpp b/api/gen_runtime.cpp
index bfc3d6d8..4570d680 100644
--- a/api/gen_runtime.cpp
+++ b/api/gen_runtime.cpp
@@ -302,6 +302,8 @@ private:
// The number of input and output parameters.
int mInputCount;
int mOutputCount;
+ // Whether one of the output parameters is a float.
+ bool mHasFloatAnswers;
string mRsKernelName;
string mJavaArgumentsClassName;
@@ -348,8 +350,8 @@ private:
void writeJavaAppendNewLineToMessage(ofstream& file, int indent) const;
void writeJavaAppendVariableToMessage(ofstream& file, int indent, const ParameterDefinition& p,
const string& value) const;
- void writeJavaAppendFloatyVariableToMessage(ofstream& file, int indent,
- const string& value) const;
+ void writeJavaAppendFloatVariableToMessage(ofstream& file, int indent, const string& value,
+ bool regularFloat) const;
void writeJavaVectorComparison(ofstream& file, int indent, const ParameterDefinition& p) const;
void writeJavaAppendVectorInputToMessage(ofstream& file, int indent,
const ParameterDefinition& p) const;
@@ -1118,12 +1120,16 @@ Permutation::Permutation(Function* func, Specification* spec, int i1, int i2, in
vector<string> paramDefinitions;
spec->getParams(i1, i2, i3, i4, &paramDefinitions);
+ mHasFloatAnswers = false;
for (size_t i = 0; i < paramDefinitions.size(); i++) {
ParameterDefinition* def = new ParameterDefinition();
def->parseParameterDefinition(paramDefinitions[i], false, &mInputCount, &mOutputCount);
if (!def->isOutParameter && mFirstInputIndex < 0) {
mFirstInputIndex = mParams.size();
}
+ if (def->isOutParameter && def->isFloatType) {
+ mHasFloatAnswers = true;
+ }
mParams.push_back(def);
}
@@ -1132,6 +1138,9 @@ Permutation::Permutation(Function* func, Specification* spec, int i1, int i2, in
ParameterDefinition* def = new ParameterDefinition();
// Adding "*" tells the parse method it's an output.
def->parseParameterDefinition(s, true, &mInputCount, &mOutputCount);
+ if (def->isOutParameter && def->isFloatType) {
+ mHasFloatAnswers = true;
+ }
mReturnIndex = mParams.size();
mParams.push_back(def);
}
@@ -1410,8 +1419,8 @@ void Permutation::writeJavaArgumentClass(ofstream& file, bool scalar) const {
for (size_t i = 0; i < mParams.size(); i++) {
const ParameterDefinition& p = *mParams[i];
s += tab(2) + "public ";
- if (p.isOutParameter && p.isFloatType) {
- s += "Floaty";
+ if (p.isOutParameter && p.isFloatType && mTest != "custom") {
+ s += "Target.Floaty";
} else {
s += p.javaBaseType;
}
@@ -1476,7 +1485,9 @@ void Permutation::writeJavaInputAllocationDefinition(ofstream& file, const strin
}
} else {
file << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
- << ", " << seed << ", false)"; // TODO set to false only for native
+ // TODO set to false only for native, i.e.
+ // << ", " << seed << ", " << (mTest == "limited" ? "false" : "true") << ")";
+ << ", " << seed << ", false)";
}
file << ";\n";
}
@@ -1606,14 +1617,26 @@ void Permutation::writeJavaVerifyScalarMethod(ofstream& file, bool verifierValid
}
}
file << tab(4) << "// Ask the CoreMathVerifier to validate.\n";
- file << tab(4) << "Floaty.setRelaxed(relaxed);\n";
+ if (mHasFloatAnswers) {
+ file << tab(4) << "Target target = new Target(relaxed);\n";
+ }
file << tab(4) << "String errorMessage = CoreMathVerifier." << mJavaVerifierVerifyMethodName
- << "(args, relaxed);\n";
+ << "(args";
+ if (mHasFloatAnswers) {
+ file << ", target";
+ }
+ file << ");\n";
file << tab(4) << "boolean valid = errorMessage == null;\n";
} else {
file << tab(4) << "// Figure out what the outputs should have been.\n";
- file << tab(4) << "Floaty.setRelaxed(relaxed);\n";
- file << tab(4) << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args);\n";
+ if (mHasFloatAnswers) {
+ file << tab(4) << "Target target = new Target(relaxed);\n";
+ }
+ file << tab(4) << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
+ if (mHasFloatAnswers) {
+ file << ", target";
+ }
+ file << ");\n";
file << tab(4) << "// Validate the outputs.\n";
file << tab(4) << "boolean valid = true;\n";
for (size_t i = 0; i < mParams.size(); i++) {
@@ -1678,7 +1701,7 @@ void Permutation::writeJavaTestOneValue(ofstream& file, int indent, const Parame
file << "args." << p.variableName << argsIndex << " != " << p.javaArrayName << actualIndex;
}
if (p.undefinedIfOutIsNan && mReturnIndex >= 0) {
- file << " && args." << mParams[mReturnIndex]->variableName << argsIndex << ".isNaN()";
+ file << " && !args." << mParams[mReturnIndex]->variableName << argsIndex << ".isNaN()";
}
file << ") {\n";
}
@@ -1691,7 +1714,7 @@ void Permutation::writeJavaAppendOutputToMessage(ofstream& file, int indent,
const string actual = "args." + p.variableName + argsIndex;
file << tab(indent) << "message.append(\"Output " + p.variableName + ": \");\n";
if (p.isFloatType) {
- writeJavaAppendFloatyVariableToMessage(file, indent, actual);
+ writeJavaAppendFloatVariableToMessage(file, indent, actual, true);
} else {
writeJavaAppendVariableToMessage(file, indent, p, actual);
}
@@ -1701,7 +1724,7 @@ void Permutation::writeJavaAppendOutputToMessage(ofstream& file, int indent,
const string actual = p.javaArrayName + actualIndex;
file << tab(indent) << "message.append(\"Expected output " + p.variableName + ": \");\n";
if (p.isFloatType) {
- writeJavaAppendFloatyVariableToMessage(file, indent, expected);
+ writeJavaAppendFloatVariableToMessage(file, indent, expected, false);
} else {
writeJavaAppendVariableToMessage(file, indent, p, expected);
}
@@ -1732,11 +1755,11 @@ void Permutation::writeJavaAppendVariableToMessage(ofstream& file, int indent,
const ParameterDefinition& p,
const string& value) const {
if (p.specType == "f16" || p.specType == "f32") {
- file << tab(indent) << "message.append(String.format(\"%14.8g %8x %15a\",\n";
+ file << tab(indent) << "message.append(String.format(\"%14.8g {%8x} %15a\",\n";
file << tab(indent + 2) << value << ", "
<< "Float.floatToRawIntBits(" << value << "), " << value << "));\n";
} else if (p.specType == "f64") {
- file << tab(indent) << "message.append(String.format(\"%24.8g %16x %31a\",\n";
+ file << tab(indent) << "message.append(String.format(\"%24.8g {%16x} %31a\",\n";
file << tab(indent + 2) << value << ", "
<< "Double.doubleToRawLongBits(" << value << "), " << value << "));\n";
} else if (p.specType[0] == 'u') {
@@ -1746,9 +1769,16 @@ void Permutation::writeJavaAppendVariableToMessage(ofstream& file, int indent,
}
}
-void Permutation::writeJavaAppendFloatyVariableToMessage(ofstream& file, int indent,
- const string& value) const {
- file << tab(indent) << "message.append(" << value << ".toString());\n";
+void Permutation::writeJavaAppendFloatVariableToMessage(ofstream& file, int indent,
+ const string& value,
+ bool regularFloat) const {
+ file << tab(indent) << "message.append(";
+ if (regularFloat) {
+ file << "Float.toString(" << value << ")";
+ } else {
+ file << value << ".toString()";
+ }
+ file << ");\n";
}
void Permutation::writeJavaVectorComparison(ofstream& file, int indent,
@@ -1804,7 +1834,7 @@ void Permutation::writeJavaVerifyVectorMethod(ofstream& file) const {
if (p.mVectorSize != "1") {
string type = p.javaBaseType;
if (p.isOutParameter && p.isFloatType) {
- type = "Floaty";
+ type = "Target.Floaty";
}
file << tab(3) << "args." << p.variableName << " = new " << type << "[" << p.mVectorSize
<< "];\n";
@@ -1827,8 +1857,9 @@ void Permutation::writeJavaVerifyVectorMethod(ofstream& file) const {
}
}
}
- file << tab(3) << "Floaty.setRelaxed(relaxed);\n";
- file << tab(3) << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args);\n\n";
+ file << tab(3) << "Target target = new Target(relaxed);\n";
+ file << tab(3) << "CoreMathVerifier." << mJavaVerifierComputeMethodName
+ << "(args, target);\n\n";
file << tab(3) << "// Compare the expected outputs to the actual values returned by RS.\n";
file << tab(3) << "boolean valid = true;\n";