Skip to content

Commit

Permalink
all signature mismatches resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
faddat committed Dec 23, 2024
1 parent 1e0588a commit 4dacdd8
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 177 deletions.
139 changes: 99 additions & 40 deletions internal/runtime/hostfunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,67 +878,72 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz

// Register BLS12-381 functions
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, elementsPtr uint32) (uint32, uint32) {
WithFunc(func(ctx context.Context, m api.Module, g1sPtr, outPtr uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
return hostBls12381AggregateG1(ctx, m, elementsPtr)
ptr, _ := hostBls12381AggregateG1(ctx, m, g1sPtr)
return ptr
}).
WithParameterNames("elements_ptr").
WithResultNames("result_ptr", "result_len").
WithParameterNames("g1s_ptr", "out_ptr").
WithResultNames("result").
Export("bls12_381_aggregate_g1")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, elementsPtr uint32) (uint32, uint32) {
WithFunc(func(ctx context.Context, m api.Module, g2sPtr, outPtr uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
return hostBls12381AggregateG2(ctx, m, elementsPtr)
ptr, _ := hostBls12381AggregateG2(ctx, m, g2sPtr)
return ptr
}).
WithParameterNames("elements_ptr").
WithResultNames("result_ptr", "result_len").
WithParameterNames("g2s_ptr", "out_ptr").
WithResultNames("result").
Export("bls12_381_aggregate_g2")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen uint32) (uint32, uint32) {
WithFunc(func(ctx context.Context, m api.Module, psPtr, qsPtr, rPtr, sPtr uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
return hostBls12381HashToG1(ctx, m, hashPtr, hashLen)
return hostBls12381PairingEquality(ctx, m, psPtr, 0, qsPtr, 0, rPtr, 0, sPtr, 0)
}).
WithParameterNames("hash_ptr", "hash_len").
WithResultNames("result_ptr", "result_len").
Export("bls12_381_hash_to_g1")
WithParameterNames("ps_ptr", "qs_ptr", "r_ptr", "s_ptr").
WithResultNames("result").
Export("bls12_381_pairing_equality")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen uint32) (uint32, uint32) {
WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
return hostBls12381HashToG2(ctx, m, hashPtr, hashLen)
ptr, _ := hostBls12381HashToG1(ctx, m, msgPtr, hashFunction)
return ptr
}).
WithParameterNames("hash_ptr", "hash_len").
WithResultNames("result_ptr", "result_len").
Export("bls12_381_hash_to_g2")
WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr").
WithResultNames("result").
Export("bls12_381_hash_to_g1")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, a1Ptr, a1Len, a2Ptr, a2Len, b1Ptr, b1Len, b2Ptr, b2Len uint32) uint32 {
WithFunc(func(ctx context.Context, m api.Module, hashFunction, msgPtr, dstPtr, outPtr uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
return hostBls12381PairingEquality(ctx, m, a1Ptr, a1Len, a2Ptr, a2Len, b1Ptr, b1Len, b2Ptr, b2Len)
ptr, _ := hostBls12381HashToG2(ctx, m, msgPtr, hashFunction)
return ptr
}).
WithParameterNames("a1_ptr", "a1_len", "a2_ptr", "a2_len", "b1_ptr", "b1_len", "b2_ptr", "b2_len").
WithParameterNames("hash_function", "msg_ptr", "dst_ptr", "out_ptr").
WithResultNames("result").
Export("bls12_381_pairing_equality")
Export("bls12_381_hash_to_g2")

// Register SECP256r1 functions
// SECP256r1 functions
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen, sigPtr, sigLen, pubkeyPtr, pubkeyLen uint32) uint32 {
WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, publicKeyPtr uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
return hostSecp256r1Verify(ctx, m, hashPtr, hashLen, sigPtr, sigLen, pubkeyPtr, pubkeyLen)
return hostSecp256r1Verify(ctx, m, messageHashPtr, 0, signaturePtr, 0, publicKeyPtr, 0)
}).
WithParameterNames("hash_ptr", "hash_len", "sig_ptr", "sig_len", "pubkey_ptr", "pubkey_len").
WithParameterNames("message_hash_ptr", "signature_ptr", "public_key_ptr").
WithResultNames("result").
Export("secp256r1_verify")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, hashPtr, hashLen, sigPtr, sigLen, recovery uint32) (uint32, uint32) {
WithFunc(func(ctx context.Context, m api.Module, messageHashPtr, signaturePtr, recoveryParam uint32) uint64 {
ctx = context.WithValue(ctx, envKey, env)
return hostSecp256r1RecoverPubkey(ctx, m, hashPtr, hashLen, sigPtr, sigLen, recovery)
ptr, len := hostSecp256r1RecoverPubkey(ctx, m, messageHashPtr, 0, signaturePtr, 0, recoveryParam)
return (uint64(len) << 32) | uint64(ptr)
}).
WithParameterNames("hash_ptr", "hash_len", "sig_ptr", "sig_len", "recovery").
WithResultNames("pubkey_ptr", "pubkey_len").
WithParameterNames("message_hash_ptr", "signature_ptr", "recovery_param").
WithResultNames("result").
Export("secp256r1_recover_pubkey")

builder.NewFunctionBuilder().
Expand All @@ -950,6 +955,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz
WithResultNames("iter_id").
Export("db_scan")

// db_next
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
Expand All @@ -959,6 +965,17 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz
WithResultNames("kv_region_ptr").
Export("db_next")

// db_next_value
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
ptr, _, _ := hostNextValue(ctx, m, uint64(iterID), 0)
return ptr
}).
WithParameterNames("iter_id").
WithResultNames("value_ptr").
Export("db_next_value")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
Expand Down Expand Up @@ -1030,16 +1047,6 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz
WithParameterNames("key_ptr").
Export("db_remove")

