Skip to content

Commit 552a84a

Browse files
committed
Additional changes for f16 type
1 parent fb98756 commit 552a84a

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

src/core/arith.rs

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use super::defines::AfError;
44
use super::dim4::Dim4;
55
use super::error::HANDLE_ERROR;
66
use super::util::{af_array, HasAfEnum, ImplicitPromote, IntegralType};
7+
8+
use half::f16;
79
use num::Zero;
810

911
use libc::c_int;
@@ -763,6 +765,7 @@ arith_scalar_spec!(Complex<f64>);
763765
arith_scalar_spec!(Complex<f32>);
764766
arith_scalar_spec!(f64);
765767
arith_scalar_spec!(f32);
768+
arith_scalar_spec!(f16);
766769
arith_scalar_spec!(u64);
767770
arith_scalar_spec!(i64);
768771
arith_scalar_spec!(u32);

src/core/data.rs

+20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::dim4::Dim4;
44
use super::error::HANDLE_ERROR;
55
use super::util::{af_array, c32, c64, dim_t, u64_t, HasAfEnum};
66

7+
use half::f16;
78
use libc::{c_double, c_int, c_uint};
89
use std::option::Option;
910
use std::vec::Vec;
@@ -231,6 +232,25 @@ impl ConstGenerator for bool {
231232
}
232233
}
233234

235+
impl ConstGenerator for f16 {
236+
type OutType = f16;
237+
238+
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
239+
unsafe {
240+
let mut temp: af_array = std::ptr::null_mut();
241+
let err_val = af_constant(
242+
&mut temp as *mut af_array,
243+
f16::to_f64(*self),
244+
dims.ndims() as c_uint,
245+
dims.get().as_ptr() as *const dim_t,
246+
12,
247+
);
248+
HANDLE_ERROR(AfError::from(err_val));
249+
temp.into()
250+
}
251+
}
252+
}
253+
234254
macro_rules! cnst {
235255
($rust_type:ty, $ffi_type:expr) => {
236256
impl ConstGenerator for $rust_type {

src/core/num.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
pub trait Zero {
23
fn zero() -> Self;
34
}
@@ -28,6 +29,7 @@ zero_impl!(i64, 0);
2829
zero_impl!(isize, 0);
2930
zero_impl!(f32, 0.0);
3031
zero_impl!(f64, 0.0);
32+
zero_impl!(half::f16, half::f16::from_f32(0.0));
3133

3234
macro_rules! one_impl {
3335
( $t:ident, $o:expr ) => {
@@ -51,3 +53,4 @@ one_impl!(i64, 1);
5153
one_impl!(isize, 1);
5254
one_impl!(f32, 1.0);
5355
one_impl!(f64, 1.0);
56+
one_impl!(half::f16, half::f16::from_f32(1.0));

src/core/util.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ impl HasAfEnum for f16 {
341341
type AbsOutType = Self;
342342
type ArgOutType = Self;
343343
type UnaryOutType = Self;
344-
type ComplexOutType = Complex<f16>;
344+
type ComplexOutType = Complex<f32>;
345345
type MeanOutType = Self;
346346
type AggregateOutType = f32;
347347
type ProductOutType = f32;
@@ -682,12 +682,18 @@ impl FloatingPoint for f32 {
682682
true
683683
}
684684
}
685+
impl FloatingPoint for f16 {
686+
fn is_real() -> bool {
687+
true
688+
}
689+
}
685690

686691
///Trait qualifier to accept real data(numbers)
687692
pub trait RealFloating: HasAfEnum {}
688693

689694
impl RealFloating for f64 {}
690695
impl RealFloating for f32 {}
696+
impl RealFloating for f16 {}
691697

692698
///Trait qualifier to accept complex data(numbers)
693699
pub trait ComplexFloating: HasAfEnum {}
@@ -700,6 +706,7 @@ pub trait RealNumber: HasAfEnum {}
700706

701707
impl RealNumber for f64 {}
702708
impl RealNumber for f32 {}
709+
impl RealNumber for f16 {}
703710
impl RealNumber for i32 {}
704711
impl RealNumber for u32 {}
705712
impl RealNumber for i16 {}
@@ -860,6 +867,8 @@ impl Fromf64 for u32 { fn fromf64(value: f64) -> Self { value as Self }}
860867
#[rustfmt::skip]
861868
impl Fromf64 for i32 { fn fromf64(value: f64) -> Self { value as Self }}
862869
#[rustfmt::skip]
870+
impl Fromf64 for f16 { fn fromf64(value: f64) -> Self { f16::from_f64(value) }}
871+
#[rustfmt::skip]
863872
impl Fromf64 for u16 { fn fromf64(value: f64) -> Self { value as Self }}
864873
#[rustfmt::skip]
865874
impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }}
@@ -877,6 +886,7 @@ impl IndexableType for u64 {}
877886
impl IndexableType for f32 {}
878887
impl IndexableType for i32 {}
879888
impl IndexableType for u32 {}
889+
impl IndexableType for f16 {}
880890
impl IndexableType for i16 {}
881891
impl IndexableType for u16 {}
882892
impl IndexableType for u8 {}

0 commit comments

Comments
 (0)