Skip to content

Commit

Permalink
database functions
Browse files Browse the repository at this point in the history
  • Loading branch information
faddat committed Dec 20, 2024
1 parent 62a2b75 commit e0a26b2
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 72 deletions.
275 changes: 218 additions & 57 deletions internal/runtime/hostfunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package runtime

import (
"context"
"encoding/binary"
"encoding/json"
"fmt"

Expand Down Expand Up @@ -229,7 +230,7 @@ func hostValidateAddress(ctx context.Context, mod api.Module, addrPtr, addrLen u
}

// hostScan implements db_scan
func hostScan(ctx context.Context, mod api.Module, startPtr, startLen, endPtr, endLen uint32, order uint32) (uint64, uint64, uint32) {
func hostScan(ctx context.Context, mod api.Module, startPtr, startLen, endPtr, endLen, order uint32) (uint64, uint64, uint32) {
env := ctx.Value("env").(*RuntimeEnvironment)
mem := mod.Memory()

Expand All @@ -239,17 +240,25 @@ func hostScan(ctx context.Context, mod api.Module, startPtr, startLen, endPtr, e
}
env.GasUsed += gasCostIteratorCreate

start, err := ReadMemory(mem, startPtr, startLen)
if err != nil {
panic(fmt.Sprintf("failed to read start key from memory: %v", err))
// Read start and end keys
var start, end []byte
var err error

if startPtr != 0 {
start, err = ReadMemory(mem, startPtr, startLen)
if err != nil {
panic(fmt.Sprintf("failed to read start key from memory: %v", err))
}
}

end, err := ReadMemory(mem, endPtr, endLen)
if err != nil {
panic(fmt.Sprintf("failed to read end key from memory: %v", err))
if endPtr != 0 {
end, err = ReadMemory(mem, endPtr, endLen)
if err != nil {
panic(fmt.Sprintf("failed to read end key from memory: %v", err))
}
}

// Check iterator limits
// Start a new call context for this iterator
callID := env.StartCall()
if len(env.iterators[callID]) >= maxIteratorsPerCall {
return 0, 0, 2 // Return error code 2 for too many iterators
Expand Down Expand Up @@ -424,74 +433,226 @@ func hostCloseIterator(ctx context.Context, mod api.Module, callID, iterID uint6
}
}

// RegisterHostFunctions registers all host functions into a module named "env"
func RegisterHostFunctions(r wazero.Runtime, env *RuntimeEnvironment) (wazero.CompiledModule, error) {
// Initialize memory allocator if not already done
if env.Memory == nil {
env.Memory = NewMemoryAllocator(65536) // Start at 64KB offset
// hostAbort implements the abort function required by Wasm modules
func hostAbort(ctx context.Context, mod api.Module, code uint32) {
panic(fmt.Sprintf("Wasm contract aborted with code: %d", code))
}

// hostDbRead implements db_read
func hostDbRead(ctx context.Context, mod api.Module, keyPtr uint32) uint32 {
env := ctx.Value("env").(*RuntimeEnvironment)
mem := mod.Memory()

// Read length prefix (4 bytes) from the key pointer
lenBytes, err := ReadMemory(mem, keyPtr, 4)
if err != nil {
panic(fmt.Sprintf("failed to read key length from memory: %v", err))
}
keyLen := binary.LittleEndian.Uint32(lenBytes)

// Build a module that exports these functions
hostModBuilder := r.NewHostModuleBuilder("env")
// Read the actual key
key, err := ReadMemory(mem, keyPtr+4, keyLen)
if err != nil {
panic(fmt.Sprintf("failed to read key from memory: %v", err))
}

// Register memory management functions
RegisterMemoryManagement(hostModBuilder, env.Memory)
value := env.DB.Get(key)
if len(value) == 0 {
return 0
}

// Register DB functions
hostModBuilder.NewFunctionBuilder().
WithFunc(hostGet).
Export("db_get")
// Allocate memory for the result: 4 bytes for length + actual value
totalLen := 4 + len(value)
offset, err := env.Memory.Allocate(mem, uint32(totalLen))
if err != nil {
panic(fmt.Sprintf("failed to allocate memory: %v", err))
}

hostModBuilder.NewFunctionBuilder().
WithFunc(hostSet).
Export("db_set")
// Write length prefix
lenData := make([]byte, 4)
binary.LittleEndian.PutUint32(lenData, uint32(len(value)))
if err := WriteMemory(mem, offset, lenData); err != nil {
panic(fmt.Sprintf("failed to write value length to memory: %v", err))
}

// Register API functions
hostModBuilder.NewFunctionBuilder().
WithFunc(hostHumanizeAddress).
Export("api_humanize_address")
// Write value
if err := WriteMemory(mem, offset+4, value); err != nil {
panic(fmt.Sprintf("failed to write value to memory: %v", err))
}

hostModBuilder.NewFunctionBuilder().
WithFunc(hostCanonicalizeAddress).
Export("api_canonicalize_address")
return offset
}

hostModBuilder.NewFunctionBuilder().
WithFunc(hostValidateAddress).
Export("api_validate_address")
// hostDbWrite implements db_write
func hostDbWrite(ctx context.Context, mod api.Module, keyPtr, valuePtr uint32) {
env := ctx.Value("env").(*RuntimeEnvironment)
mem := mod.Memory()

// Register Query functions
hostModBuilder.NewFunctionBuilder().
WithFunc(hostQueryExternal).
Export("querier_query")
// Read key length prefix (4 bytes)
keyLenBytes, err := ReadMemory(mem, keyPtr, 4)
if err != nil {
panic(fmt.Sprintf("failed to read key length from memory: %v", err))
}
keyLen := binary.LittleEndian.Uint32(keyLenBytes)

// Read value length prefix (4 bytes)
valLenBytes, err := ReadMemory(mem, valuePtr, 4)
if err != nil {
panic(fmt.Sprintf("failed to read value length from memory: %v", err))
}
valLen := binary.LittleEndian.Uint32(valLenBytes)

// Register Iterator functions
hostModBuilder.NewFunctionBuilder().
WithFunc(hostScan).
// Read the actual key and value
key, err := ReadMemory(mem, keyPtr+4, keyLen)
if err != nil {
panic(fmt.Sprintf("failed to read key from memory: %v", err))
}

value, err := ReadMemory(mem, valuePtr+4, valLen)
if err != nil {
panic(fmt.Sprintf("failed to read value from memory: %v", err))
}

env.DB.Set(key, value)
}

// hostDbRemove implements db_remove
func hostDbRemove(ctx context.Context, mod api.Module, keyPtr uint32) {
env := ctx.Value("env").(*RuntimeEnvironment)
mem := mod.Memory()

// Read length prefix (4 bytes) from the key pointer
lenBytes, err := ReadMemory(mem, keyPtr, 4)
if err != nil {
panic(fmt.Sprintf("failed to read key length from memory: %v", err))
}
keyLen := binary.LittleEndian.Uint32(lenBytes)

// Read the actual key
key, err := ReadMemory(mem, keyPtr+4, keyLen)
if err != nil {
panic(fmt.Sprintf("failed to read key from memory: %v", err))
}

env.DB.Delete(key)
}

// RegisterHostFunctions registers all host functions with the wazero runtime
func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (wazero.CompiledModule, error) {
builder := runtime.NewHostModuleBuilder("env")

// Register abort function
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, code uint32) {
ctx = context.WithValue(ctx, "env", env)
hostAbort(ctx, m, code)
}).
WithParameterNames("code").
Export("abort")

// Register DB functions
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, keyPtr, keyLen uint32) (uint32, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostGet(ctx, m, keyPtr, keyLen)
}).
WithParameterNames("key_ptr", "key_len").
Export("db_get")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, keyPtr, keyLen, valPtr, valLen uint32) {
ctx = context.WithValue(ctx, "env", env)
hostSet(ctx, m, keyPtr, keyLen, valPtr, valLen)
}).
WithParameterNames("key_ptr", "key_len", "val_ptr", "val_len").
Export("db_set")

builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, startPtr, startLen, endPtr, endLen, order uint32) (uint64, uint64, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostScan(ctx, m, startPtr, startLen, endPtr, endLen, order)
}).
WithParameterNames("start_ptr", "start_len", "end_ptr", "end_len", "order").
Export("db_scan")

