diff --git a/inclusion/blob_share_commitment_rules.go b/inclusion/blob_share_commitment_rules.go index f51bae7..93a56f1 100644 --- a/inclusion/blob_share_commitment_rules.go +++ b/inclusion/blob_share_commitment_rules.go @@ -3,6 +3,7 @@ package inclusion import ( "fmt" "math" + "math/bits" "golang.org/x/exp/constraints" ) @@ -50,11 +51,19 @@ func RoundUpByMultipleOf(cursor, v int) int { // RoundUpPowerOfTwo returns the next power of two greater than or equal to input. func RoundUpPowerOfTwo[I constraints.Integer](input I) I { - var result I = 1 - for result < input { - result <<= 1 + if input <= 1 { + return 1 } - return result + if input&(input-1) == 0 { // It is already a power of 2 + return input + } + var powUp I = 1 << bits.Len64(uint64(input)) + if powUp < input { + // An overflow occurred due to a very large size + // of input and we should return a positive power of 2. + powUp = 1 + } + return powUp } // RoundDownPowerOfTwo returns the next power of two less than or equal to input. @@ -62,11 +71,12 @@ func RoundDownPowerOfTwo[I constraints.Integer](input I) (I, error) { if input <= 0 { return 0, fmt.Errorf("input %v must be positive", input) } - roundedUp := RoundUpPowerOfTwo(input) - if roundedUp == input { - return roundedUp, nil + if input&(input-1) == 0 { // It is already a power of 2 + return input, nil } - return roundedUp / 2, nil + + // Return 1 << (numberOfBits-1) + return 1 << (bits.Len64(uint64(input)) - 1), nil } // BlobMinSquareSize returns the minimum square size that can contain shareCount @@ -79,6 +89,9 @@ func BlobMinSquareSize(shareCount int) int { // commitment over a given blob. The input should be the total number of shares // used by that blob. See ADR-013. func SubTreeWidth(shareCount, subtreeRootThreshold int) int { + if subtreeRootThreshold <= 0 { + return 1 + } // Per ADR-013, we use a predetermined threshold to determine width of sub // trees used to create share commitments s := (shareCount / subtreeRootThreshold) diff --git a/inclusion/blob_share_commitment_rules_test.go b/inclusion/blob_share_commitment_rules_test.go index 81cc115..9f212bf 100644 --- a/inclusion/blob_share_commitment_rules_test.go +++ b/inclusion/blob_share_commitment_rules_test.go @@ -2,6 +2,7 @@ package inclusion_test import ( "fmt" + "math" "testing" "github.com/celestiaorg/go-square/v2/inclusion" @@ -257,25 +258,35 @@ func TestRoundUpByMultipleOf(t *testing.T) { } } +type roundUpTestCase struct { + input int + want int +} + +var roundUpTestCases = []roundUpTestCase{ + {input: -1, want: 1}, + {input: 0, want: 1}, + {input: 1, want: 1}, + {input: 2, want: 2}, + {input: 4, want: 4}, + {input: 5, want: 8}, + {input: 8, want: 8}, + {input: 11, want: 16}, + {input: 511, want: 512}, + {input: math.MaxInt32 - 1, want: 1 << 31}, + {input: math.MaxInt32 + 1, want: 1 << 31}, + {input: math.MaxInt32, want: 1 << 31}, + {input: math.MaxInt >> 1, want: 1 << 62}, + {input: math.MaxInt, want: 1}, +} + func TestRoundUpPowerOfTwo(t *testing.T) { - type testCase struct { - input int - want int - } - testCases := []testCase{ - {input: -1, want: 1}, - {input: 0, want: 1}, - {input: 1, want: 1}, - {input: 2, want: 2}, - {input: 4, want: 4}, - {input: 5, want: 8}, - {input: 8, want: 8}, - {input: 11, want: 16}, - {input: 511, want: 512}, - } - for _, tc := range testCases { - got := inclusion.RoundUpPowerOfTwo(tc.input) - assert.Equal(t, tc.want, got) + for _, tc := range roundUpTestCases { + testName := fmt.Sprintf("%d=%x", tc.input, tc.input) + t.Run(testName, func(t *testing.T) { + got := inclusion.RoundUpPowerOfTwo(tc.input) + assert.Equal(t, tc.want, got) + }) } } @@ -326,62 +337,64 @@ func TestBlobMinSquareSize(t *testing.T) { } } +type subtreeWidthTestCase struct { + shareCount int + want int +} + +var subtreeWidthTestCases = []subtreeWidthTestCase{ + { + shareCount: 0, + want: 1, + }, + { + shareCount: 1, + want: 1, + }, + { + shareCount: 2, + want: 1, + }, + { + shareCount: defaultSubtreeRootThreshold, + want: 1, + }, + { + shareCount: defaultSubtreeRootThreshold + 1, + want: 2, + }, + { + shareCount: defaultSubtreeRootThreshold - 1, + want: 1, + }, + { + shareCount: defaultSubtreeRootThreshold * 2, + want: 2, + }, + { + shareCount: (defaultSubtreeRootThreshold * 2) + 1, + want: 4, + }, + { + shareCount: (defaultSubtreeRootThreshold * 3) - 1, + want: 4, + }, + { + shareCount: (defaultSubtreeRootThreshold * 4), + want: 4, + }, + { + shareCount: (defaultSubtreeRootThreshold * 5), + want: 8, + }, + { + shareCount: (defaultSubtreeRootThreshold * defaultMaxSquareSize) - 1, + want: 128, + }, +} + func TestSubTreeWidth(t *testing.T) { - type testCase struct { - shareCount int - want int - } - testCases := []testCase{ - { - shareCount: 0, - want: 1, - }, - { - shareCount: 1, - want: 1, - }, - { - shareCount: 2, - want: 1, - }, - { - shareCount: defaultSubtreeRootThreshold, - want: 1, - }, - { - shareCount: defaultSubtreeRootThreshold + 1, - want: 2, - }, - { - shareCount: defaultSubtreeRootThreshold - 1, - want: 1, - }, - { - shareCount: defaultSubtreeRootThreshold * 2, - want: 2, - }, - { - shareCount: (defaultSubtreeRootThreshold * 2) + 1, - want: 4, - }, - { - shareCount: (defaultSubtreeRootThreshold * 3) - 1, - want: 4, - }, - { - shareCount: (defaultSubtreeRootThreshold * 4), - want: 4, - }, - { - shareCount: (defaultSubtreeRootThreshold * 5), - want: 8, - }, - { - shareCount: (defaultSubtreeRootThreshold * defaultMaxSquareSize) - 1, - want: 128, - }, - } - for i, tc := range testCases { + for i, tc := range subtreeWidthTestCases { t.Run(fmt.Sprintf("shareCount %d", tc.shareCount), func(t *testing.T) { got := inclusion.SubTreeWidth(tc.shareCount, defaultSubtreeRootThreshold) assert.Equal(t, tc.want, got, i) @@ -389,23 +402,85 @@ func TestSubTreeWidth(t *testing.T) { } } -func TestRoundDownPowerOfTwo(t *testing.T) { - type testCase struct { - input int - want int +var sink any + +func BenchmarkSubTreeWidth(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tc := range subtreeWidthTestCases { + got := inclusion.SubTreeWidth(tc.shareCount, defaultSubtreeRootThreshold) + assert.Equal(b, tc.want, got) + sink = got + } } - testCases := []testCase{ - {input: 1, want: 1}, - {input: 2, want: 2}, - {input: 4, want: 4}, - {input: 5, want: 4}, - {input: 8, want: 8}, - {input: 11, want: 8}, - {input: 511, want: 256}, + + if sink == nil { + b.Fatal("Benchmark did not run!") } - for _, tc := range testCases { - got, err := inclusion.RoundDownPowerOfTwo(tc.input) - require.NoError(t, err) - assert.Equal(t, tc.want, got) + sink = nil +} + +func BenchmarkRoundDownPowerOfTwo(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tc := range roundDownTestCases { + got, _ := inclusion.RoundDownPowerOfTwo(tc.input) + assert.Equal(b, tc.want, got) + sink = got + } + } + + if sink == nil { + b.Fatal("Benchmark did not run!") + } + sink = nil +} + +func BenchmarkRoundUpPowerOfTwo(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tc := range roundUpTestCases { + got := inclusion.RoundUpPowerOfTwo(tc.input) + assert.Equal(b, tc.want, got) + sink = got + } + } + + if sink == nil { + b.Fatal("Benchmark did not run!") + } + sink = nil +} + +type roundDownTestCase struct { + input int + want int +} + +var roundDownTestCases = []roundDownTestCase{ + {input: 1, want: 1}, + {input: 2, want: 2}, + {input: 4, want: 4}, + {input: 5, want: 4}, + {input: 8, want: 8}, + {input: 11, want: 8}, + {input: 511, want: 256}, + {input: math.MaxInt32 - 1, want: 1 << 30}, + {input: math.MaxInt32, want: 1 << 30}, + {input: math.MaxInt32 + 1, want: 1 << 31}, + {input: math.MaxInt, want: 1 << 62}, +} + +func TestRoundDownPowerOfTwo(t *testing.T) { + for _, tc := range roundDownTestCases { + testName := fmt.Sprintf("%d=%x", tc.input, tc.input) + t.Run(testName, func(t *testing.T) { + got, err := inclusion.RoundDownPowerOfTwo(tc.input) + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) } }