diff --git a/cmd/util/cmd/checkpoint-list-tries/cmd.go b/cmd/util/cmd/checkpoint-list-tries/cmd.go index 105f7408fb7..bfad7c18bef 100644 --- a/cmd/util/cmd/checkpoint-list-tries/cmd.go +++ b/cmd/util/cmd/checkpoint-list-tries/cmd.go @@ -28,12 +28,12 @@ func init() { func run(*cobra.Command, []string) { - flattenedForest, err := wal.LoadCheckpoint(flagCheckpoint) + tries, err := wal.LoadCheckpoint(flagCheckpoint) if err != nil { log.Fatal().Err(err).Msg("error while loading checkpoint") } - for _, trie := range flattenedForest.Tries { - fmt.Printf("%x\n", trie.RootHash) + for _, trie := range tries { + fmt.Printf("%x\n", trie.RootHash()) } } diff --git a/cmd/util/cmd/read-execution-state/list-wals/cmd.go b/cmd/util/cmd/read-execution-state/list-wals/cmd.go index 6d71edd0da4..f2a91553a55 100644 --- a/cmd/util/cmd/read-execution-state/list-wals/cmd.go +++ b/cmd/util/cmd/read-execution-state/list-wals/cmd.go @@ -11,7 +11,7 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/pathfinder" "github.com/onflow/flow-go/ledger/complete" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/trie" "github.com/onflow/flow-go/ledger/complete/wal" "github.com/onflow/flow-go/module/metrics" ) @@ -52,7 +52,7 @@ func run(*cobra.Command, []string) { }() err = w.ReplayLogsOnly( - func(forestSequencing *flattener.FlattenedForest) error { + func(tries []*trie.MTrie) error { fmt.Printf("forest sequencing \n") return nil }, diff --git a/ledger/common/encoding/encoding.go b/ledger/common/encoding/encoding.go index 2bf7a1a5a53..0f1b3a18097 100644 --- a/ledger/common/encoding/encoding.go +++ b/ledger/common/encoding/encoding.go @@ -108,16 +108,25 @@ func EncodeKeyPart(kp *ledger.KeyPart) []byte { } func encodeKeyPart(kp *ledger.KeyPart) []byte { - buffer := make([]byte, 0) + buffer := make([]byte, 0, encodedKeyPartLength(kp)) + return encodeAndAppendKeyPart(buffer, kp) +} +func encodeAndAppendKeyPart(buffer []byte, kp *ledger.KeyPart) []byte { // encode "Type" field of the key part buffer = utils.AppendUint16(buffer, kp.Type) // encode "Value" field of the key part buffer = append(buffer, kp.Value...) + return buffer } +func encodedKeyPartLength(kp *ledger.KeyPart) int { + // Key part is encoded as: type (2 bytes) + value + return 2 + len(kp.Value) +} + // DecodeKeyPart constructs a key part from an encoded key part func DecodeKeyPart(encodedKeyPart []byte) (*ledger.KeyPart, error) { // currently we ignore the version but in the future we @@ -133,8 +142,8 @@ func DecodeKeyPart(encodedKeyPart []byte) (*ledger.KeyPart, error) { return nil, fmt.Errorf("error decoding key part: %w", err) } - // decode the key part content - key, err := decodeKeyPart(rest) + // decode the key part content (zerocopy) + key, err := decodeKeyPart(rest, true) if err != nil { return nil, fmt.Errorf("error decoding key part: %w", err) } @@ -142,13 +151,20 @@ func DecodeKeyPart(encodedKeyPart []byte) (*ledger.KeyPart, error) { return key, nil } -func decodeKeyPart(inp []byte) (*ledger.KeyPart, error) { +// decodeKeyPart decodes inp into KeyPart. If zeroCopy is true, KeyPart +// references data in inp. Otherwise, it is copied. +func decodeKeyPart(inp []byte, zeroCopy bool) (*ledger.KeyPart, error) { // read key part type and the rest is the key item part kpt, kpv, err := utils.ReadUint16(inp) if err != nil { return nil, fmt.Errorf("error decoding key part (content): %w", err) } - return &ledger.KeyPart{Type: kpt, Value: kpv}, nil + if zeroCopy { + return &ledger.KeyPart{Type: kpt, Value: kpv}, nil + } + v := make([]byte, len(kpv)) + copy(v, kpv) + return &ledger.KeyPart{Type: kpt, Value: v}, nil } // EncodeKey encodes a key into a byte slice @@ -168,21 +184,36 @@ func EncodeKey(k *ledger.Key) []byte { // encodeKey encodes a key into a byte slice func encodeKey(k *ledger.Key) []byte { - buffer := make([]byte, 0) + buffer := make([]byte, 0, encodedKeyLength(k)) + return encodeAndAppendKey(buffer, k) +} + +func encodeAndAppendKey(buffer []byte, k *ledger.Key) []byte { // encode number of key parts buffer = utils.AppendUint16(buffer, uint16(len(k.KeyParts))) + // iterate over key parts for _, kp := range k.KeyParts { - // encode the key part - encKP := encodeKeyPart(&kp) // encode the len of the encoded key part - buffer = utils.AppendUint32(buffer, uint32(len(encKP))) - // append the encoded key part - buffer = append(buffer, encKP...) + buffer = utils.AppendUint32(buffer, uint32(encodedKeyPartLength(&kp))) + + // encode the key part + buffer = encodeAndAppendKeyPart(buffer, &kp) } + return buffer } +func encodedKeyLength(k *ledger.Key) int { + // Key is encoded as: number of key parts (2 bytes) and for each key part, + // the key part size (4 bytes) + encoded key part (n bytes). + size := 2 + 4*len(k.KeyParts) + for _, kp := range k.KeyParts { + size += encodedKeyPartLength(&kp) + } + return size +} + // DecodeKey constructs a key from an encoded key part func DecodeKey(encodedKey []byte) (*ledger.Key, error) { // check the enc dec version @@ -196,21 +227,30 @@ func DecodeKey(encodedKey []byte) (*ledger.Key, error) { return nil, fmt.Errorf("error decoding key: %w", err) } - // decode the key content - key, err := decodeKey(rest) + // decode the key content (zerocopy) + key, err := decodeKey(rest, true) if err != nil { return nil, fmt.Errorf("error decoding key: %w", err) } return key, nil } -func decodeKey(inp []byte) (*ledger.Key, error) { +// decodeKey decodes inp into Key. If zeroCopy is true, returned key +// references data in inp. Otherwise, it is copied. +func decodeKey(inp []byte, zeroCopy bool) (*ledger.Key, error) { key := &ledger.Key{} + numOfParts, rest, err := utils.ReadUint16(inp) if err != nil { return nil, fmt.Errorf("error decoding key (content): %w", err) } + if numOfParts == 0 { + return key, nil + } + + key.KeyParts = make([]ledger.KeyPart, numOfParts) + for i := 0; i < int(numOfParts); i++ { var kpEncSize uint32 var kpEnc []byte @@ -227,11 +267,12 @@ func decodeKey(inp []byte) (*ledger.Key, error) { } // decode encoded key part - kp, err := decodeKeyPart(kpEnc) + kp, err := decodeKeyPart(kpEnc, zeroCopy) if err != nil { return nil, fmt.Errorf("error decoding key (content): %w", err) } - key.KeyParts = append(key.KeyParts, *kp) + + key.KeyParts[i] = *kp } return key, nil } @@ -254,6 +295,14 @@ func encodeValue(v ledger.Value) []byte { return v } +func encodeAndAppendValue(buffer []byte, v ledger.Value) []byte { + return append(buffer, v...) +} + +func encodedValueLength(v ledger.Value) int { + return len(v) +} + // DecodeValue constructs a ledger value using an encoded byte slice func DecodeValue(encodedValue []byte) (ledger.Value, error) { // check enc dec version @@ -327,30 +376,52 @@ func EncodePayload(p *ledger.Payload) []byte { return buffer } +// EncodeAndAppendPayloadWithoutPrefix encodes a ledger payload +// without prefix (version and type) and appends to buffer. +// If payload is nil, unmodified buffer is returned. +func EncodeAndAppendPayloadWithoutPrefix(buffer []byte, p *ledger.Payload) []byte { + if p == nil { + return buffer + } + return encodeAndAppendPayload(buffer, p) +} + +func EncodedPayloadLengthWithoutPrefix(p *ledger.Payload) int { + return encodedPayloadLength(p) +} + func encodePayload(p *ledger.Payload) []byte { - buffer := make([]byte, 0) + buffer := make([]byte, 0, encodedPayloadLength(p)) + return encodeAndAppendPayload(buffer, p) +} - // encode key - encK := encodeKey(&p.Key) +func encodeAndAppendPayload(buffer []byte, p *ledger.Payload) []byte { // encode encoded key size - buffer = utils.AppendUint32(buffer, uint32(len(encK))) + buffer = utils.AppendUint32(buffer, uint32(encodedKeyLength(&p.Key))) - // append encoded key content - buffer = append(buffer, encK...) - - // encode value - encV := encodeValue(p.Value) + // encode key + buffer = encodeAndAppendKey(buffer, &p.Key) // encode encoded value size - buffer = utils.AppendUint64(buffer, uint64(len(encV))) + buffer = utils.AppendUint64(buffer, uint64(encodedValueLength(p.Value))) - // append encoded key content - buffer = append(buffer, encV...) + // encode value + buffer = encodeAndAppendValue(buffer, p.Value) return buffer } +func encodedPayloadLength(p *ledger.Payload) int { + if p == nil { + return 0 + } + // Payload is encoded as: + // encode key length (4 bytes) + encoded key + + // encoded value length (8 bytes) + encode value + return 4 + encodedKeyLength(&p.Key) + 8 + encodedValueLength(p.Value) +} + // DecodePayload construct a payload from an encoded byte slice func DecodePayload(encodedPayload []byte) (*ledger.Payload, error) { // if empty don't decode @@ -367,10 +438,24 @@ func DecodePayload(encodedPayload []byte) (*ledger.Payload, error) { if err != nil { return nil, fmt.Errorf("error decoding payload: %w", err) } - return decodePayload(rest) + // decode payload (zerocopy) + return decodePayload(rest, true) +} + +// DecodePayloadWithoutPrefix constructs a payload from encoded byte slice +// without prefix (version and type). If zeroCopy is true, returned payload +// references data in encodedPayload. Otherwise, it is copied. +func DecodePayloadWithoutPrefix(encodedPayload []byte, zeroCopy bool) (*ledger.Payload, error) { + // if empty don't decode + if len(encodedPayload) == 0 { + return nil, nil + } + return decodePayload(encodedPayload, zeroCopy) } -func decodePayload(inp []byte) (*ledger.Payload, error) { +// decodePayload decodes inp into payload. If zeroCopy is true, +// returned payload references data in inp. Otherwise, it is copied. +func decodePayload(inp []byte, zeroCopy bool) (*ledger.Payload, error) { // read encoded key size encKeySize, rest, err := utils.ReadUint32(inp) @@ -385,7 +470,7 @@ func decodePayload(inp []byte) (*ledger.Payload, error) { } // decode the key - key, err := decodeKey(encKey) + key, err := decodeKey(encKey, zeroCopy) if err != nil { return nil, fmt.Errorf("error decoding payload: %w", err) } @@ -402,7 +487,13 @@ func decodePayload(inp []byte) (*ledger.Payload, error) { return nil, fmt.Errorf("error decoding payload: %w", err) } - return &ledger.Payload{Key: *key, Value: encValue}, nil + if zeroCopy { + return &ledger.Payload{Key: *key, Value: encValue}, nil + } + + v := make([]byte, len(encValue)) + copy(v, encValue) + return &ledger.Payload{Key: *key, Value: v}, nil } // EncodeTrieUpdate encodes a trie update struct @@ -475,9 +566,6 @@ func DecodeTrieUpdate(encodedTrieUpdate []byte) (*ledger.TrieUpdate, error) { func decodeTrieUpdate(inp []byte) (*ledger.TrieUpdate, error) { - paths := make([]ledger.Path, 0) - payloads := make([]*ledger.Payload, 0) - // decode root hash rhSize, rest, err := utils.ReadUint16(inp) if err != nil { @@ -505,6 +593,9 @@ func decodeTrieUpdate(inp []byte) (*ledger.TrieUpdate, error) { return nil, fmt.Errorf("error decoding trie update: %w", err) } + paths := make([]ledger.Path, numOfPaths) + payloads := make([]*ledger.Payload, numOfPaths) + var path ledger.Path var encPath []byte for i := 0; i < int(numOfPaths); i++ { @@ -516,7 +607,7 @@ func decodeTrieUpdate(inp []byte) (*ledger.TrieUpdate, error) { if err != nil { return nil, fmt.Errorf("error decoding trie update: %w", err) } - paths = append(paths, path) + paths[i] = path } var payloadSize uint32 @@ -532,11 +623,12 @@ func decodeTrieUpdate(inp []byte) (*ledger.TrieUpdate, error) { if err != nil { return nil, fmt.Errorf("error decoding trie update: %w", err) } - payload, err = decodePayload(encPayload) + // Decode payload (zerocopy) + payload, err = decodePayload(encPayload, true) if err != nil { return nil, fmt.Errorf("error decoding trie update: %w", err) } - payloads = append(payloads, payload) + payloads[i] = payload } return &ledger.TrieUpdate{RootHash: rh, Paths: paths, Payloads: payloads}, nil } @@ -660,7 +752,8 @@ func decodeTrieProof(inp []byte) (*ledger.TrieProof, error) { if err != nil { return nil, fmt.Errorf("error decoding proof: %w", err) } - payload, err := decodePayload(encPayload) + // Decode payload (zerocopy) + payload, err := decodePayload(encPayload, true) if err != nil { return nil, fmt.Errorf("error decoding proof: %w", err) } @@ -671,7 +764,8 @@ func decodeTrieProof(inp []byte) (*ledger.TrieProof, error) { if err != nil { return nil, fmt.Errorf("error decoding proof: %w", err) } - interims := make([]hash.Hash, 0) + + interims := make([]hash.Hash, interimsLen) var interimSize uint16 var interim hash.Hash @@ -692,7 +786,7 @@ func decodeTrieProof(inp []byte) (*ledger.TrieProof, error) { return nil, fmt.Errorf("error decoding proof: %w", err) } - interims = append(interims, interim) + interims[i] = interim } pInst.Interims = interims diff --git a/ledger/common/encoding/encoding_test.go b/ledger/common/encoding/encoding_test.go index ed3baa9f72b..df7a7cb126b 100644 --- a/ledger/common/encoding/encoding_test.go +++ b/ledger/common/encoding/encoding_test.go @@ -71,6 +71,120 @@ func Test_PayloadEncodingDecoding(t *testing.T) { require.True(t, newp.Equals(p)) } +func Test_NilPayloadWithoutPrefixEncodingDecoding(t *testing.T) { + + buf := []byte{1, 2, 3} + bufLen := len(buf) + + // Test encoded payload data length + encodedPayloadLen := encoding.EncodedPayloadLengthWithoutPrefix(nil) + require.Equal(t, 0, encodedPayloadLen) + + // Encode payload and append to buffer + encoded := encoding.EncodeAndAppendPayloadWithoutPrefix(buf, nil) + // Test encoded data size + require.Equal(t, bufLen, len(encoded)) + // Test original input data isn't modified + require.Equal(t, buf, encoded) + // Test returned encoded data reuses input data + require.True(t, &buf[0] == &encoded[0]) + + // Decode and copy payload (excluding prefix) + newp, err := encoding.DecodePayloadWithoutPrefix(encoded[bufLen:], false) + require.NoError(t, err) + require.Nil(t, newp) + + // Zerocopy option has no effect for nil payload, but test it anyway. + // Decode payload (excluding prefix) with zero copy + newp, err = encoding.DecodePayloadWithoutPrefix(encoded[bufLen:], true) + require.NoError(t, err) + require.Nil(t, newp) +} + +func Test_PayloadWithoutPrefixEncodingDecoding(t *testing.T) { + + kp1t := uint16(1) + kp1v := []byte("key part 1") + kp1 := ledger.NewKeyPart(kp1t, kp1v) + + kp2t := uint16(22) + kp2v := []byte("key part 2") + kp2 := ledger.NewKeyPart(kp2t, kp2v) + + k := ledger.NewKey([]ledger.KeyPart{kp1, kp2}) + v := ledger.Value([]byte{'A'}) + p := ledger.NewPayload(k, v) + + const encodedPayloadSize = 47 // size of encoded payload p without prefix (version + type) + + testCases := []struct { + name string + payload *ledger.Payload + bufCap int + zeroCopy bool + }{ + // full cap means no capacity for appending payload (new alloc) + {"full cap zerocopy", p, 0, true}, + {"full cap", p, 0, false}, + // small cap means not enough capacity for appending payload (new alloc) + {"small cap zerocopy", p, encodedPayloadSize - 1, true}, + {"small cap", p, encodedPayloadSize - 1, false}, + // exact cap means exact capacity for appending payload (no alloc) + {"exact cap zerocopy", p, encodedPayloadSize, true}, + {"exact cap", p, encodedPayloadSize, false}, + // large cap means extra capacity than is needed for appending payload (no alloc) + {"large cap zerocopy", p, encodedPayloadSize + 1, true}, + {"large cap", p, encodedPayloadSize + 1, false}, + } + + bufPrefix := []byte{1, 2, 3} + bufPrefixLen := len(bufPrefix) + + for _, tc := range testCases { + + t.Run(tc.name, func(t *testing.T) { + + // Create a buffer of specified cap + prefix length + buffer := make([]byte, bufPrefixLen, bufPrefixLen+tc.bufCap) + copy(buffer, bufPrefix) + + // Encode payload and append to buffer + encoded := encoding.EncodeAndAppendPayloadWithoutPrefix(buffer, tc.payload) + encodedPayloadLen := encoding.EncodedPayloadLengthWithoutPrefix(tc.payload) + // Test encoded data size + require.Equal(t, len(encoded), bufPrefixLen+encodedPayloadLen) + // Test if original input data is modified + require.Equal(t, bufPrefix, encoded[:bufPrefixLen]) + // Test if input buffer is reused if it fits + if tc.bufCap >= encodedPayloadLen { + require.True(t, &buffer[0] == &encoded[0]) + } else { + // new alloc + require.True(t, &buffer[0] != &encoded[0]) + } + + // Decode payload (excluding prefix) + newp, err := encoding.DecodePayloadWithoutPrefix(encoded[bufPrefixLen:], tc.zeroCopy) + require.NoError(t, err) + require.True(t, newp.Equals(tc.payload)) + + // Reset encoded payload + for i := 0; i < len(encoded); i++ { + encoded[i] = 0 + } + + if tc.zeroCopy { + // Test if decoded payload is changed after source data is modified + // because data is shared. + require.False(t, newp.Equals(tc.payload)) + } else { + // Test if decoded payload is unchanged after source data is modified. + require.True(t, newp.Equals(tc.payload)) + } + }) + } +} + // Test_ProofEncodingDecoding tests encoding decoding functionality of a proof func Test_TrieProofEncodingDecoding(t *testing.T) { p, _ := utils.TrieProofFixture() diff --git a/ledger/complete/checkpoint_benchmark_test.go b/ledger/complete/checkpoint_benchmark_test.go new file mode 100644 index 00000000000..799651035a8 --- /dev/null +++ b/ledger/complete/checkpoint_benchmark_test.go @@ -0,0 +1,375 @@ +package complete_test + +import ( + "flag" + "fmt" + "io" + "math/rand" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/hash" + "github.com/onflow/flow-go/ledger/common/pathfinder" + "github.com/onflow/flow-go/ledger/common/utils" + "github.com/onflow/flow-go/ledger/complete" + "github.com/onflow/flow-go/ledger/complete/mtrie" + "github.com/onflow/flow-go/ledger/complete/mtrie/trie" + "github.com/onflow/flow-go/ledger/complete/wal" + "github.com/onflow/flow-go/module/metrics" +) + +var dir = flag.String("dir", ".", "dir containing checkpoint and wal files") + +// BenchmarkNewCheckpoint benchmarks checkpoint file creation from existing checkpoint and wal segments. +// This requires a checkpoint file and one or more segments following the checkpoint file. +// This benchmark will create a checkpoint file. +func BenchmarkNewCheckpoint(b *testing.B) { + // Check if there is any segment in specified dir + foundSeg, err := hasSegmentInDir(*dir) + if err != nil { + b.Fatal(err) + } + if !foundSeg { + b.Fatalf("failed to find segment in %s. Use -dir to specify dir containing segments and checkpoint files.", *dir) + } + + // Check if there is any checkpoint file in specified dir + foundCheckpoint, err := hasCheckpointInDir(*dir) + if err != nil { + b.Fatal(err) + } + if !foundCheckpoint { + b.Fatalf("failed to find checkpoint in %s. Use -dir to specify dir containing segments and checkpoint files.", *dir) + } + + diskwal, err := wal.NewDiskWAL( + zerolog.Nop(), + nil, + metrics.NewNoopCollector(), + *dir, + 500, + pathfinder.PathByteSize, + wal.SegmentSize, + ) + if err != nil { + b.Fatal(err) + } + + _, to, err := diskwal.Segments() + if err != nil { + b.Fatal(err) + } + + checkpointer, err := diskwal.NewCheckpointer() + if err != nil { + b.Fatal(err) + } + + start := time.Now() + b.ResetTimer() + + err = checkpointer.Checkpoint(to-1, func() (io.WriteCloser, error) { + return checkpointer.CheckpointWriter(to - 1) + }) + + b.StopTimer() + elapsed := time.Since(start) + + if err != nil { + b.Fatal(err) + } + + b.ReportMetric(float64(elapsed/time.Millisecond), "newcheckpoint_time_(ms)") + b.ReportAllocs() +} + +// BenchmarkLoadCheckpointAndWALs benchmarks checkpoint file loading and wal segments replaying. +// This requires a checkpoint file and one or more segments following the checkpoint file. +// This mimics rebuliding mtrie at EN startup. +func BenchmarkLoadCheckpointAndWALs(b *testing.B) { + // Check if there is any segment in specified dir + foundSeg, err := hasSegmentInDir(*dir) + if err != nil { + b.Fatal(err) + } + if !foundSeg { + b.Fatalf("failed to find segment in %s. Use -dir to specify dir containing segments and checkpoint files.", *dir) + } + + // Check if there is any checkpoint file in specified dir + foundCheckpoint, err := hasCheckpointInDir(*dir) + if err != nil { + b.Fatal(err) + } + if !foundCheckpoint { + b.Fatalf("failed to find checkpoint in %s. Use -dir to specify dir containing segments and checkpoint files.", *dir) + } + + forest, err := mtrie.NewForest(500, metrics.NewNoopCollector(), nil) + if err != nil { + b.Fatal(err) + } + + diskwal, err := wal.NewDiskWAL( + zerolog.Nop(), + nil, + metrics.NewNoopCollector(), + *dir, + 500, + pathfinder.PathByteSize, + wal.SegmentSize, + ) + if err != nil { + b.Fatal(err) + } + + // pause records to prevent double logging trie removals + diskwal.PauseRecord() + defer diskwal.UnpauseRecord() + + start := time.Now() + b.ResetTimer() + + err = diskwal.Replay( + func(tries []*trie.MTrie) error { + err := forest.AddTries(tries) + if err != nil { + return fmt.Errorf("adding rebuilt tries to forest failed: %w", err) + } + return nil + }, + func(update *ledger.TrieUpdate) error { + _, err := forest.Update(update) + return err + }, + func(rootHash ledger.RootHash) error { + forest.RemoveTrie(rootHash) + return nil + }, + ) + if err != nil { + b.Fatal(err) + } + + b.StopTimer() + elapsed := time.Since(start) + + b.ReportMetric(float64(elapsed/time.Millisecond), "loadcheckpointandwals_time_(ms)") + b.ReportAllocs() +} + +func hasSegmentInDir(dir string) (bool, error) { + files, err := os.ReadDir(dir) + if err != nil { + return false, err + } + + for _, fn := range files { + fname := fn.Name() + _, err := strconv.Atoi(fname) + if err != nil { + continue + } + return true, nil + } + return false, nil +} + +func hasCheckpointInDir(dir string) (bool, error) { + const checkpointFilenamePrefix = "checkpoint." + + files, err := os.ReadDir(dir) + if err != nil { + return false, err + } + + for _, fn := range files { + fname := fn.Name() + if !strings.HasPrefix(fname, checkpointFilenamePrefix) { + continue + } + justNumber := fname[len(checkpointFilenamePrefix):] + _, err := strconv.Atoi(justNumber) + if err != nil { + continue + } + return true, nil + } + + return false, nil +} + +func BenchmarkNewCheckpointRandom5Seg(b *testing.B) { benchmarkNewCheckpointRandomData(b, 5) } + +func BenchmarkNewCheckpointRandom10Seg(b *testing.B) { benchmarkNewCheckpointRandomData(b, 10) } + +func BenchmarkNewCheckpointRandom20Seg(b *testing.B) { benchmarkNewCheckpointRandomData(b, 20) } + +func BenchmarkNewCheckpointRandom30Seg(b *testing.B) { benchmarkNewCheckpointRandomData(b, 30) } + +func BenchmarkNewCheckpointRandom40Seg(b *testing.B) { benchmarkNewCheckpointRandomData(b, 40) } + +// benchmarkCheckpointCreate benchmarks checkpoint file creation. +// This benchmark creates segmentCount+1 WAL segments. It also creates two checkpoint files: +// - checkpoint file A from segment 0, and +// - checkpoint file B from checkpoint file A and all segments after segment 0. +// This benchmark measures the creation of checkpoint file B: +// - loading checkpoint file A +// - replaying all segments after segment 0 +// - creating checkpoint file B +// Because payload data is random, number of segments created can differ from segmentCount. +func benchmarkNewCheckpointRandomData(b *testing.B, segmentCount int) { + + const ( + updatePerSegment = 75 // 75 updates for 1 segment by approximation. + kvBatchCount = 500 // Each update has 500 new payloads. + ) + + if segmentCount < 1 { + segmentCount = 1 + } + + kvOpts := randKeyValueOptions{ + keyNumberOfParts: 3, + keyPartMinByteSize: 1, + keyPartMaxByteSize: 50, + valueMinByteSize: 50, + valueMaxByteSize: 1024 * 1.5, + } + updateCount := (segmentCount + 1) * updatePerSegment + + seed := uint64(0x9E3779B97F4A7C15) // golden ratio + rand.Seed(int64(seed)) + + dir, err := os.MkdirTemp("", "test-mtrie-") + defer os.RemoveAll(dir) + if err != nil { + b.Fatal(err) + } + + wal1, err := wal.NewDiskWAL( + zerolog.Nop(), + nil, + metrics.NewNoopCollector(), + dir, + 500, + pathfinder.PathByteSize, + wal.SegmentSize) + if err != nil { + b.Fatal(err) + } + + led, err := complete.NewLedger( + wal1, + 500, + &metrics.NoopCollector{}, + zerolog.Logger{}, + complete.DefaultPathFinderVersion, + ) + if err != nil { + b.Fatal(err) + } + + state := led.InitialState() + + _, err = updateLedgerWithRandomData(led, state, updateCount, kvBatchCount, kvOpts) + if err != nil { + b.Fatal(err) + } + + <-wal1.Done() + <-led.Done() + + wal2, err := wal.NewDiskWAL( + zerolog.Nop(), + nil, + metrics.NewNoopCollector(), + dir, + 500, + pathfinder.PathByteSize, + wal.SegmentSize, + ) + if err != nil { + b.Fatal(err) + } + + checkpointer, err := wal2.NewCheckpointer() + if err != nil { + b.Fatal(err) + } + + // Create checkpoint with only one segment as the base checkpoint for the next step. + err = checkpointer.Checkpoint(0, func() (io.WriteCloser, error) { + return checkpointer.CheckpointWriter(0) + }) + require.NoError(b, err) + + // Create checkpoint with remaining segments + _, to, err := wal2.Segments() + require.NoError(b, err) + + if to == 1 { + fmt.Printf("skip creating second checkpoint file because to segment is 1\n") + return + } + + start := time.Now() + b.ResetTimer() + + err = checkpointer.Checkpoint(to-1, func() (io.WriteCloser, error) { + return checkpointer.CheckpointWriter(to) + }) + + b.StopTimer() + elapsed := time.Since(start) + + if err != nil { + b.Fatal(err) + } + + b.ReportMetric(float64(elapsed/time.Millisecond), "newcheckpoint_rand_time_(ms)") + b.ReportAllocs() +} + +type randKeyValueOptions struct { + keyNumberOfParts int + keyPartMinByteSize int + keyPartMaxByteSize int + valueMinByteSize int + valueMaxByteSize int +} + +func updateLedgerWithRandomData( + led ledger.Ledger, + state ledger.State, + updateCount int, + kvBatchCount int, + kvOpts randKeyValueOptions, +) (ledger.State, error) { + + for i := 0; i < updateCount; i++ { + keys := utils.RandomUniqueKeys(kvBatchCount, kvOpts.keyNumberOfParts, kvOpts.keyPartMinByteSize, kvOpts.keyPartMaxByteSize) + values := utils.RandomValues(kvBatchCount, kvOpts.valueMinByteSize, kvOpts.valueMaxByteSize) + + update, err := ledger.NewUpdate(state, keys, values) + if err != nil { + return ledger.State(hash.DummyHash), err + } + + newState, _, err := led.Set(update) + if err != nil { + return ledger.State(hash.DummyHash), err + } + + state = newState + } + + return state, nil +} diff --git a/ledger/complete/ledger.go b/ledger/complete/ledger.go index 79a7b8d69da..6c05eb17dac 100644 --- a/ledger/complete/ledger.go +++ b/ledger/complete/ledger.go @@ -12,7 +12,6 @@ import ( "github.com/onflow/flow-go/ledger/common/hash" "github.com/onflow/flow-go/ledger/common/pathfinder" "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" "github.com/onflow/flow-go/ledger/complete/mtrie/trie" "github.com/onflow/flow-go/ledger/complete/wal" "github.com/onflow/flow-go/module" @@ -48,15 +47,18 @@ func NewLedger( log zerolog.Logger, pathFinderVer uint8) (*Ledger, error) { - forest, err := mtrie.NewForest(capacity, metrics, func(evictedTrie *trie.MTrie) error { - return wal.RecordDelete(evictedTrie.RootHash()) + logger := log.With().Str("ledger", "complete").Logger() + + forest, err := mtrie.NewForest(capacity, metrics, func(evictedTrie *trie.MTrie) { + err := wal.RecordDelete(evictedTrie.RootHash()) + if err != nil { + logger.Error().Err(err).Msg("failed to save delete record in wal") + } }) if err != nil { return nil, fmt.Errorf("cannot create forest: %w", err) } - logger := log.With().Str("ledger", "complete").Logger() - storage := &Ledger{ forest: forest, wal: wal, @@ -359,14 +361,9 @@ func (l *Ledger) ExportCheckpointAt( return ledger.State(hash.DummyHash), fmt.Errorf("failed to create a checkpoint writer: %w", err) } - flatTrie, err := flattener.FlattenTrie(newTrie) - if err != nil { - return ledger.State(hash.DummyHash), fmt.Errorf("failed to flatten the trie: %w", err) - } - l.logger.Info().Msg("storing the checkpoint to the file") - err = wal.StoreCheckpoint(flatTrie.ToFlattenedForestWithASingleTrie(), writer) + err = wal.StoreCheckpoint(writer, newTrie) if err != nil { return ledger.State(hash.DummyHash), fmt.Errorf("failed to store the checkpoint: %w", err) } diff --git a/ledger/complete/ledger_test.go b/ledger/complete/ledger_test.go index c94edae414c..d136e6db7fe 100644 --- a/ledger/complete/ledger_test.go +++ b/ledger/complete/ledger_test.go @@ -440,7 +440,7 @@ func TestLedgerFunctionality(t *testing.T) { // capture new values for future query for j, k := range keys { encKey := encoding.EncodeKey(&k) - histStorage[string(newState[:])+string(encKey[:])] = values[j] + histStorage[string(newState[:])+string(encKey)] = values[j] latestValue[string(encKey)] = values[j] } diff --git a/ledger/complete/mtrie/flattener/encoding.go b/ledger/complete/mtrie/flattener/encoding.go index efc34137190..ccd55782259 100644 --- a/ledger/complete/mtrie/flattener/encoding.go +++ b/ledger/complete/mtrie/flattener/encoding.go @@ -1,187 +1,422 @@ package flattener import ( + "encoding/binary" "fmt" "io" - "github.com/onflow/flow-go/ledger/common/utils" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/encoding" + "github.com/onflow/flow-go/ledger/common/hash" + "github.com/onflow/flow-go/ledger/complete/mtrie/node" + "github.com/onflow/flow-go/ledger/complete/mtrie/trie" ) -const encodingDecodingVersion = uint16(0) +type nodeType byte -// EncodeStorableNode encodes StorableNode -func EncodeStorableNode(storableNode *StorableNode) []byte { +const ( + leafNodeType nodeType = iota + interimNodeType +) + +const ( + encNodeTypeSize = 1 + encHeightSize = 2 + encMaxDepthSize = 2 + encRegCountSize = 8 + encHashSize = hash.HashLen + encPathSize = ledger.PathLen + encNodeIndexSize = 8 + encPayloadLengthSize = 4 +) - length := 2 + 2 + 8 + 8 + 2 + 8 + 2 + len(storableNode.Path) + 4 + len(storableNode.EncPayload) + 2 + len(storableNode.HashValue) - buf := make([]byte, 0, length) - // 2-bytes encoding version - buf = utils.AppendUint16(buf, encodingDecodingVersion) +// encodeLeafNode encodes leaf node in the following format: +// - node type (1 byte) +// - height (2 bytes) +// - max depth (2 bytes) +// - reg count (8 bytes) +// - hash (32 bytes) +// - path (32 bytes) +// - payload (4 bytes + n bytes) +// Encoded leaf node size is 81 bytes (assuming length of hash/path is 32 bytes) + +// length of encoded payload size. +// Scratch buffer is used to avoid allocs. It should be used directly instead +// of using append. This function uses len(scratch) and ignores cap(scratch), +// so any extra capacity will not be utilized. +// WARNING: The returned buffer is likely to share the same underlying array as +// the scratch buffer. Caller is responsible for copying or using returned buffer +// before scratch buffer is used again. +func encodeLeafNode(n *node.Node, scratch []byte) []byte { + + encPayloadSize := encoding.EncodedPayloadLengthWithoutPrefix(n.Payload()) + + encodedNodeSize := encNodeTypeSize + + encHeightSize + + encMaxDepthSize + + encRegCountSize + + encHashSize + + encPathSize + + encPayloadLengthSize + + encPayloadSize + + // buf uses received scratch buffer if it's large enough. + // Otherwise, a new buffer is allocated. + // buf is used directly so len(buf) must not be 0. + // buf will be resliced to proper size before being returned from this function. + buf := scratch + if len(scratch) < encodedNodeSize { + buf = make([]byte, encodedNodeSize) + } - // 2-bytes Big Endian uint16 height - buf = utils.AppendUint16(buf, storableNode.Height) + pos := 0 - // 8-bytes Big Endian uint64 LIndex - buf = utils.AppendUint64(buf, storableNode.LIndex) + // Encode node type (1 byte) + buf[pos] = byte(leafNodeType) + pos += encNodeTypeSize - // 8-bytes Big Endian uint64 RIndex - buf = utils.AppendUint64(buf, storableNode.RIndex) + // Encode height (2 bytes Big Endian) + binary.BigEndian.PutUint16(buf[pos:], uint16(n.Height())) + pos += encHeightSize - // 2-bytes Big Endian maxDepth - buf = utils.AppendUint16(buf, storableNode.MaxDepth) + // Encode max depth (2 bytes Big Endian) + binary.BigEndian.PutUint16(buf[pos:], n.MaxDepth()) + pos += encMaxDepthSize - // 8-bytes Big Endian regCount - buf = utils.AppendUint64(buf, storableNode.RegCount) + // Encode reg count (8 bytes Big Endian) + binary.BigEndian.PutUint64(buf[pos:], n.RegCount()) + pos += encRegCountSize - // 2-bytes Big Endian uint16 encoded path length and n-bytes encoded path - buf = utils.AppendShortData(buf, storableNode.Path) + // Encode hash (32 bytes hashValue) + hash := n.Hash() + copy(buf[pos:], hash[:]) + pos += encHashSize - // 4-bytes Big Endian uint32 encoded payload length and n-bytes encoded payload - buf = utils.AppendLongData(buf, storableNode.EncPayload) + // Encode path (32 bytes path) + path := n.Path() + copy(buf[pos:], path[:]) + pos += encPathSize - // 2-bytes Big Endian uint16 hashValue length and n-bytes hashValue - buf = utils.AppendShortData(buf, storableNode.HashValue) + // Encode payload (4 bytes Big Endian for encoded payload length and n bytes encoded payload) + binary.BigEndian.PutUint32(buf[pos:], uint32(encPayloadSize)) + pos += encPayloadLengthSize + + // EncodeAndAppendPayloadWithoutPrefix appends encoded payload to the resliced buf. + // Returned buf is resliced to include appended payload. + buf = encoding.EncodeAndAppendPayloadWithoutPrefix(buf[:pos], n.Payload()) return buf } -// ReadStorableNode reads a storable node from io -func ReadStorableNode(reader io.Reader) (*StorableNode, error) { - - // reading version - buf := make([]byte, 2) - read, err := io.ReadFull(reader, buf) - if err != nil { - return nil, fmt.Errorf("error reading storable node, cannot read version part: %w", err) - } - if read != len(buf) { - return nil, fmt.Errorf("not enough bytes read %d expected %d", read, len(buf)) +// encodeInterimNode encodes interim node in the following format: +// - node type (1 byte) +// - height (2 bytes) +// - max depth (2 bytes) +// - reg count (8 bytes) +// - hash (32 bytes) +// - lchild index (8 bytes) +// - rchild index (8 bytes) +// Encoded interim node size is 61 bytes (assuming length of hash is 32 bytes). +// Scratch buffer is used to avoid allocs. It should be used directly instead +// of using append. This function uses len(scratch) and ignores cap(scratch), +// so any extra capacity will not be utilized. +// WARNING: The returned buffer is likely to share the same underlying array as +// the scratch buffer. Caller is responsible for copying or using returned buffer +// before scratch buffer is used again. +func encodeInterimNode(n *node.Node, lchildIndex uint64, rchildIndex uint64, scratch []byte) []byte { + + const encodedNodeSize = encNodeTypeSize + + encHeightSize + + encMaxDepthSize + + encRegCountSize + + encHashSize + + encNodeIndexSize + + encNodeIndexSize + + // buf uses received scratch buffer if it's large enough. + // Otherwise, a new buffer is allocated. + // buf is used directly so len(buf) must not be 0. + // buf will be resliced to proper size before being returned from this function. + buf := scratch + if len(scratch) < encodedNodeSize { + buf = make([]byte, encodedNodeSize) } - version, _, err := utils.ReadUint16(buf) - if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + pos := 0 + + // Encode node type (1 byte) + buf[pos] = byte(interimNodeType) + pos += encNodeTypeSize + + // Encode height (2 bytes Big Endian) + binary.BigEndian.PutUint16(buf[pos:], uint16(n.Height())) + pos += encHeightSize + + // Encode max depth (2 bytes Big Endian) + binary.BigEndian.PutUint16(buf[pos:], n.MaxDepth()) + pos += encMaxDepthSize + + // Encode reg count (8 bytes Big Endian) + binary.BigEndian.PutUint64(buf[pos:], n.RegCount()) + pos += encRegCountSize + + // Encode hash (32 bytes hashValue) + h := n.Hash() + copy(buf[pos:], h[:]) + pos += encHashSize + + // Encode left child index (8 bytes Big Endian) + binary.BigEndian.PutUint64(buf[pos:], lchildIndex) + pos += encNodeIndexSize + + // Encode right child index (8 bytes Big Endian) + binary.BigEndian.PutUint64(buf[pos:], rchildIndex) + pos += encNodeIndexSize + + return buf[:pos] +} + +// EncodeNode encodes node. +// Scratch buffer is used to avoid allocs. +// WARNING: The returned buffer is likely to share the same underlying array as +// the scratch buffer. Caller is responsible for copying or using returned buffer +// before scratch buffer is used again. +func EncodeNode(n *node.Node, lchildIndex uint64, rchildIndex uint64, scratch []byte) []byte { + if n.IsLeaf() { + return encodeLeafNode(n, scratch) } + return encodeInterimNode(n, lchildIndex, rchildIndex, scratch) +} + +// ReadNode reconstructs a node from data read from reader. +// Scratch buffer is used to avoid allocs. It should be used directly instead +// of using append. This function uses len(scratch) and ignores cap(scratch), +// so any extra capacity will not be utilized. +// If len(scratch) < 1024, then a new buffer will be allocated and used. +func ReadNode(reader io.Reader, scratch []byte, getNode func(nodeIndex uint64) (*node.Node, error)) (*node.Node, error) { - if version > encodingDecodingVersion { - return nil, fmt.Errorf("error reading storable node: unsuported version %d > %d", version, encodingDecodingVersion) + // minBufSize should be large enough for interim node and leaf node with small payload. + // minBufSize is a failsafe and is only used when len(scratch) is much smaller + // than expected. len(scratch) is 4096 by default, so minBufSize isn't likely to be used. + const minBufSize = 1024 + + if len(scratch) < minBufSize { + scratch = make([]byte, minBufSize) } - // reading fixed-length part - buf = make([]byte, 2+8+8+2+8) + // fixLengthSize is the size of shared data of leaf node and interim node + const fixLengthSize = encNodeTypeSize + + encHeightSize + + encMaxDepthSize + + encRegCountSize + + encHashSize - read, err = io.ReadFull(reader, buf) + _, err := io.ReadFull(reader, scratch[:fixLengthSize]) if err != nil { - return nil, fmt.Errorf("error reading storable node, cannot read fixed-length part: %w", err) - } - if read != len(buf) { - return nil, fmt.Errorf("not enough bytes read %d expected %d", read, len(buf)) + return nil, fmt.Errorf("failed to read fixed-length part of serialized node: %w", err) } - storableNode := &StorableNode{} + pos := 0 - storableNode.Height, buf, err = utils.ReadUint16(buf) - if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + // Decode node type (1 byte) + nType := scratch[pos] + pos += encNodeTypeSize + + if nType != byte(leafNodeType) && nType != byte(interimNodeType) { + return nil, fmt.Errorf("failed to decode node type %d", nType) } - storableNode.LIndex, buf, err = utils.ReadUint64(buf) + // Decode height (2 bytes) + height := binary.BigEndian.Uint16(scratch[pos:]) + pos += encHeightSize + + // Decode max depth (2 bytes) + maxDepth := binary.BigEndian.Uint16(scratch[pos:]) + pos += encMaxDepthSize + + // Decode reg count (8 bytes) + regCount := binary.BigEndian.Uint64(scratch[pos:]) + pos += encRegCountSize + + // Decode and create hash.Hash (32 bytes) + nodeHash, err := hash.ToHash(scratch[pos : pos+encHashSize]) if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + return nil, fmt.Errorf("failed to decode hash of serialized node: %w", err) } - storableNode.RIndex, buf, err = utils.ReadUint64(buf) - if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + if nType == byte(leafNodeType) { + + // Read path (32 bytes) + encPath := scratch[:encPathSize] + _, err := io.ReadFull(reader, encPath) + if err != nil { + return nil, fmt.Errorf("failed to read path of serialized node: %w", err) + } + + // Decode and create ledger.Path. + path, err := ledger.ToPath(encPath) + if err != nil { + return nil, fmt.Errorf("failed to decode path of serialized node: %w", err) + } + + // Read encoded payload data and create ledger.Payload. + payload, err := readPayloadFromReader(reader, scratch) + if err != nil { + return nil, fmt.Errorf("failed to read and decode payload of serialized node: %w", err) + } + + node := node.NewNode(int(height), nil, nil, path, payload, nodeHash, maxDepth, regCount) + return node, nil } - storableNode.MaxDepth, buf, err = utils.ReadUint16(buf) + // Read interim node + + // Read left and right child index (16 bytes) + _, err = io.ReadFull(reader, scratch[:encNodeIndexSize*2]) if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + return nil, fmt.Errorf("failed to read child index of serialized node: %w", err) } - storableNode.RegCount, _, err = utils.ReadUint64(buf) + pos = 0 + + // Decode left child index (8 bytes) + lchildIndex := binary.BigEndian.Uint64(scratch[pos:]) + pos += encNodeIndexSize + + // Decode right child index (8 bytes) + rchildIndex := binary.BigEndian.Uint64(scratch[pos:]) + + // Get left child node by node index + lchild, err := getNode(lchildIndex) if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + return nil, fmt.Errorf("failed to find left child node of serialized node: %w", err) } - storableNode.Path, err = utils.ReadShortDataFromReader(reader) + // Get right child node by node index + rchild, err := getNode(rchildIndex) if err != nil { - return nil, fmt.Errorf("cannot read key data: %w", err) + return nil, fmt.Errorf("failed to find right child node of serialized node: %w", err) } - storableNode.EncPayload, err = utils.ReadLongDataFromReader(reader) - if err != nil { - return nil, fmt.Errorf("cannot read value data: %w", err) + n := node.NewNode(int(height), lchild, rchild, ledger.DummyPath, nil, nodeHash, maxDepth, regCount) + return n, nil +} + +// EncodeTrie encodes trie in the following format: +// - root node index (8 byte) +// - root node hash (32 bytes) +// Scratch buffer is used to avoid allocs. +// WARNING: The returned buffer is likely to share the same underlying array as +// the scratch buffer. Caller is responsible for copying or using returned buffer +// before scratch buffer is used again. +func EncodeTrie(rootNode *node.Node, rootIndex uint64, scratch []byte) []byte { + + const encodedTrieSize = encNodeIndexSize + encHashSize + + // Get root hash + var rootHash ledger.RootHash + if rootNode == nil { + rootHash = trie.EmptyTrieRootHash() + } else { + rootHash = ledger.RootHash(rootNode.Hash()) } - storableNode.HashValue, err = utils.ReadShortDataFromReader(reader) - if err != nil { - return nil, fmt.Errorf("cannot read hashValue data: %w", err) + if len(scratch) < encodedTrieSize { + scratch = make([]byte, encodedTrieSize) } - return storableNode, nil + pos := 0 + + // Encode root node index (8 bytes Big Endian) + binary.BigEndian.PutUint64(scratch, rootIndex) + pos += encNodeIndexSize + + // Encode hash (32-bytes hashValue) + copy(scratch[pos:], rootHash[:]) + pos += encHashSize + + return scratch[:pos] } -// EncodeStorableTrie encodes StorableTrie -func EncodeStorableTrie(storableTrie *StorableTrie) []byte { - length := 2 + 8 + 2 + len(storableTrie.RootHash) - buf := make([]byte, 0, length) - // 2-bytes encoding version - buf = utils.AppendUint16(buf, encodingDecodingVersion) +// ReadTrie reconstructs a trie from data read from reader. +func ReadTrie(reader io.Reader, scratch []byte, getNode func(nodeIndex uint64) (*node.Node, error)) (*trie.MTrie, error) { - // 8-bytes Big Endian uint64 RootIndex - buf = utils.AppendUint64(buf, storableTrie.RootIndex) + // encodedTrieSize is a failsafe and is only used when len(scratch) is much smaller + // than expected (4096 by default). + const encodedTrieSize = encNodeIndexSize + encHashSize - // 2-bytes Big Endian uint16 RootHash length and n-bytes RootHash - buf = utils.AppendShortData(buf, storableTrie.RootHash) + if len(scratch) < encodedTrieSize { + scratch = make([]byte, encodedTrieSize) + } - return buf -} + // Read encoded trie (8 + 32 bytes) + _, err := io.ReadFull(reader, scratch[:encodedTrieSize]) + if err != nil { + return nil, fmt.Errorf("failed to read serialized trie: %w", err) + } + + pos := 0 -// ReadStorableTrie reads a storable trie from io -func ReadStorableTrie(reader io.Reader) (*StorableTrie, error) { - storableTrie := &StorableTrie{} + // Decode root node index + rootIndex := binary.BigEndian.Uint64(scratch) + pos += encNodeIndexSize - // reading version - buf := make([]byte, 2) - read, err := io.ReadFull(reader, buf) + // Decode root node hash + readRootHash, err := hash.ToHash(scratch[pos : pos+encHashSize]) if err != nil { - return nil, fmt.Errorf("error reading storable node, cannot read version part: %w", err) + return nil, fmt.Errorf("failed to decode hash of serialized trie: %w", err) } - if read != len(buf) { - return nil, fmt.Errorf("not enough bytes read %d expected %d", read, len(buf)) + + rootNode, err := getNode(rootIndex) + if err != nil { + return nil, fmt.Errorf("failed to find root node of serialized trie: %w", err) } - version, _, err := utils.ReadUint16(buf) + mtrie, err := trie.NewMTrie(rootNode) if err != nil { - return nil, fmt.Errorf("error reading storable node: %w", err) + return nil, fmt.Errorf("failed to restore serialized trie: %w", err) } - if version > encodingDecodingVersion { - return nil, fmt.Errorf("error reading storable node: unsuported version %d > %d", version, encodingDecodingVersion) + rootHash := mtrie.RootHash() + if !rootHash.Equals(ledger.RootHash(readRootHash)) { + return nil, fmt.Errorf("failed to restore serialized trie: roothash doesn't match") } - // read root uint64 RootIndex - buf = make([]byte, 8) - read, err = io.ReadFull(reader, buf) + return mtrie, nil +} + +// readPayloadFromReader reads and decodes payload from reader. +// Returned payload is a copy. +func readPayloadFromReader(reader io.Reader, scratch []byte) (*ledger.Payload, error) { + + if len(scratch) < encPayloadLengthSize { + scratch = make([]byte, encPayloadLengthSize) + } + + // Read payload size + _, err := io.ReadFull(reader, scratch[:encPayloadLengthSize]) if err != nil { - return nil, fmt.Errorf("cannot read fixed-legth part: %w", err) + return nil, fmt.Errorf("cannot read payload length: %w", err) } - if read != len(buf) { - return nil, fmt.Errorf("not enough bytes read %d expected %d", read, len(buf)) + + // Decode payload size + size := binary.BigEndian.Uint32(scratch) + + if len(scratch) < int(size) { + scratch = make([]byte, size) + } else { + scratch = scratch[:size] } - rootIndex, _, err := utils.ReadUint64(buf) + _, err = io.ReadFull(reader, scratch) if err != nil { - return nil, fmt.Errorf("cannot read root index data: %w", err) + return nil, fmt.Errorf("cannot read payload: %w", err) } - storableTrie.RootIndex = rootIndex - roothash, err := utils.ReadShortDataFromReader(reader) + // Decode and copy payload + payload, err := encoding.DecodePayloadWithoutPrefix(scratch, false) if err != nil { - return nil, fmt.Errorf("cannot read roothash data: %w", err) + return nil, fmt.Errorf("failed to decode payload: %w", err) } - storableTrie.RootHash = roothash - return storableTrie, nil + return payload, nil } diff --git a/ledger/complete/mtrie/flattener/encoding_test.go b/ledger/complete/mtrie/flattener/encoding_test.go index f1c2657371e..2e78d4de22c 100644 --- a/ledger/complete/mtrie/flattener/encoding_test.go +++ b/ledger/complete/mtrie/flattener/encoding_test.go @@ -2,88 +2,374 @@ package flattener_test import ( "bytes" + "errors" + "fmt" + "math/rand" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/ledger/common/encoding" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/hash" "github.com/onflow/flow-go/ledger/common/utils" "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/node" ) -func TestStorableNode(t *testing.T) { - path := utils.PathByUint8(3) - - storableNode := &flattener.StorableNode{ - LIndex: 1, - RIndex: 2, - Height: 2137, - Path: path[:], - EncPayload: encoding.EncodePayload(utils.LightPayload8('A', 'a')), - HashValue: []byte{4, 4, 4}, - MaxDepth: 7, - RegCount: 5000, +func TestLeafNodeEncodingDecoding(t *testing.T) { + + // Leaf node with nil payload + path1 := utils.PathByUint8(0) + payload1 := (*ledger.Payload)(nil) + hashValue1 := hash.Hash([32]byte{1, 1, 1}) + leafNodeNilPayload := node.NewNode(255, nil, nil, ledger.Path(path1), payload1, hashValue1, 0, 1) + + // Leaf node with empty payload (not nil) + // EmptyPayload() not used because decoded playload's value is empty slice (not nil) + path2 := utils.PathByUint8(1) + payload2 := &ledger.Payload{Value: []byte{}} + hashValue2 := hash.Hash([32]byte{2, 2, 2}) + leafNodeEmptyPayload := node.NewNode(255, nil, nil, ledger.Path(path2), payload2, hashValue2, 0, 1) + + // Leaf node with payload + path3 := utils.PathByUint8(2) + payload3 := utils.LightPayload8('A', 'a') + hashValue3 := hash.Hash([32]byte{3, 3, 3}) + leafNodePayload := node.NewNode(255, nil, nil, ledger.Path(path3), payload3, hashValue3, 0, 1) + + encodedLeafNodeNilPayload := []byte{ + 0x00, // node type + 0x00, 0xff, // height + 0x00, 0x00, // max depth + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // reg count + 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash data + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // path data + 0x00, 0x00, 0x00, 0x00, // payload data len } - // Version 0 - expected := []byte{ - 0, 0, // encoding version - 8, 89, // height - 0, 0, 0, 0, 0, 0, 0, 1, // LIndex - 0, 0, 0, 0, 0, 0, 0, 2, // RIndex - 0, 7, // max depth - 0, 0, 0, 0, 0, 0, 19, 136, // reg count - 0, 32, // path data len - 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // path data - 0, 0, 0, 25, // payload data len - 0, 0, 6, 0, 0, 0, 9, 0, 1, 0, 0, 0, 3, 0, 0, 65, 0, 0, 0, 0, 0, 0, 0, 1, 97, // payload data - 0, 3, // hashValue length - 4, 4, 4, // hashValue + encodedLeafNodeEmptyPayload := []byte{ + 0x00, // node type + 0x00, 0xff, // height + 0x00, 0x00, // max depth + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // reg count + 0x02, 0x02, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash data + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // path data + 0x00, 0x00, 0x00, 0x0e, // payload data len + 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // payload data } - t.Run("encode", func(t *testing.T) { - data := flattener.EncodeStorableNode(storableNode) - assert.Equal(t, expected, data) - }) + encodedLeafNodePayload := []byte{ + 0x00, // node type + 0x00, 0xff, // height + 0x00, 0x00, // max depth + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // reg count + 0x03, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash data + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // path data + 0x00, 0x00, 0x00, 0x16, // payload data len + 0x00, 0x00, 0x00, 0x09, 0x00, 0x01, 0x00, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x41, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x61, // payload data + } - t.Run("decode", func(t *testing.T) { - reader := bytes.NewReader(expected) - newStorableNode, err := flattener.ReadStorableNode(reader) - require.NoError(t, err) - assert.Equal(t, storableNode, newStorableNode) - }) + testCases := []struct { + name string + node *node.Node + encodedNode []byte + }{ + {"nil payload", leafNodeNilPayload, encodedLeafNodeNilPayload}, + {"empty payload", leafNodeEmptyPayload, encodedLeafNodeEmptyPayload}, + {"payload", leafNodePayload, encodedLeafNodePayload}, + } + + for _, tc := range testCases { + t.Run("encode "+tc.name, func(t *testing.T) { + scratchBuffers := [][]byte{ + nil, + make([]byte, 0), + make([]byte, 16), + make([]byte, 1024), + } + + for _, scratch := range scratchBuffers { + encodedNode := flattener.EncodeNode(tc.node, 0, 0, scratch) + assert.Equal(t, tc.encodedNode, encodedNode) + + if len(scratch) > 0 { + if len(scratch) >= len(encodedNode) { + // reuse scratch buffer + require.True(t, &scratch[0] == &encodedNode[0]) + } else { + // new alloc + require.True(t, &scratch[0] != &encodedNode[0]) + } + } + } + }) + + t.Run("decode "+tc.name, func(t *testing.T) { + scratchBuffers := [][]byte{ + nil, + make([]byte, 0), + make([]byte, 16), + make([]byte, 1024), + } + + for _, scratch := range scratchBuffers { + reader := bytes.NewReader(tc.encodedNode) + newNode, err := flattener.ReadNode(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + return nil, fmt.Errorf("no call expected") + }) + require.NoError(t, err) + assert.Equal(t, tc.node, newNode) + assert.Equal(t, 0, reader.Len()) + } + }) + } } -func TestStorableTrie(t *testing.T) { +func TestRandomLeafNodeEncodingDecoding(t *testing.T) { + const count = 1000 + const minPayloadSize = 40 + const maxPayloadSize = 1024 * 2 + // scratchBufferSize is intentionally small here to test + // when encoded node size is sometimes larger than scratch buffer. + const scratchBufferSize = 512 + + paths := utils.RandomPaths(count) + payloads := utils.RandomPayloads(count, minPayloadSize, maxPayloadSize) + + writeScratch := make([]byte, scratchBufferSize) + readScratch := make([]byte, scratchBufferSize) - storableTrie := &flattener.StorableTrie{ - RootIndex: 21, - RootHash: []byte{2, 2, 2}, + for i := 0; i < count; i++ { + height := rand.Intn(257) + + var hashValue hash.Hash + rand.Read(hashValue[:]) + + n := node.NewNode(height, nil, nil, paths[i], payloads[i], hashValue, 0, 1) + + encodedNode := flattener.EncodeNode(n, 0, 0, writeScratch) + + if len(writeScratch) >= len(encodedNode) { + // reuse scratch buffer + require.True(t, &writeScratch[0] == &encodedNode[0]) + } else { + // new alloc because scratch buffer isn't big enough + require.True(t, &writeScratch[0] != &encodedNode[0]) + } + + reader := bytes.NewReader(encodedNode) + newNode, err := flattener.ReadNode(reader, readScratch, func(nodeIndex uint64) (*node.Node, error) { + return nil, fmt.Errorf("no call expected") + }) + require.NoError(t, err) + assert.Equal(t, n, newNode) + assert.Equal(t, 0, reader.Len()) } +} + +func TestInterimNodeEncodingDecoding(t *testing.T) { + + const lchildIndex = 1 + const rchildIndex = 2 + + // Child node + path1 := utils.PathByUint8(0) + payload1 := utils.LightPayload8('A', 'a') + hashValue1 := hash.Hash([32]byte{1, 1, 1}) + leafNode1 := node.NewNode(255, nil, nil, ledger.Path(path1), payload1, hashValue1, 0, 1) - // Version 0 - expected := []byte{ - 0, 0, // encoding version - 0, 0, 0, 0, 0, 0, 0, 21, // RootIndex - 0, 3, 2, 2, 2, // RootHash length + data + // Child node + path2 := utils.PathByUint8(1) + payload2 := utils.LightPayload8('B', 'b') + hashValue2 := hash.Hash([32]byte{2, 2, 2}) + leafNode2 := node.NewNode(255, nil, nil, ledger.Path(path2), payload2, hashValue2, 0, 1) + + // Interim node + hashValue3 := hash.Hash([32]byte{3, 3, 3}) + interimNode := node.NewNode(256, leafNode1, leafNode2, ledger.DummyPath, nil, hashValue3, 1, 2) + + encodedInterimNode := []byte{ + 0x01, // node type + 0x01, 0x00, // height + 0x00, 0x01, // max depth + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, // reg count + 0x03, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash data + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // LIndex + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, // RIndex } t.Run("encode", func(t *testing.T) { - data := flattener.EncodeStorableTrie(storableTrie) + scratchBuffers := [][]byte{ + nil, + make([]byte, 0), + make([]byte, 16), + make([]byte, 1024), + } - assert.Equal(t, expected, data) + for _, scratch := range scratchBuffers { + data := flattener.EncodeNode(interimNode, lchildIndex, rchildIndex, scratch) + assert.Equal(t, encodedInterimNode, data) + } }) t.Run("decode", func(t *testing.T) { + scratchBuffers := [][]byte{ + nil, + make([]byte, 0), + make([]byte, 16), + make([]byte, 1024), + } - reader := bytes.NewReader(expected) + for _, scratch := range scratchBuffers { + reader := bytes.NewReader(encodedInterimNode) + newNode, err := flattener.ReadNode(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + switch nodeIndex { + case lchildIndex: + return leafNode1, nil + case rchildIndex: + return leafNode2, nil + default: + return nil, fmt.Errorf("unexpected child node index %d ", nodeIndex) + } + }) + require.NoError(t, err) + assert.Equal(t, interimNode, newNode) + assert.Equal(t, 0, reader.Len()) + } + }) - newStorableNode, err := flattener.ReadStorableTrie(reader) - require.NoError(t, err) + t.Run("decode child node not found error", func(t *testing.T) { + nodeNotFoundError := errors.New("failed to find node by index") + scratch := make([]byte, 1024) - assert.Equal(t, storableTrie, newStorableNode) + reader := bytes.NewReader(encodedInterimNode) + newNode, err := flattener.ReadNode(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + return nil, nodeNotFoundError + }) + require.Nil(t, newNode) + require.ErrorIs(t, err, nodeNotFoundError) }) +} + +func TestTrieEncodingDecoding(t *testing.T) { + // Nil root node + rootNodeNil := (*node.Node)(nil) + rootNodeNilIndex := uint64(20) + + // Not nil root node + hashValue := hash.Hash([32]byte{2, 2, 2}) + rootNode := node.NewNode(256, nil, nil, ledger.DummyPath, nil, hashValue, 7, 5000) + rootNodeIndex := uint64(21) + + encodedNilTrie := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, // RootIndex + 0x56, 0x8f, 0x4e, 0xc7, 0x40, 0xfe, 0x3b, 0x5d, + 0xe8, 0x80, 0x34, 0xcb, 0x7b, 0x1f, 0xbd, 0xdb, + 0x41, 0x54, 0x8b, 0x06, 0x8f, 0x31, 0xae, 0xbc, + 0x8a, 0xe9, 0x18, 0x9e, 0x42, 0x9c, 0x57, 0x49, // RootHash data + } + encodedTrie := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x15, // RootIndex + 0x02, 0x02, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash data + } + + testCases := []struct { + name string + rootNode *node.Node + rootNodeIndex uint64 + encodedTrie []byte + }{ + {"nil trie", rootNodeNil, rootNodeNilIndex, encodedNilTrie}, + {"trie", rootNode, rootNodeIndex, encodedTrie}, + } + + for _, tc := range testCases { + + t.Run("encode "+tc.name, func(t *testing.T) { + scratchBuffers := [][]byte{ + nil, + make([]byte, 0), + make([]byte, 16), + make([]byte, 1024), + } + + for _, scratch := range scratchBuffers { + encodedTrie := flattener.EncodeTrie(tc.rootNode, tc.rootNodeIndex, scratch) + assert.Equal(t, tc.encodedTrie, encodedTrie) + + if len(scratch) > 0 { + if len(scratch) >= len(encodedTrie) { + // reuse scratch buffer + require.True(t, &scratch[0] == &encodedTrie[0]) + } else { + // new alloc + require.True(t, &scratch[0] != &encodedTrie[0]) + } + } + } + }) + + t.Run("decode "+tc.name, func(t *testing.T) { + scratchBuffers := [][]byte{ + nil, + make([]byte, 0), + make([]byte, 16), + make([]byte, 1024), + } + + for _, scratch := range scratchBuffers { + reader := bytes.NewReader(tc.encodedTrie) + trie, err := flattener.ReadTrie(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + if nodeIndex != tc.rootNodeIndex { + return nil, fmt.Errorf("unexpected root node index %d ", nodeIndex) + } + return tc.rootNode, nil + }) + require.NoError(t, err) + assert.Equal(t, tc.rootNode, trie.RootNode()) + assert.Equal(t, 0, reader.Len()) + } + }) + + t.Run("decode "+tc.name+" node not found error", func(t *testing.T) { + nodeNotFoundError := errors.New("failed to find node by index") + scratch := make([]byte, 1024) + + reader := bytes.NewReader(tc.encodedTrie) + newNode, err := flattener.ReadTrie(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + return nil, nodeNotFoundError + }) + require.Nil(t, newNode) + require.ErrorIs(t, err, nodeNotFoundError) + }) + } } diff --git a/ledger/complete/mtrie/flattener/encoding_v3.go b/ledger/complete/mtrie/flattener/encoding_v3.go new file mode 100644 index 00000000000..98994557325 --- /dev/null +++ b/ledger/complete/mtrie/flattener/encoding_v3.go @@ -0,0 +1,222 @@ +package flattener + +import ( + "bytes" + "fmt" + "io" + + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/encoding" + "github.com/onflow/flow-go/ledger/common/hash" + "github.com/onflow/flow-go/ledger/common/utils" + "github.com/onflow/flow-go/ledger/complete/mtrie/node" + "github.com/onflow/flow-go/ledger/complete/mtrie/trie" +) + +// This file contains decoding functions for checkpoint v3 and earlier versions. +// These functions are for backwards compatibility, not optimized. + +const encodingDecodingVersion = uint16(0) + +// ReadNodeFromCheckpointV3AndEarlier reconstructs a node from data in checkpoint v3 and earlier versions. +// Encoded node in checkpoint v3 and earlier is in the following format: +// - version (2 bytes) +// - height (2 bytes) +// - lindex (8 bytes) +// - rindex (8 bytes) +// - max depth (2 bytes) +// - reg count (8 bytes) +// - path (2 bytes + 32 bytes) +// - payload (4 bytes + n bytes) +// - hash (2 bytes + 32 bytes) +func ReadNodeFromCheckpointV3AndEarlier(reader io.Reader, getNode func(nodeIndex uint64) (*node.Node, error)) (*node.Node, error) { + + // Read version (2 bytes) + buf := make([]byte, 2) + _, err := io.ReadFull(reader, buf) + if err != nil { + return nil, fmt.Errorf("failed to read version of serialized node in v3: %w", err) + } + + // Decode version + version, _, err := utils.ReadUint16(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode version of serialized node in v3: %w", err) + } + + if version > encodingDecodingVersion { + return nil, fmt.Errorf("found unsuported version %d (> %d) of serialized node in v3", version, encodingDecodingVersion) + } + + // fixed-length data: + // height (2 bytes) + + // left child node index (8 bytes) + + // right child node index (8 bytes) + + // max depth (2 bytes) + + // reg count (8 bytes) + buf = make([]byte, 2+8+8+2+8) + + // Read fixed-length part + _, err = io.ReadFull(reader, buf) + if err != nil { + return nil, fmt.Errorf("failed to read fixed-length part of serialized node in v3: %w", err) + } + + var height, maxDepth uint16 + var lchildIndex, rchildIndex, regCount uint64 + var path, hashValue, encPayload []byte + + // Decode height (2 bytes) + height, buf, err = utils.ReadUint16(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode height of serialized node in v3: %w", err) + } + + // Decode left child index (8 bytes) + lchildIndex, buf, err = utils.ReadUint64(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode left child index of serialized node in v3: %w", err) + } + + // Decode right child index (8 bytes) + rchildIndex, buf, err = utils.ReadUint64(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode right child index of serialized node in v3: %w", err) + } + + // Decode max depth (2 bytes) + maxDepth, buf, err = utils.ReadUint16(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode max depth of serialized node in v3: %w", err) + } + + // Decode reg count (8 bytes) + regCount, _, err = utils.ReadUint64(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode reg count of serialized node in v3: %w", err) + } + + // Read path (2 bytes + 32 bytes) + path, err = utils.ReadShortDataFromReader(reader) + if err != nil { + return nil, fmt.Errorf("failed to read path of serialized node in v3: %w", err) + } + + // Read payload (4 bytes + n bytes) + encPayload, err = utils.ReadLongDataFromReader(reader) + if err != nil { + return nil, fmt.Errorf("failed to read payload of serialized node in v3: %w", err) + } + + // Read hash (2 bytes + 32 bytes) + hashValue, err = utils.ReadShortDataFromReader(reader) + if err != nil { + return nil, fmt.Errorf("failed to read hash of serialized node in v3: %w", err) + } + + // Create (and copy) hash from raw data. + nodeHash, err := hash.ToHash(hashValue) + if err != nil { + return nil, fmt.Errorf("failed to decode hash of serialized node in v3: %w", err) + } + + if len(path) > 0 { + // Create (and copy) path from raw data. + path, err := ledger.ToPath(path) + if err != nil { + return nil, fmt.Errorf("failed to decode path of serialized node in v3: %w", err) + } + + // Decode payload (payload data isn't copied). + payload, err := encoding.DecodePayload(encPayload) + if err != nil { + return nil, fmt.Errorf("failed to decode payload of serialized node in v3: %w", err) + } + + // Make a copy of payload + var pl *ledger.Payload + if payload != nil { + pl = payload.DeepCopy() + } + + n := node.NewNode(int(height), nil, nil, path, pl, nodeHash, maxDepth, regCount) + return n, nil + } + + // Get left child node by node index + lchild, err := getNode(lchildIndex) + if err != nil { + return nil, fmt.Errorf("failed to find left child node of serialized node in v3: %w", err) + } + + // Get right child node by node index + rchild, err := getNode(rchildIndex) + if err != nil { + return nil, fmt.Errorf("failed to find right child node of serialized node in v3: %w", err) + } + + n := node.NewNode(int(height), lchild, rchild, ledger.DummyPath, nil, nodeHash, maxDepth, regCount) + return n, nil +} + +// ReadTrieFromCheckpointV3AndEarlier reconstructs a trie from data in checkpoint v3 and earlier versions. +// Encoded trie in checkpoint v3 and earlier is in the following format: +// - version (2 bytes) +// - root node index (8 bytes) +// - root node hash (2 bytes + 32 bytes) +func ReadTrieFromCheckpointV3AndEarlier(reader io.Reader, getNode func(nodeIndex uint64) (*node.Node, error)) (*trie.MTrie, error) { + + // Read version (2 bytes) + buf := make([]byte, 2) + _, err := io.ReadFull(reader, buf) + if err != nil { + return nil, fmt.Errorf("failed to read version of serialized trie in v3: %w", err) + } + + // Decode version + version, _, err := utils.ReadUint16(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode version of serialized trie in v3: %w", err) + } + + if version > encodingDecodingVersion { + return nil, fmt.Errorf("found unsuported version %d (> %d) of serialized trie in v3", version, encodingDecodingVersion) + } + + // Read root index (8 bytes) + buf = make([]byte, 8) + _, err = io.ReadFull(reader, buf) + if err != nil { + return nil, fmt.Errorf("failed to read root index of serialized trie in v3: %w", err) + } + + // Decode root index + rootIndex, _, err := utils.ReadUint64(buf) + if err != nil { + return nil, fmt.Errorf("failed to decode root index of serialized trie in v3: %w", err) + } + + // Read root hash (2 bytes + 32 bytes) + readRootHash, err := utils.ReadShortDataFromReader(reader) + if err != nil { + return nil, fmt.Errorf("failed to read root hash of serialized trie in v3: %w", err) + } + + // Get node by index + rootNode, err := getNode(rootIndex) + if err != nil { + return nil, fmt.Errorf("failed to find root node of serialized trie in v3: %w", err) + } + + mtrie, err := trie.NewMTrie(rootNode) + if err != nil { + return nil, fmt.Errorf("failed to restore serialized trie in v3: %w", err) + } + + rootHash := mtrie.RootHash() + if !bytes.Equal(readRootHash, rootHash[:]) { + return nil, fmt.Errorf("failed to restore serialized trie in v3: roothash doesn't match") + } + + return mtrie, nil +} diff --git a/ledger/complete/mtrie/flattener/encoding_v3_test.go b/ledger/complete/mtrie/flattener/encoding_v3_test.go new file mode 100644 index 00000000000..374de31cf9c --- /dev/null +++ b/ledger/complete/mtrie/flattener/encoding_v3_test.go @@ -0,0 +1,126 @@ +package flattener_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/hash" + "github.com/onflow/flow-go/ledger/common/utils" + "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/node" +) + +// This file contains node/trie decoding tests for checkpoint v3 and earlier versions. +// These tests are based on TestStorableNode and TestStorableTrie. + +func TestNodeV3Decoding(t *testing.T) { + + const leafNode1Index = 1 + const leafNode2Index = 2 + + leafNode1 := node.NewNode(255, nil, nil, utils.PathByUint8(0), utils.LightPayload8('A', 'a'), hash.Hash([32]byte{1, 1, 1}), 0, 1) + leafNode2 := node.NewNode(255, nil, nil, utils.PathByUint8(1), utils.LightPayload8('B', 'b'), hash.Hash([32]byte{2, 2, 2}), 0, 1) + + interimNode := node.NewNode(256, leafNode1, leafNode2, ledger.DummyPath, nil, hash.Hash([32]byte{3, 3, 3}), 1, 2) + + encodedLeafNode1 := []byte{ + 0x00, 0x00, // encoding version + 0x00, 0xff, // height + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // LIndex + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RIndex + 0x00, 0x00, // max depth + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // reg count + 0x00, 0x20, // path data len + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // path data + 0x00, 0x00, 0x00, 0x19, // payload data len + 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x09, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x61, // payload data + 0x00, 0x20, // hashValue length + 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash value + } + + encodedInterimNode := []byte{ + 0x00, 0x00, // encoding version + 0x01, 0x00, // height + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // LIndex + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, // RIndex + 0x00, 0x01, // max depth + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, // reg count + 0x00, 0x00, // path data len + 0x00, 0x00, 0x00, 0x00, // payload data len + 0x00, 0x20, // hashValue length + 0x03, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash value + } + + t.Run("leaf node", func(t *testing.T) { + reader := bytes.NewReader(encodedLeafNode1) + newNode, err := flattener.ReadNodeFromCheckpointV3AndEarlier(reader, func(nodeIndex uint64) (*node.Node, error) { + return nil, fmt.Errorf("no call expected") + }) + require.NoError(t, err) + assert.Equal(t, leafNode1, newNode) + }) + + t.Run("interim node", func(t *testing.T) { + reader := bytes.NewReader(encodedInterimNode) + newNode, err := flattener.ReadNodeFromCheckpointV3AndEarlier(reader, func(nodeIndex uint64) (*node.Node, error) { + switch nodeIndex { + case leafNode1Index: + return leafNode1, nil + case leafNode2Index: + return leafNode2, nil + default: + return nil, fmt.Errorf("unexpected child node index %d ", nodeIndex) + } + }) + require.NoError(t, err) + assert.Equal(t, interimNode, newNode) + }) +} + +func TestTrieV3Decoding(t *testing.T) { + + const rootNodeIndex = 21 + + hashValue := hash.Hash([32]byte{2, 2, 2}) + rootNode := node.NewNode(256, nil, nil, ledger.DummyPath, nil, hashValue, 7, 5000) + + expected := []byte{ + 0x00, 0x00, // encoding version + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 21, // RootIndex + 0x00, 0x20, // hashValue length + 0x02, 0x02, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // hash value + } + + reader := bytes.NewReader(expected) + + trie, err := flattener.ReadTrieFromCheckpointV3AndEarlier(reader, func(nodeIndex uint64) (*node.Node, error) { + switch nodeIndex { + case rootNodeIndex: + return rootNode, nil + default: + return nil, fmt.Errorf("unexpected root node index %d ", nodeIndex) + } + }) + require.NoError(t, err) + assert.Equal(t, rootNode, trie.RootNode()) +} diff --git a/ledger/complete/mtrie/flattener/forest.go b/ledger/complete/mtrie/flattener/forest.go deleted file mode 100644 index c2a7ac72f98..00000000000 --- a/ledger/complete/mtrie/flattener/forest.go +++ /dev/null @@ -1,194 +0,0 @@ -package flattener - -import ( - "bytes" - "encoding/hex" - "fmt" - - "github.com/onflow/flow-go/ledger" - "github.com/onflow/flow-go/ledger/common/encoding" - "github.com/onflow/flow-go/ledger/common/hash" - "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/node" - "github.com/onflow/flow-go/ledger/complete/mtrie/trie" -) - -// FlattenedForest represents an Forest as a flattened data structure. -// Specifically it consists of : -// * a list of storable nodes, where references to nodes are replaced by index in the slice -// * and a list of storable tries, each referencing their respective root node by index. -// 0 is a special index, meaning nil, but is included in this list for ease of use -// and removing would make it necessary to constantly add/subtract indexes -// -// As an important property, the nodes are listed in an order which satisfies -// Descendents-First-Relationship. The Descendents-First-Relationship has the -// following important property: -// When re-building the Trie from the sequence of nodes, one can build the trie on the fly, -// as for each node, the children have been previously encountered. -type FlattenedForest struct { - Nodes []*StorableNode - Tries []*StorableTrie -} - -// node2indexMap maps a node pointer to the node index in the serialization -type node2indexMap map[*node.Node]uint64 - -// FlattenForest returns forest FlattenedForest, which contains all nodes and tries of the Forest. -func FlattenForest(f *mtrie.Forest) (*FlattenedForest, error) { - tries, err := f.GetTries() - if err != nil { - return nil, fmt.Errorf("cannot get cached tries root hashes: %w", err) - } - - storableTries := make([]*StorableTrie, 0, len(tries)) - storableNodes := []*StorableNode{nil} // 0th element is nil - - // assign unique value to every node - allNodes := make(node2indexMap) - allNodes[nil] = 0 // 0th element is nil - - counter := uint64(1) // start from 1, as 0 marks nil - for _, t := range tries { - for itr := NewNodeIterator(t); itr.Next(); { - n := itr.Value() - // if node not in map - if _, has := allNodes[n]; !has { - allNodes[n] = counter - counter++ - storableNode, err := toStorableNode(n, allNodes) - if err != nil { - return nil, fmt.Errorf("failed to construct storable node: %w", err) - } - storableNodes = append(storableNodes, storableNode) - } - } - //fix root nodes indices - // since we indexed all nodes, root must be present - storableTrie, err := toStorableTrie(t, allNodes) - if err != nil { - return nil, fmt.Errorf("failed to construct storable trie: %w", err) - } - storableTries = append(storableTries, storableTrie) - } - - return &FlattenedForest{ - Nodes: storableNodes, - Tries: storableTries, - }, nil -} - -func toStorableNode(node *node.Node, indexForNode node2indexMap) (*StorableNode, error) { - leftIndex, found := indexForNode[node.LeftChild()] - if !found { - hash := node.LeftChild().Hash() - return nil, fmt.Errorf("internal error: missing node with hash %s", hex.EncodeToString(hash[:])) - } - rightIndex, found := indexForNode[node.RightChild()] - if !found { - hash := node.RightChild().Hash() - return nil, fmt.Errorf("internal error: missing node with hash %s", hex.EncodeToString(hash[:])) - } - - hash := node.Hash() - // if node is a leaf, path is a slice of 32 bytes, otherwise path is nil - var path []byte - if node.IsLeaf() { - temp := *node.Path() - path = temp[:] - } - storableNode := &StorableNode{ - LIndex: leftIndex, - RIndex: rightIndex, - Height: uint16(node.Height()), - Path: path, - EncPayload: encoding.EncodePayload(node.Payload()), - HashValue: hash[:], - MaxDepth: node.MaxDepth(), - RegCount: node.RegCount(), - } - return storableNode, nil -} - -func toStorableTrie(mtrie *trie.MTrie, indexForNode node2indexMap) (*StorableTrie, error) { - rootIndex, found := indexForNode[mtrie.RootNode()] - if !found { - hash := mtrie.RootNode().Hash() - return nil, fmt.Errorf("internal error: missing node with hash %s", hex.EncodeToString(hash[:])) - } - hash := mtrie.RootHash() - storableTrie := &StorableTrie{ - RootIndex: rootIndex, - RootHash: hash[:], - } - - return storableTrie, nil -} - -// RebuildTries construct a forest from a storable FlattenedForest -func RebuildTries(flatForest *FlattenedForest) ([]*trie.MTrie, error) { - tries := make([]*trie.MTrie, 0, len(flatForest.Tries)) - nodes, err := RebuildNodes(flatForest.Nodes) - if err != nil { - return nil, fmt.Errorf("reconstructing nodes from storables failed: %w", err) - } - - //restore tries - for _, storableTrie := range flatForest.Tries { - mtrie, err := trie.NewMTrie(nodes[storableTrie.RootIndex]) - if err != nil { - return nil, fmt.Errorf("restoring trie failed: %w", err) - } - rootHash := mtrie.RootHash() - if !bytes.Equal(storableTrie.RootHash, rootHash[:]) { - return nil, fmt.Errorf("restoring trie failed: roothash doesn't match") - } - tries = append(tries, mtrie) - } - return tries, nil -} - -// RebuildNodes generates a list of Nodes from a sequence of StorableNodes. -// The sequence must obey the DESCENDANTS-FIRST-RELATIONSHIP -func RebuildNodes(storableNodes []*StorableNode) ([]*node.Node, error) { - nodes := make([]*node.Node, 0, len(storableNodes)) - for i, snode := range storableNodes { - if snode == nil { - nodes = append(nodes, nil) - continue - } - if (snode.LIndex >= uint64(i)) || (snode.RIndex >= uint64(i)) { - return nil, fmt.Errorf("sequence of StorableNodes does not satisfy Descendents-First-Relationship") - } - - if len(snode.Path) > 0 { - path, err := ledger.ToPath(snode.Path) - if err != nil { - return nil, fmt.Errorf("failed to decode a path of a storableNode %w", err) - } - payload, err := encoding.DecodePayload(snode.EncPayload) - if err != nil { - return nil, fmt.Errorf("failed to decode a payload for an storableNode %w", err) - } - nodeHash, err := hash.ToHash(snode.HashValue) - if err != nil { - return nil, fmt.Errorf("failed to decode a hash of a storableNode %w", err) - } - // make a copy of payload - var pl *ledger.Payload - if payload != nil { - pl = payload.DeepCopy() - } - - node := node.NewNode(int(snode.Height), nodes[snode.LIndex], nodes[snode.RIndex], path, pl, nodeHash, snode.MaxDepth, snode.RegCount) - nodes = append(nodes, node) - continue - } - nodeHash, err := hash.ToHash(snode.HashValue) - if err != nil { - return nil, fmt.Errorf("failed to decode a hash of a storableNode %w", err) - } - node := node.NewNode(int(snode.Height), nodes[snode.LIndex], nodes[snode.RIndex], ledger.DummyPath, nil, nodeHash, snode.MaxDepth, snode.RegCount) - nodes = append(nodes, node) - } - return nodes, nil -} diff --git a/ledger/complete/mtrie/flattener/forest_test.go b/ledger/complete/mtrie/flattener/forest_test.go deleted file mode 100644 index 4762dbb7bf0..00000000000 --- a/ledger/complete/mtrie/flattener/forest_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package flattener_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/onflow/flow-go/ledger" - "github.com/onflow/flow-go/ledger/common/utils" - "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" - "github.com/onflow/flow-go/module/metrics" -) - -func TestForestStoreAndLoad(t *testing.T) { - - metricsCollector := &metrics.NoopCollector{} - mForest, err := mtrie.NewForest(5, metricsCollector, nil) - require.NoError(t, err) - rootHash := mForest.GetEmptyRootHash() - - p1 := utils.PathByUint8(1) - v1 := utils.LightPayload8('A', 'a') - p2 := utils.PathByUint8(2) - v2 := utils.LightPayload8('B', 'b') - p3 := utils.PathByUint8(130) - v3 := utils.LightPayload8('C', 'c') - p4 := utils.PathByUint8(131) - v4 := utils.LightPayload8('D', 'd') - p5 := utils.PathByUint8(132) - v5 := utils.LightPayload8('E', 'e') - - paths := []ledger.Path{p1, p2, p3, p4, p5} - payloads := []*ledger.Payload{v1, v2, v3, v4, v5} - - update := &ledger.TrieUpdate{RootHash: rootHash, Paths: paths, Payloads: payloads} - rootHash, err = mForest.Update(update) - require.NoError(t, err) - - p6 := utils.PathByUint8(133) - v6 := utils.LightPayload8('F', 'f') - update = &ledger.TrieUpdate{RootHash: rootHash, Paths: []ledger.Path{p6}, Payloads: []*ledger.Payload{v6}} - rootHash, err = mForest.Update(update) - require.NoError(t, err) - - forestSequencing, err := flattener.FlattenForest(mForest) - require.NoError(t, err) - - newForest, err := mtrie.NewForest(5, metricsCollector, nil) - require.NoError(t, err) - - //forests are different - assert.NotEqual(t, mForest, newForest) - - rebuiltTries, err := flattener.RebuildTries(forestSequencing) - require.NoError(t, err) - err = newForest.AddTries(rebuiltTries) - require.NoError(t, err) - - //forests are the same now - assert.Equal(t, mForest, newForest) - - read := &ledger.TrieRead{RootHash: rootHash, Paths: paths} - retPayloads, err := mForest.Read(read) - require.NoError(t, err) - newRetPayloads, err := newForest.Read(read) - require.NoError(t, err) - for i := range paths { - require.True(t, retPayloads[i].Equals(newRetPayloads[i])) - } -} diff --git a/ledger/complete/mtrie/flattener/iterator.go b/ledger/complete/mtrie/flattener/iterator.go index 7e49a426471..dd01e5b0733 100644 --- a/ledger/complete/mtrie/flattener/iterator.go +++ b/ledger/complete/mtrie/flattener/iterator.go @@ -38,7 +38,7 @@ type NodeIterator struct { // no children, it can be recalled without restriction. // * When popping node `n` from the stack, its parent `p` (if it exists) is now the // head of the stack. - // - If `p` has only one child, this child is must be `n`. + // - If `p` has only one child, this child must be `n`. // Therefore, by recalling `n`, we have recalled all ancestors of `p`. // - If `n` is the right child, we haven already searched through all of `p` // descendents (as the `p.LeftChild` must have been searched before) @@ -53,6 +53,15 @@ type NodeIterator struct { // This has the advantage, that we gracefully handle tries whose root node is nil. unprocessedRoot *node.Node stack []*node.Node + // visitedNodes are nodes that were visited and can be skipped during + // traversal through dig(). visitedNodes is used to optimize node traveral + // IN FOREST by skipping nodes in shared sub-tries after they are visited, + // because sub-tries are shared between tries (original MTrie before register updates + // and updated MTrie after register writes). + // NodeIterator only uses visitedNodes for read operation. + // No special handling is needed if visitedNodes is nil. + // WARNING: visitedNodes is not safe for concurrent use. + visitedNodes map[*node.Node]uint64 } // NewNodeIterator returns a node NodeIterator, which iterates through all nodes @@ -65,6 +74,8 @@ type NodeIterator struct { // The Descendents-First-Relationship has the following important property: // When re-building the Trie from the sequence of nodes, one can build the trie on the fly, // as for each node, the children have been previously encountered. +// NodeIterator created by NewNodeIterator is safe for concurrent use +// because visitedNodes is always nil in this case. func NewNodeIterator(mTrie *trie.MTrie) *NodeIterator { // for a Trie with height H (measured by number of edges), the longest possible path contains H+1 vertices stackSize := ledger.NodeMaxHeight + 1 @@ -75,6 +86,30 @@ func NewNodeIterator(mTrie *trie.MTrie) *NodeIterator { return i } +// NewUniqueNodeIterator returns a node NodeIterator, which iterates through all unique nodes +// that weren't visited. This should be used for forest node iteration to avoid repeatedly +// traversing shared sub-tries. +// The Iterator guarantees a DESCENDANTS-FIRST-RELATIONSHIP in the sequence of nodes it generates: +// * Consider the sequence of nodes, in the order they are generated by NodeIterator. +// Let `node[k]` denote the node with index `k` in this sequence. +// * Descendents-First-Relationship means that for any `node[k]`, all its descendents +// have indices strictly smaller than k in the iterator's sequence. +// The Descendents-First-Relationship has the following important property: +// When re-building the Trie from the sequence of nodes, one can build the trie on the fly, +// as for each node, the children have been previously encountered. +// WARNING: visitedNodes is not safe for concurrent use. +func NewUniqueNodeIterator(mTrie *trie.MTrie, visitedNodes map[*node.Node]uint64) *NodeIterator { + // For a Trie with height H (measured by number of edges), the longest possible path + // contains H+1 vertices. + stackSize := ledger.NodeMaxHeight + 1 + i := &NodeIterator{ + stack: make([]*node.Node, 0, stackSize), + visitedNodes: visitedNodes, + } + i.unprocessedRoot = mTrie.RootNode() + return i +} + func (i *NodeIterator) Next() bool { if i.unprocessedRoot != nil { // initial call to Next() for a non-empty trie @@ -125,15 +160,22 @@ func (i *NodeIterator) dig(n *node.Node) { if n == nil { return } + if _, found := i.visitedNodes[n]; found { + return + } for { i.stack = append(i.stack, n) if lChild := n.LeftChild(); lChild != nil { - n = lChild - continue + if _, found := i.visitedNodes[lChild]; !found { + n = lChild + continue + } } if rChild := n.RightChild(); rChild != nil { - n = rChild - continue + if _, found := i.visitedNodes[rChild]; !found { + n = rChild + continue + } } return } diff --git a/ledger/complete/mtrie/flattener/iterator_test.go b/ledger/complete/mtrie/flattener/iterator_test.go index 8fda1ca4dc7..64d3574034e 100644 --- a/ledger/complete/mtrie/flattener/iterator_test.go +++ b/ledger/complete/mtrie/flattener/iterator_test.go @@ -9,6 +9,7 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/utils" "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/node" "github.com/onflow/flow-go/ledger/complete/mtrie/trie" ) @@ -73,3 +74,195 @@ func TestPopulatedTrie(t *testing.T) { require.False(t, itr.Next()) require.True(t, nil == itr.Value()) } + +func TestUniqueNodeIterator(t *testing.T) { + t.Run("empty trie", func(t *testing.T) { + emptyTrie := trie.NewEmptyMTrie() + + // visitedNodes is nil + itr := flattener.NewUniqueNodeIterator(emptyTrie, nil) + require.False(t, itr.Next()) + require.True(t, nil == itr.Value()) // initial iterator should return nil + + // visitedNodes is empty map + visitedNodes := make(map[*node.Node]uint64) + itr = flattener.NewUniqueNodeIterator(emptyTrie, visitedNodes) + require.False(t, itr.Next()) + require.True(t, nil == itr.Value()) // initial iterator should return nil + }) + + t.Run("trie", func(t *testing.T) { + emptyTrie := trie.NewEmptyMTrie() + + // key: 0000... + p1 := utils.PathByUint8(1) + v1 := utils.LightPayload8('A', 'a') + + // key: 0100.... + p2 := utils.PathByUint8(64) + v2 := utils.LightPayload8('B', 'b') + + paths := []ledger.Path{p1, p2} + payloads := []ledger.Payload{*v1, *v2} + + updatedTrie, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, payloads, true) + require.NoError(t, err) + + // n4 + // / + // / + // n3 + // / \ + // / \ + // n1 (p1/v1) n2 (p2/v2) + // + + expectedNodes := []*node.Node{ + updatedTrie.RootNode().LeftChild().LeftChild(), // n1 + updatedTrie.RootNode().LeftChild().RightChild(), // n2 + updatedTrie.RootNode().LeftChild(), // n3 + updatedTrie.RootNode(), // n4 + } + + // visitedNodes is nil + i := 0 + for itr := flattener.NewUniqueNodeIterator(updatedTrie, nil); itr.Next(); { + n := itr.Value() + require.True(t, i < len(expectedNodes)) + require.Equal(t, expectedNodes[i], n) + i++ + } + require.Equal(t, i, len(expectedNodes)) + + // visitedNodes is not nil, but it's pointless for iterating a single trie because + // there isn't any shared sub-trie. + visitedNodes := make(map[*node.Node]uint64) + i = 0 + for itr := flattener.NewUniqueNodeIterator(updatedTrie, visitedNodes); itr.Next(); { + n := itr.Value() + visitedNodes[n] = uint64(i) + + require.True(t, i < len(expectedNodes)) + require.Equal(t, expectedNodes[i], n) + i++ + } + require.Equal(t, i, len(expectedNodes)) + }) + + t.Run("forest", func(t *testing.T) { + + // tries is a slice of mtries to guarantee order. + var tries []*trie.MTrie + + emptyTrie := trie.NewEmptyMTrie() + + // key: 0000... + p1 := utils.PathByUint8(1) + v1 := utils.LightPayload8('A', 'a') + + // key: 0100.... + p2 := utils.PathByUint8(64) + v2 := utils.LightPayload8('B', 'b') + + paths := []ledger.Path{p1, p2} + payloads := []ledger.Payload{*v1, *v2} + + trie1, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, payloads, true) + require.NoError(t, err) + + // trie1 + // n4 + // / + // / + // n3 + // / \ + // / \ + // n1 (p1/v1) n2 (p2/v2) + // + + tries = append(tries, trie1) + + // New trie reuses its parent's left sub-trie. + + // key: 1000... + p3 := utils.PathByUint8(128) + v3 := utils.LightPayload8('C', 'c') + + // key: 1100.... + p4 := utils.PathByUint8(192) + v4 := utils.LightPayload8('D', 'd') + + paths = []ledger.Path{p3, p4} + payloads = []ledger.Payload{*v3, *v4} + + trie2, err := trie.NewTrieWithUpdatedRegisters(trie1, paths, payloads, true) + require.NoError(t, err) + + // trie2 + // n8 + // / \ + // / \ + // n3 n7 + // (shared) / \ + // / \ + // n5 n6 + // (p3/v3) (p4/v4) + + tries = append(tries, trie2) + + // New trie reuses its parent's right sub-trie, and left sub-trie's leaf node. + + // key: 0000... + v5 := utils.LightPayload8('E', 'e') + + paths = []ledger.Path{p1} + payloads = []ledger.Payload{*v5} + + trie3, err := trie.NewTrieWithUpdatedRegisters(trie2, paths, payloads, true) + require.NoError(t, err) + + // trie3 + // n11 + // / \ + // / \ + // n10 n7 + // / \ (shared) + // / \ + // n9 n2 + // (p1/v5) (shared) + + tries = append(tries, trie3) + + expectedNodes := []*node.Node{ + // unique nodes from trie1 + trie1.RootNode().LeftChild().LeftChild(), // n1 + trie1.RootNode().LeftChild().RightChild(), // n2 + trie1.RootNode().LeftChild(), // n3 + trie1.RootNode(), // n4 + // unique nodes from trie2 + trie2.RootNode().RightChild().LeftChild(), // n5 + trie2.RootNode().RightChild().RightChild(), // n6 + trie2.RootNode().RightChild(), // n7 + trie2.RootNode(), // n8 + // unique nodes from trie3 + trie3.RootNode().LeftChild().LeftChild(), // n9 + trie3.RootNode().LeftChild(), // n10 + trie3.RootNode(), // n11 + } + + // Use visitedNodes to prevent revisiting shared sub-tries. + visitedNodes := make(map[*node.Node]uint64) + i := 0 + for _, trie := range tries { + for itr := flattener.NewUniqueNodeIterator(trie, visitedNodes); itr.Next(); { + n := itr.Value() + visitedNodes[n] = uint64(i) + + require.True(t, i < len(expectedNodes)) + require.Equal(t, expectedNodes[i], n) + i++ + } + } + require.Equal(t, i, len(expectedNodes)) + }) +} diff --git a/ledger/complete/mtrie/flattener/storables.go b/ledger/complete/mtrie/flattener/storables.go deleted file mode 100644 index 9f35a812228..00000000000 --- a/ledger/complete/mtrie/flattener/storables.go +++ /dev/null @@ -1,18 +0,0 @@ -package flattener - -type StorableNode struct { - LIndex uint64 - RIndex uint64 - Height uint16 // Height where the node is at - Path []byte // path - EncPayload []byte // encoded data for payload - HashValue []byte - MaxDepth uint16 - RegCount uint64 -} - -// StorableTrie is a data structure for storing trie -type StorableTrie struct { - RootIndex uint64 - RootHash []byte -} diff --git a/ledger/complete/mtrie/flattener/trie.go b/ledger/complete/mtrie/flattener/trie.go deleted file mode 100644 index ec53bc564b5..00000000000 --- a/ledger/complete/mtrie/flattener/trie.go +++ /dev/null @@ -1,74 +0,0 @@ -package flattener - -import ( - "fmt" - - "github.com/onflow/flow-go/ledger/complete/mtrie/node" - "github.com/onflow/flow-go/ledger/complete/mtrie/trie" -) - -// FlattenedTrie is similar to FlattenedForest except only including a single trie -type FlattenedTrie struct { - Nodes []*StorableNode - Trie *StorableTrie -} - -// ToFlattenedForestWithASingleTrie converts the flattenedTrie into a FlattenedForest with only one trie included -func (ft *FlattenedTrie) ToFlattenedForestWithASingleTrie() *FlattenedForest { - storableTries := make([]*StorableTrie, 1) - storableTries[0] = ft.Trie - return &FlattenedForest{ - Nodes: ft.Nodes, - Tries: storableTries, - } -} - -// FlattenTrie returns the trie as a FlattenedTrie, which contains all nodes of that trie. -func FlattenTrie(trie *trie.MTrie) (*FlattenedTrie, error) { - storableNodes := []*StorableNode{nil} // 0th element is nil - - // assign unique value to every node - allNodes := make(map[*node.Node]uint64) - allNodes[nil] = 0 // 0th element is nil - - counter := uint64(1) // start from 1, as 0 marks nil - for itr := NewNodeIterator(trie); itr.Next(); { - n := itr.Value() - // if node not in map - if _, has := allNodes[n]; !has { - allNodes[n] = counter - counter++ - storableNode, err := toStorableNode(n, allNodes) - if err != nil { - return nil, fmt.Errorf("failed to construct storable node: %w", err) - } - storableNodes = append(storableNodes, storableNode) - } - } - // fix root nodes indices - // since we indexed all nodes, root must be present - storableTrie, err := toStorableTrie(trie, allNodes) - if err != nil { - return nil, fmt.Errorf("failed to construct storable trie: %w", err) - } - - return &FlattenedTrie{ - Nodes: storableNodes, - Trie: storableTrie, - }, nil -} - -// RebuildTrie construct a trie from a storable FlattenedForest -func RebuildTrie(flatTrie *FlattenedTrie) (*trie.MTrie, error) { - nodes, err := RebuildNodes(flatTrie.Nodes) - if err != nil { - return nil, fmt.Errorf("reconstructing nodes from storables failed: %w", err) - } - - //restore tries - mtrie, err := trie.NewMTrie(nodes[flatTrie.Trie.RootIndex]) - if err != nil { - return nil, fmt.Errorf("restoring trie failed: %w", err) - } - return mtrie, nil -} diff --git a/ledger/complete/mtrie/flattener/trie_test.go b/ledger/complete/mtrie/flattener/trie_test.go deleted file mode 100644 index c17326c0419..00000000000 --- a/ledger/complete/mtrie/flattener/trie_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package flattener_test - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/onflow/flow-go/ledger" - "github.com/onflow/flow-go/ledger/common/utils" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" - "github.com/onflow/flow-go/ledger/complete/mtrie/trie" -) - -func TestTrieStoreAndLoad(t *testing.T) { - dir, err := os.MkdirTemp("", "test-mtrie-") - require.NoError(t, err) - defer os.RemoveAll(dir) - - emptyTrie := trie.NewEmptyMTrie() - require.NoError(t, err) - - p1 := utils.PathByUint8(1) - v1 := utils.LightPayload8('A', 'a') - p2 := utils.PathByUint8(2) - v2 := utils.LightPayload8('B', 'b') - p3 := utils.PathByUint8(130) - v3 := utils.LightPayload8('C', 'c') - p4 := utils.PathByUint8(131) - v4 := utils.LightPayload8('D', 'd') - p5 := utils.PathByUint8(132) - v5 := utils.LightPayload8('E', 'e') - - paths := []ledger.Path{p1, p2, p3, p4, p5} - payloads := []ledger.Payload{*v1, *v2, *v3, *v4, *v5} - - newTrie, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, payloads, true) - require.NoError(t, err) - - flattedTrie, err := flattener.FlattenTrie(newTrie) - require.NoError(t, err) - - rebuiltTrie, err := flattener.RebuildTrie(flattedTrie) - require.NoError(t, err) - - //tries are the same now - assert.Equal(t, newTrie, rebuiltTrie) - - retPayloads := newTrie.UnsafeRead(paths) - newRetPayloads := rebuiltTrie.UnsafeRead(paths) - for i := range paths { - require.True(t, retPayloads[i].Equals(newRetPayloads[i])) - } -} diff --git a/ledger/complete/mtrie/forest.go b/ledger/complete/mtrie/forest.go index 69812496bfd..46e4783aab4 100644 --- a/ledger/complete/mtrie/forest.go +++ b/ledger/complete/mtrie/forest.go @@ -29,7 +29,7 @@ type Forest struct { // needed trie in the forest might cause a fatal application logic error. tries *lru.Cache forestCapacity int - onTreeEvicted func(tree *trie.MTrie) error + onTreeEvicted func(tree *trie.MTrie) metrics module.LedgerMetrics } @@ -40,7 +40,7 @@ type Forest struct { // THIS IS A ROUGH HEURISTIC as it might evict tries that are still needed. // Make sure you chose a sufficiently large forestCapacity, such that, when reaching the capacity, the // Least Recently Used trie will never be needed again. -func NewForest(forestCapacity int, metrics module.LedgerMetrics, onTreeEvicted func(tree *trie.MTrie) error) (*Forest, error) { +func NewForest(forestCapacity int, metrics module.LedgerMetrics, onTreeEvicted func(tree *trie.MTrie)) (*Forest, error) { // init LRU cache as a SHORTCUT for a usage-related storage eviction policy var cache *lru.Cache var err error @@ -50,8 +50,7 @@ func NewForest(forestCapacity int, metrics module.LedgerMetrics, onTreeEvicted f if !ok { panic(fmt.Sprintf("cache contains item of type %T", value)) } - //TODO Log error - _ = onTreeEvicted(trie) + onTreeEvicted(trie) }) } else { cache, err = lru.New(forestCapacity) diff --git a/ledger/complete/mtrie/trie/trie_test.go b/ledger/complete/mtrie/trie/trie_test.go index ee3c1a282eb..aa72767d31b 100644 --- a/ledger/complete/mtrie/trie/trie_test.go +++ b/ledger/complete/mtrie/trie/trie_test.go @@ -257,7 +257,7 @@ func sampleRandomRegisterWritesWithPrefix(rng *LinearCongruentialGenerator, numb nextRandomByteIndex = 0 } p[b] = nextRandomBytes[nextRandomByteIndex] - nextRandomByteIndex += 1 + nextRandomByteIndex++ } paths = append(paths, p) diff --git a/ledger/complete/wal/checkpointer.go b/ledger/complete/wal/checkpointer.go index 401823ea3da..a2ca396e77f 100644 --- a/ledger/complete/wal/checkpointer.go +++ b/ledger/complete/wal/checkpointer.go @@ -3,6 +3,7 @@ package wal import ( "bufio" "encoding/binary" + "encoding/hex" "fmt" "io" "os" @@ -14,6 +15,7 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/complete/mtrie" "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/node" "github.com/onflow/flow-go/ledger/complete/mtrie/trie" "github.com/onflow/flow-go/model/bootstrap" "github.com/onflow/flow-go/module/metrics" @@ -29,6 +31,29 @@ const VersionV1 uint16 = 0x01 // Version 3 contains a file checksum for detecting corrupted checkpoint files. const VersionV3 uint16 = 0x03 +// Version 4 contains a footer with node count and trie count (previously in the header). +// Version 4 also reduces checkpoint data size. See EncodeNode() and EncodeTrie() for more details. +const VersionV4 uint16 = 0x04 + +const ( + encMagicSize = 2 + encVersionSize = 2 + headerSize = encMagicSize + encVersionSize + encNodeCountSize = 8 + encTrieCountSize = 2 + crc32SumSize = 4 +) + +// defaultBufioReadSize replaces the default bufio buffer size of 4096 bytes. +// defaultBufioReadSize can be increased to 8KiB, 16KiB, 32KiB, etc. if it +// improves performance on typical EN hardware. +const defaultBufioReadSize = 1024 * 32 + +// defaultBufioWriteSize replaces the default bufio buffer size of 4096 bytes. +// defaultBufioWriteSize can be increased to 8KiB, 16KiB, 32KiB, etc. if it +// improves performance on typical EN hardware. +const defaultBufioWriteSize = 1024 * 32 + type Checkpointer struct { dir string wal *DiskWAL @@ -153,26 +178,16 @@ func (c *Checkpointer) Checkpoint(to int, targetWriter func() (io.WriteCloser, e return fmt.Errorf("no segments to checkpoint to %d, latests not checkpointed segment: %d", to, notCheckpointedTo) } - forest, err := mtrie.NewForest(c.forestCapacity, &metrics.NoopCollector{}, func(evictedTrie *trie.MTrie) error { - return nil - }) + forest, err := mtrie.NewForest(c.forestCapacity, &metrics.NoopCollector{}, nil) if err != nil { return fmt.Errorf("cannot create Forest: %w", err) } + c.wal.log.Info().Msgf("creating checkpoint %d", to) + err = c.wal.replay(0, to, - func(forestSequencing *flattener.FlattenedForest) error { - tries, err := flattener.RebuildTries(forestSequencing) - if err != nil { - return err - } - for _, t := range tries { - err := forest.AddTrie(t) - if err != nil { - return err - } - } - return nil + func(tries []*trie.MTrie) error { + return forest.AddTries(tries) }, func(update *ledger.TrieUpdate) error { _, err := forest.Update(update) @@ -185,11 +200,9 @@ func (c *Checkpointer) Checkpoint(to int, targetWriter func() (io.WriteCloser, e return fmt.Errorf("cannot replay WAL: %w", err) } - c.wal.log.Info().Msgf("flattening forest for checkpoint %d", to) - - forestSequencing, err := flattener.FlattenForest(forest) + tries, err := forest.GetTries() if err != nil { - return fmt.Errorf("cannot get storables: %w", err) + return fmt.Errorf("cannot get forest tries: %w", err) } c.wal.log.Info().Msgf("serializing checkpoint %d", to) @@ -206,7 +219,7 @@ func (c *Checkpointer) Checkpoint(to int, targetWriter func() (io.WriteCloser, e } }() - err = StoreCheckpoint(forestSequencing, writer) + err = StoreCheckpoint(writer, tries...) return err } @@ -242,7 +255,7 @@ func CreateCheckpointWriterForFile(dir, filename string) (io.WriteCloser, error) return nil, fmt.Errorf("cannot create temporary file for checkpoint %v: %w", tmpFile, err) } - writer := bufio.NewWriter(tmpFile) + writer := bufio.NewWriterSize(tmpFile, defaultBufioWriteSize) return &SyncOnCloseRenameFile{ file: tmpFile, targetName: fullname, @@ -250,59 +263,139 @@ func CreateCheckpointWriterForFile(dir, filename string) (io.WriteCloser, error) }, nil } -// StoreCheckpoint writes the given checkpoint to disk, and also append with a CRC32 file checksum for integrity check. -func StoreCheckpoint(forestSequencing *flattener.FlattenedForest, writer io.Writer) error { - storableNodes := forestSequencing.Nodes - storableTries := forestSequencing.Tries - header := make([]byte, 4+8+2) +// StoreCheckpoint writes the given tries to checkpoint file, and also appends +// a CRC32 file checksum for integrity check. +// Checkpoint file consists of a flattened forest. Specifically, it consists of: +// * a list of encoded nodes, where references to other nodes are by list index. +// * a list of encoded tries, each referencing their respective root node by index. +// Referencing to other nodes by index 0 is a special case, meaning nil. +// +// As an important property, the nodes are listed in an order which satisfies +// Descendents-First-Relationship. The Descendents-First-Relationship has the +// following important property: +// When rebuilding the trie from the sequence of nodes, build the trie on the fly, +// as for each node, the children have been previously encountered. +// TODO: evaluate alternatives to CRC32 since checkpoint file is many GB in size. +// TODO: add concurrency if the performance gains are enough to offset complexity. +func StoreCheckpoint(writer io.Writer, tries ...*trie.MTrie) error { crc32Writer := NewCRC32Writer(writer) - pos := writeUint16(header, 0, MagicBytes) - pos = writeUint16(header, pos, VersionV3) - pos = writeUint64(header, pos, uint64(len(storableNodes)-1)) // -1 to account for 0 node meaning nil - writeUint16(header, pos, uint16(len(storableTries))) + // Scratch buffer is used as temporary buffer that node can encode into. + // Data in scratch buffer should be copied or used before scratch buffer is used again. + // If the scratch buffer isn't large enough, a new buffer will be allocated. + // However, 4096 bytes will be large enough to handle almost all payloads + // and 100% of interim nodes. + scratch := make([]byte, 1024*4) + + // Write header: magic (2 bytes) + version (2 bytes) + header := scratch[:headerSize] + binary.BigEndian.PutUint16(header, MagicBytes) + binary.BigEndian.PutUint16(header[encMagicSize:], VersionV4) _, err := crc32Writer.Write(header) if err != nil { return fmt.Errorf("cannot write checkpoint header: %w", err) } - // 0 element = nil, we don't need to store it - for i := 1; i < len(storableNodes); i++ { - bytes := flattener.EncodeStorableNode(storableNodes[i]) - _, err = crc32Writer.Write(bytes) - if err != nil { - return fmt.Errorf("error while writing node date: %w", err) + // allNodes contains all unique nodes of given tries and their index + // (ordered by node traversal sequence). + // Index 0 is a special case with nil node. + allNodes := make(map[*node.Node]uint64) + allNodes[nil] = 0 + + allRootNodes := make([]*node.Node, len(tries)) + + // Serialize all unique nodes + nodeCounter := uint64(1) // start from 1, as 0 marks nil node + for i, t := range tries { + + // Traverse all unique nodes for trie t. + for itr := flattener.NewUniqueNodeIterator(t, allNodes); itr.Next(); { + n := itr.Value() + + allNodes[n] = nodeCounter + nodeCounter++ + + var lchildIndex, rchildIndex uint64 + + if lchild := n.LeftChild(); lchild != nil { + var found bool + lchildIndex, found = allNodes[lchild] + if !found { + hash := lchild.Hash() + return fmt.Errorf("internal error: missing node with hash %s", hex.EncodeToString(hash[:])) + } + } + if rchild := n.RightChild(); rchild != nil { + var found bool + rchildIndex, found = allNodes[rchild] + if !found { + hash := rchild.Hash() + return fmt.Errorf("internal error: missing node with hash %s", hex.EncodeToString(hash[:])) + } + } + + encNode := flattener.EncodeNode(n, lchildIndex, rchildIndex, scratch) + _, err = crc32Writer.Write(encNode) + if err != nil { + return fmt.Errorf("cannot serialize node: %w", err) + } } + + // Save trie root for serialization later. + allRootNodes[i] = t.RootNode() } - for _, storableTrie := range storableTries { - bytes := flattener.EncodeStorableTrie(storableTrie) - _, err = crc32Writer.Write(bytes) + // Serialize trie root nodes + for _, rootNode := range allRootNodes { + // Get root node index + rootIndex, found := allNodes[rootNode] + if !found { + var rootHash ledger.RootHash + if rootNode == nil { + rootHash = trie.EmptyTrieRootHash() + } else { + rootHash = ledger.RootHash(rootNode.Hash()) + } + return fmt.Errorf("internal error: missing node with hash %s", hex.EncodeToString(rootHash[:])) + } + + encTrie := flattener.EncodeTrie(rootNode, rootIndex, scratch) + _, err = crc32Writer.Write(encTrie) if err != nil { - return fmt.Errorf("error while writing trie date: %w", err) + return fmt.Errorf("cannot serialize trie: %w", err) } } - // add CRC32 sum - crc32buf := make([]byte, 4) - writeUint32(crc32buf, 0, crc32Writer.Crc32()) + // Write footer with nodes count and tries count + footer := scratch[:encNodeCountSize+encTrieCountSize] + binary.BigEndian.PutUint64(footer, uint64(len(allNodes)-1)) // -1 to account for 0 node meaning nil + binary.BigEndian.PutUint16(footer[encNodeCountSize:], uint16(len(allRootNodes))) + + _, err = crc32Writer.Write(footer) + if err != nil { + return fmt.Errorf("cannot write checkpoint footer: %w", err) + } + + // Write CRC32 sum + crc32buf := scratch[:crc32SumSize] + binary.BigEndian.PutUint32(crc32buf, crc32Writer.Crc32()) _, err = writer.Write(crc32buf) if err != nil { - return fmt.Errorf("cannot write crc32: %w", err) + return fmt.Errorf("cannot write CRC32: %w", err) } return nil } -func (c *Checkpointer) LoadCheckpoint(checkpoint int) (*flattener.FlattenedForest, error) { +func (c *Checkpointer) LoadCheckpoint(checkpoint int) ([]*trie.MTrie, error) { filepath := path.Join(c.dir, NumberToFilename(checkpoint)) return LoadCheckpoint(filepath) } -func (c *Checkpointer) LoadRootCheckpoint() (*flattener.FlattenedForest, error) { +func (c *Checkpointer) LoadRootCheckpoint() ([]*trie.MTrie, error) { filepath := path.Join(c.dir, bootstrap.FilenameWALRootCheckpoint) return LoadCheckpoint(filepath) } @@ -321,7 +414,7 @@ func (c *Checkpointer) RemoveCheckpoint(checkpoint int) error { return os.Remove(path.Join(c.dir, NumberToFilename(checkpoint))) } -func LoadCheckpoint(filepath string) (*flattener.FlattenedForest, error) { +func LoadCheckpoint(filepath string) ([]*trie.MTrie, error) { file, err := os.Open(filepath) if err != nil { return nil, fmt.Errorf("cannot open checkpoint file %s: %w", filepath, err) @@ -330,106 +423,224 @@ func LoadCheckpoint(filepath string) (*flattener.FlattenedForest, error) { _ = file.Close() }() - return ReadCheckpoint(file) + return readCheckpoint(file) } -func ReadCheckpoint(r io.Reader) (*flattener.FlattenedForest, error) { +func readCheckpoint(f *os.File) ([]*trie.MTrie, error) { - var bufReader io.Reader = bufio.NewReader(r) - crcReader := NewCRC32Reader(bufReader) - var reader io.Reader = crcReader + // Read header: magic (2 bytes) + version (2 bytes) + header := make([]byte, headerSize) + _, err := io.ReadFull(f, header) + if err != nil { + return nil, fmt.Errorf("cannot read header: %w", err) + } - header := make([]byte, 4+8+2) + // Decode header + magicBytes := binary.BigEndian.Uint16(header) + version := binary.BigEndian.Uint16(header[encMagicSize:]) - _, err := io.ReadFull(reader, header) + // Reset offset + _, err = f.Seek(0, io.SeekStart) if err != nil { - return nil, fmt.Errorf("cannot read header bytes: %w", err) + return nil, fmt.Errorf("cannot seek to start of file: %w", err) } - magicBytes, pos := readUint16(header, 0) - version, pos := readUint16(header, pos) - nodesCount, pos := readUint64(header, pos) - triesCount, _ := readUint16(header, pos) - if magicBytes != MagicBytes { return nil, fmt.Errorf("unknown file format. Magic constant %x does not match expected %x", magicBytes, MagicBytes) } - if version != VersionV1 && version != VersionV3 { - return nil, fmt.Errorf("unsupported file version %x ", version) + + switch version { + case VersionV1, VersionV3: + return readCheckpointV3AndEarlier(f, version) + case VersionV4: + return readCheckpointV4(f) + default: + return nil, fmt.Errorf("unsupported file version %x", version) } +} + +// readCheckpointV3AndEarlier deserializes checkpoint file (version 3 and earlier) and returns a list of tries. +// Header (magic and version) is verified by the caller. +// This function is for backwards compatibility, not optimized. +func readCheckpointV3AndEarlier(f *os.File, version uint16) ([]*trie.MTrie, error) { + + var bufReader io.Reader = bufio.NewReaderSize(f, defaultBufioReadSize) + crcReader := NewCRC32Reader(bufReader) + + var reader io.Reader if version != VersionV3 { - reader = bufReader //switch back to plain reader + reader = bufReader + } else { + reader = crcReader } - nodes := make([]*flattener.StorableNode, nodesCount+1) //+1 for 0 index meaning nil - tries := make([]*flattener.StorableTrie, triesCount) + // Read header (magic + version), node count, and trie count. + header := make([]byte, headerSize+encNodeCountSize+encTrieCountSize) + + _, err := io.ReadFull(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot read header: %w", err) + } + + // Magic and version are verified by the caller. + + // Decode node count and trie count + nodesCount := binary.BigEndian.Uint64(header[headerSize:]) + triesCount := binary.BigEndian.Uint16(header[headerSize+encNodeCountSize:]) + + nodes := make([]*node.Node, nodesCount+1) //+1 for 0 index meaning nil + tries := make([]*trie.MTrie, triesCount) for i := uint64(1); i <= nodesCount; i++ { - storableNode, err := flattener.ReadStorableNode(reader) + n, err := flattener.ReadNodeFromCheckpointV3AndEarlier(reader, func(nodeIndex uint64) (*node.Node, error) { + if nodeIndex >= uint64(i) { + return nil, fmt.Errorf("sequence of stored nodes does not satisfy Descendents-First-Relationship") + } + return nodes[nodeIndex], nil + }) if err != nil { - return nil, fmt.Errorf("cannot read storable node %d: %w", i, err) + return nil, fmt.Errorf("cannot read node %d: %w", i, err) } - nodes[i] = storableNode + nodes[i] = n } - // TODO version ? for i := uint16(0); i < triesCount; i++ { - storableTrie, err := flattener.ReadStorableTrie(reader) + trie, err := flattener.ReadTrieFromCheckpointV3AndEarlier(reader, func(nodeIndex uint64) (*node.Node, error) { + if nodeIndex >= uint64(len(nodes)) { + return nil, fmt.Errorf("sequence of stored nodes doesn't contain node") + } + return nodes[nodeIndex], nil + }) if err != nil { - return nil, fmt.Errorf("cannot read storable trie %d: %w", i, err) + return nil, fmt.Errorf("cannot read trie %d: %w", i, err) } - tries[i] = storableTrie + tries[i] = trie } if version == VersionV3 { - crc32buf := make([]byte, 4) - _, err := bufReader.Read(crc32buf) + crc32buf := make([]byte, crc32SumSize) + + _, err := io.ReadFull(bufReader, crc32buf) if err != nil { - return nil, fmt.Errorf("error while reading CRC32 checksum: %w", err) + return nil, fmt.Errorf("cannot read CRC32: %w", err) } - readCrc32, _ := readUint32(crc32buf, 0) + + readCrc32 := binary.BigEndian.Uint32(crc32buf) calculatedCrc32 := crcReader.Crc32() if calculatedCrc32 != readCrc32 { - return nil, fmt.Errorf("checkpoint checksum failed! File contains %x but read data checksums to %x", readCrc32, calculatedCrc32) + return nil, fmt.Errorf("checkpoint checksum failed! File contains %x but calculated crc32 is %x", readCrc32, calculatedCrc32) } } - return &flattener.FlattenedForest{ - Nodes: nodes, - Tries: tries, - }, nil - + return tries, nil } -func writeUint16(buffer []byte, location int, value uint16) int { - binary.BigEndian.PutUint16(buffer[location:], value) - return location + 2 -} +// readCheckpointV4 deserializes checkpoint file (version 4) and returns a list of tries. +// Checkpoint file header (magic and version) are verified by the caller. +func readCheckpointV4(f *os.File) ([]*trie.MTrie, error) { -func readUint16(buffer []byte, location int) (uint16, int) { - value := binary.BigEndian.Uint16(buffer[location:]) - return value, location + 2 -} + // Scratch buffer is used as temporary buffer that reader can read into. + // Raw data in scratch buffer should be copied or converted into desired + // objects before next Read operation. If the scratch buffer isn't large + // enough, a new buffer will be allocated. However, 4096 bytes will + // be large enough to handle almost all payloads and 100% of interim nodes. + scratch := make([]byte, 1024*4) // must not be less than 1024 -func writeUint32(buffer []byte, location int, value uint32) int { - binary.BigEndian.PutUint32(buffer[location:], value) - return location + 4 -} + // Read footer to get node count and trie count -func readUint32(buffer []byte, location int) (uint32, int) { - value := binary.BigEndian.Uint32(buffer[location:]) - return value, location + 4 -} + // footer offset: nodes count (8 bytes) + tries count (2 bytes) + CRC32 sum (4 bytes) + const footerOffset = encNodeCountSize + encTrieCountSize + crc32SumSize + const footerSize = encNodeCountSize + encTrieCountSize // footer doesn't include crc32 sum -func readUint64(buffer []byte, location int) (uint64, int) { - value := binary.BigEndian.Uint64(buffer[location:]) - return value, location + 8 -} + // Seek to footer + _, err := f.Seek(-footerOffset, io.SeekEnd) + if err != nil { + return nil, fmt.Errorf("cannot seek to footer: %w", err) + } + + footer := scratch[:footerSize] + + _, err = io.ReadFull(f, footer) + if err != nil { + return nil, fmt.Errorf("cannot read footer: %w", err) + } + + // Decode node count and trie count + nodesCount := binary.BigEndian.Uint64(footer) + triesCount := binary.BigEndian.Uint16(footer[encNodeCountSize:]) + + // Seek to the start of file + _, err = f.Seek(0, io.SeekStart) + if err != nil { + return nil, fmt.Errorf("cannot seek to start of file: %w", err) + } + + var bufReader io.Reader = bufio.NewReaderSize(f, defaultBufioReadSize) + crcReader := NewCRC32Reader(bufReader) + var reader io.Reader = crcReader + + // Read header: magic (2 bytes) + version (2 bytes) + // No action is needed for header because it is verified by the caller. + + _, err = io.ReadFull(reader, scratch[:headerSize]) + if err != nil { + return nil, fmt.Errorf("cannot read header: %w", err) + } + + // nodes's element at index 0 is a special, meaning nil . + nodes := make([]*node.Node, nodesCount+1) //+1 for 0 index meaning nil + tries := make([]*trie.MTrie, triesCount) + + for i := uint64(1); i <= nodesCount; i++ { + n, err := flattener.ReadNode(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + if nodeIndex >= uint64(i) { + return nil, fmt.Errorf("sequence of serialized nodes does not satisfy Descendents-First-Relationship") + } + return nodes[nodeIndex], nil + }) + if err != nil { + return nil, fmt.Errorf("cannot read node %d: %w", i, err) + } + nodes[i] = n + } + + for i := uint16(0); i < triesCount; i++ { + trie, err := flattener.ReadTrie(reader, scratch, func(nodeIndex uint64) (*node.Node, error) { + if nodeIndex >= uint64(len(nodes)) { + return nil, fmt.Errorf("sequence of stored nodes doesn't contain node") + } + return nodes[nodeIndex], nil + }) + if err != nil { + return nil, fmt.Errorf("cannot read trie %d: %w", i, err) + } + tries[i] = trie + } + + // Read footer again for crc32 computation + // No action is needed. + _, err = io.ReadFull(reader, footer) + if err != nil { + return nil, fmt.Errorf("cannot read footer: %w", err) + } + + // Read CRC32 + crc32buf := scratch[:crc32SumSize] + _, err = io.ReadFull(bufReader, crc32buf) + if err != nil { + return nil, fmt.Errorf("cannot read CRC32: %w", err) + } + + readCrc32 := binary.BigEndian.Uint32(crc32buf) + + calculatedCrc32 := crcReader.Crc32() + + if calculatedCrc32 != readCrc32 { + return nil, fmt.Errorf("checkpoint checksum failed! File contains %x but calculated crc32 is %x", readCrc32, calculatedCrc32) + } -func writeUint64(buffer []byte, location int, value uint64) int { - binary.BigEndian.PutUint64(buffer[location:], value) - return location + 8 + return tries, nil } diff --git a/ledger/complete/wal/checkpointer_test.go b/ledger/complete/wal/checkpointer_test.go index ef7cb6d0ef0..cd811ce45c4 100644 --- a/ledger/complete/wal/checkpointer_test.go +++ b/ledger/complete/wal/checkpointer_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "math/rand" "os" "path" @@ -21,7 +22,6 @@ import ( "github.com/onflow/flow-go/ledger/common/utils" "github.com/onflow/flow-go/ledger/complete" "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" "github.com/onflow/flow-go/ledger/complete/mtrie/trie" realWAL "github.com/onflow/flow-go/ledger/complete/wal" "github.com/onflow/flow-go/module/metrics" @@ -120,7 +120,7 @@ func Test_Checkpointing(t *testing.T) { unittest.RunWithTempDir(t, func(dir string) { - f, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) var rootHash = f.GetEmptyRootHash() @@ -170,7 +170,7 @@ func Test_Checkpointing(t *testing.T) { }) // create a new forest and replay WAL - f2, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f2, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) t.Run("replay WAL and create checkpoint", func(t *testing.T) { @@ -181,7 +181,7 @@ func Test_Checkpointing(t *testing.T) { require.NoError(t, err) err = wal2.Replay( - func(forestSequencing *flattener.FlattenedForest) error { + func(tries []*trie.MTrie) error { return fmt.Errorf("I should fail as there should be no checkpoints") }, func(update *ledger.TrieUpdate) error { @@ -207,7 +207,7 @@ func Test_Checkpointing(t *testing.T) { <-wal2.Done() }) - f3, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f3, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) t.Run("read checkpoint", func(t *testing.T) { @@ -215,8 +215,8 @@ func Test_Checkpointing(t *testing.T) { require.NoError(t, err) err = wal3.Replay( - func(forestSequencing *flattener.FlattenedForest) error { - return loadIntoForest(f3, forestSequencing) + func(tries []*trie.MTrie) error { + return f3.AddTries(tries) }, func(update *ledger.TrieUpdate) error { return fmt.Errorf("I should fail as there should be no updates") @@ -287,7 +287,7 @@ func Test_Checkpointing(t *testing.T) { require.FileExists(t, path.Join(dir, "00000011")) //make sure we have extra segment }) - f5, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f5, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) t.Run("replay both checkpoint and updates after checkpoint", func(t *testing.T) { @@ -297,8 +297,8 @@ func Test_Checkpointing(t *testing.T) { updatesLeft := 1 // there should be only one update err = wal5.Replay( - func(forestSequencing *flattener.FlattenedForest) error { - return loadIntoForest(f5, forestSequencing) + func(tries []*trie.MTrie) error { + return f5.AddTries(tries) }, func(update *ledger.TrieUpdate) error { if updatesLeft == 0 { @@ -338,7 +338,7 @@ func Test_Checkpointing(t *testing.T) { t.Run("corrupted checkpoints are skipped", func(t *testing.T) { - f6, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f6, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) wal6, err := realWAL.NewDiskWAL(zerolog.Nop(), nil, metrics.NewNoopCollector(), dir, size*10, pathByteSize, segmentSize) @@ -495,7 +495,7 @@ func randomlyModifyFile(t *testing.T, filename string) { require.NoError(t, err) // byte addition will simply wrap around - buf[0] += 1 + buf[0]++ _, err = file.WriteAt(buf, offset) require.NoError(t, err) @@ -503,59 +503,60 @@ func randomlyModifyFile(t *testing.T, filename string) { func Test_StoringLoadingCheckpoints(t *testing.T) { - // some hash will be literally copied into the output file - // so we can find it and modify - to make sure we get a different checksum - // but not fail process by, for example, modifying saved data length causing EOF - someHash := []byte{22, 22, 22} - forest := &flattener.FlattenedForest{ - Nodes: []*flattener.StorableNode{ - {}, {}, - }, - Tries: []*flattener.StorableTrie{ - {}, { - RootHash: someHash, - }, - }, - } - buffer := &bytes.Buffer{} + unittest.RunWithTempDir(t, func(dir string) { + // some hash will be literally encoded in output file + // so we can find it and modify - to make sure we get a different checksum + // but not fail process by, for example, modifying saved data length causing EOF - err := realWAL.StoreCheckpoint(forest, buffer) - require.NoError(t, err) + emptyTrie := trie.NewEmptyMTrie() - // copy buffer data - bytes2 := buffer.Bytes()[:] + p1 := utils.PathByUint8(0) + v1 := utils.LightPayload8('A', 'a') - t.Run("works without data modification", func(t *testing.T) { + p2 := utils.PathByUint8(1) + v2 := utils.LightPayload8('B', 'b') - // first buffer reads ok - _, err = realWAL.ReadCheckpoint(buffer) + paths := []ledger.Path{p1, p2} + payloads := []ledger.Payload{*v1, *v2} + + updatedTrie, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, payloads, true) require.NoError(t, err) - }) - t.Run("detects modified data", func(t *testing.T) { + someHash := updatedTrie.RootNode().LeftChild().Hash() // Hash of left child - index := bytes.Index(bytes2, someHash) - bytes2[index] = 23 + file, err := ioutil.TempFile(dir, "temp-checkpoint") + filepath := file.Name() + require.NoError(t, err) - _, err = realWAL.ReadCheckpoint(bytes.NewBuffer(bytes2)) - require.Error(t, err) - require.Contains(t, err.Error(), "checksum") - }) + err = realWAL.StoreCheckpoint(file, updatedTrie) + require.NoError(t, err) -} + file.Close() -func loadIntoForest(forest *mtrie.Forest, forestSequencing *flattener.FlattenedForest) error { - tries, err := flattener.RebuildTries(forestSequencing) - if err != nil { - return err - } - for _, t := range tries { - err := forest.AddTrie(t) - if err != nil { - return err - } - } - return nil + t.Run("works without data modification", func(t *testing.T) { + tries, err := realWAL.LoadCheckpoint(filepath) + require.NoError(t, err) + require.Equal(t, 1, len(tries)) + require.Equal(t, updatedTrie, tries[0]) + }) + + t.Run("detects modified data", func(t *testing.T) { + b, err := ioutil.ReadFile(filepath) + require.NoError(t, err) + + index := bytes.Index(b, someHash[:]) + require.NotEqual(t, -1, index) + b[index] = 23 + + err = os.WriteFile(filepath, b, 0644) + require.NoError(t, err) + + tries, err := realWAL.LoadCheckpoint(filepath) + require.Error(t, err) + require.Nil(t, tries) + require.Contains(t, err.Error(), "checksum") + }) + }) } type writeCloserWithErrors struct { diff --git a/ledger/complete/wal/checkpointer_versioning_test.go b/ledger/complete/wal/checkpointer_versioning_test.go index 9f2ead3e3a4..1e093cc605f 100644 --- a/ledger/complete/wal/checkpointer_versioning_test.go +++ b/ledger/complete/wal/checkpointer_versioning_test.go @@ -1,52 +1,159 @@ package wal import ( + "encoding/hex" "testing" "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger" ) -var v1Forest = &flattener.FlattenedForest{ - Nodes: []*flattener.StorableNode{ - nil, // node 0 is special and skipped - { - LIndex: 0, - RIndex: 0, - Height: 0, - Path: []byte{1}, - EncPayload: []byte{2}, - HashValue: []byte{3}, - MaxDepth: 1, - RegCount: 1, - }, { - LIndex: 1, - RIndex: 2, - Height: 3, - Path: []byte{11}, - EncPayload: []byte{22}, - HashValue: []byte{33}, - MaxDepth: 11, - RegCount: 11, - }, - }, - Tries: []*flattener.StorableTrie{ - { - RootIndex: 0, - RootHash: []byte{4}, - }, - { - RootIndex: 1, - RootHash: []byte{44}, - }, - }, +func Test_LoadingV1Checkpoint(t *testing.T) { + + expectedRootHash := [4]ledger.RootHash{ + mustToHash("568f4ec740fe3b5de88034cb7b1fbddb41548b068f31aebc8ae9189e429c5749"), // empty trie root hash + mustToHash("f53f9696b85b7428227f1b39f40b2ce07c175f58dea2b86cb6f84dc7c9fbeabd"), + mustToHash("7ac8daf34733cce3d5d03b5a1afde33a572249f81c45da91106412e94661e109"), + mustToHash("63df641430e5e0745c3d99ece6ac209467ccfdb77e362e7490a830db8e8803ae"), + } + + tries, err := LoadCheckpoint("test_data/checkpoint.v1") + require.NoError(t, err) + require.Equal(t, len(expectedRootHash), len(tries)) + + for i, trie := range tries { + require.Equal(t, expectedRootHash[i], trie.RootHash()) + require.True(t, trie.RootNode().VerifyCachedHash()) + } } -func Test_LoadingV1Checkpoint(t *testing.T) { +func Test_LoadingV3Checkpoint(t *testing.T) { + + expectedRootHash := [4]ledger.RootHash{ + mustToHash("568f4ec740fe3b5de88034cb7b1fbddb41548b068f31aebc8ae9189e429c5749"), // empty trie root hash + mustToHash("f53f9696b85b7428227f1b39f40b2ce07c175f58dea2b86cb6f84dc7c9fbeabd"), + mustToHash("7ac8daf34733cce3d5d03b5a1afde33a572249f81c45da91106412e94661e109"), + mustToHash("63df641430e5e0745c3d99ece6ac209467ccfdb77e362e7490a830db8e8803ae"), + } + + tries, err := LoadCheckpoint("test_data/checkpoint.v3") + require.NoError(t, err) + require.Equal(t, len(expectedRootHash), len(tries)) + + for i, trie := range tries { + require.Equal(t, expectedRootHash[i], trie.RootHash()) + require.True(t, trie.RootNode().VerifyCachedHash()) + } +} + +func mustToHash(s string) ledger.RootHash { + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + h, err := ledger.ToRootHash(b) + if err != nil { + panic(err) + } + return h +} + +/* +// CreateCheckpointV3 is used to create checkpoint.v3 test file used by Test_LoadingV3Checkpoint. +func CreateCheckpointV3() { + + f, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + require.NoError(t, err) + + emptyTrie := trie.NewEmptyMTrie() + + // key: 0000... + p1 := utils.PathByUint8(1) + v1 := utils.LightPayload8('A', 'a') + + // key: 0100.... + p2 := utils.PathByUint8(64) + v2 := utils.LightPayload8('B', 'b') + + paths := []ledger.Path{p1, p2} + payloads := []ledger.Payload{*v1, *v2} + + trie1, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, payloads, true) + require.NoError(t, err) + + // trie1 + // n4 + // / + // / + // n3 + // / \ + // / \ + // n1 (p1/v1) n2 (p2/v2) + // + + f.AddTrie(trie1) + + // New trie reuses its parent's left sub-trie. + + // key: 1000... + p3 := utils.PathByUint8(128) + v3 := utils.LightPayload8('C', 'c') + + // key: 1100.... + p4 := utils.PathByUint8(192) + v4 := utils.LightPayload8('D', 'd') + + paths = []ledger.Path{p3, p4} + payloads = []ledger.Payload{*v3, *v4} + + trie2, err := trie.NewTrieWithUpdatedRegisters(trie1, paths, payloads, true) + require.NoError(t, err) + + // trie2 + // n8 + // / \ + // / \ + // n3 n7 + // (shared) / \ + // / \ + // n5 n6 + // (p3/v3) (p4/v4) + + f.AddTrie(trie2) + + // New trie reuses its parent's right sub-trie, and left sub-trie's leaf node. + + // key: 0000... + v5 := utils.LightPayload8('E', 'e') + + paths = []ledger.Path{p1} + payloads = []ledger.Payload{*v5} + + trie3, err := trie.NewTrieWithUpdatedRegisters(trie2, paths, payloads, true) + require.NoError(t, err) + + // trie3 + // n11 + // / \ + // / \ + // n10 n7 + // / \ (shared) + // / \ + // n9 n2 + // (p1/v5) (shared) + + f.AddTrie(trie3) + + flattenedForest, err := flattener.FlattenForest(f) + require.NoError(t, err) + + file, err := os.OpenFile("checkpoint.v3", os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) + require.NoError(t, err) - forest, err := LoadCheckpoint("test_data/checkpoint.v1") + err = realWAL.StoreCheckpoint(flattenedForest, file) require.NoError(t, err) - require.Equal(t, v1Forest, forest) + file.Close() } +*/ diff --git a/ledger/complete/wal/compactor_test.go b/ledger/complete/wal/compactor_test.go index 8abab425ead..ab957c412a4 100644 --- a/ledger/complete/wal/compactor_test.go +++ b/ledger/complete/wal/compactor_test.go @@ -14,7 +14,6 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/utils" "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" "github.com/onflow/flow-go/ledger/complete/mtrie/trie" "github.com/onflow/flow-go/model/bootstrap" "github.com/onflow/flow-go/module/metrics" @@ -54,7 +53,7 @@ func Test_Compactor(t *testing.T) { unittest.RunWithTempDir(t, func(dir string) { - f, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) var rootHash = f.GetEmptyRootHash() @@ -158,7 +157,7 @@ func Test_Compactor(t *testing.T) { } }) - f2, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f2, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) time.Sleep(2 * time.Second) @@ -168,8 +167,8 @@ func Test_Compactor(t *testing.T) { require.NoError(t, err) err = wal2.Replay( - func(forestSequencing *flattener.FlattenedForest) error { - return loadIntoForest(f2, forestSequencing) + func(tries []*trie.MTrie) error { + return f2.AddTries(tries) }, func(update *ledger.TrieUpdate) error { _, err := f2.Update(update) @@ -234,7 +233,7 @@ func Test_Compactor_checkpointInterval(t *testing.T) { unittest.RunWithTempDir(t, func(dir string) { - f, err := mtrie.NewForest(size*10, metricsCollector, func(tree *trie.MTrie) error { return nil }) + f, err := mtrie.NewForest(size*10, metricsCollector, nil) require.NoError(t, err) var rootHash = f.GetEmptyRootHash() @@ -330,17 +329,3 @@ func Test_Compactor_checkpointInterval(t *testing.T) { }) }) } - -func loadIntoForest(forest *mtrie.Forest, forestSequencing *flattener.FlattenedForest) error { - tries, err := flattener.RebuildTries(forestSequencing) - if err != nil { - return err - } - for _, t := range tries { - err := forest.AddTrie(t) - if err != nil { - return err - } - } - return nil -} diff --git a/ledger/complete/wal/fixtures/noopwal.go b/ledger/complete/wal/fixtures/noopwal.go index 3f88fc6e557..8f705efdbf2 100644 --- a/ledger/complete/wal/fixtures/noopwal.go +++ b/ledger/complete/wal/fixtures/noopwal.go @@ -3,7 +3,7 @@ package fixtures import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/trie" "github.com/onflow/flow-go/ledger/complete/wal" ) @@ -37,10 +37,10 @@ func (w *NoopWAL) ReplayOnForest(forest *mtrie.Forest) error { return nil } func (w *NoopWAL) Segments() (first, last int, err error) { return 0, 0, nil } -func (w *NoopWAL) Replay(checkpointFn func(forestSequencing *flattener.FlattenedForest) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(ledger.RootHash) error) error { +func (w *NoopWAL) Replay(checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(ledger.RootHash) error) error { return nil } -func (w *NoopWAL) ReplayLogsOnly(checkpointFn func(forestSequencing *flattener.FlattenedForest) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(rootHash ledger.RootHash) error) error { +func (w *NoopWAL) ReplayLogsOnly(checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(rootHash ledger.RootHash) error) error { return nil } diff --git a/ledger/complete/wal/test_data/checkpoint.v1 b/ledger/complete/wal/test_data/checkpoint.v1 index c5567395b83..86f6d15684a 100644 Binary files a/ledger/complete/wal/test_data/checkpoint.v1 and b/ledger/complete/wal/test_data/checkpoint.v1 differ diff --git a/ledger/complete/wal/test_data/checkpoint.v3 b/ledger/complete/wal/test_data/checkpoint.v3 new file mode 100644 index 00000000000..e7745c797eb Binary files /dev/null and b/ledger/complete/wal/test_data/checkpoint.v3 differ diff --git a/ledger/complete/wal/wal.go b/ledger/complete/wal/wal.go index 3ce1bdbd7aa..63b4891c599 100644 --- a/ledger/complete/wal/wal.go +++ b/ledger/complete/wal/wal.go @@ -11,7 +11,7 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/complete/mtrie" - "github.com/onflow/flow-go/ledger/complete/mtrie/flattener" + "github.com/onflow/flow-go/ledger/complete/mtrie/trie" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/utils/io" ) @@ -105,12 +105,8 @@ func (w *DiskWAL) RecordDelete(rootHash ledger.RootHash) error { func (w *DiskWAL) ReplayOnForest(forest *mtrie.Forest) error { return w.Replay( - func(forestSequencing *flattener.FlattenedForest) error { - rebuiltTries, err := flattener.RebuildTries(forestSequencing) - if err != nil { - return fmt.Errorf("rebuilding forest from sequenced nodes failed: %w", err) - } - err = forest.AddTries(rebuiltTries) + func(tries []*trie.MTrie) error { + err := forest.AddTries(tries) if err != nil { return fmt.Errorf("adding rebuilt tries to forest failed: %w", err) } @@ -132,7 +128,7 @@ func (w *DiskWAL) Segments() (first, last int, err error) { } func (w *DiskWAL) Replay( - checkpointFn func(forestSequencing *flattener.FlattenedForest) error, + checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(ledger.RootHash) error, ) error { @@ -144,7 +140,7 @@ func (w *DiskWAL) Replay( } func (w *DiskWAL) ReplayLogsOnly( - checkpointFn func(forestSequencing *flattener.FlattenedForest) error, + checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(rootHash ledger.RootHash) error, ) error { @@ -157,7 +153,7 @@ func (w *DiskWAL) ReplayLogsOnly( func (w *DiskWAL) replay( from, to int, - checkpointFn func(forestSequencing *flattener.FlattenedForest) error, + checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(rootHash ledger.RootHash) error, useCheckpoints bool, @@ -343,12 +339,12 @@ type LedgerWAL interface { ReplayOnForest(forest *mtrie.Forest) error Segments() (first, last int, err error) Replay( - checkpointFn func(forestSequencing *flattener.FlattenedForest) error, + checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(ledger.RootHash) error, ) error ReplayLogsOnly( - checkpointFn func(forestSequencing *flattener.FlattenedForest) error, + checkpointFn func(tries []*trie.MTrie) error, updateFn func(update *ledger.TrieUpdate) error, deleteFn func(rootHash ledger.RootHash) error, ) error