hostModBuilder.NewFunctionBuilder().
WithFunc(hostNext).
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) (uint32, uint32, uint32, uint32, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostNext(ctx, m, callID, iterID)
}).
WithParameterNames("call_id", "iter_id").
Export("db_next")

hostModBuilder.NewFunctionBuilder().
WithFunc(hostNextKey).
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, callID, iterID uint64) (uint32, uint32, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostNextKey(ctx, m, callID, iterID)
}).
WithParameterNames("call_id", "iter_id").
Export("db_next_key")

hostModBuilder.NewFunctionBuilder().
WithFunc(hostNextValue).
Export("db_next_value")
// Register API functions
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) (uint32, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostHumanizeAddress(ctx, m, addrPtr, addrLen)
}).
WithParameterNames("addr_ptr", "addr_len").
Export("api_humanize_address")

hostModBuilder.NewFunctionBuilder().
WithFunc(hostCloseIterator).
Export("db_close_iterator")
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) (uint32, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostCanonicalizeAddress(ctx, m, addrPtr, addrLen)
}).
WithParameterNames("addr_ptr", "addr_len").
Export("api_canonicalize_address")

// Compile the host module
compiled, err := hostModBuilder.Compile(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to compile host module: %w", err)
}
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, addrPtr, addrLen uint32) uint32 {
ctx = context.WithValue(ctx, "env", env)
return hostValidateAddress(ctx, m, addrPtr, addrLen)
}).
WithParameterNames("addr_ptr", "addr_len").
Export("api_validate_address")

// Register Query functions
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, reqPtr, reqLen, gasLimit uint32) (uint32, uint32) {
ctx = context.WithValue(ctx, "env", env)
return hostQueryExternal(ctx, m, reqPtr, reqLen, gasLimit)
}).
WithParameterNames("req_ptr", "req_len", "gas_limit").
Export("querier_query")

return compiled, nil
// Register DB read function
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, keyPtr uint32) uint32 {
ctx = context.WithValue(ctx, "env", env)
return hostDbRead(ctx, m, keyPtr)
}).
WithParameterNames("key_ptr").
Export("db_read")

// Register DB write function
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, keyPtr, valuePtr uint32) {
ctx = context.WithValue(ctx, "env", env)
hostDbWrite(ctx, m, keyPtr, valuePtr)
}).
WithParameterNames("key_ptr", "value_ptr").
Export("db_write")

// Register DB remove function
builder.NewFunctionBuilder().
WithFunc(func(ctx context.Context, m api.Module, keyPtr uint32) {
ctx = context.WithValue(ctx, "env", env)
hostDbRemove(ctx, m, keyPtr)
}).
WithParameterNames("key_ptr").
Export("db_remove")

return builder.Compile(context.Background())
}

// When you instantiate a contract, you can do something like:
Expand Down
Loading

0 comments on commit e0a26b2

Please sign in to comment.