From 215420d20940e7c24e94bcd9844ca68f789e8141 Mon Sep 17 00:00:00 2001 From: Chao Xu Date: Wed, 14 Dec 2022 18:24:56 -0500 Subject: [PATCH] add log_2 to more_math (#3) --- more_math.move.template | 39 ++++++ more_math/sources/more.move | 228 ++++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) diff --git a/more_math.move.template b/more_math.move.template index 6ce6a7b..9c88291 100644 --- a/more_math.move.template +++ b/more_math.move.template @@ -5,6 +5,9 @@ // Version: {{.Version}} module {{.Address}}::{{.ModuleName}} { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = {{.HalfWidth}}; const MAX_SHIFT_SIZE: u8 = {{.MaxShiftSize}}; @@ -88,6 +91,42 @@ module {{.Address}}::{{.ModuleName}} { z } } + + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + // see: https://en.wikipedia.org/wiki/Binary_logarithm + public fun log_2(x: {{$typename}}, n: u8): {{$typename}} { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: {{$typename}} = 1 << n; + let two: {{$typename}} = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: {{$typename}} = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } {{if .DoTest}} {{range $i, $t := .SqrtTestCases}} #[test] diff --git a/more_math/sources/more.move b/more_math/sources/more.move index 4ee9160..f46066e 100644 --- a/more_math/sources/more.move +++ b/more_math/sources/more.move @@ -5,6 +5,9 @@ // Version: v1.4.0 module more_math::more_math_u8 { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = 4; const MAX_SHIFT_SIZE: u8 = 7; @@ -96,6 +99,41 @@ module more_math::more_math_u8 { } } + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + public fun log_2(x: u8, n: u8): u8 { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: u8 = 1 << n; + let two: u8 = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: u8 = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } + #[test] fun test_sqrt_u8_0() { @@ -266,6 +304,9 @@ module more_math::more_math_u8 { // Version: v1.4.0 module more_math::more_math_u16 { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = 8; const MAX_SHIFT_SIZE: u8 = 15; @@ -362,6 +403,41 @@ module more_math::more_math_u16 { } } + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + public fun log_2(x: u16, n: u8): u16 { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: u16 = 1 << n; + let two: u16 = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: u16 = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } + #[test] fun test_sqrt_u16_0() { @@ -692,6 +768,9 @@ module more_math::more_math_u16 { // Version: v1.4.0 module more_math::more_math_u32 { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = 16; const MAX_SHIFT_SIZE: u8 = 31; @@ -793,6 +872,41 @@ module more_math::more_math_u32 { } } + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + public fun log_2(x: u32, n: u8): u32 { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: u32 = 1 << n; + let two: u32 = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: u32 = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } + #[test] fun test_sqrt_u32_0() { @@ -1443,6 +1557,9 @@ module more_math::more_math_u32 { // Version: v1.4.0 module more_math::more_math_u64 { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = 32; const MAX_SHIFT_SIZE: u8 = 63; @@ -1549,6 +1666,41 @@ module more_math::more_math_u64 { } } + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + public fun log_2(x: u64, n: u8): u64 { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: u64 = 1 << n; + let two: u64 = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: u64 = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } + #[test] fun test_sqrt_u64_0() { @@ -2839,6 +2991,9 @@ module more_math::more_math_u64 { // Version: v1.4.0 module more_math::more_math_u128 { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = 64; const MAX_SHIFT_SIZE: u8 = 127; @@ -2950,6 +3105,41 @@ module more_math::more_math_u128 { } } + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + public fun log_2(x: u128, n: u8): u128 { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: u128 = 1 << n; + let two: u128 = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: u128 = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } + #[test] fun test_sqrt_u128_0() { @@ -5520,6 +5710,9 @@ module more_math::more_math_u128 { // Version: v1.4.0 module more_math::more_math_u256 { const E_WIDTH_OVERFLOW_U8: u64 = 1; + const E_LOG_2_OUT_OF_RANGE: u64 = 2; + const E_ZERO_FOR_LOG_2: u64 = 3; + const HALF_SIZE: u8 = 128; const MAX_SHIFT_SIZE: u8 = 255; @@ -5636,6 +5829,41 @@ module more_math::more_math_u256 { } } + // log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2) + // This can be used to calculate log_2 for any positive number by bit shifting the binary representation + // into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used. + public fun log_2(x: u256, n: u8): u256 { + assert!(x != 0, E_ZERO_FOR_LOG_2); + + let one: u256 = 1 << n; + let two: u256 = one << 1; + + assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE); + + let z_2 = x; + let r: u256 = 0; + let sum_m: u8 = 0; + + loop { + if (z_2 == one) { + break + }; + let z = (z_2 * z_2) >> n; + sum_m = sum_m + 1; + if (sum_m > n) { + break + }; + if (z >= two) { + r = r + (one >> sum_m); + z_2 = z >> 1; + } else { + z_2 = z; + }; + }; + + r + } + #[test] fun test_sqrt_u256_0() {