From d9f5a49e89fbc75f643eb286fa9b6c43c9544724 Mon Sep 17 00:00:00 2001 From: nbos <35620093+nbos@users.noreply.github.com> Date: Thu, 29 Dec 2022 12:58:35 +0000 Subject: [PATCH] Add log1p (#1821) Closes #1820 Co-authored-by: nbos <> --- prelude/math.fut | 5 +++++ rts/c/scalar.h | 26 ++++++++++++++++++++++++++ rts/c/scalar_f16.h | 19 +++++++++++++++++++ rts/c/uniform.h | 21 +++++++++++++++++++++ rts/python/scalar.py | 9 +++++++++ src/Futhark/AD/Derivatives.hs | 6 ++++++ src/Language/Futhark/Primitive.hs | 5 +++++ tests/primitive/log32.fut | 19 ++++++++++++------- tests/primitive/log64.fut | 20 +++++++++++++------- tests/primitive/naninf32.fut | 8 +++++++- tests/primitive/naninf64.fut | 8 +++++++- 11 files changed, 130 insertions(+), 16 deletions(-) diff --git a/prelude/math.fut b/prelude/math.fut index e2e1b0b297..ed52df6f84 100644 --- a/prelude/math.fut +++ b/prelude/math.fut @@ -187,6 +187,8 @@ module type real = { val log2: t -> t -- | Base-10 logarithm. val log10: t -> t + -- | Compute `log (1 + x)` accurately even when `x` is very small. + val log1p: t -> t -- | Round towards infinity. val ceil : t -> t @@ -927,6 +929,7 @@ module f64: (float with t = f64 with int_t = u64) = { def log (x: f64) = intrinsics.log64 x def log2 (x: f64) = intrinsics.log2_64 x def log10 (x: f64) = intrinsics.log10_64 x + def log1p (x: f64) = intrinsics.log1p_64 x def exp (x: f64) = intrinsics.exp64 x def sin (x: f64) = intrinsics.sin64 x def cos (x: f64) = intrinsics.cos64 x @@ -1041,6 +1044,7 @@ module f32: (float with t = f32 with int_t = u32) = { def log (x: f32) = intrinsics.log32 x def log2 (x: f32) = intrinsics.log2_32 x def log10 (x: f32) = intrinsics.log10_32 x + def log1p (x: f32) = intrinsics.log1p_32 x def exp (x: f32) = intrinsics.exp32 x def sin (x: f32) = intrinsics.sin32 x def cos (x: f32) = intrinsics.cos32 x @@ -1159,6 +1163,7 @@ module f16: (float with t = f16 with int_t = u16) = { def log (x: f16) = intrinsics.log16 x def log2 (x: f16) = intrinsics.log2_16 x def log10 (x: f16) = intrinsics.log10_16 x + def log1p (x: f16) = intrinsics.log1p_16 x def exp (x: f16) = intrinsics.exp16 x def sin (x: f16) = intrinsics.sin16 x def cos (x: f16) = intrinsics.cos16 x diff --git a/rts/c/scalar.h b/rts/c/scalar.h index 5cc1b4e42c..76e8c7021e 100644 --- a/rts/c/scalar.h +++ b/rts/c/scalar.h @@ -1774,6 +1774,10 @@ static inline float futrts_log10_32(float x) { return log10(x); } +static inline float futrts_log1p_32(float x) { + return log1p(x); +} + static inline float futrts_sqrt32(float x) { return sqrt(x); } @@ -1904,6 +1908,13 @@ static inline float futrts_log10_32(float x) { return futrts_log32(x) / log(10.0f); } +static inline float futrts_log1p_32(float x) { + if(x == -1.0f || (futrts_isinf32(x) && x > 0.0f)) return x / 0.0f; + float y = 1.0f + x; + float z = y - 1.0f; + return log(y) - (z-x)/y; +} + static inline float futrts_sqrt32(float x) { return sqrt(x); } @@ -2106,6 +2117,10 @@ static inline float futrts_log10_32(float x) { return log10f(x); } +static inline float futrts_log1p_32(float x) { + return log1pf(x); +} + static inline float futrts_sqrt32(float x) { return sqrtf(x); } @@ -2357,6 +2372,13 @@ static inline double futrts_log10_64(double x) { return futrts_log64(x)/log(10.0d); } +static inline double futrts_log1p_64(double x) { + if(x == -1.0d || (futrts_isinf64(x) && x > 0.0d)) return x / 0.0d; + double y = 1.0d + x; + double z = y - 1.0d; + return log(y) - (z-x)/y; +} + static inline double futrts_sqrt64(double x) { return sqrt(x); } @@ -2724,6 +2746,10 @@ static inline double futrts_log10_64(double x) { return log10(x); } +static inline double futrts_log1p_64(double x) { + return log1p(x); +} + static inline double futrts_sqrt64(double x) { return sqrt(x); } diff --git a/rts/c/scalar_f16.h b/rts/c/scalar_f16.h index 81841472f9..6afad424f0 100644 --- a/rts/c/scalar_f16.h +++ b/rts/c/scalar_f16.h @@ -217,6 +217,10 @@ static inline f16 futrts_log10_16(f16 x) { return log10(x); } +static inline f16 futrts_log1p_16(f16 x) { + return log1p(x); +} + static inline f16 futrts_sqrt16(f16 x) { return sqrt(x); } @@ -346,6 +350,13 @@ static inline f16 futrts_log10_16(f16 x) { return futrts_log16(x) / log(10.0f16); } +static inline f16 futrts_log1p_16(f16 x) { + if(x == -1.0f16 || (futrts_isinf16(x) && x > 0.0f16)) return x / 0.0f16; + f16 y = 1.0f16 + x; + f16 z = y - 1.0f16; + return log(y) - (z-x)/y; +} + static inline f16 futrts_sqrt16(f16 x) { return (float16)sqrt((float)x); } @@ -497,6 +508,10 @@ static inline f16 futrts_log10_16(f16 x) { return hlog10(x); } +static inline f16 futrts_log1p_16(f16 x) { + return (float16)log1pf((float)x); +} + static inline f16 futrts_sqrt16(f16 x) { return hsqrt(x); } @@ -695,6 +710,10 @@ static inline f16 futrts_log10_16(f16 x) { return futrts_log10_32(x); } +static inline f16 futrts_log1p_16(f16 x) { + return futrts_log1p_32(x); +} + static inline f16 futrts_sqrt16(f16 x) { return futrts_sqrt32(x); } diff --git a/rts/c/uniform.h b/rts/c/uniform.h index c6121c4ddc..cd457e24dc 100644 --- a/rts/c/uniform.h +++ b/rts/c/uniform.h @@ -931,6 +931,13 @@ static inline uniform float futrts_log10_32(uniform float x) { return futrts_log32(x) / log(10.0f); } +static inline uniform float futrts_log1p_32(uniform float x) { + if(x == -1.0f || (futrts_isinf32(x) && x > 0.0f)) return x / 0.0f; + uniform float y = 1.0f + x; + uniform float z = y - 1.0f; + return log(y) - (z-x)/y; +} + static inline uniform float futrts_sqrt32(uniform float x) { return sqrt(x); } @@ -1189,6 +1196,13 @@ static inline uniform double futrts_log10_64(uniform double x) { return futrts_log64(x)/log(10.0d); } +static inline uniform double futrts_log1p_64(uniform double x) { + if(x == -1.0d || (futrts_isinf64(x) && x > 0.0d)) return x / 0.0d; + uniform double y = 1.0d + x; + uniform double z = y - 1.0d; + return log(y) - (z-x)/y; +} + static inline uniform double futrts_sqrt64(uniform double x) { return sqrt(x); } @@ -1561,6 +1575,13 @@ static inline uniform f16 futrts_log10_16(uniform f16 x) { return futrts_log16(x) / log(10.0f16); } +static inline uniform f16 futrts_log1p_16(uniform f16 x) { + if(x == -1.0f16 || (futrts_isinf16(x) && x > 0.0f16)) return x / 0.0f16; + uniform f16 y = 1.0f16 + x; + uniform f16 z = y - 1.0f16; + return log(y) - (z-x)/y; +} + static inline uniform f16 futrts_sqrt16(uniform f16 x) { return (uniform f16)sqrt((uniform float)x); } diff --git a/rts/python/scalar.py b/rts/python/scalar.py index f870e48cc4..9ab89ae974 100644 --- a/rts/python/scalar.py +++ b/rts/python/scalar.py @@ -436,6 +436,9 @@ def futhark_log2_64(x): def futhark_log10_64(x): return np.float64(np.log10(x)) +def futhark_log1p_64(x): + return np.float64(np.log1p(x)) + def futhark_sqrt64(x): return np.sqrt(x) @@ -534,6 +537,9 @@ def futhark_log2_32(x): def futhark_log10_32(x): return np.float32(np.log10(x)) +def futhark_log1p_32(x): + return np.float32(np.log1p(x)) + def futhark_sqrt32(x): return np.float32(np.sqrt(x)) @@ -632,6 +638,9 @@ def futhark_log2_16(x): def futhark_log10_16(x): return np.float16(np.log10(x)) +def futhark_log1p_16(x): + return np.float16(np.log1p(x)) + def futhark_sqrt16(x): return np.float16(np.sqrt(x)) diff --git a/src/Futhark/AD/Derivatives.hs b/src/Futhark/AD/Derivatives.hs index 5df1412c64..e6a474fabc 100644 --- a/src/Futhark/AD/Derivatives.hs +++ b/src/Futhark/AD/Derivatives.hs @@ -183,6 +183,12 @@ pdBuiltin "log2_32" [x] = Just [untyped $ 1 / (isF32 x * log 2)] pdBuiltin "log2_64" [x] = Just [untyped $ 1 / (isF64 x * log 2)] +pdBuiltin "log1p_16" [x] = + Just [untyped $ 1 / (isF16 x + 1)] +pdBuiltin "log1p_32" [x] = + Just [untyped $ 1 / (isF32 x + 1)] +pdBuiltin "log1p_64" [x] = + Just [untyped $ 1 / (isF64 x + 1)] pdBuiltin "exp16" [x] = Just [untyped $ exp (isF16 x)] pdBuiltin "exp32" [x] = diff --git a/src/Language/Futhark/Primitive.hs b/src/Language/Futhark/Primitive.hs index 274c148f34..fc43fc1445 100644 --- a/src/Language/Futhark/Primitive.hs +++ b/src/Language/Futhark/Primitive.hs @@ -135,6 +135,7 @@ import Foreign.C.Types (CUShort (..)) import Futhark.Util (convFloat) import Futhark.Util.CMath import Futhark.Util.Pretty +import Numeric (log1p) import Numeric.Half import Prelude hiding (id, (.)) @@ -1194,6 +1195,10 @@ primFuns = f32 "log10_32" (logBase 10), f64 "log10_64" (logBase 10), -- + f16 "log1p_16" log1p, + f32 "log1p_32" log1p, + f64 "log1p_64" log1p, + -- f16 "log2_16" (logBase 2), f32 "log2_32" (logBase 2), f64 "log2_64" (logBase 2), diff --git a/tests/primitive/log32.fut b/tests/primitive/log32.fut index a40d82c152..32e2f575b7 100644 --- a/tests/primitive/log32.fut +++ b/tests/primitive/log32.fut @@ -1,19 +1,24 @@ -- == -- entry: logf32 --- input { [2.718281828459045f32, 2f32, 10f32] } --- output { [1f32, 0.6931471805599453f32, 2.302585092994046f32] } +-- input { [0.0f32, 2.718281828459045f32, 2f32, 10f32, f32.inf] } +-- output { [-f32.inf, 1f32, 0.6931471805599453f32, 2.302585092994046f32, f32.inf] } -- == -- entry: log2f32 --- input { [2.718281828459045f32, 2f32, 10f32] } --- output { [1.4426950408889634f32, 1f32, 3.321928094887362f32] } +-- input { [0.0f32, 2.718281828459045f32, 2f32, 10f32, f32.inf] } +-- output { [-f32.inf, 1.4426950408889634f32, 1f32, 3.321928094887362f32, f32.inf] } -- == -- entry: log10f32 --- input { [2.718281828459045f32, 2f32, 10f32] } --- output { [0.4342944819032518f32, 0.3010299956639812f32, 1f32] } +-- input { [0.0f32, 2.718281828459045f32, 2f32, 10f32, f32.inf] } +-- output { [-f32.inf, 0.4342944819032518f32, 0.3010299956639812f32, 1f32, f32.inf] } + +-- == +-- entry: log1pf32 +-- input { [-1.0f32, -1e-12f32, 0.0f32, 1e-23f32, 1.718281828459045f32, 1f32, f32.inf] } +-- output { [-f32.inf, -1e-12f32, 0.0f32, 1e-23f32, 1.0f32, 0.6931471805599453f32, f32.inf] } entry logf32 = map f32.log entry log2f32 = map f32.log2 entry log10f32 = map f32.log10 - +entry log1pf32 = map f32.log1p diff --git a/tests/primitive/log64.fut b/tests/primitive/log64.fut index 535d257a75..83a6465a2e 100644 --- a/tests/primitive/log64.fut +++ b/tests/primitive/log64.fut @@ -1,18 +1,24 @@ -- == -- entry: logf64 --- input { [2.718281828459045f64, 2f64, 10f64] } --- output { [1f64, 0.6931471805599453f64, 2.302585092994046f64] } +-- input { [0.0f64, 2.718281828459045f64, 2f64, 10f64, f64.inf] } +-- output { [-f64.inf, 1f64, 0.6931471805599453f64, 2.302585092994046f64, f64.inf] } -- == -- entry: log2f64 --- input { [2.718281828459045f64, 2f64, 10f64] } --- output { [1.4426950408889634f64, 1f64, 3.321928094887362f64] } +-- input { [0.0f64, 2.718281828459045f64, 2f64, 10f64, f64.inf] } +-- output { [-f64.inf, 1.4426950408889634f64, 1f64, 3.321928094887362f64, f64.inf] } -- == -- entry: log10f64 --- input { [2.718281828459045f64, 2f64, 10f64] } --- output { [0.4342944819032518f64, 0.3010299956639812f64, 1f64] } +-- input { [0.0f64, 2.718281828459045f64, 2f64, 10f64, f64.inf] } +-- output { [-f64.inf, 0.4342944819032518f64, 0.3010299956639812f64, 1f64, f64.inf] } + +-- == +-- entry: log1pf64 +-- input { [-1.0f64, -1e-123f64, 0.0f64, 1e-234f64, 1.718281828459045f64, 1f64, f64.inf] } +-- output { [-f64.inf, -1e-123f64, 0.0f64, 1e-234f64, 1.0f64, 0.6931471805599453f64, f64.inf] } entry logf64 = map f64.log entry log2f64 = map f64.log2 -entry log10f64 = map f64.log10 \ No newline at end of file +entry log10f64 = map f64.log10 +entry log1pf64 = map f64.log1p diff --git a/tests/primitive/naninf32.fut b/tests/primitive/naninf32.fut index 6cf788f6ef..144ae0fa02 100644 --- a/tests/primitive/naninf32.fut +++ b/tests/primitive/naninf32.fut @@ -50,6 +50,11 @@ -- input { [10f32, f32.nan, f32.inf, -f32.inf] } -- output { [false, true, false, true] } +-- == +-- entry: log1p +-- input { [-2f32, -1f32, 2f32, f32.nan, f32.inf, -f32.inf] } +-- output { [true, false, false, true, false, true] } + entry eqNaN = map (\x -> x == f32.nan) entry ltNaN = map (\x -> x < f32.nan) entry lteNaN = map (\x -> x <= f32.nan) @@ -59,4 +64,5 @@ entry diffInf = map (\x -> x - f32.inf < x + f32.inf) entry sumNaN = map (\x -> f32.isnan (x + f32.nan)) entry sumInf = map (\x -> f32.isinf (x + f32.inf)) entry log2 = map (\x -> f32.isnan (f32.log2 (x))) -entry log10 = map (\x -> f32.isnan (f32.log10 (x))) \ No newline at end of file +entry log10 = map (\x -> f32.isnan (f32.log10 (x))) +entry log1p = map (\x -> f32.isnan (f32.log1p (x))) diff --git a/tests/primitive/naninf64.fut b/tests/primitive/naninf64.fut index a92d841f31..1db2d581a7 100644 --- a/tests/primitive/naninf64.fut +++ b/tests/primitive/naninf64.fut @@ -50,6 +50,11 @@ -- input { [10f64, f64.nan, f64.inf, -f64.inf] } -- output { [false, true, false, true] } +-- == +-- entry: log1p +-- input { [-2f64, -1f64, 2f64, f64.nan, f64.inf, -f64.inf] } +-- output { [true, false, false, true, false, true] } + entry eqNaN = map (\x -> x == f64.nan) entry ltNaN = map (\x -> x < f64.nan) entry lteNaN = map (\x -> x <= f64.nan) @@ -59,4 +64,5 @@ entry diffInf = map (\x -> x - f64.inf < x + f64.inf) entry sumNaN = map (\x -> f64.isnan (x + f64.nan)) entry sumInf = map (\x -> f64.isinf (x + f64.inf)) entry log2 = map (\x -> f64.isnan (f64.log2 (x))) -entry log10 = map (\x -> f64.isnan (f64.log10 (x))) \ No newline at end of file +entry log10 = map (\x -> f64.isnan (f64.log10 (x))) +entry log1p = map (\x -> f64.isnan (f64.log1p (x)))