diff options
author | Yi Kong <yikong@google.com> | 2022-02-25 16:41:05 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2022-02-25 16:41:05 +0000 |
commit | bc0f5df265caa21a2120c22453655a7fcc941991 (patch) | |
tree | fb979fb4cf4f8052c8cc66b1ec9516d91fcd859b /unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h | |
parent | 8fd413e275f78a4c240f1442ce5cf77c73a20a55 (diff) | |
parent | 7cb50013986f04dce5fac87bebf319bb8db37a36 (diff) | |
download | eigen-android13-d4-release.tar.gz |
Merge changes Iee153445,Iee274471 am: 79df15ea88 am: 10f298fc41 am: 7cb5001398t_frc_odp_330442040t_frc_odp_330442000t_frc_ase_330444010android-wear-13.0.0-gpl_r3android-wear-13.0.0-gpl_r2android-wear-13.0.0-gpl_r1android-vts-13.0_r8android-vts-13.0_r7android-vts-13.0_r6android-vts-13.0_r5android-vts-13.0_r4android-vts-13.0_r3android-vts-13.0_r2android-t-qpr3-beta-3-gplandroid-t-qpr3-beta-1-gplandroid-t-qpr2-beta-3-gplandroid-t-qpr2-beta-2-gplandroid-t-qpr1-beta-3-gplandroid-t-qpr1-beta-1-gplandroid-cts-13.0_r8android-cts-13.0_r7android-cts-13.0_r6android-cts-13.0_r5android-cts-13.0_r4android-cts-13.0_r3android-cts-13.0_r2android-13.0.0_r83android-13.0.0_r82android-13.0.0_r81android-13.0.0_r80android-13.0.0_r79android-13.0.0_r78android-13.0.0_r77android-13.0.0_r76android-13.0.0_r75android-13.0.0_r74android-13.0.0_r73android-13.0.0_r72android-13.0.0_r71android-13.0.0_r70android-13.0.0_r69android-13.0.0_r68android-13.0.0_r67android-13.0.0_r66android-13.0.0_r65android-13.0.0_r64android-13.0.0_r63android-13.0.0_r62android-13.0.0_r61android-13.0.0_r60android-13.0.0_r59android-13.0.0_r58android-13.0.0_r57android-13.0.0_r56android-13.0.0_r55android-13.0.0_r54android-13.0.0_r53android-13.0.0_r52android-13.0.0_r51android-13.0.0_r50android-13.0.0_r49android-13.0.0_r48android-13.0.0_r47android-13.0.0_r46android-13.0.0_r45android-13.0.0_r44android-13.0.0_r43android-13.0.0_r42android-13.0.0_r41android-13.0.0_r40android-13.0.0_r39android-13.0.0_r38android-13.0.0_r37android-13.0.0_r36android-13.0.0_r35android-13.0.0_r34android-13.0.0_r33android-13.0.0_r32android-13.0.0_r30android-13.0.0_r29android-13.0.0_r28android-13.0.0_r27android-13.0.0_r24android-13.0.0_r23android-13.0.0_r22android-13.0.0_r21android-13.0.0_r20android-13.0.0_r19android-13.0.0_r18android-13.0.0_r17android-13.0.0_r16aml_go_odp_330912000aml_go_ads_330915100aml_go_ads_330915000aml_go_ads_330913000android13-tests-releaseandroid13-tests-devandroid13-qpr3-s9-releaseandroid13-qpr3-s8-releaseandroid13-qpr3-s7-releaseandroid13-qpr3-s6-releaseandroid13-qpr3-s5-releaseandroid13-qpr3-s4-releaseandroid13-qpr3-s3-releaseandroid13-qpr3-s2-releaseandroid13-qpr3-s14-releaseandroid13-qpr3-s13-releaseandroid13-qpr3-s12-releaseandroid13-qpr3-s11-releaseandroid13-qpr3-s10-releaseandroid13-qpr3-s1-releaseandroid13-qpr3-releaseandroid13-qpr3-c-s8-releaseandroid13-qpr3-c-s7-releaseandroid13-qpr3-c-s6-releaseandroid13-qpr3-c-s5-releaseandroid13-qpr3-c-s4-releaseandroid13-qpr3-c-s3-releaseandroid13-qpr3-c-s2-releaseandroid13-qpr3-c-s12-releaseandroid13-qpr3-c-s11-releaseandroid13-qpr3-c-s10-releaseandroid13-qpr3-c-s1-releaseandroid13-qpr2-s9-releaseandroid13-qpr2-s8-releaseandroid13-qpr2-s7-releaseandroid13-qpr2-s6-releaseandroid13-qpr2-s5-releaseandroid13-qpr2-s3-releaseandroid13-qpr2-s2-releaseandroid13-qpr2-s12-releaseandroid13-qpr2-s11-releaseandroid13-qpr2-s10-releaseandroid13-qpr2-s1-releaseandroid13-qpr2-releaseandroid13-qpr2-b-s1-releaseandroid13-qpr1-s8-releaseandroid13-qpr1-s7-releaseandroid13-qpr1-s6-releaseandroid13-qpr1-s5-releaseandroid13-qpr1-s4-releaseandroid13-qpr1-s3-releaseandroid13-qpr1-s2-releaseandroid13-qpr1-s1-releaseandroid13-qpr1-releaseandroid13-mainline-go-adservices-releaseandroid13-frc-odp-releaseandroid13-devandroid13-d4-s2-releaseandroid13-d4-s1-releaseandroid13-d4-releaseandroid13-d3-s1-releaseandroid13-d2-releaseandroid-wear-13.0.0-gpl_r1
Original change: https://android-review.googlesource.com/c/platform/external/eigen/+/1999079
Change-Id: I4c76dc5ddc7fb0ae9fc42436f28bd8bf9de50a97
Diffstat (limited to 'unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h')
-rw-r--r-- | unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h | 140 |
1 files changed, 117 insertions, 23 deletions
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h index d8f2363be..abefe99b7 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h @@ -41,6 +41,60 @@ struct functor_traits<scalar_igamma_op<Scalar> > { }; }; +/** \internal + * \brief Template functor to compute the derivative of the incomplete gamma + * function igamma_der_a(a, x) + * + * \sa class CwiseBinaryOp, Cwise::igamma_der_a + */ +template <typename Scalar> +struct scalar_igamma_der_a_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_der_a_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const { + using numext::igamma_der_a; + return igamma_der_a(a, x); + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const { + return internal::pigamma_der_a(a, x); + } +}; +template <typename Scalar> +struct functor_traits<scalar_igamma_der_a_op<Scalar> > { + enum { + // 2x the cost of igamma + Cost = 40 * NumTraits<Scalar>::MulCost + 20 * NumTraits<Scalar>::AddCost, + PacketAccess = packet_traits<Scalar>::HasIGammaDerA + }; +}; + +/** \internal + * \brief Template functor to compute the derivative of the sample + * of a Gamma(alpha, 1) random variable with respect to the parameter alpha + * gamma_sample_der_alpha(alpha, sample) + * + * \sa class CwiseBinaryOp, Cwise::gamma_sample_der_alpha + */ +template <typename Scalar> +struct scalar_gamma_sample_der_alpha_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_gamma_sample_der_alpha_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const { + using numext::gamma_sample_der_alpha; + return gamma_sample_der_alpha(alpha, sample); + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const { + return internal::pgamma_sample_der_alpha(alpha, sample); + } +}; +template <typename Scalar> +struct functor_traits<scalar_gamma_sample_der_alpha_op<Scalar> > { + enum { + // 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out) + Cost = 30 * NumTraits<Scalar>::MulCost + 15 * NumTraits<Scalar>::AddCost, + PacketAccess = packet_traits<Scalar>::HasGammaSampleDerAlpha + }; +}; /** \internal * \brief Template functor to compute the complementary incomplete gamma function igammac(a, x) @@ -101,11 +155,11 @@ struct functor_traits<scalar_betainc_op<Scalar> > { */ template<typename Scalar> struct scalar_lgamma_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_lgamma_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { using numext::lgamma; return lgamma(a); } typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plgamma(a); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::plgamma(a); } }; template<typename Scalar> struct functor_traits<scalar_lgamma_op<Scalar> > @@ -123,11 +177,11 @@ struct functor_traits<scalar_lgamma_op<Scalar> > */ template<typename Scalar> struct scalar_digamma_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_digamma_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { using numext::digamma; return digamma(a); } typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pdigamma(a); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pdigamma(a); } }; template<typename Scalar> struct functor_traits<scalar_digamma_op<Scalar> > @@ -145,11 +199,11 @@ struct functor_traits<scalar_digamma_op<Scalar> > */ template<typename Scalar> struct scalar_zeta_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_zeta_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& x, const Scalar& q) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& q) const { using numext::zeta; return zeta(x, q); } typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& x, const Packet& q) const { return internal::pzeta(x, q); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, const Packet& q) const { return internal::pzeta(x, q); } }; template<typename Scalar> struct functor_traits<scalar_zeta_op<Scalar> > @@ -167,11 +221,11 @@ struct functor_traits<scalar_zeta_op<Scalar> > */ template<typename Scalar> struct scalar_polygamma_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_polygamma_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& n, const Scalar& x) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& n, const Scalar& x) const { using numext::polygamma; return polygamma(n, x); } typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); } }; template<typename Scalar> struct functor_traits<scalar_polygamma_op<Scalar> > @@ -184,25 +238,40 @@ struct functor_traits<scalar_polygamma_op<Scalar> > }; /** \internal - * \brief Template functor to compute the Gauss error function of a - * scalar - * \sa class CwiseUnaryOp, Cwise::erf() + * \brief Template functor to compute the error function of a scalar + * \sa class CwiseUnaryOp, ArrayBase::erf() */ template<typename Scalar> struct scalar_erf_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_erf_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { - using numext::erf; return erf(a); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& a) const { + return numext::erf(a); + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + return perf(x); } - typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::perf(a); } }; -template<typename Scalar> -struct functor_traits<scalar_erf_op<Scalar> > -{ +template <typename Scalar> +struct functor_traits<scalar_erf_op<Scalar> > { enum { - // Guesstimate - Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost, - PacketAccess = packet_traits<Scalar>::HasErf + PacketAccess = packet_traits<Scalar>::HasErf, + Cost = + (PacketAccess +#ifdef EIGEN_VECTORIZE_FMA + // TODO(rmlarsen): Move the FMA cost model to a central location. + // Haswell can issue 2 add/mul/madd per cycle. + // 10 pmadd, 2 pmul, 1 div, 2 other + ? (2 * NumTraits<Scalar>::AddCost + + 7 * NumTraits<Scalar>::MulCost + + scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value) +#else + ? (12 * NumTraits<Scalar>::AddCost + + 12 * NumTraits<Scalar>::MulCost + + scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value) +#endif + // Assume for simplicity that this is as expensive as an exp(). + : (functor_traits<scalar_exp_op<Scalar> >::Cost)) }; }; @@ -213,11 +282,11 @@ struct functor_traits<scalar_erf_op<Scalar> > */ template<typename Scalar> struct scalar_erfc_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_erfc_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { using numext::erfc; return erfc(a); } typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::perfc(a); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::perfc(a); } }; template<typename Scalar> struct functor_traits<scalar_erfc_op<Scalar> > @@ -229,6 +298,31 @@ struct functor_traits<scalar_erfc_op<Scalar> > }; }; +/** \internal + * \brief Template functor to compute the Inverse of the normal distribution + * function of a scalar + * \sa class CwiseUnaryOp, Cwise::ndtri() + */ +template<typename Scalar> struct scalar_ndtri_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_ndtri_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { + using numext::ndtri; return ndtri(a); + } + typedef typename packet_traits<Scalar>::type Packet; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pndtri(a); } +}; +template<typename Scalar> +struct functor_traits<scalar_ndtri_op<Scalar> > +{ + enum { + // On average, We are evaluating rational functions with degree N=9 in the + // numerator and denominator. This results in 2*N additions and 2*N + // multiplications. + Cost = 18 * NumTraits<Scalar>::MulCost + 18 * NumTraits<Scalar>::AddCost, + PacketAccess = packet_traits<Scalar>::HasNdtri + }; +}; + } // end namespace internal } // end namespace Eigen |