aboutsummaryrefslogtreecommitdiff
path: root/libc
diff options
context:
space:
mode:
authorSiva Chandra Reddy <sivachandra@google.com>2020-08-20 22:36:53 -0700
committerSiva Chandra Reddy <sivachandra@google.com>2020-08-25 21:42:49 -0700
commit3f4674a5577dcc63a846d33f61e9bd95e388223d (patch)
treed683c03c11b697656d07cb4ff85e6a4e467eeb5e /libc
parent75e0b5866869ea1feb140d6f718d74c786547113 (diff)
downloadllvm-project-3f4674a5577dcc63a846d33f61e9bd95e388223d.tar.gz
[libc] Extend MPFRMatcher to handle multiple-input-multiple-output functions.
Tests for frexp[f|l] now use the new capability. Not all input-output combinations have been addressed by this change. Support for newer combinations can be added in future as needed. Reviewed By: lntue Differential Revision: https://reviews.llvm.org/D86506
Diffstat (limited to 'libc')
-rw-r--r--libc/test/src/math/CMakeLists.txt3
-rw-r--r--libc/test/src/math/frexp_test.cpp25
-rw-r--r--libc/test/src/math/frexpf_test.cpp26
-rw-r--r--libc/test/src/math/frexpl_test.cpp12
-rw-r--r--libc/utils/MPFRWrapper/MPFRUtils.cpp345
-rw-r--r--libc/utils/MPFRWrapper/MPFRUtils.h200
6 files changed, 514 insertions, 97 deletions
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index e73de5403564..2fe766a2ffc6 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -333,6 +333,7 @@ add_fp_unittest(
add_fp_unittest(
frexp_test
+ NEED_MPFR
SUITE
libc_math_unittests
SRCS
@@ -345,6 +346,7 @@ add_fp_unittest(
add_fp_unittest(
frexpf_test
+ NEED_MPFR
SUITE
libc_math_unittests
SRCS
@@ -357,6 +359,7 @@ add_fp_unittest(
add_fp_unittest(
frexpl_test
+ NEED_MPFR
SUITE
libc_math_unittests
SRCS
diff --git a/libc/test/src/math/frexp_test.cpp b/libc/test/src/math/frexp_test.cpp
index f828d515a688..360bbf237560 100644
--- a/libc/test/src/math/frexp_test.cpp
+++ b/libc/test/src/math/frexp_test.cpp
@@ -11,13 +11,18 @@
#include "utils/FPUtil/BasicOperations.h"
#include "utils/FPUtil/BitPatterns.h"
#include "utils/FPUtil/ClassificationFunctions.h"
+#include "utils/FPUtil/FPBits.h"
#include "utils/FPUtil/FloatOperations.h"
#include "utils/FPUtil/FloatProperties.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
+using FPBits = __llvm_libc::fputil::FPBits<double>;
using __llvm_libc::fputil::valueAsBits;
using __llvm_libc::fputil::valueFromBits;
+namespace mpfr = __llvm_libc::testing::mpfr;
+
using BitPatterns = __llvm_libc::fputil::BitPatterns<double>;
using Properties = __llvm_libc::fputil::FloatProperties<double>;
@@ -127,17 +132,19 @@ TEST(FrexpTest, SomeIntegers) {
}
TEST(FrexpTest, InDoubleRange) {
- using BitsType = Properties::BitsType;
- constexpr BitsType count = 1000000;
- constexpr BitsType step = UINT64_MAX / count;
- for (BitsType i = 0, v = 0; i <= count; ++i, v += step) {
- double x = valueFromBits(v);
+ using UIntType = FPBits::UIntType;
+ constexpr UIntType count = 1000001;
+ constexpr UIntType step = UIntType(-1) / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ double x = FPBits(v);
if (isnan(x) || isinf(x) || x == 0.0)
continue;
- int exponent;
- double frac = __llvm_libc::frexp(x, &exponent);
- ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0);
- ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5);
+ mpfr::BinaryOutput<double> result;
+ result.f = __llvm_libc::frexp(x, &result.i);
+
+ ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
+ ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
+ ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
}
}
diff --git a/libc/test/src/math/frexpf_test.cpp b/libc/test/src/math/frexpf_test.cpp
index 3b82c68078ee..1bf0c36cf165 100644
--- a/libc/test/src/math/frexpf_test.cpp
+++ b/libc/test/src/math/frexpf_test.cpp
@@ -11,14 +11,18 @@
#include "utils/FPUtil/BasicOperations.h"
#include "utils/FPUtil/BitPatterns.h"
#include "utils/FPUtil/ClassificationFunctions.h"
+#include "utils/FPUtil/FPBits.h"
#include "utils/FPUtil/FloatOperations.h"
#include "utils/FPUtil/FloatProperties.h"
#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
+using FPBits = __llvm_libc::fputil::FPBits<float>;
using __llvm_libc::fputil::valueAsBits;
using __llvm_libc::fputil::valueFromBits;
+namespace mpfr = __llvm_libc::testing::mpfr;
+
using BitPatterns = __llvm_libc::fputil::BitPatterns<float>;
using Properties = __llvm_libc::fputil::FloatProperties<float>;
@@ -109,7 +113,7 @@ TEST(FrexpfTest, PowersOfTwo) {
EXPECT_EQ(exponent, 7);
}
-TEST(FrexpTest, SomeIntegers) {
+TEST(FrexpfTest, SomeIntegers) {
int exponent;
EXPECT_EQ(valueAsBits(0.75f),
@@ -135,17 +139,19 @@ TEST(FrexpTest, SomeIntegers) {
}
TEST(FrexpfTest, InFloatRange) {
- using BitsType = Properties::BitsType;
- constexpr BitsType count = 1000000;
- constexpr BitsType step = UINT32_MAX / count;
- for (BitsType i = 0, v = 0; i <= count; ++i, v += step) {
- float x = valueFromBits(v);
+ using UIntType = FPBits::UIntType;
+ constexpr UIntType count = 1000001;
+ constexpr UIntType step = UIntType(-1) / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ float x = FPBits(v);
if (isnan(x) || isinf(x) || x == 0.0)
continue;
- int exponent;
- float frac = __llvm_libc::frexpf(x, &exponent);
- ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0f);
- ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5f);
+ mpfr::BinaryOutput<float> result;
+ result.f = __llvm_libc::frexpf(x, &result.i);
+
+ ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
+ ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
+ ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
}
}
diff --git a/libc/test/src/math/frexpl_test.cpp b/libc/test/src/math/frexpl_test.cpp
index ace445f0a2de..9846bb84ae27 100644
--- a/libc/test/src/math/frexpl_test.cpp
+++ b/libc/test/src/math/frexpl_test.cpp
@@ -10,10 +10,13 @@
#include "src/math/frexpl.h"
#include "utils/FPUtil/BasicOperations.h"
#include "utils/FPUtil/FPBits.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
using FPBits = __llvm_libc::fputil::FPBits<long double>;
+namespace mpfr = __llvm_libc::testing::mpfr;
+
TEST(FrexplTest, SpecialNumbers) {
int exponent;
@@ -94,10 +97,11 @@ TEST(FrexplTest, LongDoubleRange) {
if (isnan(x) || isinf(x) || x == 0.0l)
continue;
- int exponent;
- long double frac = __llvm_libc::frexpl(x, &exponent);
+ mpfr::BinaryOutput<long double> result;
+ result.f = __llvm_libc::frexpl(x, &result.i);
- ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0l);
- ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5l);
+ ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
+ ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
+ ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
}
}
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index a3abfce08bf3..86882d05cc39 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -14,6 +14,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include <memory>
#include <mpfr.h>
#include <stdint.h>
#include <string>
@@ -65,50 +66,90 @@ public:
mpfr_set_sj(value, x, MPFR_RNDN);
}
- template <typename XType,
- cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
- MPFRNumber(Operation op, XType rawValue) {
- mpfr_init2(value, mpfrPrecision);
- MPFRNumber mpfrInput(rawValue);
- switch (op) {
- case Operation::Abs:
- mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
- break;
- case Operation::Ceil:
- mpfr_ceil(value, mpfrInput.value);
- break;
- case Operation::Cos:
- mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
- break;
- case Operation::Exp:
- mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
- break;
- case Operation::Exp2:
- mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
- break;
- case Operation::Floor:
- mpfr_floor(value, mpfrInput.value);
- break;
- case Operation::Round:
- mpfr_round(value, mpfrInput.value);
- break;
- case Operation::Sin:
- mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
- break;
- case Operation::Sqrt:
- mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN);
- break;
- case Operation::Trunc:
- mpfr_trunc(value, mpfrInput.value);
- break;
- }
- }
-
MPFRNumber(const MPFRNumber &other) {
mpfr_set(value, other.value, MPFR_RNDN);
}
- ~MPFRNumber() { mpfr_clear(value); }
+ MPFRNumber &operator=(const MPFRNumber &rhs) {
+ mpfr_set(value, rhs.value, MPFR_RNDN);
+ return *this;
+ }
+
+ MPFRNumber abs() const {
+ MPFRNumber result;
+ mpfr_abs(result.value, value, MPFR_RNDN);
+ return result;
+ }
+
+ MPFRNumber ceil() const {
+ MPFRNumber result;
+ mpfr_ceil(result.value, value);
+ return result;
+ }
+
+ MPFRNumber cos() const {
+ MPFRNumber result;
+ mpfr_cos(result.value, value, MPFR_RNDN);
+ return result;
+ }
+
+ MPFRNumber exp() const {
+ MPFRNumber result;
+ mpfr_exp(result.value, value, MPFR_RNDN);
+ return result;
+ }
+
+ MPFRNumber exp2() const {
+ MPFRNumber result;
+ mpfr_exp2(result.value, value, MPFR_RNDN);
+ return result;
+ }
+
+ MPFRNumber floor() const {
+ MPFRNumber result;
+ mpfr_floor(result.value, value);
+ return result;
+ }
+
+ MPFRNumber frexp(int &exp) {
+ MPFRNumber result;
+ mpfr_exp_t resultExp;
+ mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
+ exp = resultExp;
+ return result;
+ }
+
+ MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
+ MPFRNumber remainder;
+ long q;
+ mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
+ quotient = q;
+ return remainder;
+ }
+
+ MPFRNumber round() const {
+ MPFRNumber result;
+ mpfr_round(result.value, value);
+ return result;
+ }
+
+ MPFRNumber sin() const {
+ MPFRNumber result;
+ mpfr_sin(result.value, value, MPFR_RNDN);
+ return result;
+ }
+
+ MPFRNumber sqrt() const {
+ MPFRNumber result;
+ mpfr_sqrt(result.value, value, MPFR_RNDN);
+ return result;
+ }
+
+ MPFRNumber trunc() const {
+ MPFRNumber result;
+ mpfr_trunc(result.value, value);
+ return result;
+ }
std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
@@ -179,10 +220,65 @@ public:
namespace internal {
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+unaryOperation(Operation op, InputType input) {
+ MPFRNumber mpfrInput(input);
+ switch (op) {
+ case Operation::Abs:
+ return mpfrInput.abs();
+ case Operation::Ceil:
+ return mpfrInput.ceil();
+ case Operation::Cos:
+ return mpfrInput.cos();
+ case Operation::Exp:
+ return mpfrInput.exp();
+ case Operation::Exp2:
+ return mpfrInput.exp2();
+ case Operation::Floor:
+ return mpfrInput.floor();
+ case Operation::Round:
+ return mpfrInput.round();
+ case Operation::Sin:
+ return mpfrInput.sin();
+ case Operation::Sqrt:
+ return mpfrInput.sqrt();
+ case Operation::Trunc:
+ return mpfrInput.trunc();
+ default:
+ __builtin_unreachable();
+ }
+}
+
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+unaryOperationTwoOutputs(Operation op, InputType input, int &output) {
+ MPFRNumber mpfrInput(input);
+ switch (op) {
+ case Operation::Frexp:
+ return mpfrInput.frexp(output);
+ default:
+ __builtin_unreachable();
+ }
+}
+
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
+ MPFRNumber inputX(x), inputY(y);
+ switch (op) {
+ case Operation::RemQuo:
+ return inputX.remquo(inputY, output);
+ default:
+ __builtin_unreachable();
+ }
+}
+
template <typename T>
-void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
- MPFRNumber mpfrResult(operation, input);
+void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
+ testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
+ MPFRNumber mpfrResult = unaryOperation(op, input);
MPFRNumber mpfrMatchValue(matchValue);
FPBits<T> inputBits(input);
FPBits<T> matchBits(matchValue);
@@ -201,25 +297,174 @@ void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
<< '\n';
}
-template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
-template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
template void
-MPFRMatcher<long double>::explainError(testutils::StreamWrapper &);
+explainUnaryOperationSingleOutputError<float>(Operation op, float, float,
+ testutils::StreamWrapper &);
+template void
+explainUnaryOperationSingleOutputError<double>(Operation op, double, double,
+ testutils::StreamWrapper &);
+template void explainUnaryOperationSingleOutputError<long double>(
+ Operation op, long double, long double, testutils::StreamWrapper &);
+
+template <typename T>
+void explainUnaryOperationTwoOutputsError(Operation op, T input,
+ const BinaryOutput<T> &libcResult,
+ testutils::StreamWrapper &OS) {
+ MPFRNumber mpfrInput(input);
+ FPBits<T> inputBits(input);
+ int mpfrIntResult;
+ MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
+
+ if (mpfrIntResult != libcResult.i) {
+ OS << "MPFR integral result: " << mpfrIntResult << '\n'
+ << "Libc integral result: " << libcResult.i << '\n';
+ } else {
+ OS << "Integral result from libc matches integral result from MPFR.\n";
+ }
+
+ MPFRNumber mpfrMatchValue(libcResult.f);
+ OS << "Libc floating point result is not within tolerance value of the MPFR "
+ << "result.\n\n";
+
+ OS << " Input decimal: " << mpfrInput.str() << "\n\n";
+
+ OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n';
+ __llvm_libc::fputil::testing::describeValue(
+ " Libc floating point bits: ", libcResult.f, OS);
+ OS << "\n\n";
+
+ OS << " MPFR result: " << mpfrResult.str() << '\n';
+ __llvm_libc::fputil::testing::describeValue(
+ " MPFR rounded: ", mpfrResult.as<T>(), OS);
+ OS << '\n'
+ << " ULP error: "
+ << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
+}
+
+template void explainUnaryOperationTwoOutputsError<float>(
+ Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
+template void
+explainUnaryOperationTwoOutputsError<double>(Operation, double,
+ const BinaryOutput<double> &,
+ testutils::StreamWrapper &);
+template void explainUnaryOperationTwoOutputsError<long double>(
+ Operation, long double, const BinaryOutput<long double> &,
+ testutils::StreamWrapper &);
template <typename T>
-bool compare(Operation op, T input, T libcResult, double ulpError) {
+void explainBinaryOperationTwoOutputsError(Operation op,
+ const BinaryInput<T> &input,
+ const BinaryOutput<T> &libcResult,
+ testutils::StreamWrapper &OS) {
+ MPFRNumber mpfrX(input.x);
+ MPFRNumber mpfrY(input.y);
+ FPBits<T> xbits(input.x);
+ FPBits<T> ybits(input.y);
+ int mpfrIntResult;
+ MPFRNumber mpfrResult =
+ binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
+ MPFRNumber mpfrMatchValue(libcResult.f);
+
+ OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
+ << "MPFR integral result: " << mpfrIntResult << '\n'
+ << "Libc integral result: " << libcResult.i << '\n'
+ << "Libc floating point result: " << mpfrMatchValue.str() << '\n'
+ << " MPFR result: " << mpfrResult.str() << '\n';
+ __llvm_libc::fputil::testing::describeValue(
+ "Libc floating point result bits: ", libcResult.f, OS);
+ __llvm_libc::fputil::testing::describeValue(
+ " MPFR rounded bits: ", mpfrResult.as<T>(), OS);
+ OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
+}
+
+template void explainBinaryOperationTwoOutputsError<float>(
+ Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
+ testutils::StreamWrapper &);
+template void explainBinaryOperationTwoOutputsError<double>(
+ Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
+ testutils::StreamWrapper &);
+template void explainBinaryOperationTwoOutputsError<long double>(
+ Operation, const BinaryInput<long double> &,
+ const BinaryOutput<long double> &, testutils::StreamWrapper &);
+
+template <typename T>
+bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
+ double ulpError) {
// If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
// is rounded to the nearest even.
- MPFRNumber mpfrResult(op, input);
+ MPFRNumber mpfrResult = unaryOperation(op, input);
double ulp = mpfrResult.ulp(libcResult);
bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}
-template bool compare<float>(Operation, float, float, double);
-template bool compare<double>(Operation, double, double, double);
-template bool compare<long double>(Operation, long double, long double, double);
+template bool compareUnaryOperationSingleOutput<float>(Operation, float, float,
+ double);
+template bool compareUnaryOperationSingleOutput<double>(Operation, double,
+ double, double);
+template bool compareUnaryOperationSingleOutput<long double>(Operation,
+ long double,
+ long double,
+ double);
+
+template <typename T>
+bool compareUnaryOperationTwoOutputs(Operation op, T input,
+ const BinaryOutput<T> &libcResult,
+ double ulpError) {
+ int mpfrIntResult;
+ MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
+ double ulp = mpfrResult.ulp(libcResult.f);
+
+ if (mpfrIntResult != libcResult.i)
+ return false;
+
+ bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
+ return (ulp < ulpError) ||
+ ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool
+compareUnaryOperationTwoOutputs<float>(Operation, float,
+ const BinaryOutput<float> &, double);
+template bool
+compareUnaryOperationTwoOutputs<double>(Operation, double,
+ const BinaryOutput<double> &, double);
+template bool compareUnaryOperationTwoOutputs<long double>(
+ Operation, long double, const BinaryOutput<long double> &, double);
+
+template <typename T>
+bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
+ const BinaryOutput<T> &libcResult,
+ double ulpError) {
+ int mpfrIntResult;
+ MPFRNumber mpfrResult =
+ binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
+ double ulp = mpfrResult.ulp(libcResult.f);
+
+ if (mpfrIntResult != libcResult.i) {
+ if (op == Operation::RemQuo) {
+ if ((0x7 & mpfrIntResult) != libcResult.i)
+ return false;
+ } else {
+ return false;
+ }
+ }
+
+ bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
+ return (ulp < ulpError) ||
+ ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool
+compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &,
+ const BinaryOutput<float> &, double);
+template bool
+compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &,
+ const BinaryOutput<double> &, double);
+template bool compareBinaryOperationTwoOutputs<long double>(
+ Operation, const BinaryInput<long double> &,
+ const BinaryOutput<long double> &, double);
} // namespace internal
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 3d94079e65d8..b46f09dd5e55 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -19,6 +19,10 @@ namespace testing {
namespace mpfr {
enum class Operation : int {
+ // Operations with take a single floating point number as input
+ // and produce a single floating point number as output. The input
+ // and output floating point numbers are of the same kind.
+ BeginUnaryOperationsSingleOutput,
Abs,
Ceil,
Cos,
@@ -28,45 +32,193 @@ enum class Operation : int {
Round,
Sin,
Sqrt,
- Trunc
+ Trunc,
+ EndUnaryOperationsSingleOutput,
+
+ // Operations which take a single floating point nubmer as input
+ // but produce two outputs. The first ouput is a floating point
+ // number of the same type as the input. The second output is of type
+ // 'int'.
+ BeginUnaryOperationsTwoOutputs,
+ Frexp, // Floating point output, the first output, is the fractional part.
+ EndUnaryOperationsTwoOutputs,
+
+ // Operations wich take two floating point nubmers of the same type as
+ // input and produce a single floating point number of the same type as
+ // output.
+ BeginBinaryOperationsSingleOutput,
+ // TODO: Add operations like hypot.
+ EndBinaryOperationsSingleOutput,
+
+ // Operations which take two floating point numbers of the same type as
+ // input and produce two outputs. The first output is a floating nubmer of
+ // the same type as the inputs. The second output is af type 'int'.
+ BeginBinaryOperationsTwoOutputs,
+ RemQuo, // The first output, the floating point output, is the remainder.
+ EndBinaryOperationsTwoOutputs,
+
+ BeginTernaryOperationsSingleOuput,
+ // TODO: Add operations like fma.
+ EndTernaryOperationsSingleOutput,
+};
+
+template <typename T> struct BinaryInput {
+ static_assert(
+ __llvm_libc::cpp::IsFloatingPointType<T>::Value,
+ "Template parameter of BinaryInput must be a floating point type.");
+
+ using Type = T;
+ T x, y;
+};
+
+template <typename T> struct TernaryInput {
+ static_assert(
+ __llvm_libc::cpp::IsFloatingPointType<T>::Value,
+ "Template parameter of TernaryInput must be a floating point type.");
+
+ using Type = T;
+ T x, y, z;
+};
+
+template <typename T> struct BinaryOutput {
+ T f;
+ int i;
};
namespace internal {
+template <typename T1, typename T2>
+struct AreMatchingBinaryInputAndBinaryOutput {
+ static constexpr bool value = false;
+};
+
template <typename T>
-bool compare(Operation op, T input, T libcOutput, double t);
+struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
+ static constexpr bool value = cpp::IsFloatingPointType<T>::Value;
+};
-template <typename T> class MPFRMatcher : public testing::Matcher<T> {
- static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value,
- "MPFRMatcher can only be used with floating point values.");
+template <typename T>
+bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput,
+ double t);
+template <typename T>
+bool compareUnaryOperationTwoOutputs(Operation op, T input,
+ const BinaryOutput<T> &libcOutput,
+ double t);
+template <typename T>
+bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
+ const BinaryOutput<T> &libcOutput,
+ double t);
- Operation operation;
- T input;
- T matchValue;
+template <typename T>
+void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
+ testutils::StreamWrapper &OS);
+template <typename T>
+void explainUnaryOperationTwoOutputsError(Operation op, T input,
+ const BinaryOutput<T> &matchValue,
+ testutils::StreamWrapper &OS);
+template <typename T>
+void explainBinaryOperationTwoOutputsError(Operation op,
+ const BinaryInput<T> &input,
+ const BinaryOutput<T> &matchValue,
+ testutils::StreamWrapper &OS);
+
+template <Operation op, typename InputType, typename OutputType>
+class MPFRMatcher : public testing::Matcher<OutputType> {
+ InputType input;
+ OutputType matchValue;
double ulpTolerance;
public:
- MPFRMatcher(Operation op, T testInput, double ulpTolerance)
- : operation(op), input(testInput), ulpTolerance(ulpTolerance) {}
+ MPFRMatcher(InputType testInput, double ulpTolerance)
+ : input(testInput), ulpTolerance(ulpTolerance) {}
- bool match(T libcResult) {
+ bool match(OutputType libcResult) {
matchValue = libcResult;
- return internal::compare(operation, input, libcResult, ulpTolerance);
+ return match(input, matchValue, ulpTolerance);
}
- void explainError(testutils::StreamWrapper &OS) override;
+ void explainError(testutils::StreamWrapper &OS) override {
+ explainError(input, matchValue, OS);
+ }
+
+private:
+ template <typename T> static bool match(T in, T out, double tolerance) {
+ return compareUnaryOperationSingleOutput(op, in, out, tolerance);
+ }
+
+ template <typename T>
+ static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
+ return compareUnaryOperationTwoOutputs(op, in, out, tolerance);
+ }
+
+ template <typename T>
+ static bool match(const BinaryInput<T> &in, T out, double tolerance) {
+ // TODO: Implement the comparision function and error reporter.
+ }
+
+ template <typename T>
+ static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
+ double tolerance) {
+ return compareBinaryOperationTwoOutputs(op, in, out, tolerance);
+ }
+
+ template <typename T>
+ static bool match(const TernaryInput<T> &in, T out, double tolerance) {
+ // TODO: Implement the comparision function and error reporter.
+ }
+
+ template <typename T>
+ static void explainError(T in, T out, testutils::StreamWrapper &OS) {
+ explainUnaryOperationSingleOutputError(op, in, out, OS);
+ }
+
+ template <typename T>
+ static void explainError(T in, const BinaryOutput<T> &out,
+ testutils::StreamWrapper &OS) {
+ explainUnaryOperationTwoOutputsError(op, in, out, OS);
+ }
+
+ template <typename T>
+ static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out,
+ testutils::StreamWrapper &OS) {
+ explainBinaryOperationTwoOutputsError(op, in, out, OS);
+ }
};
} // namespace internal
-template <typename T, typename U>
+// Return true if the input and ouput types for the operation op are valid
+// types.
+template <Operation op, typename InputType, typename OutputType>
+constexpr bool isValidOperation() {
+ return (Operation::BeginUnaryOperationsSingleOutput < op &&
+ op < Operation::EndUnaryOperationsSingleOutput &&
+ cpp::IsSame<InputType, OutputType>::Value &&
+ cpp::IsFloatingPointType<InputType>::Value) ||
+ (Operation::BeginUnaryOperationsTwoOutputs < op &&
+ op < Operation::EndUnaryOperationsTwoOutputs &&
+ cpp::IsFloatingPointType<InputType>::Value &&
+ cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
+ (Operation::BeginBinaryOperationsSingleOutput < op &&
+ op < Operation::EndBinaryOperationsSingleOutput &&
+ cpp::IsFloatingPointType<OutputType>::Value &&
+ cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
+ (Operation::BeginBinaryOperationsTwoOutputs < op &&
+ op < Operation::EndBinaryOperationsTwoOutputs &&
+ internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
+ OutputType>::value) ||
+ (Operation::BeginTernaryOperationsSingleOuput < op &&
+ op < Operation::EndTernaryOperationsSingleOutput &&
+ cpp::IsFloatingPointType<OutputType>::Value &&
+ cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
+}
+
+template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address")))
-typename cpp::EnableIfType<cpp::IsSameV<U, double>, internal::MPFRMatcher<T>>
-getMPFRMatcher(Operation op, T input, U t) {
- static_assert(
- __llvm_libc::cpp::IsFloatingPointType<T>::Value,
- "getMPFRMatcher can only be used to match floating point results.");
- return internal::MPFRMatcher<T>(op, input, t);
+cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(),
+ internal::MPFRMatcher<op, InputType, OutputType>>
+getMPFRMatcher(InputType input, OutputType outputUnused, double t) {
+ return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
}
} // namespace mpfr
@@ -74,11 +226,11 @@ getMPFRMatcher(Operation op, T input, U t) {
} // namespace __llvm_libc
#define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \
- EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
- op, input, tolerance))
+ EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
+ input, matchValue, tolerance))
#define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \
- ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
- op, input, tolerance))
+ ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
+ input, matchValue, tolerance))
#endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H