Skip to content

Commit 93b1036

Browse files
umar4569prady9
authored andcommitted
Additional changes for f16 type
1 parent 9eb882a commit 93b1036

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

src/core/arith.rs

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use super::defines::AfError;
44
use super::dim4::Dim4;
55
use super::error::HANDLE_ERROR;
66
use super::util::{af_array, HasAfEnum, ImplicitPromote, IntegralType};
7+
8+
use half::f16;
79
use num::Zero;
810

911
use libc::c_int;
@@ -758,6 +760,7 @@ arith_scalar_spec!(Complex<f64>);
758760
arith_scalar_spec!(Complex<f32>);
759761
arith_scalar_spec!(f64);
760762
arith_scalar_spec!(f32);
763+
arith_scalar_spec!(f16);
761764
arith_scalar_spec!(u64);
762765
arith_scalar_spec!(i64);
763766
arith_scalar_spec!(u32);

src/core/data.rs

+20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::dim4::Dim4;
44
use super::error::HANDLE_ERROR;
55
use super::util::{af_array, c32, c64, dim_t, u64_t, HasAfEnum};
66

7+
use half::f16;
78
use libc::{c_double, c_int, c_uint};
89
use std::option::Option;
910
use std::vec::Vec;
@@ -231,6 +232,25 @@ impl ConstGenerator for bool {
231232
}
232233
}
233234

235+
impl ConstGenerator for f16 {
236+
type OutType = f16;
237+
238+
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
239+
unsafe {
240+
let mut temp: af_array = std::ptr::null_mut();
241+
let err_val = af_constant(
242+
&mut temp as *mut af_array,
243+
f16::to_f64(*self),
244+
dims.ndims() as c_uint,
245+
dims.get().as_ptr() as *const dim_t,
246+
12,
247+
);
248+
HANDLE_ERROR(AfError::from(err_val));
249+
temp.into()
250+
}
251+
}
252+
}
253+
234254
macro_rules! cnst {
235255
($rust_type:ty, $ffi_type:expr) => {
236256
impl ConstGenerator for $rust_type {

src/core/num.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
pub trait Zero {
23
fn zero() -> Self;
34
}
@@ -28,6 +29,7 @@ zero_impl!(i64, 0);
2829
zero_impl!(isize, 0);
2930
zero_impl!(f32, 0.0);
3031
zero_impl!(f64, 0.0);
32+
zero_impl!(half::f16, half::f16::from_f32(0.0));
3133

3234
macro_rules! one_impl {
3335
( $t:ident, $o:expr ) => {
@@ -51,3 +53,4 @@ one_impl!(i64, 1);
5153
one_impl!(isize, 1);
5254
one_impl!(f32, 1.0);
5355
one_impl!(f64, 1.0);
56+
one_impl!(half::f16, half::f16::from_f32(1.0));

src/core/util.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ impl HasAfEnum for f16 {
337337
type AbsOutType = Self;
338338
type ArgOutType = Self;
339339
type UnaryOutType = Self;
340-
type ComplexOutType = Complex<f16>;
340+
type ComplexOutType = Complex<f32>;
341341
type MeanOutType = Self;
342342
type AggregateOutType = f32;
343343
type ProductOutType = f32;
@@ -678,12 +678,18 @@ impl FloatingPoint for f32 {
678678
true
679679
}
680680
}
681+
impl FloatingPoint for f16 {
682+
fn is_real() -> bool {
683+
true
684+
}
685+
}
681686

682687
///Trait qualifier to accept real data(numbers)
683688
pub trait RealFloating: HasAfEnum {}
684689

685690
impl RealFloating for f64 {}
686691
impl RealFloating for f32 {}
692+
impl RealFloating for f16 {}
687693

688694
///Trait qualifier to accept complex data(numbers)
689695
pub trait ComplexFloating: HasAfEnum {}
@@ -696,6 +702,7 @@ pub trait RealNumber: HasAfEnum {}
696702

697703
impl RealNumber for f64 {}
698704
impl RealNumber for f32 {}
705+
impl RealNumber for f16 {}
699706
impl RealNumber for i32 {}
700707
impl RealNumber for u32 {}
701708
impl RealNumber for i16 {}
@@ -856,6 +863,8 @@ impl Fromf64 for u32 { fn fromf64(value: f64) -> Self { value as Self }}
856863
#[rustfmt::skip]
857864
impl Fromf64 for i32 { fn fromf64(value: f64) -> Self { value as Self }}
858865
#[rustfmt::skip]
866+
impl Fromf64 for f16 { fn fromf64(value: f64) -> Self { f16::from_f64(value) }}
867+
#[rustfmt::skip]
859868
impl Fromf64 for u16 { fn fromf64(value: f64) -> Self { value as Self }}
860869
#[rustfmt::skip]
861870
impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }}
@@ -873,6 +882,7 @@ impl IndexableType for u64 {}
873882
impl IndexableType for f32 {}
874883
impl IndexableType for i32 {}
875884
impl IndexableType for u32 {}
885+
impl IndexableType for f16 {}
876886
impl IndexableType for i16 {}
877887
impl IndexableType for u16 {}
878888
impl IndexableType for u8 {}

0 commit comments

Comments
 (0)