From 5d3aad736d6459eee9aabb25877fac885ba67419 Mon Sep 17 00:00:00 2001 From: Benjamin Philip Date: Wed, 8 Jun 2022 18:09:18 +0530 Subject: [PATCH 1/2] Support implicit promoting with f16 - Create a ImplicitPromote trait and implement it for other types - Implement ImplicitPromote traits for f16 --- src/core/util.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/core/util.rs b/src/core/util.rs index d5c10f4f..6a0ceb2d 100644 --- a/src/core/util.rs +++ b/src/core/util.rs @@ -468,6 +468,7 @@ macro_rules! implicit { implicit!(c64, c32 => c64); implicit!(c64, f64 => c64); implicit!(c64, f32 => c64); +implicit!(c64, f16 => c64); implicit!(c64, i64 => c64); implicit!(c64, u64 => c64); implicit!(c64, i32 => c64); @@ -481,6 +482,7 @@ implicit!(c64, u8 => c64); implicit!(c32, c64 => c64); implicit!(c32, f64 => c64); implicit!(c32, f32 => c32); +implicit!(c32, f16 => c32); implicit!(c32, i64 => c32); implicit!(c32, u64 => c32); implicit!(c32, i32 => c32); @@ -494,6 +496,7 @@ implicit!(c32, u8 => c32); implicit!(f64, c64 => c64); implicit!(f64, c32 => c64); implicit!(f64, f32 => f64); +implicit!(f64, f16 => f64); implicit!(f64, i64 => f64); implicit!(f64, u64 => f64); implicit!(f64, i32 => f64); @@ -507,6 +510,7 @@ implicit!(f64, u8 => f64); implicit!(f32, c64 => c64); implicit!(f32, c32 => c32); implicit!(f32, f64 => f64); +implicit!(f32, f16 => f32); implicit!(f32, i64 => f32); implicit!(f32, u64 => f32); implicit!(f32, i32 => f32); @@ -516,11 +520,26 @@ implicit!(f32, u16 => f32); implicit!(f32, bool => f32); implicit!(f32, u8 => f32); +//LHS is 16-bit floating point +implicit!(f16, c64 => c64); +implicit!(f16, c32 => c32); +implicit!(f16, f64 => f64); +implicit!(f16, f32 => f32); +implicit!(f16, i64 => f16); +implicit!(f16, u64 => f16); +implicit!(f16, i32 => f16); +implicit!(f16, u32 => f16); +implicit!(f16, i16 => f16); +implicit!(f16, u16 => f16); +implicit!(f16, bool => f16); +implicit!(f16, u8 => f16); + //LHS is 64-bit signed integer implicit!(i64, c64 => c64); implicit!(i64, c32 => c32); implicit!(i64, f64 => f64); implicit!(i64, f32 => f32); +implicit!(i64, f16 => f16); implicit!(i64, u64 => u64); implicit!(i64, i32 => i64); implicit!(i64, u32 => i64); @@ -534,6 +553,7 @@ implicit!(u64, c64 => c64); implicit!(u64, c32 => c32); implicit!(u64, f64 => f64); implicit!(u64, f32 => f32); +implicit!(u64, f16 => f16); implicit!(u64, i64 => u64); implicit!(u64, i32 => u64); implicit!(u64, u32 => u64); @@ -547,6 +567,7 @@ implicit!(i32, c64 => c64); implicit!(i32, c32 => c32); implicit!(i32, f64 => f64); implicit!(i32, f32 => f32); +implicit!(i32, f16 => f16); implicit!(i32, i64 => i64); implicit!(i32, u64 => u64); implicit!(i32, u32 => u32); @@ -560,6 +581,7 @@ implicit!(u32, c64 => c64); implicit!(u32, c32 => c32); implicit!(u32, f64 => f64); implicit!(u32, f32 => f32); +implicit!(u32, f16 => f16); implicit!(u32, i64 => i64); implicit!(u32, u64 => u64); implicit!(u32, i32 => u32); @@ -573,6 +595,7 @@ implicit!(i16, c64 => c64); implicit!(i16, c32 => c32); implicit!(i16, f64 => f64); implicit!(i16, f32 => f32); +implicit!(i16, f16 => f16); implicit!(i16, i64 => i64); implicit!(i16, u64 => u64); implicit!(i16, i32 => i32); @@ -586,6 +609,7 @@ implicit!(u16, c64 => c64); implicit!(u16, c32 => c32); implicit!(u16, f64 => f64); implicit!(u16, f32 => f32); +implicit!(u16, f16 => f16); implicit!(u16, i64 => i64); implicit!(u16, u64 => u64); implicit!(u16, i32 => i32); @@ -599,6 +623,7 @@ implicit!(u8, c64 => c64); implicit!(u8, c32 => c32); implicit!(u8, f64 => f64); implicit!(u8, f32 => f32); +implicit!(u8, f16 => f16); implicit!(u8, i64 => i64); implicit!(u8, u64 => u64); implicit!(u8, i32 => i32); @@ -612,6 +637,7 @@ implicit!(bool, c64 => c64); implicit!(bool, c32 => c32); implicit!(bool, f64 => f64); implicit!(bool, f32 => f32); +implicit!(bool, f16 => f16); implicit!(bool, i64 => i64); implicit!(bool, u64 => u64); implicit!(bool, i32 => i32); From f0c63ac7c208e7b6908ea0ebddab45ed8b4b477c Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Wed, 13 Jul 2022 17:53:32 -0400 Subject: [PATCH 2/2] Additional changes for f16 type --- src/core/arith.rs | 3 +++ src/core/data.rs | 20 ++++++++++++++++++++ src/core/num.rs | 3 +++ src/core/util.rs | 12 +++++++++++- 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/core/arith.rs b/src/core/arith.rs index d9a39629..d0f1dda8 100644 --- a/src/core/arith.rs +++ b/src/core/arith.rs @@ -4,6 +4,8 @@ use super::defines::AfError; use super::dim4::Dim4; use super::error::HANDLE_ERROR; use super::util::{af_array, HasAfEnum, ImplicitPromote, IntegralType}; + +use half::f16; use num::Zero; use libc::c_int; @@ -758,6 +760,7 @@ arith_scalar_spec!(Complex); arith_scalar_spec!(Complex); arith_scalar_spec!(f64); arith_scalar_spec!(f32); +arith_scalar_spec!(f16); arith_scalar_spec!(u64); arith_scalar_spec!(i64); arith_scalar_spec!(u32); diff --git a/src/core/data.rs b/src/core/data.rs index 62630f58..75308807 100644 --- a/src/core/data.rs +++ b/src/core/data.rs @@ -4,6 +4,7 @@ use super::dim4::Dim4; use super::error::HANDLE_ERROR; use super::util::{af_array, c32, c64, dim_t, u64_t, HasAfEnum}; +use half::f16; use libc::{c_double, c_int, c_uint}; use std::option::Option; use std::vec::Vec; @@ -231,6 +232,25 @@ impl ConstGenerator for bool { } } +impl ConstGenerator for f16 { + type OutType = f16; + + fn generate(&self, dims: Dim4) -> Array { + unsafe { + let mut temp: af_array = std::ptr::null_mut(); + let err_val = af_constant( + &mut temp as *mut af_array, + f16::to_f64(*self), + dims.ndims() as c_uint, + dims.get().as_ptr() as *const dim_t, + 12, + ); + HANDLE_ERROR(AfError::from(err_val)); + temp.into() + } + } +} + macro_rules! cnst { ($rust_type:ty, $ffi_type:expr) => { impl ConstGenerator for $rust_type { diff --git a/src/core/num.rs b/src/core/num.rs index 27f5ab26..73eca474 100644 --- a/src/core/num.rs +++ b/src/core/num.rs @@ -1,3 +1,4 @@ + pub trait Zero { fn zero() -> Self; } @@ -28,6 +29,7 @@ zero_impl!(i64, 0); zero_impl!(isize, 0); zero_impl!(f32, 0.0); zero_impl!(f64, 0.0); +zero_impl!(half::f16, half::f16::from_f32(0.0)); macro_rules! one_impl { ( $t:ident, $o:expr ) => { @@ -51,3 +53,4 @@ one_impl!(i64, 1); one_impl!(isize, 1); one_impl!(f32, 1.0); one_impl!(f64, 1.0); +one_impl!(half::f16, half::f16::from_f32(1.0)); diff --git a/src/core/util.rs b/src/core/util.rs index 6a0ceb2d..5a626d52 100644 --- a/src/core/util.rs +++ b/src/core/util.rs @@ -337,7 +337,7 @@ impl HasAfEnum for f16 { type AbsOutType = Self; type ArgOutType = Self; type UnaryOutType = Self; - type ComplexOutType = Complex; + type ComplexOutType = Complex; type MeanOutType = Self; type AggregateOutType = f32; type ProductOutType = f32; @@ -678,12 +678,18 @@ impl FloatingPoint for f32 { true } } +impl FloatingPoint for f16 { + fn is_real() -> bool { + true + } +} ///Trait qualifier to accept real data(numbers) pub trait RealFloating: HasAfEnum {} impl RealFloating for f64 {} impl RealFloating for f32 {} +impl RealFloating for f16 {} ///Trait qualifier to accept complex data(numbers) pub trait ComplexFloating: HasAfEnum {} @@ -696,6 +702,7 @@ pub trait RealNumber: HasAfEnum {} impl RealNumber for f64 {} impl RealNumber for f32 {} +impl RealNumber for f16 {} impl RealNumber for i32 {} impl RealNumber for u32 {} impl RealNumber for i16 {} @@ -856,6 +863,8 @@ impl Fromf64 for u32 { fn fromf64(value: f64) -> Self { value as Self }} #[rustfmt::skip] impl Fromf64 for i32 { fn fromf64(value: f64) -> Self { value as Self }} #[rustfmt::skip] +impl Fromf64 for f16 { fn fromf64(value: f64) -> Self { f16::from_f64(value) }} +#[rustfmt::skip] impl Fromf64 for u16 { fn fromf64(value: f64) -> Self { value as Self }} #[rustfmt::skip] impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }} @@ -873,6 +882,7 @@ impl IndexableType for u64 {} impl IndexableType for f32 {} impl IndexableType for i32 {} impl IndexableType for u32 {} +impl IndexableType for f16 {} impl IndexableType for i16 {} impl IndexableType for u16 {} impl IndexableType for u8 {}