diff --git a/cmd/util/cmd/execution-state-extract/cmd.go b/cmd/util/cmd/execution-state-extract/cmd.go index 55728b428a8..55e3432ba9c 100644 --- a/cmd/util/cmd/execution-state-extract/cmd.go +++ b/cmd/util/cmd/execution-state-extract/cmd.go @@ -3,10 +3,13 @@ package extract import ( "encoding/hex" "path" + "strings" "github.com/rs/zerolog/log" "github.com/spf13/cobra" + runtimeCommon "github.com/onflow/cadence/runtime/common" + "github.com/onflow/flow-go/cmd/util/cmd/common" "github.com/onflow/flow-go/model/bootstrap" "github.com/onflow/flow-go/model/flow" @@ -26,6 +29,8 @@ var ( flagNoReport bool flagValidateMigration bool flagLogVerboseValidationError bool + flagInputPayload bool + flagOutputPayloadByAddresses string ) var Cmd = &cobra.Command{ @@ -68,6 +73,19 @@ func init() { Cmd.Flags().BoolVar(&flagLogVerboseValidationError, "log-verbose-validation-error", false, "log entire Cadence values on validation error (atree migration)") + Cmd.Flags().StringVar( + &flagOutputPayloadByAddresses, + "extract-payloads-by-address", + "", + "extract payloads of specified addresses (comma separated list of hex-encoded addresses or \"all\"", // empty string ignores this flag + ) + + Cmd.Flags().BoolVar( + &flagInputPayload, + "use-payload-as-input", + false, + "use payload file instead of checkpoint file as input", + ) } func run(*cobra.Command, []string) { @@ -112,20 +130,65 @@ func run(*cobra.Command, []string) { log.Info().Msgf("extracting state by state commitment: %x", stateCommitment) } - if len(flagBlockHash) == 0 && len(flagStateCommitment) == 0 { - log.Fatal().Msg("no --block-hash or --state-commitment was specified") + if len(flagBlockHash) == 0 && len(flagStateCommitment) == 0 && !flagInputPayload { + log.Fatal().Msg("no --block-hash or --state-commitment or --use-payload-as-input was specified") } - log.Info().Msgf("Extracting state from %s, exporting root checkpoint to %s, version: %v", - flagExecutionStateDir, - path.Join(flagOutputDir, bootstrap.FilenameWALRootCheckpoint), - 6, - ) + exportPayloads := len(flagOutputPayloadByAddresses) > 0 + + var exportedAddresses []runtimeCommon.Address + + if exportPayloads { + + addresses := strings.Split(flagOutputPayloadByAddresses, ",") + + if len(addresses) == 1 && strings.TrimSpace(addresses[0]) == "all" { + // Extract payloads of the entire state. + log.Info().Msgf("Extracting state from %s, exporting all payloads to %s", + flagExecutionStateDir, + path.Join(flagOutputDir, FilenamePayloads), + ) + } else { + // Extract payloads of specified accounts + for _, hexAddr := range addresses { + b, err := hex.DecodeString(strings.TrimSpace(hexAddr)) + if err != nil { + log.Fatal().Err(err).Msgf("cannot hex decode address %s for payload export", strings.TrimSpace(hexAddr)) + } + + addr, err := runtimeCommon.BytesToAddress(b) + if err != nil { + log.Fatal().Err(err).Msgf("cannot decode address %x for payload export", b) + } + + exportedAddresses = append(exportedAddresses, addr) + } + + log.Info().Msgf("Extracting state from %s, exporting payloads by addresses %v to %s", + flagExecutionStateDir, + flagOutputPayloadByAddresses, + path.Join(flagOutputDir, FilenamePayloads), + ) + } - log.Info().Msgf("Block state commitment: %s from %v, output dir: %s", - hex.EncodeToString(stateCommitment[:]), - flagExecutionStateDir, - flagOutputDir) + } else { + log.Info().Msgf("Extracting state from %s, exporting root checkpoint to %s, version: %v", + flagExecutionStateDir, + path.Join(flagOutputDir, bootstrap.FilenameWALRootCheckpoint), + 6, + ) + } + + if flagInputPayload { + log.Info().Msgf("Payload input from %v, output dir: %s", + flagExecutionStateDir, + flagOutputDir) + } else { + log.Info().Msgf("Block state commitment: %s from %v, output dir: %s", + hex.EncodeToString(stateCommitment[:]), + flagExecutionStateDir, + flagOutputDir) + } // err := ensureCheckpointFileExist(flagExecutionStateDir) // if err != nil { @@ -148,14 +211,29 @@ func run(*cobra.Command, []string) { log.Warn().Msgf("atree migration has verbose validation error logging enabled which may increase size of log") } - err := extractExecutionState( - log.Logger, - flagExecutionStateDir, - stateCommitment, - flagOutputDir, - flagNWorker, - !flagNoMigration, - ) + var err error + if flagInputPayload { + err = extractExecutionStateFromPayloads( + log.Logger, + flagExecutionStateDir, + flagOutputDir, + flagNWorker, + !flagNoMigration, + exportPayloads, + exportedAddresses, + ) + } else { + err = extractExecutionState( + log.Logger, + flagExecutionStateDir, + stateCommitment, + flagOutputDir, + flagNWorker, + !flagNoMigration, + exportPayloads, + exportedAddresses, + ) + } if err != nil { log.Fatal().Err(err).Msgf("error extracting the execution state: %s", err.Error()) diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract.go b/cmd/util/cmd/execution-state-extract/execution_state_extract.go index 90bcd70533d..dc552682886 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract.go @@ -5,7 +5,9 @@ import ( "fmt" "math" "os" + "time" + "github.com/onflow/cadence/runtime/common" "github.com/rs/zerolog" "go.uber.org/atomic" @@ -34,6 +36,8 @@ func extractExecutionState( outputDir string, nWorker int, // number of concurrent worker to migation payloads runMigrations bool, + exportPayloads bool, + exportPayloadsByAddresses []common.Address, ) error { log.Info().Msg("init WAL") @@ -84,30 +88,7 @@ func extractExecutionState( <-compactor.Done() }() - var migrations []ledger.Migration - - if runMigrations { - rwf := reporters.NewReportFileWriterFactory(dir, log) - - migrations = []ledger.Migration{ - migrators.CreateAccountBasedMigration( - log, - nWorker, - []migrators.AccountBasedMigration{ - migrators.NewAtreeRegisterMigrator( - rwf, - flagValidateMigration, - flagLogVerboseValidationError, - ), - - &migrators.DeduplicateContractNamesMigration{}, - - // This will fix storage used discrepancies caused by the - // DeduplicateContractNamesMigration. - &migrators.AccountUsageMigrator{}, - }), - } - } + migrations := newMigrations(log, dir, nWorker, runMigrations) newState := ledger.State(targetHash) @@ -134,6 +115,19 @@ func extractExecutionState( log.Error().Err(err).Msgf("can not generate report for migrated state: %v", newMigratedState) } + if exportPayloads { + payloads := newTrie.AllPayloads() + + exportedPayloadCount, err := createPayloadFile(log, outputDir, payloads, exportPayloadsByAddresses) + if err != nil { + return fmt.Errorf("cannot generate payloads file: %w", err) + } + + log.Info().Msgf("Exported %d payloads out of %d payloads", exportedPayloadCount, len(payloads)) + + return nil + } + migratedState, err := createCheckpoint( newTrie, log, @@ -191,3 +185,160 @@ func writeStatusFile(fileName string, e error) error { err := os.WriteFile(fileName, checkpointStatusJson, 0644) return err } + +func extractExecutionStateFromPayloads( + log zerolog.Logger, + dir string, + outputDir string, + nWorker int, // number of concurrent worker to migation payloads + runMigrations bool, + exportPayloads bool, + exportPayloadsByAddresses []common.Address, +) error { + + payloads, err := readPayloadFile(log, dir) + if err != nil { + return err + } + + log.Info().Msgf("read %d payloads\n", len(payloads)) + + migrations := newMigrations(log, dir, nWorker, runMigrations) + + payloads, err = migratePayloads(log, payloads, migrations) + if err != nil { + return err + } + + if exportPayloads { + exportedPayloadCount, err := createPayloadFile(log, outputDir, payloads, exportPayloadsByAddresses) + if err != nil { + return fmt.Errorf("cannot generate payloads file: %w", err) + } + + log.Info().Msgf("Exported %d payloads out of %d payloads", exportedPayloadCount, len(payloads)) + + return nil + } + + newTrie, err := createTrieFromPayloads(log, payloads) + if err != nil { + return err + } + + migratedState, err := createCheckpoint( + newTrie, + log, + outputDir, + bootstrap.FilenameWALRootCheckpoint, + ) + if err != nil { + return fmt.Errorf("cannot generate the output checkpoint: %w", err) + } + + log.Info().Msgf( + "New state commitment for the exported state is: %s (base64: %s)", + migratedState.String(), + migratedState.Base64(), + ) + + return nil +} + +func migratePayloads(logger zerolog.Logger, payloads []*ledger.Payload, migrations []ledger.Migration) ([]*ledger.Payload, error) { + + if len(migrations) == 0 { + return payloads, nil + } + + var err error + payloadCount := len(payloads) + + // migrate payloads + for i, migrate := range migrations { + logger.Info().Msgf("migration %d/%d is underway", i, len(migrations)) + + start := time.Now() + payloads, err = migrate(payloads) + elapsed := time.Since(start) + + if err != nil { + return nil, fmt.Errorf("error applying migration (%d): %w", i, err) + } + + newPayloadCount := len(payloads) + + if payloadCount != newPayloadCount { + logger.Warn(). + Int("migration_step", i). + Int("expected_size", payloadCount). + Int("outcome_size", newPayloadCount). + Msg("payload counts has changed during migration, make sure this is expected.") + } + logger.Info().Str("timeTaken", elapsed.String()).Msgf("migration %d is done", i) + + payloadCount = newPayloadCount + } + + return payloads, nil +} + +func createTrieFromPayloads(logger zerolog.Logger, payloads []*ledger.Payload) (*trie.MTrie, error) { + // get paths + paths, err := pathfinder.PathsFromPayloads(payloads, complete.DefaultPathFinderVersion) + if err != nil { + return nil, fmt.Errorf("cannot export checkpoint, can't construct paths: %w", err) + } + + logger.Info().Msgf("constructing a new trie with migrated payloads (count: %d)...", len(payloads)) + + emptyTrie := trie.NewEmptyMTrie() + + derefPayloads := make([]ledger.Payload, len(payloads)) + for i, p := range payloads { + derefPayloads[i] = *p + } + + // no need to prune the data since it has already been prunned through migrations + applyPruning := false + newTrie, _, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, derefPayloads, applyPruning) + if err != nil { + return nil, fmt.Errorf("constructing updated trie failed: %w", err) + } + + return newTrie, nil +} + +func newMigrations( + log zerolog.Logger, + dir string, + nWorker int, // number of concurrent worker to migation payloads + runMigrations bool, +) []ledger.Migration { + if runMigrations { + rwf := reporters.NewReportFileWriterFactory(dir, log) + + migrations := []ledger.Migration{ + migrators.CreateAccountBasedMigration( + log, + nWorker, + []migrators.AccountBasedMigration{ + migrators.NewAtreeRegisterMigrator( + rwf, + flagValidateMigration, + flagLogVerboseValidationError, + ), + + &migrators.DeduplicateContractNamesMigration{}, + + // This will fix storage used discrepancies caused by the + // DeduplicateContractNamesMigration. + &migrators.AccountUsageMigrator{}, + }), + } + + return migrations + } + + return nil +} diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go index 2f91ea7d603..39b05ad557f 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go @@ -2,13 +2,17 @@ package extract import ( "crypto/rand" + "encoding/hex" "math" + "strings" "testing" "github.com/rs/zerolog" "github.com/stretchr/testify/require" "go.uber.org/atomic" + runtimeCommon "github.com/onflow/cadence/runtime/common" + "github.com/onflow/flow-go/cmd/util/cmd/common" "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/pathfinder" @@ -66,6 +70,8 @@ func TestExtractExecutionState(t *testing.T) { outdir, 10, false, + false, + nil, ) require.Error(t, err) }) @@ -96,7 +102,7 @@ func TestExtractExecutionState(t *testing.T) { var stateCommitment = f.InitialState() - //saved data after updates + // saved data after updates keysValuesByCommit := make(map[string]map[string]keyPair) commitsByBlocks := make(map[flow.Identifier]ledger.State) blocksInOrder := make([]flow.Identifier, size) @@ -108,7 +114,7 @@ func TestExtractExecutionState(t *testing.T) { require.NoError(t, err) stateCommitment, _, err = f.Set(update) - //stateCommitment, err = f.UpdateRegisters(keys, values, stateCommitment) + // stateCommitment, err = f.UpdateRegisters(keys, values, stateCommitment) require.NoError(t, err) // generate random block and map it to state commitment @@ -135,13 +141,13 @@ func TestExtractExecutionState(t *testing.T) { err = db.Close() require.NoError(t, err) - //for blockID, stateCommitment := range commitsByBlocks { + // for blockID, stateCommitment := range commitsByBlocks { for i, blockID := range blocksInOrder { stateCommitment := commitsByBlocks[blockID] - //we need fresh output dir to prevent contamination + // we need fresh output dir to prevent contamination unittest.RunWithTempDir(t, func(outdir string) { Cmd.SetArgs([]string{ @@ -182,7 +188,7 @@ func TestExtractExecutionState(t *testing.T) { require.NoError(t, err) registerValues, err := storage.Get(query) - //registerValues, err := mForest.Read([]byte(stateCommitment), keys) + // registerValues, err := mForest.Read([]byte(stateCommitment), keys) require.NoError(t, err) for i, key := range keys { @@ -190,7 +196,7 @@ func TestExtractExecutionState(t *testing.T) { require.Equal(t, data[key.String()].value, registerValue) } - //make sure blocks after this one are not in checkpoint + // make sure blocks after this one are not in checkpoint // ie - extraction stops after hitting right hash for j := i + 1; j < len(blocksInOrder); j++ { @@ -207,6 +213,312 @@ func TestExtractExecutionState(t *testing.T) { }) } +// TestExtractPayloadsFromExecutionState tests state extraction with checkpoint as input and payload as output. +func TestExtractPayloadsFromExecutionState(t *testing.T) { + + metr := &metrics.NoopCollector{} + + t.Run("all payloads", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + + const ( + checkpointDistance = math.MaxInt // A large number to prevent checkpoint creation. + checkpointsToKeep = 1 + ) + + size := 10 + + diskWal, err := wal.NewDiskWAL(zerolog.Nop(), nil, metrics.NewNoopCollector(), execdir, size, pathfinder.PathByteSize, wal.SegmentSize) + require.NoError(t, err) + f, err := complete.NewLedger(diskWal, size*10, metr, zerolog.Nop(), complete.DefaultPathFinderVersion) + require.NoError(t, err) + compactor, err := complete.NewCompactor(f, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + require.NoError(t, err) + <-compactor.Ready() + + var stateCommitment = f.InitialState() + + // Save generated data after updates + keysValues := make(map[string]keyPair) + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + update, err := ledger.NewUpdate(stateCommitment, keys, values) + require.NoError(t, err) + + stateCommitment, _, err = f.Set(update) + require.NoError(t, err) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + } + } + + <-f.Done() + <-compactor.Done() + + tries, err := f.Tries() + require.NoError(t, err) + + err = wal.StoreCheckpointV6SingleThread(tries, execdir, "checkpoint.00000001", zerolog.Nop()) + require.NoError(t, err) + + // Export all payloads + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--state-commitment", hex.EncodeToString(stateCommitment[:]), + "--no-migration", + "--no-report", + "--extract-payloads-by-address", "all", + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := readPayloadFile(zerolog.Nop(), outdir) + require.NoError(t, err) + require.Equal(t, len(keysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) + + t.Run("some payloads", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + const ( + checkpointDistance = math.MaxInt // A large number to prevent checkpoint creation. + checkpointsToKeep = 1 + ) + + size := 10 + + diskWal, err := wal.NewDiskWAL(zerolog.Nop(), nil, metrics.NewNoopCollector(), execdir, size, pathfinder.PathByteSize, wal.SegmentSize) + require.NoError(t, err) + f, err := complete.NewLedger(diskWal, size*10, metr, zerolog.Nop(), complete.DefaultPathFinderVersion) + require.NoError(t, err) + compactor, err := complete.NewCompactor(f, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false)) + require.NoError(t, err) + <-compactor.Ready() + + var stateCommitment = f.InitialState() + + // Save generated data after updates + keysValues := make(map[string]keyPair) + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + update, err := ledger.NewUpdate(stateCommitment, keys, values) + require.NoError(t, err) + + stateCommitment, _, err = f.Set(update) + require.NoError(t, err) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + } + } + + <-f.Done() + <-compactor.Done() + + tries, err := f.Tries() + require.NoError(t, err) + + err = wal.StoreCheckpointV6SingleThread(tries, execdir, "checkpoint.00000001", zerolog.Nop()) + require.NoError(t, err) + + const selectedAddressCount = 10 + selectedAddresses := make(map[string]struct{}) + selectedKeysValues := make(map[string]keyPair) + for k, kv := range keysValues { + owner := kv.key.KeyParts[0].Value + if len(owner) != runtimeCommon.AddressLength { + continue + } + + address, err := runtimeCommon.BytesToAddress(owner) + require.NoError(t, err) + + if len(selectedAddresses) < selectedAddressCount { + selectedAddresses[address.Hex()] = struct{}{} + } + + if _, exist := selectedAddresses[address.Hex()]; exist { + selectedKeysValues[k] = kv + } + } + + addresses := make([]string, 0, len(selectedAddresses)) + for address := range selectedAddresses { + addresses = append(addresses, address) + } + + // Export selected payloads + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--state-commitment", hex.EncodeToString(stateCommitment[:]), + "--no-migration", + "--no-report", + "--extract-payloads-by-address", strings.Join(addresses, ","), + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := readPayloadFile(zerolog.Nop(), outdir) + require.NoError(t, err) + require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := selectedKeysValues[k.String()] + require.True(t, exist) + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) +} + +// TestExtractStateFromPayloads tests state extraction with payload as input. +func TestExtractStateFromPayloads(t *testing.T) { + + t.Run("create checkpoint", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + size := 10 + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := createPayloadFile(zerolog.Nop(), execdir, payloads, nil) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + // Export checkpoint file + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--no-migration", + "--no-report", + "--use-payload-as-input", + "--extract-payloads-by-address", "", + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + tries, err := wal.OpenAndReadCheckpointV6(outdir, "root.checkpoint", zerolog.Nop()) + require.NoError(t, err) + require.Equal(t, 1, len(tries)) + + // Verify exported checkpoint + payloadsFromFile := tries[0].AllPayloads() + require.NoError(t, err) + require.Equal(t, len(keysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + + }) + + t.Run("create payloads", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + size := 10 + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := createPayloadFile(zerolog.Nop(), execdir, payloads, nil) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + // Export all payloads + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--no-migration", + "--no-report", + "--use-payload-as-input", + "--extract-payloads-by-address", "all", + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := readPayloadFile(zerolog.Nop(), outdir) + require.NoError(t, err) + require.Equal(t, len(keysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) +} + func getSampleKeyValues(i int) ([]ledger.Key, []ledger.Value) { switch i { case 0: @@ -226,7 +538,8 @@ func getSampleKeyValues(i int) ([]ledger.Key, []ledger.Value) { keys := make([]ledger.Key, 0) values := make([]ledger.Value, 0) for j := 0; j < 10; j++ { - address := make([]byte, 32) + // address := make([]byte, 32) + address := make([]byte, 8) _, err := rand.Read(address) if err != nil { panic(err) diff --git a/cmd/util/cmd/execution-state-extract/export_payloads.go b/cmd/util/cmd/execution-state-extract/export_payloads.go new file mode 100644 index 00000000000..68325dac3de --- /dev/null +++ b/cmd/util/cmd/execution-state-extract/export_payloads.go @@ -0,0 +1,205 @@ +package extract + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/fxamacker/cbor/v2" + "github.com/rs/zerolog" + + "github.com/onflow/cadence/runtime/common" + + "github.com/onflow/flow-go/ledger" +) + +const ( + FilenamePayloads = "root.payloads" + + defaultBufioWriteSize = 1024 * 32 + defaultBufioReadSize = 1024 * 32 + + payloadEncodingVersion = 1 +) + +func createPayloadFile( + logger zerolog.Logger, + outputDir string, + payloads []*ledger.Payload, + addresses []common.Address, +) (int, error) { + payloadFile := filepath.Join(outputDir, FilenamePayloads) + + f, err := os.Create(payloadFile) + if err != nil { + return 0, fmt.Errorf("can't create %s: %w", payloadFile, err) + } + defer f.Close() + + writer := bufio.NewWriterSize(f, defaultBufioWriteSize) + if err != nil { + return 0, fmt.Errorf("can't create bufio writer for %s: %w", payloadFile, err) + } + defer writer.Flush() + + includeAllPayloads := len(addresses) == 0 + + if includeAllPayloads { + return writeAllPayloads(logger, writer, payloads) + } + + return writeSelectedPayloads(logger, writer, payloads, addresses) +} + +func writeAllPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledger.Payload) (int, error) { + logger.Info().Msgf("writing %d payloads to file", len(payloads)) + + enc := cbor.NewEncoder(w) + + // Encode number of payloads + err := enc.Encode(len(payloads)) + if err != nil { + return 0, fmt.Errorf("failed to encode number of payloads %d in CBOR: %w", len(payloads), err) + } + + var payloadScratchBuffer [1024 * 2]byte + for _, p := range payloads { + + buf := ledger.EncodeAndAppendPayloadWithoutPrefix(payloadScratchBuffer[:0], p, payloadEncodingVersion) + + // Encode payload + err = enc.Encode(buf) + if err != nil { + return 0, err + } + } + + return len(payloads), nil +} + +func writeSelectedPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledger.Payload, addresses []common.Address) (int, error) { + var includedPayloadCount int + + includedFlags := make([]bool, len(payloads)) + for i, p := range payloads { + include, err := includePayloadByAddresses(p, addresses) + if err != nil { + return 0, err + } + + includedFlags[i] = include + + if include { + includedPayloadCount++ + } + } + + logger.Info().Msgf("writing %d payloads to file", includedPayloadCount) + + enc := cbor.NewEncoder(w) + + // Encode number of payloads + err := enc.Encode(includedPayloadCount) + if err != nil { + return 0, fmt.Errorf("failed to encode number of payloads %d in CBOR: %w", includedPayloadCount, err) + } + + var payloadScratchBuffer [1024 * 2]byte + for i, included := range includedFlags { + if !included { + continue + } + + p := payloads[i] + + buf := ledger.EncodeAndAppendPayloadWithoutPrefix(payloadScratchBuffer[:0], p, payloadEncodingVersion) + + // Encode payload + err = enc.Encode(buf) + if err != nil { + return 0, err + } + } + + return includedPayloadCount, nil +} + +func includePayloadByAddresses(payload *ledger.Payload, addresses []common.Address) (bool, error) { + if len(addresses) == 0 { + // Include all payloads + return true, nil + } + + for _, address := range addresses { + k, err := payload.Key() + if err != nil { + return false, fmt.Errorf("failed to get key from payload: %w", err) + } + + owner := k.KeyParts[0].Value + if bytes.Equal(owner, address[:]) { + return true, nil + } + } + + return false, nil +} + +func readPayloadFile(logger zerolog.Logger, inputDir string) ([]*ledger.Payload, error) { + payloadFile := filepath.Join(inputDir, FilenamePayloads) + + if _, err := os.Stat(payloadFile); os.IsNotExist(err) { + return nil, fmt.Errorf("%s doesn't exist", payloadFile) + } + + f, err := os.Open(payloadFile) + if err != nil { + return nil, fmt.Errorf("failed to open %s: %w", payloadFile, err) + } + defer f.Close() + + r := bufio.NewReaderSize(f, defaultBufioReadSize) + if err != nil { + return nil, fmt.Errorf("failed to create bufio reader for %s: %w", payloadFile, err) + } + + dec := cbor.NewDecoder(r) + + // Decode number of payloads + var payloadCount int + err = dec.Decode(&payloadCount) + if err != nil { + return nil, fmt.Errorf("failed to decode number of payload in CBOR: %w", err) + } + + logger.Info().Msgf("reading %d payloads from file", payloadCount) + + payloads := make([]*ledger.Payload, 0, payloadCount) + + for { + var rawPayload []byte + err := dec.Decode(&rawPayload) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("failed to decode payload in CBOR: %w", err) + } + + payload, err := ledger.DecodePayloadWithoutPrefix(rawPayload, false, payloadEncodingVersion) + if err != nil { + return nil, fmt.Errorf("failed to decode payload 0x%x: %w", rawPayload, err) + } + + payloads = append(payloads, payload) + } + + if payloadCount != len(payloads) { + return nil, fmt.Errorf("failed to decode %s: expect %d payloads, got %d payloads", payloadFile, payloadCount, len(payloads)) + } + + return payloads, nil +}