From 4dacdd8c100659b3522f55f640a564294198039b Mon Sep 17 00:00:00 2001 From: Jacob Gadikian Date: Tue, 24 Dec 2024 04:04:29 +0700 Subject: [PATCH] all signature mismatches resolved --- internal/runtime/hostfunctions.go | 139 +++++++++++++++++------- internal/runtime/wazeroruntime.go | 175 +++++++----------------------- 2 files changed, 137 insertions(+), 177 deletions(-) diff --git a/internal/runtime/hostfunctions.go b/internal/runtime/hostfunctions.go index 0693174d..4780a035 100644 --- a/internal/runtime/hostfunctions.go +++ b/internal/runtime/hostfunctions.go @@ -878,67 +878,72 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz // Register BLS12-381 functions builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, elementsPtr uint32) (uint32, uint32) { + WithFunc(func(ctx context.Context, m api.Module, g1sPtr, outPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostBls12381AggregateG1(ctx, m, elementsPtr) + ptr, _ := hostBls12381AggregateG1(ctx, m, g1sPtr) + return ptr }). - WithParameterNames("elements_ptr"). - WithResultNames("result_ptr", "result_len"). + WithParameterNames("g1s_ptr", "out_ptr"). + WithResultNames("result"). Export("bls12_381_aggregate_g1") builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, elementsPtr uint32) (uint32, uint32) { + WithFunc(func(ctx context.Context, m api.Module, g2sPtr, outPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostBls12381AggregateG2(ctx, m, elementsPtr) + ptr, _ := hostBls12381AggregateG2(ctx, m, g2sPtr) + return ptr }). - WithParameterNames("elements_ptr"). - WithResultNames("result_ptr", "result_len"). + WithParameterNames("g2s_ptr", "out_ptr"). + WithResultNames("result"). Export("bls12_381_aggregate_g2") builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen uint32) (uint32, uint32) { + WithFunc(func(ctx context.Context, m api.Module, psPtr, qsPtr, rPtr, sPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostBls12381HashToG1(ctx, m, hashPtr, hashLen) + return hostBls12381PairingEquality(ctx, m, psPtr, 0, qsPtr, 0, rPtr, 0, sPtr, 0) }). - WithParameterNames("hash_ptr", "hash_len"). - WithResultNames("result_ptr", "result_len"). - Export("bls12_381_hash_to_g1") + WithParameterNames("ps_ptr", "qs_ptr", "r_ptr", "s_ptr"). + WithResultNames("result"). + Export("bls12_381_pairing_equality") builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen uint32) (uint32, uint32) { + WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostBls12381HashToG2(ctx, m, hashPtr, hashLen) + ptr, _ := hostBls12381HashToG1(ctx, m, msgPtr, hashFunction) + return ptr }). - WithParameterNames("hash_ptr", "hash_len"). - WithResultNames("result_ptr", "result_len"). - Export("bls12_381_hash_to_g2") + WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr"). + WithResultNames("result"). + Export("bls12_381_hash_to_g1") builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, a1Ptr, a1Len, a2Ptr, a2Len, b1Ptr, b1Len, b2Ptr, b2Len uint32) uint32 { + WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostBls12381PairingEquality(ctx, m, a1Ptr, a1Len, a2Ptr, a2Len, b1Ptr, b1Len, b2Ptr, b2Len) + ptr, _ := hostBls12381HashToG2(ctx, m, msgPtr, hashFunction) + return ptr }). - WithParameterNames("a1_ptr", "a1_len", "a2_ptr", "a2_len", "b1_ptr", "b1_len", "b2_ptr", "b2_len"). + WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr"). WithResultNames("result"). - Export("bls12_381_pairing_equality") + Export("bls12_381_hash_to_g2") - // Register SECP256r1 functions + // SECP256r1 functions builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen, sigPtr, sigLen, pubkeyPtr, pubkeyLen uint32) uint32 { + WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, publicKeyPtr uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) - return hostSecp256r1Verify(ctx, m, hashPtr, hashLen, sigPtr, sigLen, pubkeyPtr, pubkeyLen) + return hostSecp256r1Verify(ctx, m, messageHashPtr, 0, signaturePtr, 0, publicKeyPtr, 0) }). - WithParameterNames("hash_ptr", "hash_len", "sig_ptr", "sig_len", "pubkey_ptr", "pubkey_len"). + WithParameterNames("message_hash_ptr", "signature_ptr", "public_key_ptr"). WithResultNames("result"). Export("secp256r1_verify") builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen, sigPtr, sigLen, recovery uint32) (uint32, uint32) { + WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, recoveryParam uint32) uint64 { ctx = context.WithValue(ctx, envKey, env) - return hostSecp256r1RecoverPubkey(ctx, m, hashPtr, hashLen, sigPtr, sigLen, recovery) + ptr, len := hostSecp256r1RecoverPubkey(ctx, m, messageHashPtr, 0, signaturePtr, 0, recoveryParam) + return (uint64(len) << 32) | uint64(ptr) }). - WithParameterNames("hash_ptr", "hash_len", "sig_ptr", "sig_len", "recovery"). - WithResultNames("pubkey_ptr", "pubkey_len"). + WithParameterNames("message_hash_ptr", "signature_ptr", "recovery_param"). + WithResultNames("result"). Export("secp256r1_recover_pubkey") builder.NewFunctionBuilder(). @@ -950,6 +955,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz WithResultNames("iter_id"). Export("db_scan") + // db_next builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) @@ -959,6 +965,17 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz WithResultNames("kv_region_ptr"). Export("db_next") + // db_next_value + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, _, _ := hostNextValue(ctx, m, uint64(iterID), 0) + return ptr + }). + WithParameterNames("iter_id"). + WithResultNames("value_ptr"). + Export("db_next_value") + builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) uint32 { ctx = context.WithValue(ctx, envKey, env) @@ -1030,16 +1047,6 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz WithParameterNames("key_ptr"). Export("db_remove") - // db_next_value - builder.NewFunctionBuilder(). - WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) (uint32, uint32, uint32) { - ctx = context.WithValue(ctx, envKey, env) - return hostNextValue(ctx, m, callID, iterID) - }). - WithParameterNames("call_id", "iter_id"). - WithResultNames("val_ptr", "val_len", "err_code"). - Export("db_next_value") - // db_close_iterator builder.NewFunctionBuilder(). WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) { @@ -1088,6 +1095,17 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz WithParameterNames("msg_ptr"). Export("debug") + // db_next_key + builder.NewFunctionBuilder(). + WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 { + ctx = context.WithValue(ctx, envKey, env) + ptr, _, _ := hostNextKey(ctx, m, uint64(iterID), 0) + return ptr + }). + WithParameterNames("iter_id"). + WithResultNames("key_ptr"). + Export("db_next_key") + return builder.Compile(context.Background()) } @@ -1110,3 +1128,44 @@ type contextKey string const ( envKey contextKey = "env" ) + +// hostNextKey implements db_next_key +func hostNextKey(ctx context.Context, mod api.Module, callID, iterID uint64) (keyPtr, keyLen, errCode uint32) { + env := ctx.Value("env").(*RuntimeEnvironment) + mem := mod.Memory() + + // Check gas for iterator next operation + if env.GasUsed+gasCostIteratorNext > env.Gas.GasConsumed() { + return 0, 0, 1 // Return error code 1 for out of gas + } + env.GasUsed += gasCostIteratorNext + + // Get iterator from environment + iter := env.GetIterator(callID, iterID) + if iter == nil { + return 0, 0, 2 // Return error code 2 for invalid iterator + } + + // Check if there are more items + if !iter.Valid() { + return 0, 0, 0 // Return 0 for end of iteration + } + + // Read key + key := iter.Key() + + // Allocate memory for key + keyOffset, err := allocateInContract(ctx, mod, uint32(len(key))) + if err != nil { + panic(fmt.Sprintf("failed to allocate memory for key (via contract's allocate): %v", err)) + } + + if err := writeMemory(mem, keyOffset, key); err != nil { + panic(fmt.Sprintf("failed to write key to memory: %v", err)) + } + + // Move to next item + iter.Next() + + return keyOffset, uint32(len(key)), 0 +} diff --git a/internal/runtime/wazeroruntime.go b/internal/runtime/wazeroruntime.go index d6269963..0caa7b14 100644 --- a/internal/runtime/wazeroruntime.go +++ b/internal/runtime/wazeroruntime.go @@ -8,10 +8,8 @@ import ( "encoding/json" "errors" "fmt" - "strconv" "strings" "sync" - "time" "github.com/CosmWasm/wasmvm/v2/types" "github.com/tetratelabs/wazero" @@ -1168,187 +1166,90 @@ func (w *WazeroRuntime) GetPinnedMetrics() (*types.PinnedMetrics, error) { // serializeEnvForContract serializes the environment based on the contract version func serializeEnvForContract(env []byte, checksum []byte, w *WazeroRuntime) ([]byte, error) { - // First try to unmarshal into a strongly typed Env + // We'll try to parse it as a strongly typed Env first. + // If that works, we adapt to a 1.0+ shape (numeric block height/time). var typedEnv types.Env if err := json.Unmarshal(env, &typedEnv); err == nil { - // Convert to a map for easier manipulation + // Convert to raw map so we can adjust fields. var rawEnv map[string]interface{} if err := json.Unmarshal(env, &rawEnv); err != nil { - return nil, fmt.Errorf("failed to unmarshal env: %w", err) + return nil, fmt.Errorf("failed to unmarshal env into raw map: %w", err) } - // Handle block info time conversion + // If there's a "block" section, set `height` and `time` as numeric if block, ok := rawEnv["block"].(map[string]interface{}); ok { - // For strongly typed Env, we know this is a Uint64 - block["time"] = strconv.FormatUint(uint64(typedEnv.Block.Time), 10) - block["height"] = strconv.FormatUint(typedEnv.Block.Height, 10) + block["height"] = typedEnv.Block.Height // store as integer + block["time"] = typedEnv.Block.Time // store as integer + // chain_id presumably remains a string } - // Get contract version info for version-specific adaptations - report, err := w.AnalyzeCode(checksum) - if err != nil { - return nil, fmt.Errorf("failed to analyze code: %w", err) - } - - // Apply version-specific adaptations - if report.ContractMigrateVersion != nil && *report.ContractMigrateVersion < 1 { - delete(rawEnv, "transaction") - } else if txn, ok := rawEnv["transaction"].(map[string]interface{}); ok { - // Ensure transaction index is a string - if idx, ok := txn["index"]; ok { - txn["index"] = strconv.FormatUint(uint64(idx.(float64)), 10) + // If there's a "transaction" section, set the transaction index as numeric (if present) + if txn, ok := rawEnv["transaction"].(map[string]interface{}); ok { + if typedEnv.Transaction != nil { + txn["index"] = typedEnv.Transaction.Index } } - // Re-serialize with appropriate version adaptations + // Re-serialize with the numeric fields adaptedEnv, err := json.Marshal(rawEnv) if err != nil { return nil, fmt.Errorf("failed to marshal adapted env: %w", err) } - return adaptedEnv, nil } - // If we couldn't unmarshal into a strongly typed Env, try as a raw map + // If we *couldn’t* parse typedEnv, then we just handle it as raw JSON + // but still enforce numeric shape for block.height/time and transaction.index + // by heuristics. Example below: var rawEnv map[string]interface{} d := json.NewDecoder(strings.NewReader(string(env))) - d.UseNumber() // Use json.Number to preserve numeric precision + d.UseNumber() // preserve numeric precision if err := d.Decode(&rawEnv); err != nil { return nil, fmt.Errorf("failed to unmarshal env: %w", err) } - // Handle block info time conversion + // Try to enforce numeric `height` and `time` in "block" if block, ok := rawEnv["block"].(map[string]interface{}); ok { - if timeVal, ok := block["time"]; ok { - var timeUint uint64 - switch v := timeVal.(type) { - case json.Number: - if i, err := v.Int64(); err == nil { - timeUint = uint64(i) - } else { - return nil, fmt.Errorf("failed to parse time as int64: %w", err) - } - case string: - // Try parsing as nanosecond timestamp first - if i, err := strconv.ParseUint(v, 10, 64); err == nil { - timeUint = i + // parse block.height + if hval, ok := block["height"]; ok { + if num, ok := hval.(json.Number); ok { + if i64, err := num.Int64(); err == nil { + block["height"] = i64 // store as integer } else { - // Try parsing as RFC3339 - t, err := time.Parse(time.RFC3339, v) - if err != nil { - return nil, fmt.Errorf("failed to parse time: %w", err) - } - timeUint = uint64(t.UnixNano()) + return nil, fmt.Errorf("unable to parse block.height as integer: %w", err) } - case float64: - timeUint = uint64(v) - case uint64: - timeUint = v - case types.Uint64: - timeUint = uint64(v) - case map[string]interface{}: - // Handle the case where time is serialized as a JSON object with a uint64 field - if u, ok := v["uint64"].(float64); ok { - timeUint = uint64(u) - } else { - return nil, fmt.Errorf("invalid uint64 object format: %v", v) - } - default: - return nil, fmt.Errorf("unexpected time format: %T", v) } - block["time"] = strconv.FormatUint(timeUint, 10) } - - // Ensure height is a string - if heightVal, ok := block["height"]; ok { - var height uint64 - switch v := heightVal.(type) { - case json.Number: - if i, err := v.Int64(); err == nil { - height = uint64(i) - } else { - return nil, fmt.Errorf("failed to parse height as int64: %w", err) - } - case float64: - height = uint64(v) - case uint64: - height = v - case string: - if i, err := strconv.ParseUint(v, 10, 64); err == nil { - height = i + // parse block.time + if tval, ok := block["time"]; ok { + if num, ok := tval.(json.Number); ok { + if i64, err := num.Int64(); err == nil { + block["time"] = i64 // store as integer } else { - return nil, fmt.Errorf("failed to parse height: %w", err) + return nil, fmt.Errorf("unable to parse block.time as integer: %w", err) } - default: - return nil, fmt.Errorf("unexpected height format: %T", v) - } - block["height"] = strconv.FormatUint(height, 10) - } - - // Ensure chain_id is a string - if chainID, ok := block["chain_id"]; ok { - if str, ok := chainID.(string); ok { - block["chain_id"] = str - } else { - return nil, fmt.Errorf("chain_id must be a string, got %T", chainID) } } } - // Ensure contract address is a string - if contract, ok := rawEnv["contract"].(map[string]interface{}); ok { - if addr, ok := contract["address"]; ok { - if str, ok := addr.(string); ok { - contract["address"] = str - } else { - return nil, fmt.Errorf("contract address must be a string, got %T", addr) - } - } - } - - // Get contract version info for version-specific adaptations - report, err := w.AnalyzeCode(checksum) - if err != nil { - return nil, fmt.Errorf("failed to analyze code: %w", err) - } - - // Apply version-specific adaptations - if report.ContractMigrateVersion != nil && *report.ContractMigrateVersion < 1 { - delete(rawEnv, "transaction") - } else if txn, ok := rawEnv["transaction"].(map[string]interface{}); ok { - // Ensure transaction index is a string - if idx, ok := txn["index"]; ok { - var index uint32 - switch v := idx.(type) { - case json.Number: - if i, err := v.Int64(); err == nil { - index = uint32(i) + // parse transaction.index as numeric if present + if txn, ok := rawEnv["transaction"].(map[string]interface{}); ok { + if idxVal, ok := txn["index"]; ok { + if num, ok := idxVal.(json.Number); ok { + if i64, err := num.Int64(); err == nil { + txn["index"] = i64 // store as integer } else { - return nil, fmt.Errorf("failed to parse transaction index as int64: %w", err) + return nil, fmt.Errorf("unable to parse transaction.index: %w", err) } - case float64: - index = uint32(v) - case uint32: - index = v - case string: - if i, err := strconv.ParseUint(v, 10, 32); err == nil { - index = uint32(i) - } else { - return nil, fmt.Errorf("failed to parse transaction index: %w", err) - } - default: - return nil, fmt.Errorf("unexpected transaction index format: %T", v) } - txn["index"] = strconv.FormatUint(uint64(index), 10) } } - // Re-serialize with appropriate version adaptations + // Now re-serialize adaptedEnv, err := json.Marshal(rawEnv) if err != nil { return nil, fmt.Errorf("failed to marshal adapted env: %w", err) } - return adaptedEnv, nil }