diff --git a/js/bundle.go b/js/bundle.go index 258cbae69dd1..c2eeacfaf28d 100644 --- a/js/bundle.go +++ b/js/bundle.go @@ -1,14 +1,15 @@ package js import ( - "bytes" "context" "encoding/json" "errors" "fmt" "net/url" "path/filepath" + "reflect" "runtime" + "sort" "github.com/dop251/goja" "github.com/sirupsen/logrus" @@ -102,7 +103,7 @@ func newBundle( return nil, err } - err = bundle.populateExports(piState.Logger, updateOptions, instance) + err = bundle.populateExports(updateOptions, instance) if err != nil { return nil, err } @@ -159,7 +160,7 @@ func (b *Bundle) makeArchive() *lib.Archive { } // populateExports validates and extracts exported objects -func (b *Bundle) populateExports(logger logrus.FieldLogger, updateOptions bool, instance moduleInstance) error { +func (b *Bundle) populateExports(updateOptions bool, instance moduleInstance) error { exports := instance.exports() if exports == nil { return errors.New("exports must be an object") @@ -175,17 +176,8 @@ func (b *Bundle) populateExports(logger logrus.FieldLogger, updateOptions bool, if !updateOptions { continue } - data, err := json.Marshal(v.Export()) - if err != nil { - return fmt.Errorf("error parsing script options: %w", err) - } - dec := json.NewDecoder(bytes.NewReader(data)) - dec.DisallowUnknownFields() - if err := dec.Decode(&b.Options); err != nil { - if uerr := json.Unmarshal(data, &b.Options); uerr != nil { - return uerr - } - logger.WithError(err).Warn("There were unknown fields in the options exported in the script") + if err := b.updateOptions(v); err != nil { + return err } case consts.SetupFn: return errors.New("exported 'setup' must be a function") @@ -201,6 +193,66 @@ func (b *Bundle) populateExports(logger logrus.FieldLogger, updateOptions bool, return nil } +// TODO: something cleaner than this, with far less reflection magic... +func (b *Bundle) updateOptions(jsVal goja.Value) (err error) { + if common.IsNullish(jsVal) { + return nil // no options were exported, nothing to update + } + + if jsVal.ExportType().Kind() != reflect.Map { + return fmt.Errorf("the exported script options should be a JS object") + } + + // TODO: maybe work with the *goja.Object directly, if we can pass the + // runtime shomehow to call jsVal.ToObject(rt)? + expOptions, isMap := jsVal.Export().(map[string]interface{}) + if !isMap { + return fmt.Errorf("the exported script options should be a JS object with string keys") + } + + keys := make([]string, 0, len(expOptions)) + for k := range expOptions { + keys = append(keys, k) + } + sort.Strings(keys) + + optionsJSONFields := lib.GetStructFieldsByTagKey(&b.Options, "json") + + var errs []error + for _, k := range keys { + opt, ok := optionsJSONFields[k] + if !ok { + // TODO: make this an error + b.preInitState.Logger.Warnf("'%s' is used in the exported script options, but it's not a valid k6 option", k) + continue + } + + // TODO: have a way to work with these values without having to go through JSON? + optJSON, err := json.Marshal(expOptions[k]) + if err != nil { + errs = append(errs, fmt.Errorf("error extracting '%s': %w", k, err)) + continue + } + + switch v := opt.(type) { + case lib.JSONUnmarshalerWithPreInitState: + err = v.UnmarshalJSONWithPIState(b.preInitState, optJSON) + case json.Unmarshaler: + err = v.UnmarshalJSON(optJSON) + default: + err = json.Unmarshal(optJSON, opt) // fingers crossed... + } + if err != nil { + errs = append(errs, fmt.Errorf("error parsing '%s': %w", k, err)) + } + } + if len(errs) > 0 { + return fmt.Errorf("there were errors with the exported script options: %w", errors.Join(errs...)) + } + + return nil +} + // Instantiate creates a new runtime from this bundle. func (b *Bundle) Instantiate(ctx context.Context, vuID uint64) (*BundleInstance, error) { // Instantiate the bundle into a new VM using a bound init context. This uses a context with a diff --git a/js/bundle_test.go b/js/bundle_test.go index 4ac2cf270c1b..d1a52aaa5043 100644 --- a/js/bundle_test.go +++ b/js/bundle_test.go @@ -201,8 +201,8 @@ func TestNewBundle(t *testing.T) { invalidOptions := map[string]struct { Expr, Error string }{ - "Array": {`[]`, "json: cannot unmarshal array into Go value of type lib.Options"}, - "Function": {`function(){}`, "error parsing script options: json: unsupported type: func(goja.FunctionCall) goja.Value"}, + "Array": {`[]`, "the exported script options should be a JS object"}, + "Function": {`function(){}`, "the exported script options should be a JS object"}, } for name, data := range invalidOptions { t.Run(name, func(t *testing.T) { @@ -453,8 +453,7 @@ func TestNewBundle(t *testing.T) { entries := hook.Drain() require.Len(t, entries, 1) assert.Equal(t, logrus.WarnLevel, entries[0].Level) - assert.Contains(t, entries[0].Message, "There were unknown fields") - assert.Contains(t, entries[0].Data["error"].(error).Error(), "unknown field \"something\"") + assert.Contains(t, entries[0].Message, "'something' is used in the exported script options, but it's not a valid k6 option") }) }) } diff --git a/lib/options.go b/lib/options.go index b90e666b0a68..ee30c5823f0d 100644 --- a/lib/options.go +++ b/lib/options.go @@ -563,3 +563,35 @@ func (o Options) ForEachSpecified(structTag string, callback func(key string, va } } } + +// JSONUnmarshalerWithPreInitState can be implemented by types that require +// stateful unmarshalling of JSON values. +type JSONUnmarshalerWithPreInitState interface { + UnmarshalJSONWithPIState(*TestPreInitState, []byte) error +} + +// GetStructFieldsByTagKey returns a map with pointers to all of the struct +// fields. The keys of that map are the confugured struct tag values for the +// given structTagKey (e.g. "json"). +func GetStructFieldsByTagKey(val interface{}, structTagKey string) map[string]interface{} { + structPType := reflect.TypeOf(val) + if structPType.Kind() != reflect.Pointer { + panic(fmt.Errorf("GetStructFieldsByTagKey() expects a pointer, but was given %s", structPType.Kind())) + } + structPVal := reflect.ValueOf(val) + + structType := structPType.Elem() + structVal := structPVal.Elem() + res := map[string]interface{}{} + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + fieldVal := structVal.Field(i) + + key, ok := fieldType.Tag.Lookup(structTagKey) + if !ok { + continue + } + res[key] = fieldVal.Addr().Interface() + } + return res +}