Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(plugins): Use wazero instead of wasmtime #3042

Merged
merged 8 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.21

require (
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0
github.com/cubicdaiya/gonp v1.0.4
github.com/davecgh/go-spew v1.1.1
github.com/fatih/structtag v1.2.0
Expand All @@ -20,6 +19,7 @@ require (
github.com/riza-io/grpc-go v0.2.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/tetratelabs/wazero v1.5.0
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e
github.com/xeipuuv/gojsonschema v1.2.0
golang.org/x/sync v0.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1/g
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand Down Expand Up @@ -185,6 +183,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE=
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU=
Expand Down
1 change: 0 additions & 1 deletion internal/endtoend/case_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ type Exec struct {
Contexts []string `json:"contexts"`
Process string `json:"process"`
OS []string `json:"os"`
WASM bool `json:"wasm"`
Env map[string]string `json:"env"`
}

Expand Down
5 changes: 0 additions & 5 deletions internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (

"github.com/sqlc-dev/sqlc/internal/cmd"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/ext/wasm"
"github.com/sqlc-dev/sqlc/internal/opts"
)

Expand Down Expand Up @@ -177,10 +176,6 @@ func TestReplay(t *testing.T) {
}
}

if args.WASM && !wasm.Enabled() {
t.Skipf("wasm support not enabled")
}

if len(args.OS) > 0 {
if !slices.Contains(args.OS, runtime.GOOS) {
t.Skipf("unsupported os: %s", runtime.GOOS)
Expand Down

This file was deleted.

This file was deleted.

23 changes: 0 additions & 23 deletions internal/ext/wasm/nowasm.go

This file was deleted.

160 changes: 46 additions & 114 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
//go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))

// The above build constraint is based of the cgo directives in this file:
// https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
package wasm

import (
"bytes"
"context"
"crypto/sha256"
"errors"
Expand All @@ -15,10 +12,11 @@ import (
"os"
"path/filepath"
"runtime"
"runtime/trace"
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/sys"
"golang.org/x/sync/singleflight"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand All @@ -31,13 +29,6 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func Enabled() bool {
return true
}

// This version must be updated whenever the wasmtime-go dependency is updated
const wasmtimeVersion = `v14.0.0`

func cacheDir() (string, error) {
cache := os.Getenv("SQLCCACHE")
if cache != "" {
Expand Down Expand Up @@ -70,13 +61,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
return sum, nil
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) {
expected, err := r.getChecksum(ctx)
if err != nil {
return nil, err
}
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine, expected)
return r.loadWASM(ctx, cacheDir, expected)
})
if err != nil {
return nil, err
Expand All @@ -85,52 +80,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
if !ok {
return nil, fmt.Errorf("returned value was not a byte slice")
}
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}

pluginDir := filepath.Join(cacheDir, expectedSha)
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
modPath := filepath.Join(pluginDir, modName)
_, staterr := os.Stat(modPath)
if staterr == nil {
data, err := os.ReadFile(modPath)
if err != nil {
return nil, err
}
return data, nil
}

wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
if err != nil {
return nil, err
}

moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
module, err := wasmtime.NewModule(engine, wmod)
moduRegion.End()
if err != nil {
return nil, fmt.Errorf("define wasi: %w", err)
}

err = os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}

return out, nil
return data, nil
}

func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
Expand Down Expand Up @@ -245,72 +195,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
return fmt.Errorf("failed to encode codegen request: %w", err)
}

engine := wasmtime.NewEngine()
module, err := r.loadModule(ctx, engine)
cacheDir, err := cache.PluginsDir()
if err != nil {
return fmt.Errorf("loadModule: %w", err)
return err
}

linker := wasmtime.NewLinker(engine)
if err := linker.DefineWasi(); err != nil {
cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero"))
if err != nil {
return err
}

dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
wasmBytes, err := r.loadBytes(ctx)
if err != nil {
return fmt.Errorf("temp dir: %w", err)
return fmt.Errorf("loadModule: %w", err)
}

defer os.RemoveAll(dir)
stdinPath := filepath.Join(dir, "stdin")
stderrPath := filepath.Join(dir, "stderr")
stdoutPath := filepath.Join(dir, "stdout")

if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
return fmt.Errorf("write file: %w", err)
}
config := wazero.NewRuntimeConfig().WithCompilationCache(cache)
rt := wazero.NewRuntimeWithConfig(ctx, config)
defer rt.Close(ctx)

// Configure WASI imports to write stdout into a file.
wasiConfig := wasmtime.NewWasiConfig()
wasiConfig.SetArgv([]string{"plugin.wasm", method})
wasiConfig.SetStdinFile(stdinPath)
wasiConfig.SetStdoutFile(stdoutPath)
wasiConfig.SetStderrFile(stderrPath)
// TODO: Handle error
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is Instantiate if you'd like to return the error. Though I think any failure here would be a programming bug, not non-determinstic


keys := []string{"SQLC_VERSION"}
vals := []string{info.Version}
for _, key := range r.Env {
keys = append(keys, key)
vals = append(vals, os.Getenv(key))
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
mod, err := rt.CompileModule(ctx, wasmBytes)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's possible, it would be nice to rejigger to scope this to Runner, possibly with some map[/* wasm url */ string]wazero.CompiledModule. The compilation cache is good to reuse across executions of the sqlc process itself, but it's also good to only compile once per wasm within a process if possible, the cache key computation isn't trivial. Though if the latter doesn't happen that much maybe it doesn't matter

return err
}
wasiConfig.SetEnv(keys, vals)

store := wasmtime.NewStore(engine)
store.SetWasi(wasiConfig)
var stderr, stdout bytes.Buffer

linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
err = linker.DefineModule(store, "", module)
linkRegion.End()
if err != nil {
return fmt.Errorf("define wasi: %w", err)
conf := wazero.NewModuleConfig()
conf = conf.WithArgs("plugin.wasm", method)
conf = conf.WithEnv("SQLC_VERSION", info.Version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, consider chaining, it's arguably idiomatic for wazero users

conf := wazero.NewModuleConfig().
  WithArgs().
  WithStdin().
  WithStdout().

for _, key := range r.Env {
conf = conf.WithEnv(key, os.Getenv(key))
}
conf = conf.WithStdin(bytes.NewReader(stdinBlob))
conf = conf.WithStdout(&stdout)
conf = conf.WithStderr(&stderr)

// Run the function
fn, err := linker.GetDefault(store, "")
if err != nil {
return fmt.Errorf("wasi: get default: %w", err)
result, err := rt.InstantiateModule(ctx, mod, conf)
if result != nil {
defer result.Close(ctx)
}

callRegion := trace.StartRegion(ctx, "call _start")
_, err = fn.Call(store)
callRegion.End()

if cerr := checkError(err, stderrPath); cerr != nil {
if cerr := checkError(err, &stderr); cerr != nil {
return cerr
}

// Print WASM stdout
stdoutBlob, err := os.ReadFile(stdoutPath)
stdoutBlob, err := io.ReadAll(&stdout)
if err != nil {
return fmt.Errorf("read file: %w", err)
}
Expand All @@ -331,21 +265,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
return nil, status.Error(codes.Unimplemented, "")
}

func checkError(err error, stderrPath string) error {
func checkError(err error, stderr io.Reader) error {
if err == nil {
return err
}

var wtError *wasmtime.Error
if errors.As(err, &wtError) {
if code, ok := wtError.ExitStatus(); ok {
if code == 0 {
return nil
}
if exitErr, ok := err.(*sys.ExitError); ok {
if exitErr.ExitCode() == 0 {
return nil
}
}

// Print WASM stdout
stderrBlob, rferr := os.ReadFile(stderrPath)
stderrBlob, rferr := io.ReadAll(stderr)
if rferr == nil && len(stderrBlob) > 0 {
return errors.New(string(stderrBlob))
}
Expand Down
Loading