diff options
Diffstat (limited to 'absl/random/bit_gen_ref.h')
-rw-r--r-- | absl/random/bit_gen_ref.h | 108 |
1 files changed, 40 insertions, 68 deletions
diff --git a/absl/random/bit_gen_ref.h b/absl/random/bit_gen_ref.h index 9555460f..e8771162 100644 --- a/absl/random/bit_gen_ref.h +++ b/absl/random/bit_gen_ref.h @@ -24,11 +24,11 @@ #ifndef ABSL_RANDOM_BIT_GEN_REF_H_ #define ABSL_RANDOM_BIT_GEN_REF_H_ -#include "absl/base/internal/fast_type_id.h" #include "absl/base/macros.h" #include "absl/meta/type_traits.h" #include "absl/random/internal/distribution_caller.h" #include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/mocking_bit_gen_base.h" namespace absl { ABSL_NAMESPACE_BEGIN @@ -51,10 +51,6 @@ struct is_urbg< typename std::decay<decltype(std::declval<URBG>()())>::type>::value>> : std::true_type {}; -template <typename> -struct DistributionCaller; -class MockHelpers; - } // namespace random_internal // ----------------------------------------------------------------------------- @@ -81,50 +77,23 @@ class MockHelpers; // } // class BitGenRef { - // SFINAE to detect whether the URBG type includes a member matching - // bool InvokeMock(base_internal::FastTypeIdType, void*, void*). - // - // These live inside BitGenRef so that they have friend access - // to MockingBitGen. (see similar methods in DistributionCaller). - template <template <class...> class Trait, class AlwaysVoid, class... Args> - struct detector : std::false_type {}; - template <template <class...> class Trait, class... Args> - struct detector<Trait, absl::void_t<Trait<Args...>>, Args...> - : std::true_type {}; - - template <class T> - using invoke_mock_t = decltype(std::declval<T*>()->InvokeMock( - std::declval<base_internal::FastTypeIdType>(), std::declval<void*>(), - std::declval<void*>())); - - template <typename T> - using HasInvokeMock = typename detector<invoke_mock_t, void, T>::type; - public: - BitGenRef(const BitGenRef&) = default; - BitGenRef(BitGenRef&&) = default; - BitGenRef& operator=(const BitGenRef&) = default; - BitGenRef& operator=(BitGenRef&&) = default; - - template <typename URBG, typename absl::enable_if_t< - (!std::is_same<URBG, BitGenRef>::value && - random_internal::is_urbg<URBG>::value && - !HasInvokeMock<URBG>::value)>* = nullptr> - BitGenRef(URBG& gen) // NOLINT - : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), - mock_call_(NotAMock), - generate_impl_fn_(ImplFn<URBG>) {} + using result_type = uint64_t; + + BitGenRef(const absl::BitGenRef&) = default; + BitGenRef(absl::BitGenRef&&) = default; + BitGenRef& operator=(const absl::BitGenRef&) = default; + BitGenRef& operator=(absl::BitGenRef&&) = default; template <typename URBG, - typename absl::enable_if_t<(!std::is_same<URBG, BitGenRef>::value && - random_internal::is_urbg<URBG>::value && - HasInvokeMock<URBG>::value)>* = nullptr> + typename absl::enable_if_t< + (!std::is_same<URBG, BitGenRef>::value && + random_internal::is_urbg<URBG>::value)>* = nullptr> BitGenRef(URBG& gen) // NOLINT - : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), - mock_call_(&MockCall<URBG>), - generate_impl_fn_(ImplFn<URBG>) {} - - using result_type = uint64_t; + : mocked_gen_ptr_(MakeMockPointer(&gen)), + t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)), + generate_impl_fn_(ImplFn<URBG>) { + } static constexpr result_type(min)() { return (std::numeric_limits<result_type>::min)(); @@ -137,9 +106,14 @@ class BitGenRef { result_type operator()() { return generate_impl_fn_(t_erased_gen_ptr_); } private: + friend struct absl::random_internal::DistributionCaller<absl::BitGenRef>; using impl_fn = result_type (*)(uintptr_t); - using mock_call_fn = bool (*)(uintptr_t, base_internal::FastTypeIdType, void*, - void*); + using mocker_base_t = absl::random_internal::MockingBitGenBase; + + // Convert an arbitrary URBG pointer into either a valid mocker_base_t + // pointer or a nullptr. + static inline mocker_base_t* MakeMockPointer(mocker_base_t* t) { return t; } + static inline mocker_base_t* MakeMockPointer(void*) { return nullptr; } template <typename URBG> static result_type ImplFn(uintptr_t ptr) { @@ -149,32 +123,30 @@ class BitGenRef { return fast_uniform_bits(*reinterpret_cast<URBG*>(ptr)); } - // Get a type-erased InvokeMock pointer. - template <typename URBG> - static bool MockCall(uintptr_t gen_ptr, base_internal::FastTypeIdType type, - void* result, void* arg_tuple) { - return reinterpret_cast<URBG*>(gen_ptr)->InvokeMock(type, result, - arg_tuple); - } - static bool NotAMock(uintptr_t, base_internal::FastTypeIdType, void*, void*) { - return false; - } - - inline bool InvokeMock(base_internal::FastTypeIdType type, void* args_tuple, - void* result) { - if (mock_call_ == NotAMock) return false; // avoids an indirect call. - return mock_call_(t_erased_gen_ptr_, type, args_tuple, result); - } - + mocker_base_t* mocked_gen_ptr_; uintptr_t t_erased_gen_ptr_; - mock_call_fn mock_call_; impl_fn generate_impl_fn_; +}; + +namespace random_internal { - template <typename> - friend struct ::absl::random_internal::DistributionCaller; // for InvokeMock - friend class ::absl::random_internal::MockHelpers; // for InvokeMock +template <> +struct DistributionCaller<absl::BitGenRef> { + template <typename DistrT, typename FormatT, typename... Args> + static typename DistrT::result_type Call(absl::BitGenRef* gen_ref, + Args&&... args) { + auto* mock_ptr = gen_ref->mocked_gen_ptr_; + if (mock_ptr == nullptr) { + DistrT dist(std::forward<Args>(args)...); + return dist(*gen_ref); + } else { + return mock_ptr->template Call<DistrT, FormatT>( + std::forward<Args>(args)...); + } + } }; +} // namespace random_internal ABSL_NAMESPACE_END } // namespace absl |