diff options
Diffstat (limited to 'src/bfloat/convert.rs')
-rw-r--r-- | src/bfloat/convert.rs | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/src/bfloat/convert.rs b/src/bfloat/convert.rs index 4aa0aec..8f258f5 100644 --- a/src/bfloat/convert.rs +++ b/src/bfloat/convert.rs @@ -1,6 +1,11 @@ -pub(crate) fn f32_to_bf16(value: f32) -> u16 { +use crate::leading_zeros::leading_zeros_u16; +use core::mem; + +#[inline] +pub(crate) const fn f32_to_bf16(value: f32) -> u16 { + // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized // Convert to raw bytes - let x = value.to_bits(); + let x: u32 = unsafe { mem::transmute(value) }; // check for NaN if x & 0x7FFF_FFFFu32 > 0x7F80_0000u32 { @@ -17,10 +22,12 @@ pub(crate) fn f32_to_bf16(value: f32) -> u16 { } } -pub(crate) fn f64_to_bf16(value: f64) -> u16 { +#[inline] +pub(crate) const fn f64_to_bf16(value: f64) -> u16 { + // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always // be lost on half-precision. - let val = value.to_bits(); + let val: u64 = unsafe { mem::transmute(value) }; let x = (val >> 32) as u32; // Extract IEEE754 components @@ -83,19 +90,23 @@ pub(crate) fn f64_to_bf16(value: f64) -> u16 { } } -pub(crate) fn bf16_to_f32(i: u16) -> f32 { +#[inline] +pub(crate) const fn bf16_to_f32(i: u16) -> f32 { + // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized // If NaN, keep current mantissa but also set most significiant mantissa bit if i & 0x7FFFu16 > 0x7F80u16 { - f32::from_bits((i as u32 | 0x0040u32) << 16) + unsafe { mem::transmute((i as u32 | 0x0040u32) << 16) } } else { - f32::from_bits((i as u32) << 16) + unsafe { mem::transmute((i as u32) << 16) } } } -pub(crate) fn bf16_to_f64(i: u16) -> f64 { +#[inline] +pub(crate) const fn bf16_to_f64(i: u16) -> f64 { + // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized // Check for signed zero if i & 0x7FFFu16 == 0 { - return f64::from_bits((i as u64) << 48); + return unsafe { mem::transmute((i as u64) << 48) }; } let half_sign = (i & 0x8000u16) as u64; @@ -106,10 +117,12 @@ pub(crate) fn bf16_to_f64(i: u16) -> f64 { if half_exp == 0x7F80u64 { // Check for signed infinity if mantissa is zero if half_man == 0 { - return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64); + return unsafe { mem::transmute((half_sign << 48) | 0x7FF0_0000_0000_0000u64) }; } else { // NaN, keep current mantissa but also set most significiant mantissa bit - return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 45)); + return unsafe { + mem::transmute((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 45)) + }; } } @@ -121,15 +134,15 @@ pub(crate) fn bf16_to_f64(i: u16) -> f64 { // Check for subnormals, which will be normalized by adjusting exponent if half_exp == 0 { // Calculate how much to adjust the exponent by - let e = (half_man as u16).leading_zeros() - 9; + let e = leading_zeros_u16(half_man as u16) - 9; // Rebias and adjust exponent let exp = ((1023 - 127 - e) as u64) << 52; let man = (half_man << (46 + e)) & 0xF_FFFF_FFFF_FFFFu64; - return f64::from_bits(sign | exp | man); + return unsafe { mem::transmute(sign | exp | man) }; } // Rebias exponent for a normalized normal let exp = ((unbiased_exp + 1023) as u64) << 52; let man = (half_man & 0x007Fu64) << 45; - f64::from_bits(sign | exp | man) + unsafe { mem::transmute(sign | exp | man) } } |