// db_next_value
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) (uint32, uint32, uint32) {
ctx = context.WithValue(ctx, envKey, env)
return hostNextValue(ctx, m, callID, iterID)
}).
WithParameterNames("call_id", "iter_id").
WithResultNames("val_ptr", "val_len", "err_code").
Export("db_next_value")

// db_close_iterator
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) {
Expand Down Expand Up @@ -1088,6 +1095,17 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz
WithParameterNames("msg_ptr").
Export("debug")

// db_next_key
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, iterID uint32) uint32 {
ctx = context.WithValue(ctx, envKey, env)
ptr, _, _ := hostNextKey(ctx, m, uint64(iterID), 0)
return ptr
}).
WithParameterNames("iter_id").
WithResultNames("key_ptr").
Export("db_next_key")

return builder.Compile(context.Background())
}

Expand All @@ -1110,3 +1128,44 @@ type contextKey string
const (
envKey contextKey = "env"
)

// hostNextKey implements db_next_key
func hostNextKey(ctx context.Context, mod api.Module, callID, iterID uint64) (keyPtr, keyLen, errCode uint32) {
env := ctx.Value("env").(*RuntimeEnvironment)
mem := mod.Memory()

// Check gas for iterator next operation
if env.GasUsed+gasCostIteratorNext > env.Gas.GasConsumed() {
return 0, 0, 1 // Return error code 1 for out of gas
}
env.GasUsed += gasCostIteratorNext

// Get iterator from environment
iter := env.GetIterator(callID, iterID)
if iter == nil {
return 0, 0, 2 // Return error code 2 for invalid iterator
}

// Check if there are more items
if !iter.Valid() {
return 0, 0, 0 // Return 0 for end of iteration
}

// Read key
key := iter.Key()

// Allocate memory for key
keyOffset, err := allocateInContract(ctx, mod, uint32(len(key)))
if err != nil {
panic(fmt.Sprintf("failed to allocate memory for key (via contract's allocate): %v", err))
}

if err := writeMemory(mem, keyOffset, key); err != nil {
panic(fmt.Sprintf("failed to write key to memory: %v", err))
}

// Move to next item
iter.Next()

return keyOffset, uint32(len(key)), 0
}
Loading

0 comments on commit 4dacdd8

Please sign in to comment.