aboutsummaryrefslogtreecommitdiff
path: root/src/bfloat/convert.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bfloat/convert.rs')
-rw-r--r--src/bfloat/convert.rs41
1 files changed, 27 insertions, 14 deletions
diff --git a/src/bfloat/convert.rs b/src/bfloat/convert.rs
index 4aa0aec..8f258f5 100644
--- a/src/bfloat/convert.rs
+++ b/src/bfloat/convert.rs
@@ -1,6 +1,11 @@
-pub(crate) fn f32_to_bf16(value: f32) -> u16 {
+use crate::leading_zeros::leading_zeros_u16;
+use core::mem;
+
+#[inline]
+pub(crate) const fn f32_to_bf16(value: f32) -> u16 {
+ // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
// Convert to raw bytes
- let x = value.to_bits();
+ let x: u32 = unsafe { mem::transmute(value) };
// check for NaN
if x & 0x7FFF_FFFFu32 > 0x7F80_0000u32 {
@@ -17,10 +22,12 @@ pub(crate) fn f32_to_bf16(value: f32) -> u16 {
}
}
-pub(crate) fn f64_to_bf16(value: f64) -> u16 {
+#[inline]
+pub(crate) const fn f64_to_bf16(value: f64) -> u16 {
+ // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
// Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
// be lost on half-precision.
- let val = value.to_bits();
+ let val: u64 = unsafe { mem::transmute(value) };
let x = (val >> 32) as u32;
// Extract IEEE754 components
@@ -83,19 +90,23 @@ pub(crate) fn f64_to_bf16(value: f64) -> u16 {
}
}
-pub(crate) fn bf16_to_f32(i: u16) -> f32 {
+#[inline]
+pub(crate) const fn bf16_to_f32(i: u16) -> f32 {
+ // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
// If NaN, keep current mantissa but also set most significiant mantissa bit
if i & 0x7FFFu16 > 0x7F80u16 {
- f32::from_bits((i as u32 | 0x0040u32) << 16)
+ unsafe { mem::transmute((i as u32 | 0x0040u32) << 16) }
} else {
- f32::from_bits((i as u32) << 16)
+ unsafe { mem::transmute((i as u32) << 16) }
}
}
-pub(crate) fn bf16_to_f64(i: u16) -> f64 {
+#[inline]
+pub(crate) const fn bf16_to_f64(i: u16) -> f64 {
+ // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
// Check for signed zero
if i & 0x7FFFu16 == 0 {
- return f64::from_bits((i as u64) << 48);
+ return unsafe { mem::transmute((i as u64) << 48) };
}
let half_sign = (i & 0x8000u16) as u64;
@@ -106,10 +117,12 @@ pub(crate) fn bf16_to_f64(i: u16) -> f64 {
if half_exp == 0x7F80u64 {
// Check for signed infinity if mantissa is zero
if half_man == 0 {
- return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64);
+ return unsafe { mem::transmute((half_sign << 48) | 0x7FF0_0000_0000_0000u64) };
} else {
// NaN, keep current mantissa but also set most significiant mantissa bit
- return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 45));
+ return unsafe {
+ mem::transmute((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 45))
+ };
}
}
@@ -121,15 +134,15 @@ pub(crate) fn bf16_to_f64(i: u16) -> f64 {
// Check for subnormals, which will be normalized by adjusting exponent
if half_exp == 0 {
// Calculate how much to adjust the exponent by
- let e = (half_man as u16).leading_zeros() - 9;
+ let e = leading_zeros_u16(half_man as u16) - 9;
// Rebias and adjust exponent
let exp = ((1023 - 127 - e) as u64) << 52;
let man = (half_man << (46 + e)) & 0xF_FFFF_FFFF_FFFFu64;
- return f64::from_bits(sign | exp | man);
+ return unsafe { mem::transmute(sign | exp | man) };
}
// Rebias exponent for a normalized normal
let exp = ((unbiased_exp + 1023) as u64) << 52;
let man = (half_man & 0x007Fu64) << 45;
- f64::from_bits(sign | exp | man)
+ unsafe { mem::transmute(sign | exp | man) }
}