diff options
author | Siva Chandra Reddy <sivachandra@google.com> | 2020-08-20 22:36:53 -0700 |
---|---|---|
committer | Siva Chandra Reddy <sivachandra@google.com> | 2020-08-25 21:42:49 -0700 |
commit | 3f4674a5577dcc63a846d33f61e9bd95e388223d (patch) | |
tree | d683c03c11b697656d07cb4ff85e6a4e467eeb5e /libc | |
parent | 75e0b5866869ea1feb140d6f718d74c786547113 (diff) | |
download | llvm-project-3f4674a5577dcc63a846d33f61e9bd95e388223d.tar.gz |
[libc] Extend MPFRMatcher to handle multiple-input-multiple-output functions.
Tests for frexp[f|l] now use the new capability. Not all input-output
combinations have been addressed by this change. Support for newer combinations
can be added in future as needed.
Reviewed By: lntue
Differential Revision: https://reviews.llvm.org/D86506
Diffstat (limited to 'libc')
-rw-r--r-- | libc/test/src/math/CMakeLists.txt | 3 | ||||
-rw-r--r-- | libc/test/src/math/frexp_test.cpp | 25 | ||||
-rw-r--r-- | libc/test/src/math/frexpf_test.cpp | 26 | ||||
-rw-r--r-- | libc/test/src/math/frexpl_test.cpp | 12 | ||||
-rw-r--r-- | libc/utils/MPFRWrapper/MPFRUtils.cpp | 345 | ||||
-rw-r--r-- | libc/utils/MPFRWrapper/MPFRUtils.h | 200 |
6 files changed, 514 insertions, 97 deletions
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt index e73de5403564..2fe766a2ffc6 100644 --- a/libc/test/src/math/CMakeLists.txt +++ b/libc/test/src/math/CMakeLists.txt @@ -333,6 +333,7 @@ add_fp_unittest( add_fp_unittest( frexp_test + NEED_MPFR SUITE libc_math_unittests SRCS @@ -345,6 +346,7 @@ add_fp_unittest( add_fp_unittest( frexpf_test + NEED_MPFR SUITE libc_math_unittests SRCS @@ -357,6 +359,7 @@ add_fp_unittest( add_fp_unittest( frexpl_test + NEED_MPFR SUITE libc_math_unittests SRCS diff --git a/libc/test/src/math/frexp_test.cpp b/libc/test/src/math/frexp_test.cpp index f828d515a688..360bbf237560 100644 --- a/libc/test/src/math/frexp_test.cpp +++ b/libc/test/src/math/frexp_test.cpp @@ -11,13 +11,18 @@ #include "utils/FPUtil/BasicOperations.h" #include "utils/FPUtil/BitPatterns.h" #include "utils/FPUtil/ClassificationFunctions.h" +#include "utils/FPUtil/FPBits.h" #include "utils/FPUtil/FloatOperations.h" #include "utils/FPUtil/FloatProperties.h" +#include "utils/MPFRWrapper/MPFRUtils.h" #include "utils/UnitTest/Test.h" +using FPBits = __llvm_libc::fputil::FPBits<double>; using __llvm_libc::fputil::valueAsBits; using __llvm_libc::fputil::valueFromBits; +namespace mpfr = __llvm_libc::testing::mpfr; + using BitPatterns = __llvm_libc::fputil::BitPatterns<double>; using Properties = __llvm_libc::fputil::FloatProperties<double>; @@ -127,17 +132,19 @@ TEST(FrexpTest, SomeIntegers) { } TEST(FrexpTest, InDoubleRange) { - using BitsType = Properties::BitsType; - constexpr BitsType count = 1000000; - constexpr BitsType step = UINT64_MAX / count; - for (BitsType i = 0, v = 0; i <= count; ++i, v += step) { - double x = valueFromBits(v); + using UIntType = FPBits::UIntType; + constexpr UIntType count = 1000001; + constexpr UIntType step = UIntType(-1) / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + double x = FPBits(v); if (isnan(x) || isinf(x) || x == 0.0) continue; - int exponent; - double frac = __llvm_libc::frexp(x, &exponent); - ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0); - ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5); + mpfr::BinaryOutput<double> result; + result.f = __llvm_libc::frexp(x, &result.i); + + ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0); + ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5); + ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0); } } diff --git a/libc/test/src/math/frexpf_test.cpp b/libc/test/src/math/frexpf_test.cpp index 3b82c68078ee..1bf0c36cf165 100644 --- a/libc/test/src/math/frexpf_test.cpp +++ b/libc/test/src/math/frexpf_test.cpp @@ -11,14 +11,18 @@ #include "utils/FPUtil/BasicOperations.h" #include "utils/FPUtil/BitPatterns.h" #include "utils/FPUtil/ClassificationFunctions.h" +#include "utils/FPUtil/FPBits.h" #include "utils/FPUtil/FloatOperations.h" #include "utils/FPUtil/FloatProperties.h" #include "utils/MPFRWrapper/MPFRUtils.h" #include "utils/UnitTest/Test.h" +using FPBits = __llvm_libc::fputil::FPBits<float>; using __llvm_libc::fputil::valueAsBits; using __llvm_libc::fputil::valueFromBits; +namespace mpfr = __llvm_libc::testing::mpfr; + using BitPatterns = __llvm_libc::fputil::BitPatterns<float>; using Properties = __llvm_libc::fputil::FloatProperties<float>; @@ -109,7 +113,7 @@ TEST(FrexpfTest, PowersOfTwo) { EXPECT_EQ(exponent, 7); } -TEST(FrexpTest, SomeIntegers) { +TEST(FrexpfTest, SomeIntegers) { int exponent; EXPECT_EQ(valueAsBits(0.75f), @@ -135,17 +139,19 @@ TEST(FrexpTest, SomeIntegers) { } TEST(FrexpfTest, InFloatRange) { - using BitsType = Properties::BitsType; - constexpr BitsType count = 1000000; - constexpr BitsType step = UINT32_MAX / count; - for (BitsType i = 0, v = 0; i <= count; ++i, v += step) { - float x = valueFromBits(v); + using UIntType = FPBits::UIntType; + constexpr UIntType count = 1000001; + constexpr UIntType step = UIntType(-1) / count; + for (UIntType i = 0, v = 0; i <= count; ++i, v += step) { + float x = FPBits(v); if (isnan(x) || isinf(x) || x == 0.0) continue; - int exponent; - float frac = __llvm_libc::frexpf(x, &exponent); - ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0f); - ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5f); + mpfr::BinaryOutput<float> result; + result.f = __llvm_libc::frexpf(x, &result.i); + + ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0); + ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5); + ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0); } } diff --git a/libc/test/src/math/frexpl_test.cpp b/libc/test/src/math/frexpl_test.cpp index ace445f0a2de..9846bb84ae27 100644 --- a/libc/test/src/math/frexpl_test.cpp +++ b/libc/test/src/math/frexpl_test.cpp @@ -10,10 +10,13 @@ #include "src/math/frexpl.h" #include "utils/FPUtil/BasicOperations.h" #include "utils/FPUtil/FPBits.h" +#include "utils/MPFRWrapper/MPFRUtils.h" #include "utils/UnitTest/Test.h" using FPBits = __llvm_libc::fputil::FPBits<long double>; +namespace mpfr = __llvm_libc::testing::mpfr; + TEST(FrexplTest, SpecialNumbers) { int exponent; @@ -94,10 +97,11 @@ TEST(FrexplTest, LongDoubleRange) { if (isnan(x) || isinf(x) || x == 0.0l) continue; - int exponent; - long double frac = __llvm_libc::frexpl(x, &exponent); + mpfr::BinaryOutput<long double> result; + result.f = __llvm_libc::frexpl(x, &result.i); - ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0l); - ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5l); + ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0); + ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5); + ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0); } } diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp index a3abfce08bf3..86882d05cc39 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include <memory> #include <mpfr.h> #include <stdint.h> #include <string> @@ -65,50 +66,90 @@ public: mpfr_set_sj(value, x, MPFR_RNDN); } - template <typename XType, - cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0> - MPFRNumber(Operation op, XType rawValue) { - mpfr_init2(value, mpfrPrecision); - MPFRNumber mpfrInput(rawValue); - switch (op) { - case Operation::Abs: - mpfr_abs(value, mpfrInput.value, MPFR_RNDN); - break; - case Operation::Ceil: - mpfr_ceil(value, mpfrInput.value); - break; - case Operation::Cos: - mpfr_cos(value, mpfrInput.value, MPFR_RNDN); - break; - case Operation::Exp: - mpfr_exp(value, mpfrInput.value, MPFR_RNDN); - break; - case Operation::Exp2: - mpfr_exp2(value, mpfrInput.value, MPFR_RNDN); - break; - case Operation::Floor: - mpfr_floor(value, mpfrInput.value); - break; - case Operation::Round: - mpfr_round(value, mpfrInput.value); - break; - case Operation::Sin: - mpfr_sin(value, mpfrInput.value, MPFR_RNDN); - break; - case Operation::Sqrt: - mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN); - break; - case Operation::Trunc: - mpfr_trunc(value, mpfrInput.value); - break; - } - } - MPFRNumber(const MPFRNumber &other) { mpfr_set(value, other.value, MPFR_RNDN); } - ~MPFRNumber() { mpfr_clear(value); } + MPFRNumber &operator=(const MPFRNumber &rhs) { + mpfr_set(value, rhs.value, MPFR_RNDN); + return *this; + } + + MPFRNumber abs() const { + MPFRNumber result; + mpfr_abs(result.value, value, MPFR_RNDN); + return result; + } + + MPFRNumber ceil() const { + MPFRNumber result; + mpfr_ceil(result.value, value); + return result; + } + + MPFRNumber cos() const { + MPFRNumber result; + mpfr_cos(result.value, value, MPFR_RNDN); + return result; + } + + MPFRNumber exp() const { + MPFRNumber result; + mpfr_exp(result.value, value, MPFR_RNDN); + return result; + } + + MPFRNumber exp2() const { + MPFRNumber result; + mpfr_exp2(result.value, value, MPFR_RNDN); + return result; + } + + MPFRNumber floor() const { + MPFRNumber result; + mpfr_floor(result.value, value); + return result; + } + + MPFRNumber frexp(int &exp) { + MPFRNumber result; + mpfr_exp_t resultExp; + mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN); + exp = resultExp; + return result; + } + + MPFRNumber remquo(const MPFRNumber &divisor, int "ient) { + MPFRNumber remainder; + long q; + mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN); + quotient = q; + return remainder; + } + + MPFRNumber round() const { + MPFRNumber result; + mpfr_round(result.value, value); + return result; + } + + MPFRNumber sin() const { + MPFRNumber result; + mpfr_sin(result.value, value, MPFR_RNDN); + return result; + } + + MPFRNumber sqrt() const { + MPFRNumber result; + mpfr_sqrt(result.value, value, MPFR_RNDN); + return result; + } + + MPFRNumber trunc() const { + MPFRNumber result; + mpfr_trunc(result.value, value); + return result; + } std::string str() const { // 200 bytes should be more than sufficient to hold a 100-digit number @@ -179,10 +220,65 @@ public: namespace internal { +template <typename InputType> +cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber> +unaryOperation(Operation op, InputType input) { + MPFRNumber mpfrInput(input); + switch (op) { + case Operation::Abs: + return mpfrInput.abs(); + case Operation::Ceil: + return mpfrInput.ceil(); + case Operation::Cos: + return mpfrInput.cos(); + case Operation::Exp: + return mpfrInput.exp(); + case Operation::Exp2: + return mpfrInput.exp2(); + case Operation::Floor: + return mpfrInput.floor(); + case Operation::Round: + return mpfrInput.round(); + case Operation::Sin: + return mpfrInput.sin(); + case Operation::Sqrt: + return mpfrInput.sqrt(); + case Operation::Trunc: + return mpfrInput.trunc(); + default: + __builtin_unreachable(); + } +} + +template <typename InputType> +cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber> +unaryOperationTwoOutputs(Operation op, InputType input, int &output) { + MPFRNumber mpfrInput(input); + switch (op) { + case Operation::Frexp: + return mpfrInput.frexp(output); + default: + __builtin_unreachable(); + } +} + +template <typename InputType> +cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber> +binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) { + MPFRNumber inputX(x), inputY(y); + switch (op) { + case Operation::RemQuo: + return inputX.remquo(inputY, output); + default: + __builtin_unreachable(); + } +} + template <typename T> -void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { - MPFRNumber mpfrResult(operation, input); +void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue, + testutils::StreamWrapper &OS) { MPFRNumber mpfrInput(input); + MPFRNumber mpfrResult = unaryOperation(op, input); MPFRNumber mpfrMatchValue(matchValue); FPBits<T> inputBits(input); FPBits<T> matchBits(matchValue); @@ -201,25 +297,174 @@ void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { << '\n'; } -template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &); -template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &); template void -MPFRMatcher<long double>::explainError(testutils::StreamWrapper &); +explainUnaryOperationSingleOutputError<float>(Operation op, float, float, + testutils::StreamWrapper &); +template void +explainUnaryOperationSingleOutputError<double>(Operation op, double, double, + testutils::StreamWrapper &); +template void explainUnaryOperationSingleOutputError<long double>( + Operation op, long double, long double, testutils::StreamWrapper &); + +template <typename T> +void explainUnaryOperationTwoOutputsError(Operation op, T input, + const BinaryOutput<T> &libcResult, + testutils::StreamWrapper &OS) { + MPFRNumber mpfrInput(input); + FPBits<T> inputBits(input); + int mpfrIntResult; + MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult); + + if (mpfrIntResult != libcResult.i) { + OS << "MPFR integral result: " << mpfrIntResult << '\n' + << "Libc integral result: " << libcResult.i << '\n'; + } else { + OS << "Integral result from libc matches integral result from MPFR.\n"; + } + + MPFRNumber mpfrMatchValue(libcResult.f); + OS << "Libc floating point result is not within tolerance value of the MPFR " + << "result.\n\n"; + + OS << " Input decimal: " << mpfrInput.str() << "\n\n"; + + OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n'; + __llvm_libc::fputil::testing::describeValue( + " Libc floating point bits: ", libcResult.f, OS); + OS << "\n\n"; + + OS << " MPFR result: " << mpfrResult.str() << '\n'; + __llvm_libc::fputil::testing::describeValue( + " MPFR rounded: ", mpfrResult.as<T>(), OS); + OS << '\n' + << " ULP error: " + << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n'; +} + +template void explainUnaryOperationTwoOutputsError<float>( + Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &); +template void +explainUnaryOperationTwoOutputsError<double>(Operation, double, + const BinaryOutput<double> &, + testutils::StreamWrapper &); +template void explainUnaryOperationTwoOutputsError<long double>( + Operation, long double, const BinaryOutput<long double> &, + testutils::StreamWrapper &); template <typename T> -bool compare(Operation op, T input, T libcResult, double ulpError) { +void explainBinaryOperationTwoOutputsError(Operation op, + const BinaryInput<T> &input, + const BinaryOutput<T> &libcResult, + testutils::StreamWrapper &OS) { + MPFRNumber mpfrX(input.x); + MPFRNumber mpfrY(input.y); + FPBits<T> xbits(input.x); + FPBits<T> ybits(input.y); + int mpfrIntResult; + MPFRNumber mpfrResult = + binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult); + MPFRNumber mpfrMatchValue(libcResult.f); + + OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n' + << "MPFR integral result: " << mpfrIntResult << '\n' + << "Libc integral result: " << libcResult.i << '\n' + << "Libc floating point result: " << mpfrMatchValue.str() << '\n' + << " MPFR result: " << mpfrResult.str() << '\n'; + __llvm_libc::fputil::testing::describeValue( + "Libc floating point result bits: ", libcResult.f, OS); + __llvm_libc::fputil::testing::describeValue( + " MPFR rounded bits: ", mpfrResult.as<T>(), OS); + OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n'; +} + +template void explainBinaryOperationTwoOutputsError<float>( + Operation, const BinaryInput<float> &, const BinaryOutput<float> &, + testutils::StreamWrapper &); +template void explainBinaryOperationTwoOutputsError<double>( + Operation, const BinaryInput<double> &, const BinaryOutput<double> &, + testutils::StreamWrapper &); +template void explainBinaryOperationTwoOutputsError<long double>( + Operation, const BinaryInput<long double> &, + const BinaryOutput<long double> &, testutils::StreamWrapper &); + +template <typename T> +bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult, + double ulpError) { // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result // is rounded to the nearest even. - MPFRNumber mpfrResult(op, input); + MPFRNumber mpfrResult = unaryOperation(op, input); double ulp = mpfrResult.ulp(libcResult); bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0); return (ulp < ulpError) || ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven)); } -template bool compare<float>(Operation, float, float, double); -template bool compare<double>(Operation, double, double, double); -template bool compare<long double>(Operation, long double, long double, double); +template bool compareUnaryOperationSingleOutput<float>(Operation, float, float, + double); +template bool compareUnaryOperationSingleOutput<double>(Operation, double, + double, double); +template bool compareUnaryOperationSingleOutput<long double>(Operation, + long double, + long double, + double); + +template <typename T> +bool compareUnaryOperationTwoOutputs(Operation op, T input, + const BinaryOutput<T> &libcResult, + double ulpError) { + int mpfrIntResult; + MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult); + double ulp = mpfrResult.ulp(libcResult.f); + + if (mpfrIntResult != libcResult.i) + return false; + + bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0); + return (ulp < ulpError) || + ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven)); +} + +template bool +compareUnaryOperationTwoOutputs<float>(Operation, float, + const BinaryOutput<float> &, double); +template bool +compareUnaryOperationTwoOutputs<double>(Operation, double, + const BinaryOutput<double> &, double); +template bool compareUnaryOperationTwoOutputs<long double>( + Operation, long double, const BinaryOutput<long double> &, double); + +template <typename T> +bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input, + const BinaryOutput<T> &libcResult, + double ulpError) { + int mpfrIntResult; + MPFRNumber mpfrResult = + binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult); + double ulp = mpfrResult.ulp(libcResult.f); + + if (mpfrIntResult != libcResult.i) { + if (op == Operation::RemQuo) { + if ((0x7 & mpfrIntResult) != libcResult.i) + return false; + } else { + return false; + } + } + + bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0); + return (ulp < ulpError) || + ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven)); +} + +template bool +compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &, + const BinaryOutput<float> &, double); +template bool +compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &, + const BinaryOutput<double> &, double); +template bool compareBinaryOperationTwoOutputs<long double>( + Operation, const BinaryInput<long double> &, + const BinaryOutput<long double> &, double); } // namespace internal diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h index 3d94079e65d8..b46f09dd5e55 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -19,6 +19,10 @@ namespace testing { namespace mpfr { enum class Operation : int { + // Operations with take a single floating point number as input + // and produce a single floating point number as output. The input + // and output floating point numbers are of the same kind. + BeginUnaryOperationsSingleOutput, Abs, Ceil, Cos, @@ -28,45 +32,193 @@ enum class Operation : int { Round, Sin, Sqrt, - Trunc + Trunc, + EndUnaryOperationsSingleOutput, + + // Operations which take a single floating point nubmer as input + // but produce two outputs. The first ouput is a floating point + // number of the same type as the input. The second output is of type + // 'int'. + BeginUnaryOperationsTwoOutputs, + Frexp, // Floating point output, the first output, is the fractional part. + EndUnaryOperationsTwoOutputs, + + // Operations wich take two floating point nubmers of the same type as + // input and produce a single floating point number of the same type as + // output. + BeginBinaryOperationsSingleOutput, + // TODO: Add operations like hypot. + EndBinaryOperationsSingleOutput, + + // Operations which take two floating point numbers of the same type as + // input and produce two outputs. The first output is a floating nubmer of + // the same type as the inputs. The second output is af type 'int'. + BeginBinaryOperationsTwoOutputs, + RemQuo, // The first output, the floating point output, is the remainder. + EndBinaryOperationsTwoOutputs, + + BeginTernaryOperationsSingleOuput, + // TODO: Add operations like fma. + EndTernaryOperationsSingleOutput, +}; + +template <typename T> struct BinaryInput { + static_assert( + __llvm_libc::cpp::IsFloatingPointType<T>::Value, + "Template parameter of BinaryInput must be a floating point type."); + + using Type = T; + T x, y; +}; + +template <typename T> struct TernaryInput { + static_assert( + __llvm_libc::cpp::IsFloatingPointType<T>::Value, + "Template parameter of TernaryInput must be a floating point type."); + + using Type = T; + T x, y, z; +}; + +template <typename T> struct BinaryOutput { + T f; + int i; }; namespace internal { +template <typename T1, typename T2> +struct AreMatchingBinaryInputAndBinaryOutput { + static constexpr bool value = false; +}; + template <typename T> -bool compare(Operation op, T input, T libcOutput, double t); +struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> { + static constexpr bool value = cpp::IsFloatingPointType<T>::Value; +}; -template <typename T> class MPFRMatcher : public testing::Matcher<T> { - static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value, - "MPFRMatcher can only be used with floating point values."); +template <typename T> +bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput, + double t); +template <typename T> +bool compareUnaryOperationTwoOutputs(Operation op, T input, + const BinaryOutput<T> &libcOutput, + double t); +template <typename T> +bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input, + const BinaryOutput<T> &libcOutput, + double t); - Operation operation; - T input; - T matchValue; +template <typename T> +void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue, + testutils::StreamWrapper &OS); +template <typename T> +void explainUnaryOperationTwoOutputsError(Operation op, T input, + const BinaryOutput<T> &matchValue, + testutils::StreamWrapper &OS); +template <typename T> +void explainBinaryOperationTwoOutputsError(Operation op, + const BinaryInput<T> &input, + const BinaryOutput<T> &matchValue, + testutils::StreamWrapper &OS); + +template <Operation op, typename InputType, typename OutputType> +class MPFRMatcher : public testing::Matcher<OutputType> { + InputType input; + OutputType matchValue; double ulpTolerance; public: - MPFRMatcher(Operation op, T testInput, double ulpTolerance) - : operation(op), input(testInput), ulpTolerance(ulpTolerance) {} + MPFRMatcher(InputType testInput, double ulpTolerance) + : input(testInput), ulpTolerance(ulpTolerance) {} - bool match(T libcResult) { + bool match(OutputType libcResult) { matchValue = libcResult; - return internal::compare(operation, input, libcResult, ulpTolerance); + return match(input, matchValue, ulpTolerance); } - void explainError(testutils::StreamWrapper &OS) override; + void explainError(testutils::StreamWrapper &OS) override { + explainError(input, matchValue, OS); + } + +private: + template <typename T> static bool match(T in, T out, double tolerance) { + return compareUnaryOperationSingleOutput(op, in, out, tolerance); + } + + template <typename T> + static bool match(T in, const BinaryOutput<T> &out, double tolerance) { + return compareUnaryOperationTwoOutputs(op, in, out, tolerance); + } + + template <typename T> + static bool match(const BinaryInput<T> &in, T out, double tolerance) { + // TODO: Implement the comparision function and error reporter. + } + + template <typename T> + static bool match(BinaryInput<T> in, const BinaryOutput<T> &out, + double tolerance) { + return compareBinaryOperationTwoOutputs(op, in, out, tolerance); + } + + template <typename T> + static bool match(const TernaryInput<T> &in, T out, double tolerance) { + // TODO: Implement the comparision function and error reporter. + } + + template <typename T> + static void explainError(T in, T out, testutils::StreamWrapper &OS) { + explainUnaryOperationSingleOutputError(op, in, out, OS); + } + + template <typename T> + static void explainError(T in, const BinaryOutput<T> &out, + testutils::StreamWrapper &OS) { + explainUnaryOperationTwoOutputsError(op, in, out, OS); + } + + template <typename T> + static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out, + testutils::StreamWrapper &OS) { + explainBinaryOperationTwoOutputsError(op, in, out, OS); + } }; } // namespace internal -template <typename T, typename U> +// Return true if the input and ouput types for the operation op are valid +// types. +template <Operation op, typename InputType, typename OutputType> +constexpr bool isValidOperation() { + return (Operation::BeginUnaryOperationsSingleOutput < op && + op < Operation::EndUnaryOperationsSingleOutput && + cpp::IsSame<InputType, OutputType>::Value && + cpp::IsFloatingPointType<InputType>::Value) || + (Operation::BeginUnaryOperationsTwoOutputs < op && + op < Operation::EndUnaryOperationsTwoOutputs && + cpp::IsFloatingPointType<InputType>::Value && + cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) || + (Operation::BeginBinaryOperationsSingleOutput < op && + op < Operation::EndBinaryOperationsSingleOutput && + cpp::IsFloatingPointType<OutputType>::Value && + cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) || + (Operation::BeginBinaryOperationsTwoOutputs < op && + op < Operation::EndBinaryOperationsTwoOutputs && + internal::AreMatchingBinaryInputAndBinaryOutput<InputType, + OutputType>::value) || + (Operation::BeginTernaryOperationsSingleOuput < op && + op < Operation::EndTernaryOperationsSingleOutput && + cpp::IsFloatingPointType<OutputType>::Value && + cpp::IsSame<InputType, TernaryInput<OutputType>>::Value); +} + +template <Operation op, typename InputType, typename OutputType> __attribute__((no_sanitize("address"))) -typename cpp::EnableIfType<cpp::IsSameV<U, double>, internal::MPFRMatcher<T>> -getMPFRMatcher(Operation op, T input, U t) { - static_assert( - __llvm_libc::cpp::IsFloatingPointType<T>::Value, - "getMPFRMatcher can only be used to match floating point results."); - return internal::MPFRMatcher<T>(op, input, t); +cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(), + internal::MPFRMatcher<op, InputType, OutputType>> +getMPFRMatcher(InputType input, OutputType outputUnused, double t) { + return internal::MPFRMatcher<op, InputType, OutputType>(input, t); } } // namespace mpfr @@ -74,11 +226,11 @@ getMPFRMatcher(Operation op, T input, U t) { } // namespace __llvm_libc #define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \ - EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \ - op, input, tolerance)) + EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \ + input, matchValue, tolerance)) #define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \ - ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \ - op, input, tolerance)) + ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \ + input, matchValue, tolerance)) #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H |