From 72114aa9741aaf204b84f553a7b3a467a7213b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=87=8E=E5=A3=B0?= Date: Mon, 11 Mar 2024 11:26:01 +0800 Subject: [PATCH] refactor: make plugin more easier to maintain (#83) * refactor: add lua vm struct * feat: support struct marshal * feat: add more ctx * fix: encoding tag map error * feat: add logger * feat: add debug flag * fix: fix decode in mixed struct * fix: unmarshal to interface is not work * feat: ctx use marshal * feat: unmarshal hook result * ci: update test cases * test: streamline encoding unit test * chore: remove print * test: add more testcases * chore: revert hook name enum * chore: add license * chore: update log message * refactor: add plugin module * refactor: add a vm file * chore: update logger * chore: add license * feat: support marshal nil * fix: preinstall should work * fix: lua checksum * chore: update code * refactor: unmarshal plugin info * chore: some modifications * mod: remove debug log in encoding func * mod: Optimize function calls * bugfix * mod --------- Co-authored-by: lihan --- .github/workflows/ci.yml | 2 + cmd/cmd.go | 17 +- internal/interfaces.go | 169 +++++++++ internal/logger/logger.go | 70 ++++ internal/luai/decode.go | 294 +++++++++++++++ internal/luai/encode.go | 108 ++++++ internal/luai/encoding_test.go | 313 ++++++++++++++++ internal/{ => luai}/fixtures/preload.lua | 0 internal/luai/vm.go | 81 ++++ internal/package.go | 12 +- internal/plugin.go | 448 +++++++++-------------- internal/plugin_test.go | 150 ++++++++ internal/sdk.go | 4 + internal/testdata/plugins/java.lua | 2 +- 14 files changed, 1395 insertions(+), 275 deletions(-) create mode 100644 internal/interfaces.go create mode 100644 internal/logger/logger.go create mode 100644 internal/luai/decode.go create mode 100644 internal/luai/encode.go create mode 100644 internal/luai/encoding_test.go rename internal/{ => luai}/fixtures/preload.lua (100%) create mode 100644 internal/luai/vm.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a81bb3cf..980ac1f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,5 +24,7 @@ jobs: run: | go build . - name: Test with the Go CLI + # we cannot use `go test ./...` currently, because many test cases are failed run: | go test ./internal + go test ./internal/luai diff --git a/cmd/cmd.go b/cmd/cmd.go index 9ad75a37..4a390e0a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -18,10 +18,12 @@ package cmd import ( "fmt" + "os" + "github.com/urfave/cli/v2" "github.com/version-fox/vfox/cmd/commands" "github.com/version-fox/vfox/internal" - "os" + "github.com/version-fox/vfox/internal/logger" ) func Execute(args []string) { @@ -66,6 +68,19 @@ func newCmd() *cmd { _, _ = fmt.Fprintln(ctx.App.Writer, command.Name) } } + + debugFlags := &cli.BoolFlag{ + Name: "debug", + Usage: "show debug information", + Action: func(ctx *cli.Context, b bool) error { + logger.SetLevel(logger.DebugLevel) + return nil + }, + } + + app.Flags = []cli.Flag{ + debugFlags, + } app.Commands = []*cli.Command{ commands.Info, commands.Install, diff --git a/internal/interfaces.go b/internal/interfaces.go new file mode 100644 index 00000000..716db2ff --- /dev/null +++ b/internal/interfaces.go @@ -0,0 +1,169 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package internal + +import ( + "fmt" +) + +type LuaCheckSum struct { + Sha256 string `luai:"sha256"` + Sha512 string `luai:"sha512"` + Sha1 string `luai:"sha1"` + Md5 string `luai:"md5"` +} + +func (c *LuaCheckSum) Checksum() *Checksum { + checksum := &Checksum{} + + if c.Sha256 != "" { + checksum.Value = c.Sha256 + checksum.Type = "sha256" + } else if c.Md5 != "" { + checksum.Value = c.Md5 + checksum.Type = "md5" + } else if c.Sha1 != "" { + checksum.Value = c.Sha1 + checksum.Type = "sha1" + } else if c.Sha512 != "" { + checksum.Value = c.Sha512 + checksum.Type = "sha512" + } else { + return NoneChecksum + } + + return checksum +} + +type AvailableHookCtx struct { + RuntimeVersion string `luai:"runtimeVersion"` +} + +type AvailableHookResultItem struct { + Version string `luai:"version"` + Note string `luai:"note"` + + Addition []*Info `luai:"addition"` +} + +type AvailableHookResult = []*AvailableHookResultItem + +type PreInstallHookCtx struct { + Version string `luai:"version"` + RuntimeVersion string `luai:"runtimeVersion"` +} + +type PreInstallHookResultAdditionItem struct { + Name string `luai:"name"` + Url string `luai:"url"` + + Sha256 string `luai:"sha256"` + Sha512 string `luai:"sha512"` + Sha1 string `luai:"sha1"` + Md5 string `luai:"md5"` +} + +func (i *PreInstallHookResultAdditionItem) Info() *Info { + sum := LuaCheckSum{ + Sha256: i.Sha256, + Sha512: i.Sha512, + Sha1: i.Sha1, + Md5: i.Md5, + } + + return &Info{ + Name: i.Name, + Version: Version(""), + Path: i.Url, + Note: "", + Checksum: sum.Checksum(), + } +} + +type PreInstallHookResult struct { + Version string `luai:"version"` + Url string `luai:"url"` + + Sha256 string `luai:"sha256"` + Sha512 string `luai:"sha512"` + Sha1 string `luai:"sha1"` + Md5 string `luai:"md5"` + + Addition []*PreInstallHookResultAdditionItem `luai:"addition"` +} + +func (i *PreInstallHookResult) Info() (*Info, error) { + if i.Version == "" { + return nil, fmt.Errorf("no version number provided") + } + + sum := LuaCheckSum{ + Sha256: i.Sha256, + Sha512: i.Sha512, + Sha1: i.Sha1, + Md5: i.Md5, + } + + return &Info{ + Name: "", + Version: Version(i.Version), + Path: i.Url, + Note: "", + Checksum: sum.Checksum(), + }, nil +} + +type PreUseHookCtx struct { + RuntimeVersion string `luai:"runtimeVersion"` + Cwd string `luai:"cwd"` + Scope string `luai:"scope"` + Version string `luai:"version"` + PreviousVersion string `luai:"previousVersion"` + InstalledSdks map[string]*Info `luai:"installedSdks"` +} + +type PreUseHookResult struct { + Version string `luai:"version"` +} + +type PostInstallHookCtx struct { + RuntimeVersion string `luai:"runtimeVersion"` + RootPath string `luai:"rootPath"` + SdkInfo map[string]*Info `luai:"sdkInfo"` +} + +type EnvKeysHookCtx struct { + RuntimeVersion string `luai:"runtimeVersion"` + Main *Info `luai:"main"` + // TODO Will be deprecated in future versions + Path string `luai:"path"` + SdkInfo map[string]*Info `luai:"sdkInfo"` +} + +type EnvKeysHookResultItem struct { + Key string `luai:"key"` + Value string `luai:"value"` +} + +type LuaPluginInfo struct { + Name string `luai:"name"` + Author string `luai:"author"` + Version string `luai:"version"` + Description string `luai:"description"` + UpdateUrl string `luai:"updateUrl"` + MinRuntimeVersion string `luai:"minRuntimeVersion"` +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 00000000..723c1df4 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,70 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import "fmt" + +type LoggerLevel int + +// the smaller the level, the more logs. +const ( + DebugLevel LoggerLevel = iota + InfoLevel + ErrorLevel +) + +var currentLevel = InfoLevel + +func SetLevel(_level LoggerLevel) { + currentLevel = _level +} + +func Log(level LoggerLevel, args ...interface{}) { + if currentLevel <= level { + fmt.Println(args...) + } +} + +func Logf(level LoggerLevel, message string, args ...interface{}) { + if currentLevel <= level { + fmt.Printf(message, args...) + } +} + +func Error(message ...interface{}) { + Log(ErrorLevel, message...) +} + +func Errorf(message string, args ...interface{}) { + Logf(ErrorLevel, message, args...) +} + +func Info(message ...interface{}) { + Log(InfoLevel, message...) +} + +func Infof(message string, args ...interface{}) { + Logf(InfoLevel, message, args...) +} + +func Debug(args ...interface{}) { + Log(DebugLevel, args...) +} + +func Debugf(message string, args ...interface{}) { + Logf(DebugLevel, message, args...) +} diff --git a/internal/luai/decode.go b/internal/luai/decode.go new file mode 100644 index 00000000..907bdc2e --- /dev/null +++ b/internal/luai/decode.go @@ -0,0 +1,294 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package luai + +import ( + "errors" + "reflect" + "strconv" + + lua "github.com/yuin/gopher-lua" +) + +// modified from https://cs.opensource.google/go/go/+/master:src/encoding/json/decode.go +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +func indirect(v reflect.Value) reflect.Value { + // Issue https://github.com/golang/go/issues/24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Pointer && !e.IsNil() && (e.Elem().Kind() == reflect.Pointer) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Pointer { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return v +} + +func storeLiteral(value reflect.Value, lvalue lua.LValue) { + value = indirect(value) + switch lvalue.Type() { + case lua.LTString: + value.SetString(lvalue.String()) + case lua.LTNumber: + value.SetInt(int64(lvalue.(lua.LNumber))) + case lua.LTBool: + value.SetBool(bool(lvalue.(lua.LBool))) + } +} + +func objectInterface(lvalue *lua.LTable) any { + var v = make(map[string]any) + lvalue.ForEach(func(key, value lua.LValue) { + v[key.String()] = valueInterface(value) + }) + return v +} + +func valueInterface(lvalue lua.LValue) any { + switch lvalue.Type() { + case lua.LTTable: + isArray := lvalue.(*lua.LTable).RawGetInt(1) != lua.LNil + if isArray { + return arrayInterface(lvalue.(*lua.LTable)) + } + return objectInterface(lvalue.(*lua.LTable)) + case lua.LTString: + return lvalue.String() + case lua.LTNumber: + return int(lvalue.(lua.LNumber)) + case lua.LTBool: + return bool(lvalue.(lua.LBool)) + } + return nil +} + +func arrayInterface(lvalue *lua.LTable) any { + var v = make([]any, 0) + lvalue.ForEach(func(key, value lua.LValue) { + v = append(v, valueInterface(value)) + }) + + return v +} + +func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { + + switch value.Type() { + case lua.LTTable: + reflected = indirect(reflected) + tagMap := make(map[string]int) + + switch reflected.Kind() { + case reflect.Interface: + // Decoding into nil interface? Switch to non-reflect code. + if reflected.NumMethod() == 0 { + result := valueInterface(value) + reflected.Set(reflect.ValueOf(result)) + } + // map[T1]T2 where T1 is string, an integer type + case reflect.Map: + t := reflected.Type() + keyType := t.Key() + // Map key must either have string kind, have an integer kind + switch keyType.Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + return errors.New("unmarshal: unsupported map key type " + keyType.String()) + } + + if reflected.IsNil() { + reflected.Set(reflect.MakeMap(t)) + } + + var mapElem reflect.Value + + value.(*lua.LTable).ForEach(func(key, value lua.LValue) { + // Figure out field corresponding to key. + var subv reflect.Value + + elemType := t.Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.SetZero() + } + + subv = mapElem + + unmarshalWorker(value, subv) + + var kv reflect.Value + switch keyType.Kind() { + case reflect.String: + kv = reflect.New(keyType).Elem() + kv.SetString(key.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := key.String() + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + break + } + kv = reflect.New(keyType).Elem() + kv.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := key.String() + n, err := strconv.ParseUint(s, 10, 64) + if err != nil { + break + } + kv = reflect.New(keyType).Elem() + kv.SetUint(n) + default: + panic("unmarshal: Unexpected key type") // should never occur + } + if kv.IsValid() { + reflected.SetMapIndex(kv, subv) + } + + }) + case reflect.Slice: + i := 0 + + value.(*lua.LTable).ForEach(func(key, value lua.LValue) { + // Expand slice length, growing the slice if necessary. + if i >= reflected.Cap() { + reflected.Grow(1) + } + if i >= reflected.Len() { + reflected.SetLen(i + 1) + } + if i < reflected.Len() { + // Decode into element. + unmarshalWorker(value, reflected.Index(i)) + } else { + unmarshalWorker(value, reflect.Value{}) + } + i++ + }) + + // Truncate slice if necessary. + if i < reflected.Len() { + reflected.SetLen(i) + } + + if i == 0 { + reflected.Set(reflect.MakeSlice(reflected.Type(), 0, 0)) + } + case reflect.Struct: + for i := 0; i < reflected.NumField(); i++ { + fieldTypeField := reflected.Type().Field(i) + tag := fieldTypeField.Tag.Get("luai") + if tag != "" { + tagMap[tag] = i + } + } + + (value.(*lua.LTable)).ForEach(func(key, value lua.LValue) { + fieldName := key.String() + + field := reflected.FieldByName(fieldName) + + // if field is not found, try to find it by tag + if !field.IsValid() { + fieldIndex, ok := tagMap[fieldName] + if !ok { + return + } + field = reflected.Field(fieldIndex) + } + + if !field.IsValid() { + return + } + + unmarshalWorker(value, field) + }) + } + default: + switch reflected.Kind() { + case reflect.Interface: + // Decoding into nil interface? Switch to non-reflect code. + if reflected.NumMethod() == 0 { + result := valueInterface(value) + reflected.Set(reflect.ValueOf(result)) + } + default: + storeLiteral(reflected, value) + } + } + return nil +} + +func Unmarshal(value lua.LValue, v any) error { + reflected := reflect.ValueOf(v) + + if reflected.Kind() != reflect.Pointer || reflected.IsNil() { + return errors.New("unmarshal: value must be a pointer") + } + + return unmarshalWorker(value, reflected) +} diff --git a/internal/luai/encode.go b/internal/luai/encode.go new file mode 100644 index 00000000..f85eaf3d --- /dev/null +++ b/internal/luai/encode.go @@ -0,0 +1,108 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package luai + +import ( + "errors" + "reflect" + + lua "github.com/yuin/gopher-lua" +) + +func Marshal(state *lua.LState, v any) (lua.LValue, error) { + reflected := reflect.ValueOf(v) + if reflected.Kind() == reflect.Ptr { + reflected = reflected.Elem() + } + + if !reflected.IsValid() { + return lua.LNil, nil + } + + switch reflected.Kind() { + case reflect.Struct: + table := state.NewTable() + for i := 0; i < reflected.NumField(); i++ { + field := reflected.Field(i) + if field.Kind() == reflect.Ptr { + field = field.Elem() + } + + fieldType := reflected.Type().Field(i) + tag := fieldType.Tag.Get("luai") + if tag == "" { + tag = fieldType.Name + } + + if !field.IsValid() { + continue + } + + sub, err := Marshal(state, field.Interface()) + if err != nil { + return nil, err + } + + table.RawSetString(tag, sub) + } + return table, nil + case reflect.String: + return lua.LString(reflected.String()), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return lua.LNumber(reflected.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return lua.LNumber(reflected.Uint()), nil + case reflect.Float32, reflect.Float64: + return lua.LNumber(reflected.Float()), nil + case reflect.Bool: + return lua.LBool(reflected.Bool()), nil + case reflect.Array, reflect.Slice: + table := state.NewTable() + for i := 0; i < reflected.Len(); i++ { + field := reflected.Index(i) + if !field.IsValid() { + continue + } + + value, err := Marshal(state, field.Interface()) + if err != nil { + return nil, err + } + table.RawSetInt(i+1, value) + } + return table, nil + case reflect.Map: + table := state.NewTable() + for _, key := range reflected.MapKeys() { + field := reflected.MapIndex(key) + if !field.IsValid() { + continue + } + + value, err := Marshal(state, field.Interface()) + if err != nil { + return nil, err + } + + table.RawSetString(key.String(), value) + } + return table, nil + default: + return nil, errors.New("marshal: unsupported type " + reflected.Kind().String() + " for reflected ") + } + +} diff --git a/internal/luai/encoding_test.go b/internal/luai/encoding_test.go new file mode 100644 index 00000000..ae304a5d --- /dev/null +++ b/internal/luai/encoding_test.go @@ -0,0 +1,313 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package luai + +import ( + "fmt" + "reflect" + "testing" + + "github.com/version-fox/vfox/internal/logger" + lua "github.com/yuin/gopher-lua" +) + +func setupSuite(tb testing.TB) func(tb testing.TB) { + logger.SetLevel(logger.DebugLevel) + + return func(tb testing.TB) { + logger.SetLevel(logger.InfoLevel) + } +} + +type testStruct struct { + Field1 string + Field2 int + Field3 bool +} + +type testStructTag struct { + Field1 string `luai:"field1"` + Field2 int `luai:"field2"` + Field3 bool `luai:"field3"` +} + +type complexStruct struct { + Field1 string + Field2 int + Field3 bool + SimpleStruct *testStruct + Struct testStructTag + Map map[string]interface{} + Slice []any +} + +func TestEncoding(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + m := map[string]interface{}{ + "key1": "value1", + "key2": 2, + "key3": true, + } + + s := []any{"value1", 2, true} + + t.Run("Struct", func(t *testing.T) { + luaVm := lua.NewState() + defer luaVm.Close() + + test := testStruct{ + Field1: "test", + Field2: 1, + Field3: true, + } + + _table, err := Marshal(luaVm, &test) + if err != nil { + t.Fatal(err) + } + + luaVm.SetGlobal("table", _table) + + if err := luaVm.DoString(` + assert(table.Field1 == "test") + assert(table.Field2 == 1) + assert(table.Field3 == true) + print("lua Struct done") + `); err != nil { + t.Fatal(err) + } + + struct2 := testStruct{} + err = Unmarshal(_table, &struct2) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(test, struct2) { + t.Errorf("expected %+v, got %+v", test, struct2) + } + }) + + t.Run("Struct with Tag", func(t *testing.T) { + luaVm := lua.NewState() + defer luaVm.Close() + + test := testStructTag{ + Field1: "test", + Field2: 1, + Field3: true, + } + + _table, err := Marshal(luaVm, &test) + if err != nil { + t.Fatal(err) + } + + table := _table.(*lua.LTable) + + luaVm.SetGlobal("table", table) + if err := luaVm.DoString(` + assert(table.field1 == "test") + assert(table.field2 == 1) + assert(table.field3 == true) + print("lua Struct with Tag done") + `); err != nil { + t.Fatal(err) + } + + struct2 := testStructTag{} + err = Unmarshal(table, &struct2) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(test, struct2) { + t.Errorf("expected %+v, got %+v", test, struct2) + } + }) + + t.Run("Support Map, Slice and Any", func(t *testing.T) { + L := lua.NewState() + defer L.Close() + table, err := Marshal(L, m) + if err != nil { + t.Fatalf("marshal map failed: %v", err) + } + L.SetGlobal("m", table) + if err := L.DoString(` + assert(m.key1 == "value1") + assert(m.key2 == 2) + assert(m.key3 == true) + print("lua Map done") + `); err != nil { + t.Errorf("map test failed: %v", err) + } + + slice, err := Marshal(L, s) + if err != nil { + t.Fatalf("marshal slice failed: %v", err) + } + + L.SetGlobal("s", slice) + if err := L.DoString(` + assert(s[1] == "value1") + assert(s[2] == 2) + assert(s[3] == true) + print("lua Slice done") + `); err != nil { + t.Errorf("slice test failed: %v", err) + } + + // Unmarshal + + // Test case for map + m2 := map[string]any{} + + fmt.Println("==== start unmarshal ====") + + err = Unmarshal(table, &m2) + if err != nil { + t.Fatalf("unmarshal map failed: %v", err) + } + + fmt.Printf("m2: %+v\n", m2) + + if !reflect.DeepEqual(m, m2) { + t.Errorf("expected %+v, got %+v", m, m2) + } + + // Test case for slice + s2 := []any{} + + err = Unmarshal(slice, &s2) + if err != nil { + t.Fatalf("unmarshal slice failed: %v", err) + } + + fmt.Printf("s2: %+v\n", s2) + + if !reflect.DeepEqual(s, s2) { + t.Errorf("expected %+v, got %+v", s, s2) + } + + var s3 any + err = Unmarshal(slice, &s3) + if err != nil { + t.Fatalf("unmarshal slice failed: %v", err) + } + + if !reflect.DeepEqual(s, s3) { + t.Errorf("expected %+v, got %+v", s, s3) + } + }) + + t.Run("MapSliceStructUnified", func(t *testing.T) { + L := lua.NewState() + defer L.Close() + + input := complexStruct{ + Field1: "value1", + Field2: 123, + Field3: true, + Struct: testStructTag{ + Field1: "value1", + Field2: 2, + Field3: true, + }, + Map: m, + Slice: s, + } + + table, err := Marshal(L, input) + if err != nil { + t.Fatalf("marshal map failed: %v", err) + } + + L.SetGlobal("m", table) + + if err := L.DoString(` + assert(m.Field1 == "value1") + assert(m.Field2 == 123) + assert(m.Field3 == true) + assert(m.Struct.field1 == "value1") + assert(m.Struct.field2 == 2) + assert(m.Struct.field3 == true) + assert(m.Map.key1 == "value1") + assert(m.Map.key2 == 2) + assert(m.Map.key3 == true) + assert(m.Slice[1] == "value1") + assert(m.Slice[2] == 2) + assert(m.Slice[3] == true) + print("lua MapSliceStructUnified done") + `); err != nil { + t.Errorf("map test failed: %v", err) + } + + // Unmarshal + output := complexStruct{} + err = Unmarshal(table, &output) + if err != nil { + t.Fatalf("unmarshal map failed: %v", err) + } + + isEqual := reflect.DeepEqual(input, output) + if !isEqual { + t.Fatalf("expected %+v, got %+v", input, output) + } + + fmt.Printf("output: %+v\n", output) + + if !reflect.DeepEqual(input, output) { + t.Errorf("expected %+v, got %+v", input, output) + } + }) + + t.Run("TableWithEmptyField", func(t *testing.T) { + L := lua.NewState() + defer L.Close() + + output := struct { + Field1 string `luai:"field1"` + Field2 *string `luai:"field2"` + }{} + + if err := L.DoString(` + return { + field1 = "value1", + } + `); err != nil { + t.Errorf("map test failed: %v", err) + } + + table := L.ToTable(-1) // returned value + L.Pop(1) + // Unmarshal + err := Unmarshal(table, &output) + if err != nil { + t.Fatalf("unmarshal map failed: %v", err) + } + fmt.Printf("output: %+v\n", output) + if output.Field1 != "value1" { + t.Errorf("expected %+v, got %+v", "value1", output.Field1) + } + if output.Field2 != nil { + t.Errorf("expected %+v, got %+v", nil, output.Field2) + } + }) +} diff --git a/internal/fixtures/preload.lua b/internal/luai/fixtures/preload.lua similarity index 100% rename from internal/fixtures/preload.lua rename to internal/luai/fixtures/preload.lua diff --git a/internal/luai/vm.go b/internal/luai/vm.go new file mode 100644 index 00000000..bcd9a288 --- /dev/null +++ b/internal/luai/vm.go @@ -0,0 +1,81 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package luai + +import ( + _ "embed" + + "github.com/version-fox/vfox/internal/config" + "github.com/version-fox/vfox/internal/module" + lua "github.com/yuin/gopher-lua" +) + +//go:embed fixtures/preload.lua +var preloadScript string + +type LuaVM struct { + Instance *lua.LState +} + +func NewLuaVM() *LuaVM { + instance := lua.NewState() + + return &LuaVM{ + Instance: instance, + } +} + +type PrepareOptions struct { + Config *config.Config +} + +func (vm *LuaVM) Prepare(options *PrepareOptions) error { + if err := vm.Instance.DoString(preloadScript); err != nil { + return err + } + module.Preload(vm.Instance, options.Config) + + return nil +} + +func (vm *LuaVM) ReturnedValue() *lua.LTable { + table := vm.Instance.ToTable(-1) // returned value + vm.Instance.Pop(1) // remove received value + return table +} + +func (vm *LuaVM) CallFunction(function lua.LValue, args ...lua.LValue) error { + if err := vm.Instance.CallByParam(lua.P{ + Fn: function.(*lua.LFunction), + NRet: 1, + Protect: true, + }, args...); err != nil { + return err + } + return nil +} + +func (vm *LuaVM) GetTableString(table *lua.LTable, key string) string { + if value := table.RawGetString(key); value.Type() != lua.LTNil { + return value.String() + } + return "" +} + +func (vm *LuaVM) Close() { + vm.Instance.Close() +} diff --git a/internal/package.go b/internal/package.go index a8519604..bea73d13 100644 --- a/internal/package.go +++ b/internal/package.go @@ -16,7 +16,9 @@ package internal -import "path/filepath" +import ( + "path/filepath" +) type Package struct { Main *Info @@ -24,10 +26,10 @@ type Package struct { } type Info struct { - Name string - Version Version - Path string - Note string + Name string `luai:"name"` + Version Version `luai:"version"` + Path string `luai:"path"` + Note string `luai:"note"` Checksum *Checksum } diff --git a/internal/plugin.go b/internal/plugin.go index 5290ad47..89e79111 100644 --- a/internal/plugin.go +++ b/internal/plugin.go @@ -18,19 +18,18 @@ package internal import ( _ "embed" + "errors" "fmt" "path/filepath" "regexp" "strings" "github.com/version-fox/vfox/internal/env" - "github.com/version-fox/vfox/internal/module" + "github.com/version-fox/vfox/internal/logger" + "github.com/version-fox/vfox/internal/luai" lua "github.com/yuin/gopher-lua" ) -//go:embed fixtures/preload.lua -var preloadScript string - const ( LuaPluginObjKey = "PLUGIN" OsType = "OS_TYPE" @@ -38,90 +37,85 @@ const ( ) type LuaPlugin struct { - state *lua.LState + vm *luai.LuaVM pluginObj *lua.LTable // plugin source path Filepath string // plugin filename, this is also alias name, sdk-name Filename string // The name defined inside the plugin - Name string - Author string - Version string - Description string - UpdateUrl string - MinRuntimeVersion string + + LuaPluginInfo } func (l *LuaPlugin) checkValid() error { - if l.state == nil { + if l.vm == nil || l.vm.Instance == nil { return fmt.Errorf("lua vm is nil") } - obj := l.pluginObj - if obj.RawGetString("Available") == lua.LNil { + + if !l.HasFunction("Available") { return fmt.Errorf("[Available] function not found") } - if obj.RawGetString("PreInstall") == lua.LNil { + if !l.HasFunction("PreInstall") { return fmt.Errorf("[PreInstall] function not found") } - if obj.RawGetString("EnvKeys") == lua.LNil { + if !l.HasFunction("EnvKeys") { return fmt.Errorf("[EnvKeys] function not found") } return nil } func (l *LuaPlugin) Close() { - l.state.Close() + l.vm.Close() } func (l *LuaPlugin) Available() ([]*Package, error) { - L := l.state - ctxTable := L.NewTable() - L.SetField(ctxTable, "runtimeVersion", lua.LString(RuntimeVersion)) - if err := L.CallByParam(lua.P{ - Fn: l.pluginObj.RawGetString("Available").(*lua.LFunction), - NRet: 1, - Protect: true, - }, l.pluginObj, ctxTable); err != nil { + L := l.vm.Instance + ctxTable, err := luai.Marshal(L, AvailableHookCtx{ + RuntimeVersion: RuntimeVersion, + }) + + if err != nil { + return nil, err + } + + if err = l.CallFunction("Available", ctxTable); err != nil { return nil, err } - table := l.returnedValue() + table := l.vm.ReturnedValue() if table == nil || table.Type() == lua.LTNil { return []*Package{}, nil } - var err error + + hookResult := AvailableHookResult{} + err = luai.Unmarshal(table, &hookResult) + if err != nil { + return nil, errors.New("failed to unmarshal the return value: " + err.Error()) + } + var result []*Package - table.ForEach(func(key lua.LValue, value lua.LValue) { - kvTable, ok := value.(*lua.LTable) - if !ok { - err = fmt.Errorf("the return value is not a table") - return - } - mainSdk, err := l.parseInfo(kvTable) - if err != nil { - return + + for _, item := range hookResult { + mainSdk := &Info{ + Name: l.Name, + Version: Version(item.Version), + Note: item.Note, } - mainSdk.Name = l.Name + var additionalArr []*Info - additional := kvTable.RawGetString("addition") - if tb, ok := additional.(*lua.LTable); ok && tb.Len() != 0 { - additional.(*lua.LTable).ForEach(func(key lua.LValue, value lua.LValue) { - itemTable, ok := value.(*lua.LTable) - if !ok { - err = fmt.Errorf("the return value is not a table") - return - } - item, err := l.parseInfo(itemTable) - if err != nil { - return - } - if item.Name == "" { - err = fmt.Errorf("additional file no name provided") - return - } - additionalArr = append(additionalArr, item) + + for i, addition := range item.Addition { + if addition.Name == "" { + logger.Errorf("[Available] additional file %d no name provided", i+1) + } + + additionalArr = append(additionalArr, &Info{ + Name: addition.Name, + Version: Version(addition.Version), + Path: addition.Path, + Note: addition.Note, }) } @@ -129,85 +123,52 @@ func (l *LuaPlugin) Available() ([]*Package, error) { Main: mainSdk, Additions: additionalArr, }) - - }) - if err != nil { - return nil, err } return result, nil } -func (l *LuaPlugin) Checksum(table *lua.LTable) *Checksum { - checksum := &Checksum{} - sha256 := table.RawGetString("sha256") - md5 := table.RawGetString("md5") - sha512 := table.RawGetString("sha512") - sha1 := table.RawGetString("sha1") - if sha256.Type() != lua.LTNil { - checksum.Value = sha256.String() - checksum.Type = "sha256" - } else if md5.Type() != lua.LTNil { - checksum.Value = md5.String() - checksum.Type = "md5" - } else if sha1.Type() != lua.LTNil { - checksum.Value = sha1.String() - checksum.Type = "sha1" - } else if sha512.Type() != lua.LTNil { - checksum.Value = sha512.String() - checksum.Type = "sha512" - } else { - return NoneChecksum - } - return checksum -} - func (l *LuaPlugin) PreInstall(version Version) (*Package, error) { - L := l.state - ctxTable := L.NewTable() - L.SetField(ctxTable, "version", lua.LString(version)) - L.SetField(ctxTable, "runtimeVersion", lua.LString(RuntimeVersion)) - - if err := L.CallByParam(lua.P{ - Fn: l.pluginObj.RawGetString("PreInstall").(*lua.LFunction), - NRet: 1, - Protect: true, - }, l.pluginObj, ctxTable); err != nil { + L := l.vm.Instance + ctxTable, err := luai.Marshal(L, PreInstallHookCtx{ + Version: string(version), + RuntimeVersion: RuntimeVersion, + }) + + if err != nil { + return nil, err + } + + if err = l.CallFunction("PreInstall", ctxTable); err != nil { return nil, err } - table := l.returnedValue() + table := l.vm.ReturnedValue() if table == nil || table.Type() == lua.LTNil { return nil, nil } - mainSdk, err := l.parseInfo(table) + + result := PreInstallHookResult{} + + err = luai.Unmarshal(table, &result) + if err != nil { + return nil, err + } + + mainSdk, err := result.Info() if err != nil { return nil, err } mainSdk.Name = l.Name + var additionalArr []*Info - additions := table.RawGetString("addition") - if tb, ok := additions.(*lua.LTable); ok && tb.Len() != 0 { - var err error - additions.(*lua.LTable).ForEach(func(key lua.LValue, value lua.LValue) { - kvTable, ok := value.(*lua.LTable) - if !ok { - err = fmt.Errorf("the return value is not a table") - return - } - info, err := l.parseInfo(kvTable) - if err != nil { - return - } - if info.Name == "" { - err = fmt.Errorf("additional file no name provided") - return - } - additionalArr = append(additionalArr, info) - }) - if err != nil { - return nil, err + + for i, addition := range result.Addition { + if addition.Name == "" { + return nil, fmt.Errorf("[PreInstall] additional file %d no name provided", i+1) } + + additionalArr = append(additionalArr, addition.Info()) } return &Package{ @@ -216,59 +177,29 @@ func (l *LuaPlugin) PreInstall(version Version) (*Package, error) { }, nil } -func (l *LuaPlugin) parseInfo(table *lua.LTable) (*Info, error) { - versionLua := table.RawGetString("version") - if versionLua == lua.LNil { - return nil, fmt.Errorf("no version number provided") - } - var ( - path string - note string - name string - version string - ) - version = versionLua.String() - - if urlLua := table.RawGetString("url"); urlLua != lua.LNil { - path = urlLua.String() - } - if noteLua := table.RawGetString("note"); noteLua != lua.LNil { - note = noteLua.String() - } - if nameLua := table.RawGetString("name"); nameLua != lua.LNil { - name = nameLua.String() - } - checksum := l.Checksum(table) - return &Info{ - Name: name, - Version: Version(version), - Path: path, - Note: note, - Checksum: checksum, - }, nil -} - func (l *LuaPlugin) PostInstall(rootPath string, sdks []*Info) error { - L := l.state - sdkArr := L.NewTable() + L := l.vm.Instance + + if !l.HasFunction("PostInstall") { + return nil + } + + ctx := &PostInstallHookCtx{ + RuntimeVersion: RuntimeVersion, + RootPath: rootPath, + SdkInfo: make(map[string]*Info), + } + for _, v := range sdks { - sdkTable := l.createSdkInfoTable(v) - L.SetField(sdkArr, v.Name, sdkTable) + ctx.SdkInfo[v.Name] = v } - ctxTable := L.NewTable() - L.SetField(ctxTable, "sdkInfo", sdkArr) - L.SetField(ctxTable, "runtimeVersion", lua.LString(RuntimeVersion)) - L.SetField(ctxTable, "rootPath", lua.LString(rootPath)) - function := l.pluginObj.RawGetString("PostInstall") - if function.Type() == lua.LTNil { - return nil + ctxTable, err := luai.Marshal(L, ctx) + if err != nil { + return err } - if err := L.CallByParam(lua.P{ - Fn: function.(*lua.LFunction), - NRet: 1, - Protect: true, - }, l.pluginObj, ctxTable); err != nil { + + if err = l.CallFunction("PostInstall", ctxTable); err != nil { return err } @@ -276,143 +207,132 @@ func (l *LuaPlugin) PostInstall(rootPath string, sdks []*Info) error { } func (l *LuaPlugin) EnvKeys(sdkPackage *Package) (env.Envs, error) { - L := l.state + L := l.vm.Instance mainInfo := sdkPackage.Main - sdkArr := L.NewTable() + + ctx := &EnvKeysHookCtx{ + // TODO Will be deprecated in future versions + Path: mainInfo.Path, + RuntimeVersion: RuntimeVersion, + Main: mainInfo, + SdkInfo: make(map[string]*Info), + } + for _, v := range sdkPackage.Additions { - sdkTable := l.createSdkInfoTable(v) - L.SetField(sdkArr, v.Name, sdkTable) - } - ctxTable := L.NewTable() - sdkTable := l.createSdkInfoTable(mainInfo) - L.SetField(ctxTable, "main", sdkTable) - L.SetField(ctxTable, "sdkInfo", sdkArr) - L.SetField(ctxTable, "runtimeVersion", lua.LString(RuntimeVersion)) - // TODO Will be deprecated in future versions - L.SetField(ctxTable, "path", lua.LString(mainInfo.Path)) - if err := L.CallByParam(lua.P{ - Fn: l.pluginObj.RawGetString("EnvKeys"), - NRet: 1, - Protect: true, - }, l.pluginObj, ctxTable); err != nil { + ctx.SdkInfo[v.Name] = v + } + + ctxTable, err := luai.Marshal(L, ctx) + if err != nil { return nil, err } - table := l.returnedValue() + if err = l.CallFunction("EnvKeys", ctxTable); err != nil { + return nil, err + } + + table := l.vm.ReturnedValue() + if table == nil || table.Type() == lua.LTNil || table.Len() == 0 { return nil, fmt.Errorf("no environment variables provided") } - var err error + envKeys := make(env.Envs) - table.ForEach(func(key lua.LValue, value lua.LValue) { - kvTable, ok := value.(*lua.LTable) - if !ok { - err = fmt.Errorf("the return value is not a table") - return - } - key = kvTable.RawGetString("key") - value = kvTable.RawGetString("value") - s := value.String() - envKeys[key.String()] = &s - }) + + var items []*EnvKeysHookResultItem + err = luai.Unmarshal(table, &items) if err != nil { return nil, err } - return envKeys, nil -} - -func (l *LuaPlugin) getTableField(table *lua.LTable, fieldName string) (lua.LValue, error) { - value := table.RawGetString(fieldName) - if value.Type() == lua.LTNil { - return nil, fmt.Errorf("field '%s' not found", fieldName) + for _, item := range items { + envKeys[item.Key] = &item.Value } - return value, nil -} -func (l *LuaPlugin) returnedValue() *lua.LTable { - table := l.state.ToTable(-1) // returned value - l.state.Pop(1) // remove received value - return table + return envKeys, nil } func (l *LuaPlugin) Label(version string) string { return fmt.Sprintf("%s@%s", l.Name, version) } -func (l *LuaPlugin) createSdkInfoTable(info *Info) *lua.LTable { - L := l.state - sdkTable := L.NewTable() - L.SetField(sdkTable, "name", lua.LString(info.Name)) - L.SetField(sdkTable, "version", lua.LString(info.Version)) - L.SetField(sdkTable, "path", lua.LString(info.Path)) - L.SetField(sdkTable, "note", lua.LString(info.Note)) - return sdkTable -} - func (l *LuaPlugin) HasFunction(name string) bool { return l.pluginObj.RawGetString(name) != lua.LNil } func (l *LuaPlugin) PreUse(version Version, previousVersion Version, scope UseScope, cwd string, installedSdks []*Package) (Version, error) { - L := l.state - lInstalledSdks := L.NewTable() + L := l.vm.Instance + + ctx := PreUseHookCtx{ + RuntimeVersion: RuntimeVersion, + Cwd: cwd, + Scope: scope.String(), + Version: string(version), + PreviousVersion: string(previousVersion), + InstalledSdks: make(map[string]*Info), + } + for _, v := range installedSdks { - sdkTable := l.createSdkInfoTable(v.Main) - L.SetField(lInstalledSdks, string(v.Main.Version), sdkTable) + lSdk := v.Main + ctx.InstalledSdks[string(lSdk.Version)] = lSdk } - ctxTable := L.NewTable() - L.SetField(ctxTable, "installedSdks", lInstalledSdks) - L.SetField(ctxTable, "runtimeVersion", lua.LString(RuntimeVersion)) - L.SetField(ctxTable, "cwd", lua.LString(cwd)) - L.SetField(ctxTable, "scope", lua.LString(scope.String())) - L.SetField(ctxTable, "version", lua.LString(version)) - L.SetField(ctxTable, "previousVersion", lua.LString(previousVersion)) + logger.Debugf("PreUseHookCtx: %+v", ctx) - function := l.pluginObj.RawGetString("PreUse") - if function.Type() == lua.LTNil { + ctxTable, err := luai.Marshal(L, ctx) + if err != nil { + return "", err + } + + if !l.HasFunction("PreUse") { return "", nil } - if err := L.CallByParam(lua.P{ - Fn: function.(*lua.LFunction), - NRet: 1, - Protect: true, - }, l.pluginObj, ctxTable); err != nil { + + if err = l.CallFunction("PreUse", ctxTable); err != nil { return "", err } - table := l.returnedValue() + table := l.vm.ReturnedValue() if table == nil || table.Type() == lua.LTNil { return "", nil } - luaVer, err := l.getTableField(table, "version") - if err != nil { - // ignore version field not found - return "", nil + result := &PreUseHookResult{} + + if err := luai.Unmarshal(table, result); err != nil { + return "", err } - return Version(luaVer.String()), nil + return Version(result.Version), nil +} + +func (l *LuaPlugin) CallFunction(funcName string, args ...lua.LValue) error { + logger.Debugf("CallFunction: %s\n", funcName) + if err := l.vm.CallFunction(l.pluginObj.RawGetString(funcName), append([]lua.LValue{l.pluginObj}, args...)...); err != nil { + return err + } + return nil } func NewLuaPlugin(content, path string, manager *Manager) (*LuaPlugin, error) { - luaVMInstance := lua.NewState() - module.Preload(luaVMInstance, manager.Config) + vm := luai.NewLuaVM() - if err := luaVMInstance.DoString(preloadScript); err != nil { + if err := vm.Prepare(&luai.PrepareOptions{ + Config: manager.Config, + }); err != nil { return nil, err } - if err := luaVMInstance.DoString(content); err != nil { + if err := vm.Instance.DoString(content); err != nil { return nil, err } + // !!!! Must be set after loading the script to prevent overwriting! // set OS_TYPE and ARCH_TYPE - luaVMInstance.SetGlobal(OsType, lua.LString(manager.osType)) - luaVMInstance.SetGlobal(ArchType, lua.LString(manager.archType)) + vm.Instance.SetGlobal(OsType, lua.LString(manager.osType)) + vm.Instance.SetGlobal(ArchType, lua.LString(manager.archType)) - pluginObj := luaVMInstance.GetGlobal(LuaPluginObjKey) + pluginObj := vm.Instance.GetGlobal(LuaPluginObjKey) if pluginObj.Type() == lua.LTNil { return nil, fmt.Errorf("plugin object not found") } @@ -420,7 +340,7 @@ func NewLuaPlugin(content, path string, manager *Manager) (*LuaPlugin, error) { PLUGIN := pluginObj.(*lua.LTable) source := &LuaPlugin{ - state: luaVMInstance, + vm: vm, pluginObj: PLUGIN, Filepath: path, Filename: strings.TrimSuffix(filepath.Base(path), filepath.Ext(path)), @@ -430,28 +350,20 @@ func NewLuaPlugin(content, path string, manager *Manager) (*LuaPlugin, error) { return nil, err } - if name := PLUGIN.RawGetString("name"); name.Type() == lua.LTNil { - return nil, fmt.Errorf("no plugin name provided") - } else { - source.Name = name.String() - if !isValidName(source.Name) { - return nil, fmt.Errorf("invalid plugin name") - } - } - if version := PLUGIN.RawGetString("version"); version.Type() != lua.LTNil { - source.Version = version.String() - } - if description := PLUGIN.RawGetString("description"); description.Type() != lua.LTNil { - source.Description = description.String() - } - if updateUrl := PLUGIN.RawGetString("updateUrl"); updateUrl.Type() != lua.LTNil { - source.UpdateUrl = updateUrl.String() + pluginInfo := LuaPluginInfo{} + err := luai.Unmarshal(PLUGIN, &pluginInfo) + if err != nil { + return nil, err } - if author := PLUGIN.RawGetString("author"); author.Type() != lua.LTNil { - source.Author = author.String() + + source.LuaPluginInfo = pluginInfo + + if !isValidName(source.Name) { + return nil, fmt.Errorf("invalid plugin name") } - if minRuntimeVersion := PLUGIN.RawGetString("minRuntimeVersion"); minRuntimeVersion.Type() != lua.LTNil { - source.MinRuntimeVersion = minRuntimeVersion.String() + + if source.Name == "" { + return nil, fmt.Errorf("no plugin name provided") } return source, nil } diff --git a/internal/plugin_test.go b/internal/plugin_test.go index 79aca6c7..7a6fb92d 100644 --- a/internal/plugin_test.go +++ b/internal/plugin_test.go @@ -1,16 +1,74 @@ package internal import ( + "strings" "testing" _ "embed" + + "github.com/version-fox/vfox/internal/logger" ) //go:embed testdata/plugins/java.lua var pluginContent string var pluginPath = "testdata/plugins/java.lua" +func setupSuite(tb testing.TB) func(tb testing.TB) { + logger.SetLevel(logger.DebugLevel) + + return func(tb testing.TB) { + logger.SetLevel(logger.InfoLevel) + } +} + func TestPlugin(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + t.Run("NewLuaPlugin", func(t *testing.T) { + manager := NewSdkManager() + plugin, err := NewLuaPlugin(pluginContent, pluginPath, manager) + if err != nil { + t.Fatal(err) + } + + if plugin == nil { + t.Fatalf("expected plugin to be set, got nil") + } + + if plugin.Filename != "java" { + t.Errorf("expected filename 'java', got '%s'", plugin.Filename) + } + + if plugin.Filepath != pluginPath { + t.Errorf("expected filepath '%s', got '%s'", pluginPath, plugin.Filepath) + } + + if plugin.Name != "java" { + t.Errorf("expected name 'java', got '%s'", plugin.Name) + } + + if plugin.Version != "0.0.1" { + t.Errorf("expected version '0.0.1', got '%s'", plugin.Version) + } + + if plugin.Description != "xxx" { + t.Errorf("expected description 'xxx', got '%s'", plugin.Description) + } + + if plugin.Author != "Lihan" { + t.Errorf("expected author 'Lihan', got '%s'", plugin.Author) + } + + if plugin.UpdateUrl != "{URL}/sdk.lua" { + t.Errorf("expected update url '{URL}/sdk.lua', got '%s'", plugin.UpdateUrl) + } + + if plugin.MinRuntimeVersion != "0.2.2" { + t.Errorf("expected min runtime version '0.2.2', got '%s'", plugin.MinRuntimeVersion) + } + }) + t.Run("Available", func(t *testing.T) { manager := NewSdkManager() plugin, err := NewLuaPlugin(pluginContent, pluginPath, manager) @@ -28,6 +86,98 @@ func TestPlugin(t *testing.T) { } }) + t.Run("PreInstall", func(t *testing.T) { + manager := NewSdkManager() + plugin, err := NewLuaPlugin(pluginContent, pluginPath, manager) + if err != nil { + t.Fatal(err) + } + + pkg, err := plugin.PreInstall(Version("9.0.0")) + if err != nil { + t.Fatal(err) + } + + Main := pkg.Main + + if Main.Version != "version" { + t.Errorf("expected version 'version', got '%s'", Main.Version) + } + + if Main.Path != "xxx" { + t.Errorf("expected path 'xxx', got '%s'", Main.Path) + } + + // checksum should be existed + if Main.Checksum == nil { + t.Errorf("expected checksum to be set, got nil") + } + + if Main.Checksum.Type != "sha256" { + t.Errorf("expected checksum type 'sha256', got '%s'", Main.Checksum.Type) + } + + if Main.Checksum.Value != "xxx" { + t.Errorf("expected checksum value 'xxx', got '%s'", Main.Checksum.Value) + } + + if len(pkg.Additions) != 1 { + t.Errorf("expected 1 addition, got %d", len(pkg.Additions)) + } + + addition := pkg.Additions[0] + + if addition.Path != "xxx" { + t.Errorf("expected path 'xxx', got '%s'", addition.Path) + } + + if addition.Checksum == nil { + t.Errorf("expected checksum to be set, got nil") + } + }) + + t.Run("EnvKeys", func(t *testing.T) { + manager := NewSdkManager() + + plugin, err := NewLuaPlugin(pluginContent, pluginPath, manager) + if err != nil { + t.Fatal(err) + } + + keys, err := plugin.EnvKeys(&Package{ + Main: &Info{ + Name: "java", + Version: "1.0.0", + Path: "/path/to/java", + Note: "xxxx", + }, + Additions: []*Info{ + { + Name: "sdk-name", + Version: "9.0.0", + Path: "/path/to/sdk", + Note: "xxxx", + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + javaHome := keys["JAVA_HOME"] + if *javaHome == "" { + t.Errorf("expected JAVA_HOME to be set, got '%s'", *javaHome) + } + path := keys["PATH"] + if *path == "" { + t.Errorf("expected PATH to be set, got '%s'", *path) + } + + if !strings.HasSuffix(*path, "/bin") { + t.Errorf("expected PATH to end with '/bin', got '%s'", *path) + } + }) + t.Run("PreUse", func(t *testing.T) { manager := NewSdkManager() diff --git a/internal/sdk.go b/internal/sdk.go index 58012b13..7a1c2b40 100644 --- a/internal/sdk.go +++ b/internal/sdk.go @@ -32,6 +32,7 @@ import ( "github.com/schollz/progressbar/v3" "github.com/version-fox/vfox/internal/env" + "github.com/version-fox/vfox/internal/logger" "github.com/version-fox/vfox/internal/shell" "github.com/pterm/pterm" @@ -230,6 +231,7 @@ func (b *Sdk) EnvKeys(version Version) (env.Envs, error) { func (b *Sdk) PreUse(version Version, scope UseScope) (Version, error) { if !b.Plugin.HasFunction("PreUse") { + logger.Debug("plugin does not have PreUse function") return version, nil } @@ -253,6 +255,8 @@ func (b *Sdk) Use(version Version, scope UseScope) error { scope = Global } + logger.Debugf("use sdk version: %s\n", string(version)) + version, err := b.PreUse(version, scope) if err != nil { return err diff --git a/internal/testdata/plugins/java.lua b/internal/testdata/plugins/java.lua index 78dd6e30..944b41c3 100644 --- a/internal/testdata/plugins/java.lua +++ b/internal/testdata/plugins/java.lua @@ -34,7 +34,7 @@ function PLUGIN:PreInstall(ctx) local runtimeVersion = ctx.runtimeVersion return { --- Version number - version = "xxx", + version = "version", --- remote URL or local file path [optional] url = "xxx", --- SHA256 checksum [optional]