diff options
author | Jeff Vander Stoep <jeffv@google.com> | 2023-02-03 11:31:30 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-02-03 11:31:30 +0000 |
commit | 170fa82e98838b99e720386b52f725d3a500b19a (patch) | |
tree | 3e81bb8b1cdf5eccc346a10d1f1a3ae8cdcc787a | |
parent | 6e1f8fc5fb5533109ff4bc245d3c74af69241fba (diff) | |
parent | c451789e02a0d2c9f2ec4df721ddd8b2ababa345 (diff) | |
download | half-170fa82e98838b99e720386b52f725d3a500b19a.tar.gz |
Upgrade half to 2.2.1 am: dc2e05e117 am: 75287a0360 am: c451789e02
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/half/+/2419359
Change-Id: Iecc1d694a24dfb561f460aa9bddde8a6a2a3db59
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r-- | .cargo_vcs_info.json | 2 | ||||
-rw-r--r-- | Android.bp | 2 | ||||
-rw-r--r-- | CHANGELOG.md | 24 | ||||
-rw-r--r-- | Cargo.toml | 5 | ||||
-rw-r--r-- | Cargo.toml.orig | 4 | ||||
-rw-r--r-- | METADATA | 10 | ||||
-rw-r--r-- | src/bfloat.rs | 109 | ||||
-rw-r--r-- | src/bfloat/convert.rs | 4 | ||||
-rw-r--r-- | src/binary16.rs | 109 | ||||
-rw-r--r-- | src/binary16/convert.rs | 333 | ||||
-rw-r--r-- | src/lib.rs | 18 | ||||
-rw-r--r-- | src/slice.rs | 76 |
12 files changed, 567 insertions, 129 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index c58fed4..31d400d 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,6 +1,6 @@ { "git": { - "sha1": "ebc0b1fb4c8a21fc15b99be3d44c0ad4621cdd93" + "sha1": "4c56dab175e41faaa367bae9eba82c555f002859" }, "path_in_vcs": "" }
\ No newline at end of file @@ -42,7 +42,7 @@ rust_library { host_supported: true, crate_name: "half", cargo_env_compat: true, - cargo_pkg_version: "2.1.0", + cargo_pkg_version: "2.2.1", srcs: ["src/lib.rs"], edition: "2021", features: [ diff --git a/CHANGELOG.md b/CHANGELOG.md index ba163b9..2e72b9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,23 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +## [2.2.1] - 2023-01-08 <a name="2.2.1"></a> +### Changed +- Reduced unnecessary bounds checks for SIMD operations on slices. By [@Shnatsel]. +- Further slice conversion optimizations for slices. Resolves [#66]. + +## [2.2.0] - 2022-12-30 <a name="2.2.0"></a> +### Added +- Add `serialize_as_f32` and `serialize_as_string` functions when `serde` cargo feature is enabled. + They allowing customizing the serialization by using + `#[serde(serialize_with="f16::serialize_as_f32")]` attribute in serde derive macros. Closes [#60]. +- Deserialize now supports deserializing from `f32`, `f64`, and string values in addition to its + previous default deserialization. Closes [#60]. + +### Changed +- Add `#[inline]` on fallback functions, which improved conversion execution on non-nightly rust + by up to 50%. By [@Shnatsel]. + ## [2.1.0] - 2022-07-18 <a name="2.1.0"></a> ### Added - Add support for target_arch `spirv`. Some traits and functions are unavailble on this @@ -257,6 +274,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. [#37]: https://github.com/starkat99/half-rs/issues/37 [#48]: https://github.com/starkat99/half-rs/issues/48 [#55]: https://github.com/starkat99/half-rs/issues/55 +[#60]: https://github.com/starkat99/half-rs/issues/60 +[#66]: https://github.com/starkat99/half-rs/issues/66 [@tspiteri]: https://github.com/tspiteri [@PSeitz]: https://github.com/PSeitz @@ -271,9 +290,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. [@kali]: https://github.com/kali [@Nilstrieb]: https://github.com/Nilstrieb [@joseluis]: https://github.com/joseluis +[@Shnatsel]: https://github.com/Shnatsel -[Unreleased]: https://github.com/starkat99/half-rs/compare/v2.1.0...HEAD +[Unreleased]: https://github.com/starkat99/half-rs/compare/v2.2.1...HEAD +[2.2.1]: https://github.com/starkat99/half-rs/compare/v2.2.0...v2.2.1 +[2.2.0]: https://github.com/starkat99/half-rs/compare/v2.1.0...v2.2.0 [2.1.0]: https://github.com/starkat99/half-rs/compare/v2.0.0...v2.1.0 [2.0.0]: https://github.com/starkat99/half-rs/compare/v1.8.2...v2.0.0 [1.8.2]: https://github.com/starkat99/half-rs/compare/v1.8.1...v1.8.2 @@ -13,7 +13,7 @@ edition = "2021" rust-version = "1.58" name = "half" -version = "2.1.0" +version = "2.2.1" authors = ["Kathryn Long <squeeself@gmail.com>"] exclude = [ ".git*", @@ -33,7 +33,6 @@ categories = [ ] license = "MIT OR Apache-2.0" repository = "https://github.com/starkat99/half-rs" -resolver = "2" [package.metadata.docs.rs] rustc-args = [ @@ -76,7 +75,7 @@ optional = true default-features = false [dev-dependencies.criterion] -version = "0.3.5" +version = "0.4.0" [dev-dependencies.crunchy] version = "0.2.2" diff --git a/Cargo.toml.orig b/Cargo.toml.orig index afd8fcc..e26c032 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -1,7 +1,7 @@ [package] name = "half" # Remember to keep in sync with html_root_url crate attribute -version = "2.1.0" +version = "2.2.1" authors = ["Kathryn Long <squeeself@gmail.com>"] description = "Half-precision floating point f16 and bf16 types for Rust implementing the IEEE 754-2008 standard binary16 and bfloat16 types." repository = "https://github.com/starkat99/half-rs" @@ -33,7 +33,7 @@ zerocopy = { version = "0.6.0", default-features = false, optional = true } crunchy = "0.2.2" [dev-dependencies] -criterion = "0.3.5" +criterion = "0.4.0" quickcheck = "1.0" quickcheck_macros = "1.0" rand = "0.8.4" @@ -11,13 +11,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/half/half-2.1.0.crate" + value: "https://static.crates.io/crates/half/half-2.2.1.crate" } - version: "2.1.0" + version: "2.2.1" license_type: NOTICE last_upgrade_date { - year: 2022 - month: 12 - day: 12 + year: 2023 + month: 2 + day: 2 } } diff --git a/src/bfloat.rs b/src/bfloat.rs index bf4f2b7..8b23863 100644 --- a/src/bfloat.rs +++ b/src/bfloat.rs @@ -36,7 +36,7 @@ pub(crate) mod convert; #[allow(non_camel_case_types)] #[derive(Clone, Copy, Default)] #[repr(transparent)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize))] #[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))] #[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))] pub struct bf16(u16); @@ -642,6 +642,61 @@ impl bf16 { left.cmp(&right) } + /// Alternate serialize adapter for serializing as a float. + /// + /// By default, [`bf16`] serializes as a newtype of [`u16`]. This is an alternate serialize + /// implementation that serializes as an [`f32`] value. It is designed for use with + /// `serialize_with` serde attributes. Deserialization from `f32` values is already supported by + /// the default deserialize implementation. + /// + /// # Examples + /// + /// A demonstration on how to use this adapater: + /// + /// ``` + /// use serde::{Serialize, Deserialize}; + /// use half::bf16; + /// + /// #[derive(Serialize, Deserialize)] + /// struct MyStruct { + /// #[serde(serialize_with = "bf16::serialize_as_f32")] + /// value: bf16 // Will be serialized as f32 instead of u16 + /// } + /// ``` + #[cfg(feature = "serde")] + pub fn serialize_as_f32<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_f32(self.to_f32()) + } + + /// Alternate serialize adapter for serializing as a string. + /// + /// By default, [`bf16`] serializes as a newtype of [`u16`]. This is an alternate serialize + /// implementation that serializes as a string value. It is designed for use with + /// `serialize_with` serde attributes. Deserialization from string values is already supported + /// by the default deserialize implementation. + /// + /// # Examples + /// + /// A demonstration on how to use this adapater: + /// + /// ``` + /// use serde::{Serialize, Deserialize}; + /// use half::bf16; + /// + /// #[derive(Serialize, Deserialize)] + /// struct MyStruct { + /// #[serde(serialize_with = "bf16::serialize_as_string")] + /// value: bf16 // Will be serialized as a string instead of u16 + /// } + /// ``` + #[cfg(feature = "serde")] + pub fn serialize_as_string<S: serde::Serializer>( + &self, + serializer: S, + ) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } + /// Approximate number of [`bf16`] significant digits in base 10 pub const DIGITS: u32 = 2; /// [`bf16`] @@ -1209,6 +1264,58 @@ impl<'a> Sum<&'a bf16> for bf16 { } } +#[cfg(feature = "serde")] +struct Visitor; + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for bf16 { + fn deserialize<D>(deserializer: D) -> Result<bf16, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("bf16", Visitor) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = bf16; + + fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result { + write!(formatter, "tuple struct bf16") + } + + fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error> + where + D: serde::Deserializer<'de>, + { + Ok(bf16(<u16 as Deserialize>::deserialize(deserializer)?)) + } + + fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + v.parse().map_err(|_| { + serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"a float string") + }) + } + + fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(bf16::from_f32(v)) + } + + fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(bf16::from_f64(v)) + } +} + #[allow( clippy::cognitive_complexity, clippy::float_cmp, diff --git a/src/bfloat/convert.rs b/src/bfloat/convert.rs index b612a9b..8f258f5 100644 --- a/src/bfloat/convert.rs +++ b/src/bfloat/convert.rs @@ -1,6 +1,7 @@ use crate::leading_zeros::leading_zeros_u16; use core::mem; +#[inline] pub(crate) const fn f32_to_bf16(value: f32) -> u16 { // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized // Convert to raw bytes @@ -21,6 +22,7 @@ pub(crate) const fn f32_to_bf16(value: f32) -> u16 { } } +#[inline] pub(crate) const fn f64_to_bf16(value: f64) -> u16 { // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always @@ -88,6 +90,7 @@ pub(crate) const fn f64_to_bf16(value: f64) -> u16 { } } +#[inline] pub(crate) const fn bf16_to_f32(i: u16) -> f32 { // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized // If NaN, keep current mantissa but also set most significiant mantissa bit @@ -98,6 +101,7 @@ pub(crate) const fn bf16_to_f32(i: u16) -> f32 { } } +#[inline] pub(crate) const fn bf16_to_f64(i: u16) -> f64 { // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized // Check for signed zero diff --git a/src/binary16.rs b/src/binary16.rs index 08dbf04..b622f01 100644 --- a/src/binary16.rs +++ b/src/binary16.rs @@ -34,7 +34,7 @@ pub(crate) mod convert; #[allow(non_camel_case_types)] #[derive(Clone, Copy, Default)] #[repr(transparent)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize))] #[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))] #[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))] pub struct f16(u16); @@ -651,6 +651,61 @@ impl f16 { left.cmp(&right) } + /// Alternate serialize adapter for serializing as a float. + /// + /// By default, [`f16`] serializes as a newtype of [`u16`]. This is an alternate serialize + /// implementation that serializes as an [`f32`] value. It is designed for use with + /// `serialize_with` serde attributes. Deserialization from `f32` values is already supported by + /// the default deserialize implementation. + /// + /// # Examples + /// + /// A demonstration on how to use this adapater: + /// + /// ``` + /// use serde::{Serialize, Deserialize}; + /// use half::f16; + /// + /// #[derive(Serialize, Deserialize)] + /// struct MyStruct { + /// #[serde(serialize_with = "f16::serialize_as_f32")] + /// value: f16 // Will be serialized as f32 instead of u16 + /// } + /// ``` + #[cfg(feature = "serde")] + pub fn serialize_as_f32<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_f32(self.to_f32()) + } + + /// Alternate serialize adapter for serializing as a string. + /// + /// By default, [`f16`] serializes as a newtype of [`u16`]. This is an alternate serialize + /// implementation that serializes as a string value. It is designed for use with + /// `serialize_with` serde attributes. Deserialization from string values is already supported + /// by the default deserialize implementation. + /// + /// # Examples + /// + /// A demonstration on how to use this adapater: + /// + /// ``` + /// use serde::{Serialize, Deserialize}; + /// use half::f16; + /// + /// #[derive(Serialize, Deserialize)] + /// struct MyStruct { + /// #[serde(serialize_with = "f16::serialize_as_string")] + /// value: f16 // Will be serialized as a string instead of u16 + /// } + /// ``` + #[cfg(feature = "serde")] + pub fn serialize_as_string<S: serde::Serializer>( + &self, + serializer: S, + ) -> Result<S::Ok, S::Error> { + serializer.serialize_str(&self.to_string()) + } + /// Approximate number of [`f16`] significant digits in base 10 pub const DIGITS: u32 = 3; /// [`f16`] @@ -1224,6 +1279,58 @@ impl<'a> Sum<&'a f16> for f16 { } } +#[cfg(feature = "serde")] +struct Visitor; + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for f16 { + fn deserialize<D>(deserializer: D) -> Result<f16, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("f16", Visitor) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = f16; + + fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result { + write!(formatter, "tuple struct f16") + } + + fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error> + where + D: serde::Deserializer<'de>, + { + Ok(f16(<u16 as Deserialize>::deserialize(deserializer)?)) + } + + fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + v.parse().map_err(|_| { + serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"a float string") + }) + } + + fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(f16::from_f32(v)) + } + + fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(f16::from_f64(v)) + } +} + #[allow( clippy::cognitive_complexity, clippy::float_cmp, diff --git a/src/binary16/convert.rs b/src/binary16/convert.rs index dc1772a..b96910f 100644 --- a/src/binary16/convert.rs +++ b/src/binary16/convert.rs @@ -3,11 +3,11 @@ use crate::leading_zeros::leading_zeros_u16; use core::mem; macro_rules! convert_fn { - (fn $name:ident($var:ident : $vartype:ty) -> $restype:ty { + (fn $name:ident($($var:ident : $vartype:ty),+) -> $restype:ty { if feature("f16c") { $f16c:expr } else { $fallback:expr }}) => { #[inline] - pub(crate) fn $name($var: $vartype) -> $restype { + pub(crate) fn $name($($var: $vartype),+) -> $restype { // Use CPU feature detection if using std #[cfg(all( feature = "use-intrinsics", @@ -84,11 +84,8 @@ convert_fn! { } } -// TODO: While SIMD versions are faster, further improvements can be made by doing runtime feature -// detection once at beginning of convert slice method, rather than per chunk - convert_fn! { - fn f32x4_to_f16x4(f: &[f32]) -> [u16; 4] { + fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] { if feature("f16c") { unsafe { x86::f32x4_to_f16x4_x86_f16c(f) } } else { @@ -98,7 +95,7 @@ convert_fn! { } convert_fn! { - fn f16x4_to_f32x4(i: &[u16]) -> [f32; 4] { + fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] { if feature("f16c") { unsafe { x86::f16x4_to_f32x4_x86_f16c(i) } } else { @@ -108,7 +105,7 @@ convert_fn! { } convert_fn! { - fn f64x4_to_f16x4(f: &[f64]) -> [u16; 4] { + fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] { if feature("f16c") { unsafe { x86::f64x4_to_f16x4_x86_f16c(f) } } else { @@ -118,7 +115,7 @@ convert_fn! { } convert_fn! { - fn f16x4_to_f64x4(i: &[u16]) -> [f64; 4] { + fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] { if feature("f16c") { unsafe { x86::f16x4_to_f64x4_x86_f16c(i) } } else { @@ -127,6 +124,155 @@ convert_fn! { } } +convert_fn! { + fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] { + if feature("f16c") { + unsafe { x86::f32x8_to_f16x8_x86_f16c(f) } + } else { + f32x8_to_f16x8_fallback(f) + } + } +} + +convert_fn! { + fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] { + if feature("f16c") { + unsafe { x86::f16x8_to_f32x8_x86_f16c(i) } + } else { + f16x8_to_f32x8_fallback(i) + } + } +} + +convert_fn! { + fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] { + if feature("f16c") { + unsafe { x86::f64x8_to_f16x8_x86_f16c(f) } + } else { + f64x8_to_f16x8_fallback(f) + } + } +} + +convert_fn! { + fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] { + if feature("f16c") { + unsafe { x86::f16x8_to_f64x8_x86_f16c(i) } + } else { + f16x8_to_f64x8_fallback(i) + } + } +} + +convert_fn! { + fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) -> () { + if feature("f16c") { + convert_chunked_slice_8(src, dst, x86::f32x8_to_f16x8_x86_f16c, + x86::f32x4_to_f16x4_x86_f16c) + } else { + slice_fallback(src, dst, f32_to_f16_fallback) + } + } +} + +convert_fn! { + fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) -> () { + if feature("f16c") { + convert_chunked_slice_8(src, dst, x86::f16x8_to_f32x8_x86_f16c, + x86::f16x4_to_f32x4_x86_f16c) + } else { + slice_fallback(src, dst, f16_to_f32_fallback) + } + } +} + +convert_fn! { + fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) -> () { + if feature("f16c") { + convert_chunked_slice_8(src, dst, x86::f64x8_to_f16x8_x86_f16c, + x86::f64x4_to_f16x4_x86_f16c) + } else { + slice_fallback(src, dst, f64_to_f16_fallback) + } + } +} + +convert_fn! { + fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) -> () { + if feature("f16c") { + convert_chunked_slice_8(src, dst, x86::f16x8_to_f64x8_x86_f16c, + x86::f16x4_to_f64x4_x86_f16c) + } else { + slice_fallback(src, dst, f16_to_f64_fallback) + } + } +} + +/// Chunks sliced into x8 or x4 arrays +#[inline] +fn convert_chunked_slice_8<S: Copy + Default, D: Copy>( + src: &[S], + dst: &mut [D], + fn8: unsafe fn(&[S; 8]) -> [D; 8], + fn4: unsafe fn(&[S; 4]) -> [D; 4], +) { + assert_eq!(src.len(), dst.len()); + + // TODO: Can be further optimized with array_chunks when it becomes stabilized + + let src_chunks = src.chunks_exact(8); + let mut dst_chunks = dst.chunks_exact_mut(8); + let src_remainder = src_chunks.remainder(); + for (s, d) in src_chunks.zip(&mut dst_chunks) { + let chunk: &[S; 8] = s.try_into().unwrap(); + d.copy_from_slice(unsafe { &fn8(chunk) }); + } + + // Process remainder + if src_remainder.len() > 4 { + let mut buf: [S; 8] = Default::default(); + buf[..src_remainder.len()].copy_from_slice(src_remainder); + let vec = unsafe { fn8(&buf) }; + let dst_remainder = dst_chunks.into_remainder(); + dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); + } else if !src_remainder.is_empty() { + let mut buf: [S; 4] = Default::default(); + buf[..src_remainder.len()].copy_from_slice(src_remainder); + let vec = unsafe { fn4(&buf) }; + let dst_remainder = dst_chunks.into_remainder(); + dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); + } +} + +/// Chunks sliced into x4 arrays +#[inline] +fn convert_chunked_slice_4<S: Copy + Default, D: Copy>( + src: &[S], + dst: &mut [D], + f: unsafe fn(&[S; 4]) -> [D; 4], +) { + assert_eq!(src.len(), dst.len()); + + // TODO: Can be further optimized with array_chunks when it becomes stabilized + + let src_chunks = src.chunks_exact(4); + let mut dst_chunks = dst.chunks_exact_mut(4); + let src_remainder = src_chunks.remainder(); + for (s, d) in src_chunks.zip(&mut dst_chunks) { + let chunk: &[S; 4] = s.try_into().unwrap(); + d.copy_from_slice(unsafe { &f(chunk) }); + } + + // Process remainder + if !src_remainder.is_empty() { + let mut buf: [S; 4] = Default::default(); + buf[..src_remainder.len()].copy_from_slice(src_remainder); + let vec = unsafe { f(&buf) }; + let dst_remainder = dst_chunks.into_remainder(); + dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); + } +} + /////////////// Fallbacks //////////////// // In the below functions, round to nearest, with ties to even. @@ -143,6 +289,7 @@ convert_fn! { // which can be simplified into // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0 +#[inline] pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 { // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized // Convert to raw bytes @@ -203,6 +350,7 @@ pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 { } } +#[inline] pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 { // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always // be lost on half-precision. @@ -270,6 +418,7 @@ pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 { } } +#[inline] pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 { // Check for signed zero // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized @@ -316,6 +465,7 @@ pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 { unsafe { mem::transmute(sign | exp | man) } } +#[inline] pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 { // Check for signed zero // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized @@ -363,9 +513,7 @@ pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 { } #[inline] -fn f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4] { - debug_assert!(v.len() >= 4); - +fn f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4] { [ f16_to_f32_fallback(v[0]), f16_to_f32_fallback(v[1]), @@ -375,9 +523,7 @@ fn f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4] { } #[inline] -fn f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4] { - debug_assert!(v.len() >= 4); - +fn f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4] { [ f32_to_f16_fallback(v[0]), f32_to_f16_fallback(v[1]), @@ -387,9 +533,7 @@ fn f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4] { } #[inline] -fn f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4] { - debug_assert!(v.len() >= 4); - +fn f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4] { [ f16_to_f64_fallback(v[0]), f16_to_f64_fallback(v[1]), @@ -399,17 +543,79 @@ fn f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4] { } #[inline] -fn f64x4_to_f16x4_fallback(v: &[f64]) -> [u16; 4] { - debug_assert!(v.len() >= 4); +fn f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4] { + [ + f64_to_f16_fallback(v[0]), + f64_to_f16_fallback(v[1]), + f64_to_f16_fallback(v[2]), + f64_to_f16_fallback(v[3]), + ] +} + +#[inline] +fn f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8] { + [ + f16_to_f32_fallback(v[0]), + f16_to_f32_fallback(v[1]), + f16_to_f32_fallback(v[2]), + f16_to_f32_fallback(v[3]), + f16_to_f32_fallback(v[4]), + f16_to_f32_fallback(v[5]), + f16_to_f32_fallback(v[6]), + f16_to_f32_fallback(v[7]), + ] +} + +#[inline] +fn f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8] { + [ + f32_to_f16_fallback(v[0]), + f32_to_f16_fallback(v[1]), + f32_to_f16_fallback(v[2]), + f32_to_f16_fallback(v[3]), + f32_to_f16_fallback(v[4]), + f32_to_f16_fallback(v[5]), + f32_to_f16_fallback(v[6]), + f32_to_f16_fallback(v[7]), + ] +} +#[inline] +fn f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8] { + [ + f16_to_f64_fallback(v[0]), + f16_to_f64_fallback(v[1]), + f16_to_f64_fallback(v[2]), + f16_to_f64_fallback(v[3]), + f16_to_f64_fallback(v[4]), + f16_to_f64_fallback(v[5]), + f16_to_f64_fallback(v[6]), + f16_to_f64_fallback(v[7]), + ] +} + +#[inline] +fn f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8] { [ f64_to_f16_fallback(v[0]), f64_to_f16_fallback(v[1]), f64_to_f16_fallback(v[2]), f64_to_f16_fallback(v[3]), + f64_to_f16_fallback(v[4]), + f64_to_f16_fallback(v[5]), + f64_to_f16_fallback(v[6]), + f64_to_f16_fallback(v[7]), ] } +#[inline] +fn slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D) { + assert_eq!(src.len(), dst.len()); + for (s, d) in src.iter().copied().zip(dst.iter_mut()) { + *d = f(s); + } +} + /////////////// x86/x86_64 f16c //////////////// #[cfg(all( feature = "use-intrinsics", @@ -419,12 +625,18 @@ mod x86 { use core::{mem::MaybeUninit, ptr}; #[cfg(target_arch = "x86")] - use core::arch::x86::{__m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT}; + use core::arch::x86::{ + __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps, + _MM_FROUND_TO_NEAREST_INT, + }; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::{ - __m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT, + __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps, _mm_cvtps_ph, + _MM_FROUND_TO_NEAREST_INT, }; + use super::convert_chunked_slice_8; + #[target_feature(enable = "f16c")] #[inline] pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 { @@ -445,9 +657,7 @@ mod x86 { #[target_feature(enable = "f16c")] #[inline] - pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16]) -> [f32; 4] { - debug_assert!(v.len() >= 4); - + pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16; 4]) -> [f32; 4] { let mut vec = MaybeUninit::<__m128i>::zeroed(); ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); let retval = _mm_cvtph_ps(vec.assume_init()); @@ -456,9 +666,7 @@ mod x86 { #[target_feature(enable = "f16c")] #[inline] - pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32]) -> [u16; 4] { - debug_assert!(v.len() >= 4); - + pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32; 4]) -> [u16; 4] { let mut vec = MaybeUninit::<__m128>::uninit(); ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); @@ -467,13 +675,8 @@ mod x86 { #[target_feature(enable = "f16c")] #[inline] - pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16]) -> [f64; 4] { - debug_assert!(v.len() >= 4); - - let mut vec = MaybeUninit::<__m128i>::zeroed(); - ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); - let retval = _mm_cvtph_ps(vec.assume_init()); - let array = *(&retval as *const __m128).cast::<[f32; 4]>(); + pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16; 4]) -> [f64; 4] { + let array = f16x4_to_f32x4_x86_f16c(v); // Let compiler vectorize this regular cast for now. // TODO: investigate auto-detecting sse2/avx convert features [ @@ -486,16 +689,64 @@ mod x86 { #[target_feature(enable = "f16c")] #[inline] - pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64]) -> [u16; 4] { - debug_assert!(v.len() >= 4); - + pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64; 4]) -> [u16; 4] { // Let compiler vectorize this regular cast for now. // TODO: investigate auto-detecting sse2/avx convert features let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32]; + f32x4_to_f16x4_x86_f16c(&v) + } - let mut vec = MaybeUninit::<__m128>::uninit(); - ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); - let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); + #[target_feature(enable = "f16c")] + #[inline] + pub(super) unsafe fn f16x8_to_f32x8_x86_f16c(v: &[u16; 8]) -> [f32; 8] { + let mut vec = MaybeUninit::<__m128i>::zeroed(); + ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 8); + let retval = _mm256_cvtph_ps(vec.assume_init()); + *(&retval as *const __m256).cast() + } + + #[target_feature(enable = "f16c")] + #[inline] + pub(super) unsafe fn f32x8_to_f16x8_x86_f16c(v: &[f32; 8]) -> [u16; 8] { + let mut vec = MaybeUninit::<__m256>::uninit(); + ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 8); + let retval = _mm256_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); *(&retval as *const __m128i).cast() } + + #[target_feature(enable = "f16c")] + #[inline] + pub(super) unsafe fn f16x8_to_f64x8_x86_f16c(v: &[u16; 8]) -> [f64; 8] { + let array = f16x8_to_f32x8_x86_f16c(v); + // Let compiler vectorize this regular cast for now. + // TODO: investigate auto-detecting sse2/avx convert features + [ + array[0] as f64, + array[1] as f64, + array[2] as f64, + array[3] as f64, + array[4] as f64, + array[5] as f64, + array[6] as f64, + array[7] as f64, + ] + } + + #[target_feature(enable = "f16c")] + #[inline] + pub(super) unsafe fn f64x8_to_f16x8_x86_f16c(v: &[f64; 8]) -> [u16; 8] { + // Let compiler vectorize this regular cast for now. + // TODO: investigate auto-detecting sse2/avx convert features + let v = [ + v[0] as f32, + v[1] as f32, + v[2] as f32, + v[3] as f32, + v[4] as f32, + v[5] as f32, + v[6] as f32, + v[7] as f32, + ]; + f32x8_to_f16x8_x86_f16c(&v) + } } @@ -27,6 +27,22 @@ //! //! A [`prelude`] module is provided for easy importing of available utility traits. //! +//! # Serialization +//! +//! When the `serde` feature is enabled, [`f16`] and [`bf16`] will be serialized as a newtype of +//! [`u16`] by default. In binary formats this is ideal, as it will generally use just two bytes for +//! storage. For string formats like JSON, however, this isn't as useful, and due to design +//! limitations of serde, it's not possible for the default `Serialize` implementation to support +//! different serialization for different formats. +//! +//! Instead, it's up to the containter type of the floats to control how it is serialized. This can +//! easily be controlled when using the derive macros using `#[serde(serialize_with="")]` +//! attributes. For both [`f16`] and [`bf16`] a `serialize_as_f32` and `serialize_as_string` are +//! provided for use with this attribute. +//! +//! Deserialization of both float types supports deserializing from the default serialization, +//! strings, and `f32`/`f64` values, so no additional work is required. +//! //! # Cargo Features //! //! This crate supports a number of optional cargo features. None of these features are enabled by @@ -163,7 +179,7 @@ ), feature(stdsimd, f16c_target_feature) )] -#![doc(html_root_url = "https://docs.rs/half/2.1.0")] +#![doc(html_root_url = "https://docs.rs/half/2.2.1")] #![doc(test(attr(deny(warnings), allow(unused))))] #![cfg_attr(docsrs, feature(doc_cfg))] diff --git a/src/slice.rs b/src/slice.rs index f0dc3e4..f1e9feb 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -325,24 +325,7 @@ impl HalfFloatSliceExt for [f16] { "destination and source slices have different lengths" ); - let mut chunks = src.chunks_exact(4); - let mut chunk_count = 0usize; // Not using .enumerate() because we need this value for remainder - for chunk in &mut chunks { - let vec = convert::f32x4_to_f16x4(chunk); - let dst_idx = chunk_count * 4; - self[dst_idx..dst_idx + 4].copy_from_slice(vec.reinterpret_cast()); - chunk_count += 1; - } - - // Process remainder - if !chunks.remainder().is_empty() { - let mut buf = [0f32; 4]; - buf[..chunks.remainder().len()].copy_from_slice(chunks.remainder()); - let vec = convert::f32x4_to_f16x4(&buf); - let dst_idx = chunk_count * 4; - self[dst_idx..dst_idx + chunks.remainder().len()] - .copy_from_slice(vec[..chunks.remainder().len()].reinterpret_cast()); - } + convert::f32_to_f16_slice(src, self.reinterpret_cast_mut()) } fn convert_from_f64_slice(&mut self, src: &[f64]) { @@ -352,24 +335,7 @@ impl HalfFloatSliceExt for [f16] { "destination and source slices have different lengths" ); - let mut chunks = src.chunks_exact(4); - let mut chunk_count = 0usize; // Not using .enumerate() because we need this value for remainder - for chunk in &mut chunks { - let vec = convert::f64x4_to_f16x4(chunk); - let dst_idx = chunk_count * 4; - self[dst_idx..dst_idx + 4].copy_from_slice(vec.reinterpret_cast()); - chunk_count += 1; - } - - // Process remainder - if !chunks.remainder().is_empty() { - let mut buf = [0f64; 4]; - buf[..chunks.remainder().len()].copy_from_slice(chunks.remainder()); - let vec = convert::f64x4_to_f16x4(&buf); - let dst_idx = chunk_count * 4; - self[dst_idx..dst_idx + chunks.remainder().len()] - .copy_from_slice(vec[..chunks.remainder().len()].reinterpret_cast()); - } + convert::f64_to_f16_slice(src, self.reinterpret_cast_mut()) } fn convert_to_f32_slice(&self, dst: &mut [f32]) { @@ -379,24 +345,7 @@ impl HalfFloatSliceExt for [f16] { "destination and source slices have different lengths" ); - let mut chunks = self.chunks_exact(4); - let mut chunk_count = 0usize; // Not using .enumerate() because we need this value for remainder - for chunk in &mut chunks { - let vec = convert::f16x4_to_f32x4(chunk.reinterpret_cast()); - let dst_idx = chunk_count * 4; - dst[dst_idx..dst_idx + 4].copy_from_slice(&vec); - chunk_count += 1; - } - - // Process remainder - if !chunks.remainder().is_empty() { - let mut buf = [0u16; 4]; - buf[..chunks.remainder().len()].copy_from_slice(chunks.remainder().reinterpret_cast()); - let vec = convert::f16x4_to_f32x4(&buf); - let dst_idx = chunk_count * 4; - dst[dst_idx..dst_idx + chunks.remainder().len()] - .copy_from_slice(&vec[..chunks.remainder().len()]); - } + convert::f16_to_f32_slice(self.reinterpret_cast(), dst) } fn convert_to_f64_slice(&self, dst: &mut [f64]) { @@ -406,24 +355,7 @@ impl HalfFloatSliceExt for [f16] { "destination and source slices have different lengths" ); - let mut chunks = self.chunks_exact(4); - let mut chunk_count = 0usize; // Not using .enumerate() because we need this value for remainder - for chunk in &mut chunks { - let vec = convert::f16x4_to_f64x4(chunk.reinterpret_cast()); - let dst_idx = chunk_count * 4; - dst[dst_idx..dst_idx + 4].copy_from_slice(&vec); - chunk_count += 1; - } - - // Process remainder - if !chunks.remainder().is_empty() { - let mut buf = [0u16; 4]; - buf[..chunks.remainder().len()].copy_from_slice(chunks.remainder().reinterpret_cast()); - let vec = convert::f16x4_to_f64x4(&buf); - let dst_idx = chunk_count * 4; - dst[dst_idx..dst_idx + chunks.remainder().len()] - .copy_from_slice(&vec[..chunks.remainder().len()]); - } + convert::f16_to_f64_slice(self.reinterpret_cast(), dst) } #[cfg(any(feature = "alloc", feature = "std"))] |