diff --git a/cmd/schemagen/keyspace.tmpl b/cmd/schemagen/keyspace.tmpl index cddb23d..22f18b8 100644 --- a/cmd/schemagen/keyspace.tmpl +++ b/cmd/schemagen/keyspace.tmpl @@ -41,6 +41,7 @@ var ( {{- $type_name := .Name | camelize}} {{- $field_types := .FieldTypes}} type {{$type_name}}UserType struct { + gocqlx.UDT {{- range $index, $element := .FieldNames}} {{- $type := index $field_types $index}} {{. | camelize}} {{typeToString $type | mapScyllaToGoType}} diff --git a/cmd/schemagen/map_types.go b/cmd/schemagen/map_types.go index 4f59554..ebc04d5 100644 --- a/cmd/schemagen/map_types.go +++ b/cmd/schemagen/map_types.go @@ -33,54 +33,168 @@ var types = map[string]string{ "varint": "int64", } -func mapScyllaToGoType(s string) string { - frozenRegex := regexp.MustCompile(`frozen<([a-z]*)>`) - match := frozenRegex.FindAllStringSubmatch(s, -1) - if match != nil { - s = match[0][1] - } +type tokenStyle int - mapRegex := regexp.MustCompile(`map<([a-z]*), ([a-z]*)>`) - setRegex := regexp.MustCompile(`set<([a-z]*)>`) - listRegex := regexp.MustCompile(`list<([a-z]*)>`) - tupleRegex := regexp.MustCompile(`tuple<(?:([a-z]*),? ?)*>`) - match = mapRegex.FindAllStringSubmatch(s, -1) - if match != nil { - key := match[0][1] - value := match[0][2] +const ( + FrozenToken tokenStyle = iota + MapToken + SetToken + ListToken + TupleToken + CommaToken + AnchorToken +) - return "map[" + types[key] + "]" + types[value] - } +type token struct { + style tokenStyle + count int +} - match = setRegex.FindAllStringSubmatch(s, -1) - if match != nil { - key := match[0][1] +func (ts tokenStyle) String() string { + switch ts { + case FrozenToken: + return "frozen" + case MapToken: + return "map" + case SetToken: + return "set" + case ListToken: + return "list" + case TupleToken: + return "tuple" + case CommaToken: + return "comma" + case AnchorToken: + return "anchor" + default: + return "unknown" + } +} - return "[]" + types[key] +func (ts tokenStyle) format(values ...string) (string, error) { + l := len(values) + switch ts { + case FrozenToken: + if l != 1 { + return "", fmt.Errorf("Invalid values count=%d for %s", l, ts) + } + return values[0], nil + case MapToken: + if l != 2 { + return "", fmt.Errorf("Invalid values count=%d for %s", l, ts) + } + return "map[" + values[0] + "]" + values[1], nil + case SetToken: + if l != 1 { + return "", fmt.Errorf("Invalid values count=%d for %s", l, ts) + } + return "[]" + values[0], nil + case ListToken: + if l != 1 { + return "", fmt.Errorf("Invalid values count=%d for %s", l, ts) + } + return "[]" + values[0], nil + case TupleToken: + if l == 0 { + return "", fmt.Errorf("Invalid values count=%d for %s", l, ts) + } + tupleStr := "struct {\n" + for i, v := range values { + tupleStr = tupleStr + "\t\tField" + strconv.Itoa(i+1) + " " + v + "\n" + } + tupleStr = tupleStr + "\t}" + return tupleStr, nil + default: + return "", fmt.Errorf("Invalid token type: %s", ts) } +} - match = listRegex.FindAllStringSubmatch(s, -1) - if match != nil { - key := match[0][1] +func parseToken(s string) (*token, string) { + regexps := make(map[tokenStyle]*regexp.Regexp) + regexps[FrozenToken] = regexp.MustCompile(`^frozen<(.*)$`) + regexps[MapToken] = regexp.MustCompile(`^map<(.*)$`) + regexps[SetToken] = regexp.MustCompile(`^set<(.*)`) + regexps[ListToken] = regexp.MustCompile(`^list<(.*)$`) + regexps[TupleToken] = regexp.MustCompile(`^tuple<(.*)$`) + regexps[CommaToken] = regexp.MustCompile(`^,(.*)$`) + regexps[AnchorToken] = regexp.MustCompile(`^>(.*)$`) - return "[]" + types[key] + for tokenStyle, tokenRegexp := range regexps { + match := tokenRegexp.FindStringSubmatch(s) + if match != nil { + return &token{tokenStyle, 0}, match[1] + } } - match = tupleRegex.FindAllStringSubmatch(s, -1) - if match != nil { - tuple := match[0][0] - subStr := tuple[6 : len(tuple)-1] - types := strings.Split(subStr, ", ") + return nil, s +} + +func parsePolishNotation(s string) (*Stack, error) { + tokenStack := NewStack() + notation := NewStack() + left := s - typeStr := "struct {\n" - for i, t := range types { - typeStr = typeStr + "\t\tField" + strconv.Itoa(i+1) + " " + mapScyllaToGoType(t) + "\n" + for { + left = strings.TrimSpace(left) + if len(left) == 0 { + break } - typeStr = typeStr + "\t}" - return typeStr + var t *token + t, left = parseToken(left) + if t != nil { + switch t.style { + case CommaToken: + top, err := tokenStack.top() + if err != nil { + return nil, err + } + v, ok := top.(*token) + if !ok { + return nil, fmt.Errorf("Invalid type: %T", v) + } + v.count += 1 + case AnchorToken: + prev, err := tokenStack.pop() + if err != nil { + return nil, err + } + v, ok := prev.(*token) + if !ok { + return nil, fmt.Errorf("Invalid type: %T", v) + } + v.count += 1 + notation.push(prev) + default: + tokenStack.push(t) + } + } else { + itemRegex := regexp.MustCompile(`([^,>]+?)[,>]`) + match := itemRegex.FindStringSubmatchIndex(left) + if match != nil { + item := strings.TrimSpace(left[match[2]:match[3]]) + notation.push(item) + + left = left[match[3]:] + } else { + notation.push(strings.TrimSpace(left)) + left = "" + } + } } + for { + t, err := tokenStack.pop() + if err != nil { + break + } + notation.push(t) + } + + return notation, nil +} + +func mapToGoType(s string) string { t, exists := types[s] if exists { return t @@ -89,6 +203,65 @@ func mapScyllaToGoType(s string) string { return camelize(s) + "UserType" } +func calcPolishNotation(notation *Stack) (string, error) { + outputStack := NewStack() + for _, item := range notation.toSlice() { + switch v := item.(type) { + case string: + outputStack.push(mapToGoType(v)) + case *token: + datas, err := outputStack.popSlice(v.count) + if err != nil { + return "", err + } + var values []string + for _, data := range datas { + value, ok := data.(string) + if !ok { + return "", fmt.Errorf("Invalid output value: %v", value) + } + values = append(values, value) + } + fmtStr, err := v.style.format(values...) + if err != nil { + return "", err + } + outputStack.push(fmtStr) + default: + return "", fmt.Errorf("Invalid type: %T", v) + } + } + + if outputStack.count() != 1 { + return "", fmt.Errorf("Invalid polish notation") + } + + result, err := outputStack.pop() + if err != nil { + return "", nil + } + + if resultStr, ok := result.(string); !ok { + return "", fmt.Errorf("Invalid result value type: %T", result) + } else { + return resultStr, nil + } +} + +func mapScyllaToGoType(s string) string { + notation, err := parsePolishNotation(s) + if err != nil { + panic(fmt.Sprintf("Failed to parse polish notation for %s: %v", s, err)) + } + + goTypeStr, err := calcPolishNotation(notation) + if err != nil { + panic(fmt.Sprintf("Failed to calculate polish notation: %v", err)) + } + + return goTypeStr +} + func typeToString(t interface{}) string { tType := fmt.Sprintf("%T", t) switch tType { diff --git a/cmd/schemagen/map_types_test.go b/cmd/schemagen/map_types_test.go index d217418..c9b84b8 100644 --- a/cmd/schemagen/map_types_test.go +++ b/cmd/schemagen/map_types_test.go @@ -38,6 +38,8 @@ func TestMapScyllaToGoType(t *testing.T) { {"list", "[]int32"}, {"set", "[]int32"}, {"tuple", "struct {\n\t\tField1 bool\n\t\tField2 int32\n\t\tField3 int16\n\t}"}, + {"list>>>>", "[]map[string]map[int32]float32"}, + {"frozen>, frozen>, frozen>>>", "struct {\n\t\tField1 map[int32]string\n\t\tField2 []string\n\t\tField3 []int32\n\t}"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { diff --git a/cmd/schemagen/schemagen.go b/cmd/schemagen/schemagen.go index 9cff543..e7f205c 100644 --- a/cmd/schemagen/schemagen.go +++ b/cmd/schemagen/schemagen.go @@ -96,6 +96,9 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) { } } } + if len(md.UserTypes) != 0 && !existsInSlice(imports, "github.com/scylladb/gocqlx/v2") { + imports = append(imports, "github.com/scylladb/gocqlx/v2") + } buf := &bytes.Buffer{} data := map[string]interface{}{ diff --git a/cmd/schemagen/stack.go b/cmd/schemagen/stack.go new file mode 100644 index 0000000..a34348e --- /dev/null +++ b/cmd/schemagen/stack.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" +) + +type Stack struct { + data []interface{} +} + +func NewStack() *Stack { + return &Stack{} +} + +func (s *Stack) push(v interface{}) { + s.data = append(s.data, v) +} + +func (s *Stack) pop() (interface{}, error) { + l := len(s.data) + if l == 0 { + return nil, fmt.Errorf("Empty stack") + } + + v := s.data[l-1] + s.data = s.data[:l-1] + + return v, nil +} + +func (s *Stack) popSlice(count int) ([]interface{}, error) { + if count == 0 { + return nil, nil + } + + l := len(s.data) + if l < count { + return nil, fmt.Errorf("Too small stack") + } + + v := s.data[l-count:] + s.data = s.data[:l-count] + + return v, nil +} + +func (s *Stack) top() (interface{}, error) { + l := len(s.data) + if l == 0 { + return nil, fmt.Errorf("Empty stack") + } + + return s.data[l-1], nil +} + +func (s *Stack) bottom() (interface{}, error) { + l := len(s.data) + if l == 0 { + return nil, fmt.Errorf("Empty stack") + } + + return s.data[0], nil +} + +func (s *Stack) toSlice() []interface{} { + return s.data +} + +func (s *Stack) count() int { + return len(s.data) +}