diff --git a/resolver.go b/resolver.go index 21b26bb..300a08c 100644 --- a/resolver.go +++ b/resolver.go @@ -180,7 +180,7 @@ func (r *resolver) resolve(f ResolveCallback) error { // check if all ResolveValue fields were resolved. if _, ok := fieldValue.(ResolveValue); ok { if rv, ok := ctx.resolved[fieldName]; ok { - if rgen, ok := fieldValue.(ResolveGenerate); ok { + if rgen, ok := fieldValue.(*ResolveGenerate); ok { rv, err = r.parseResolvedValue(rgen.Type, rv) if err != nil { return errors.Join(ResolveCallbackError, diff --git a/resolver_test.go b/resolver_test.go index 8e0981d..b562cd2 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -5,6 +5,7 @@ import ( "testing" "testing/fstest" + "github.com/google/uuid" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) @@ -330,3 +331,71 @@ func TestResolveReturnResolved(t *testing.T) { assert.Assert(t, is.Len(retData.Tables["tags"].Rows, 1)) assert.Equal(t, 935, retData.Tables["tags"].Rows[0].Fields["tag_id"]) } + +func TestResolveGeneratedWithType(t *testing.T) { + provider := NewFSFileProvider(fstest.MapFS{ + "users.dbf.yaml": &fstest.MapFile{ + Data: []byte(`tags: + rows: + - tag_id: !dbfexpr generated:int + tag_name: "All" +`), + }, + }) + + data, err := Load(provider) + assert.NilError(t, err) + + rowCount := map[string]int{} + + resolved, err := Resolve(data, func(ctx ResolveContext, fields map[string]any) error { + rowCount[ctx.TableID()]++ + assert.DeepEqual(t, &ResolveGenerate{Type: "int"}, fields["tag_id"]) + ctx.ResolveField("tag_id", "45") + return nil + }) + assert.NilError(t, err) + + assert.DeepEqual(t, map[string]int{ + "tags": 1, + }, rowCount) + + assert.DeepEqual(t, map[string]any{ + "tag_id": int64(45), + "tag_name": "All", + }, resolved.Tables["tags"].Rows[0].Fields) +} + +func TestResolveGeneratedWithUUIDType(t *testing.T) { + provider := NewFSFileProvider(fstest.MapFS{ + "users.dbf.yaml": &fstest.MapFile{ + Data: []byte(`tags: + rows: + - tag_id: !dbfexpr generated:uuid + tag_name: "All" +`), + }, + }) + + data, err := Load(provider) + assert.NilError(t, err) + + rowCount := map[string]int{} + + resolved, err := Resolve(data, func(ctx ResolveContext, fields map[string]any) error { + rowCount[ctx.TableID()]++ + assert.DeepEqual(t, &ResolveGenerate{Type: "uuid"}, fields["tag_id"]) + ctx.ResolveField("tag_id", "305e1f2b-dfea-4939-862a-069abace0a40") + return nil + }, WithResolvedValueParser(&ResolvedValueParserUUID{})) + assert.NilError(t, err) + + assert.DeepEqual(t, map[string]int{ + "tags": 1, + }, rowCount) + + assert.DeepEqual(t, map[string]any{ + "tag_id": uuid.MustParse("305e1f2b-dfea-4939-862a-069abace0a40"), + "tag_name": "All", + }, resolved.Tables["tags"].Rows[0].Fields) +} diff --git a/resolver_value_parser_default.go b/resolver_value_parser_default.go new file mode 100644 index 0000000..e2dfbb6 --- /dev/null +++ b/resolver_value_parser_default.go @@ -0,0 +1,25 @@ +package debefix + +import ( + "fmt" + + "github.com/google/uuid" +) + +// ResolvedValueParserUUID is a [ResolvedValueParser] to parse "uuid" type to [uuid.UUID]. +type ResolvedValueParserUUID struct { +} + +func (r ResolvedValueParserUUID) Parse(typ string, value any) (bool, any, error) { + if typ != "uuid" { + return false, nil, nil + } + + switch vv := value.(type) { + case uuid.UUID: + return true, vv, nil + default: + v, err := uuid.Parse(fmt.Sprint(value)) + return true, v, err + } +}