Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: provide plugins get_log_level runtime function & support levels #74

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
"fmt"
"io"
"log"
"math"
"net/http"
"os"
"strings"
"sync/atomic"
"time"

observe "github.com/dylibso/observe-sdk/go"
Expand All @@ -27,6 +29,9 @@ type module struct {
wasm []byte
}

type PluginCtxKey string
type InputOffsetKey string

//go:embed extism-runtime.wasm
var extismRuntimeWasm []byte

Expand Down Expand Up @@ -63,31 +68,51 @@ type HttpRequest struct {
}

// LogLevel defines different log levels.
type LogLevel uint8
type LogLevel int32

const (
logLevelUnset LogLevel = iota // unexporting this intentionally so its only ever the default
LogLevelOff
LogLevelError
LogLevelWarn
LogLevelInfo
LogLevelDebug
LogLevelTrace
LogLevelDebug
LogLevelInfo
LogLevelWarn
LogLevelError

LogLevelOff LogLevel = math.MaxInt32
)

func (l LogLevel) ExtismCompat() int32 {
switch l {
case LogLevelTrace:
return 0
case LogLevelDebug:
return 1
case LogLevelInfo:
return 2
case LogLevelWarn:
return 3
case LogLevelError:
return 4
default:
return int32(LogLevelOff)
}
}

func (l LogLevel) String() string {
s := ""
switch l {
case LogLevelError:
s = "ERROR"
case LogLevelWarn:
s = "WARN"
case LogLevelInfo:
s = "INFO"
case LogLevelDebug:
s = "DEBUG"
case LogLevelTrace:
s = "TRACE"
case LogLevelDebug:
s = "DEBUG"
case LogLevelInfo:
s = "INFO"
case LogLevelWarn:
s = "WARN"
case LogLevelError:
s = "ERROR"
default:
s = "OFF"
}
return s
}
Expand All @@ -107,7 +132,6 @@ type Plugin struct {
MaxHttpResponseBytes int64
MaxVarBytes int64
log func(LogLevel, string)
logLevel LogLevel
guestRuntime guestRuntime
Adapter *observe.AdapterBase
TraceCtx *observe.TraceCtx
Expand All @@ -122,13 +146,8 @@ func (p *Plugin) SetLogger(logger func(LogLevel, string)) {
p.log = logger
}

// SetLogLevel sets the minim logging level, applies to custom logging callbacks too
func (p *Plugin) SetLogLevel(level LogLevel) {
p.logLevel = level
}

func (p *Plugin) Log(level LogLevel, message string) {
if level > p.logLevel {
if level < LogLevel(pluginLogLevel.Load()) {
return
}

Expand Down Expand Up @@ -311,7 +330,7 @@ func (m *Manifest) UnmarshalJSON(data []byte) error {
Name: w.Name,
})
} else {
return errors.New("Invalid Wasm entry")
return errors.New("invalid Wasm entry")
}
}
return nil
Expand All @@ -327,6 +346,14 @@ func (p *Plugin) CloseWithContext(ctx context.Context) error {
return p.Runtime.Wazero.Close(ctx)
}

// add an atomic global to store the plugin runtime-wide log level
var pluginLogLevel = atomic.Int32{}

// SetPluginLogLevel sets the log level for the plugin
func SetLogLevel(level LogLevel) {
pluginLogLevel.Store(int32(level.ExtismCompat()))
}

// NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions.
// The returned plugin can be used to call WebAssembly functions and interact with the plugin.
func NewPlugin(
Expand Down Expand Up @@ -390,7 +417,7 @@ func NewPlugin(

count := len(manifest.Wasm)
if count == 0 {
return nil, fmt.Errorf("Manifest can't be empty.")
return nil, fmt.Errorf("manifest can't be empty")
}

modules := map[string]module{}
Expand Down Expand Up @@ -444,7 +471,7 @@ func NewPlugin(
if data.Name == "main" && config.ObserveAdapter != nil {
trace, err = config.ObserveAdapter.NewTraceCtx(ctx, c.Wazero, data.Data, config.ObserveOptions)
if err != nil {
return nil, fmt.Errorf("Failed to initialize Observe Adapter: %v", err)
return nil, fmt.Errorf("failed to initialize Observe Adapter: %v", err)
}

trace.Finish()
Expand All @@ -454,13 +481,13 @@ func NewPlugin(
_, okm := modules[data.Name]

if data.Name == "extism:host/env" || okh || okm {
return nil, fmt.Errorf("Module name collision: '%s'", data.Name)
return nil, fmt.Errorf("module name collision: '%s'", data.Name)
}

if data.Hash != "" {
calculatedHash := calculateHash(data.Data)
if data.Hash != calculatedHash {
return nil, fmt.Errorf("Hash mismatch for module '%s'", data.Name)
return nil, fmt.Errorf("hash mismatch for module '%s'", data.Name)
}
}

Expand All @@ -472,11 +499,6 @@ func NewPlugin(
modules[data.Name] = module{module: m, wasm: data.Data}
}

logLevel := LogLevelWarn
if config.LogLevel != logLevelUnset {
logLevel = config.LogLevel
}

i := 0
httpMax := int64(1024 * 1024 * 50)
if manifest.Memory != nil && manifest.Memory.MaxHttpResponseBytes >= 0 {
Expand All @@ -502,7 +524,6 @@ func NewPlugin(
MaxHttpResponseBytes: httpMax,
MaxVarBytes: varMax,
log: logStd,
logLevel: logLevel,
Adapter: config.ObserveAdapter,
TraceCtx: trace,
}
Expand All @@ -514,7 +535,7 @@ func NewPlugin(
i++
}

return nil, errors.New("No main module found")
return nil, errors.New("no main module found")
}

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
Expand Down Expand Up @@ -612,28 +633,28 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b
defer cancel()
}

ctx = context.WithValue(ctx, "plugin", plugin)
ctx = context.WithValue(ctx, PluginCtxKey("plugin"), plugin)

intputOffset, err := plugin.SetInput(data)
if err != nil {
return 1, []byte{}, err
}

ctx = context.WithValue(ctx, "inputOffset", intputOffset)
ctx = context.WithValue(ctx, InputOffsetKey("inputOffset"), intputOffset)

var f = plugin.Main.module.ExportedFunction(name)

if f == nil {
return 1, []byte{}, errors.New(fmt.Sprintf("Unknown function: %s", name))
return 1, []byte{}, fmt.Errorf("unknown function: %s", name)
} else if n := len(f.Definition().ResultTypes()); n > 1 {
return 1, []byte{}, errors.New(fmt.Sprintf("Function %s has %v results, expected 0 or 1", name, n))
return 1, []byte{}, fmt.Errorf("function %s has %v results, expected 0 or 1", name, n)
}

var isStart = name == "_start"
if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.init(ctx)
if err != nil {
return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err))
return 1, []byte{}, fmt.Errorf("failed to initialize runtime: %v", err)
}
plugin.guestRuntime.initialized = true
}
Expand Down Expand Up @@ -678,14 +699,14 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b
if rc != 0 {
errMsg := plugin.GetError()
if errMsg == "" {
errMsg = "Encountered an unknown error in call to Extism plugin function " + name
errMsg = "encountered an unknown error in call to Extism plugin function " + name
}
return rc, []byte{}, errors.New(errMsg)
}

output, err := plugin.GetOutput()
if err != nil {
return rc, []byte{}, fmt.Errorf("Failed to get output: %v", err)
return rc, []byte{}, fmt.Errorf("failed to get output: %v", err)
}

return rc, output, nil
Expand Down
71 changes: 49 additions & 22 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"encoding/json"
"fmt"
"log"
"math/rand"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -443,13 +443,16 @@ func TestLog_default(t *testing.T) {
if plugin, ok := plugin(t, manifest); ok {
defer plugin.Close()

SetLogLevel(LogLevelWarn) // Only warn and error logs should be printed to the console
exit, _, err := plugin.Call("run_test", []byte{})

if assertCall(t, err, exit) {
logs := buf.String()

assert.Contains(t, logs, "this is a warning log")
assert.Contains(t, logs, "this is an error log")
assert.NotContains(t, logs, "this is a trace log")
assert.NotContains(t, logs, "this is a debug log")
assert.NotContains(t, logs, "this is an info log")
}
}
}
Expand All @@ -465,34 +468,68 @@ func TestLog_custom(t *testing.T) {
if plugin, ok := plugin(t, manifest); ok {
defer plugin.Close()

var actual []LogEntry
var actual strings.Builder

var fmtLogMessage = func(level LogLevel, message string) string {
return fmt.Sprintf("%s: %s\n", level.String(), message)
}

plugin.SetLogger(func(level LogLevel, message string) {
actual = append(actual, LogEntry{message: message, level: level})
actual.WriteString(fmtLogMessage(level, message))
switch level {
case LogLevelDebug:
assert.Equal(t, level.String(), "DEBUG")
case LogLevelInfo:
assert.Equal(t, fmt.Sprintf("%s", level), "INFO")
assert.Equal(t, level.String(), "INFO")
case LogLevelWarn:
assert.Equal(t, fmt.Sprintf("%s", level), "WARN")
assert.Equal(t, level.String(), "WARN")
case LogLevelError:
assert.Equal(t, fmt.Sprintf("%s", level), "ERROR")
assert.Equal(t, level.String(), "ERROR")
case LogLevelTrace:
assert.Equal(t, fmt.Sprintf("%s", level), "TRACE")
assert.Equal(t, level.String(), "TRACE")
}
})

plugin.SetLogLevel(LogLevelInfo)
SetLogLevel(LogLevelTrace)

exit, _, err := plugin.Call("run_test", []byte{})

if assertCall(t, err, exit) {
expected := []LogEntry{
{message: "this is a trace log", level: LogLevelTrace},
{message: "this is a debug log", level: LogLevelDebug},
{message: "this is an info log", level: LogLevelInfo},
{message: "this is a warning log", level: LogLevelWarn},
{message: "this is an error log", level: LogLevelError},
{message: "this is a trace log", level: LogLevelTrace}}
}
actualLogs := actual.String()
for _, log := range expected {
assert.Contains(t, actualLogs, fmtLogMessage(log.level, log.message))
}
}

assert.Equal(t, expected, actual)
SetLogLevel(LogLevelWarn)
actual.Reset()

exit, _, err = plugin.Call("run_test", []byte{})

if assertCall(t, err, exit) {
expected := []LogEntry{
{message: "this is a warning log", level: LogLevelWarn},
{message: "this is an error log", level: LogLevelError},
}
expectedNot := []LogEntry{
{message: "this is a trace log", level: LogLevelTrace},
{message: "this is a debug log", level: LogLevelDebug},
{message: "this is an info log", level: LogLevelInfo},
}
actualLogs := actual.String()
for _, log := range expected {
assert.Contains(t, actualLogs, fmtLogMessage(log.level, log.message))
}
for _, log := range expectedNot {
assert.NotContains(t, actualLogs, fmtLogMessage(log.level, log.message))
}
}
}
}
Expand Down Expand Up @@ -699,7 +736,7 @@ func TestHelloHaskell(t *testing.T) {
if plugin, ok := plugin(t, manifest); ok {
defer plugin.Close()

plugin.SetLogLevel(LogLevelTrace)
SetLogLevel(LogLevelTrace)
plugin.Config["greeting"] = "Howdy"

exit, output, err := plugin.Call("testing", []byte("John"))
Expand Down Expand Up @@ -1068,16 +1105,6 @@ func BenchmarkReplace(b *testing.B) {
}
}

func generateRandomString(length int, seed int64) string {
rand.Seed(seed)
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
result := make([]byte, length)
for i := range result {
result[i] = charset[rand.Intn(len(charset))]
}
return string(result)
}

func wasiPluginConfig() PluginConfig {
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
Expand Down
Loading
Loading