Skip to content

Commit

Permalink
- fix parsing macros in queries (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasklacz authored Oct 18, 2022
1 parent c6f16cd commit 53084cb
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 15 deletions.
96 changes: 85 additions & 11 deletions macros.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ type Macros map[string]MacroFunc
// Default time filter for SQL based on the query time range.
// It requires one argument, the time column to filter.
// Example:
// $__timeFilter(time) => "time BETWEEN '2006-01-02T15:04:05Z07:00' AND '2006-01-02T15:04:05Z07:00'"
//
// $__timeFilter(time) => "time BETWEEN '2006-01-02T15:04:05Z07:00' AND '2006-01-02T15:04:05Z07:00'"
func macroTimeFilter(query *Query, args []string) (string, error) {
if len(args) != 1 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand All @@ -42,7 +43,8 @@ func macroTimeFilter(query *Query, args []string) (string, error) {
// Default time filter for SQL based on the starting query time range.
// It requires one argument, the time column to filter.
// Example:
// $__timeFrom(time) => "time > '2006-01-02T15:04:05Z07:00'"
//
// $__timeFrom(time) => "time > '2006-01-02T15:04:05Z07:00'"
func macroTimeFrom(query *Query, args []string) (string, error) {
if len(args) != 1 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand All @@ -55,7 +57,8 @@ func macroTimeFrom(query *Query, args []string) (string, error) {
// Default time filter for SQL based on the ending query time range.
// It requires one argument, the time column to filter.
// Example:
// $__timeTo(time) => "time < '2006-01-02T15:04:05Z07:00'"
//
// $__timeTo(time) => "time < '2006-01-02T15:04:05Z07:00'"
func macroTimeTo(query *Query, args []string) (string, error) {
if len(args) != 1 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand All @@ -68,7 +71,8 @@ func macroTimeTo(query *Query, args []string) (string, error) {
// This basic example is meant to be customized with more complex periods.
// It requires two arguments, the column to filter and the period.
// Example:
// $__timeTo(time, month) => "datepart(year, time), datepart(month, time)'"
//
// $__timeTo(time, month) => "datepart(year, time), datepart(month, time)'"
func macroTimeGroup(query *Query, args []string) (string, error) {
if len(args) != 2 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand Down Expand Up @@ -97,14 +101,16 @@ func macroTimeGroup(query *Query, args []string) (string, error) {

// Default macro to return the query table name.
// Example:
// $__table => "my_table"
//
// $__table => "my_table"
func macroTable(query *Query, args []string) (string, error) {
return query.Table, nil
}

// Default macro to return the query column name.
// Example:
// $__column => "my_col"
//
// $__column => "my_col"
func macroColumn(query *Query, args []string) (string, error) {
return query.Column, nil
}
Expand All @@ -127,8 +133,73 @@ func trimAll(s []string) []string {
return r
}

func getMacroRegex(name string) string {
return fmt.Sprintf("\\$__%s\\b(?:\\((.*?\\)?)\\))?", name)
var pair = map[rune]rune{')': '('}

// getMacroMatches extracts macro strings with their respective arguments from the sql input given
// It manually parses the string to find the closing parenthesis of the macro (because regex has no memory)
func getMacroMatches(input string, name string) ([][]string, error) {
macroName := fmt.Sprintf("\\$__%s\\b", name)
matchedMacros := [][]string{}
rgx, err := regexp.Compile(macroName)

if err != nil {
return nil, err
}

// get all matching macro instances
matched := rgx.FindAllStringIndex(input, -1)

if matched == nil {
return nil, nil
}

for matchedIndex := 0; matchedIndex < len(matched); matchedIndex++ {
var macroEnd = 0
var argStart = 0
macroStart := matched[matchedIndex][0]
inputCopy := input[macroStart:]
cache := make([]rune, 0)

// find the opening and closing arguments brackets
for idx, r := range inputCopy {
if len(cache) == 0 && macroEnd > 0 {
break
}
switch r {
case '(':
cache = append(cache, r)
if argStart == 0 {
argStart = idx + 1
}
case ')':
l := len(cache)
if l == 0 {
break
}
cache = cache[:l-1]
macroEnd = idx + 1
default:
continue
}
}

// macroEnd equals to 0 means there are no parentheses, so just set it
// to the end of the regex match
if macroEnd == 0 {
macroEnd = matched[matchedIndex][1] - macroStart
}
macroString := inputCopy[0:macroEnd]
macroMatch := []string{macroString}

args := ""
// if opening parenthesis was found, extract contents as arguments
if argStart > 0 {
args = inputCopy[argStart : macroEnd-1]
}
macroMatch = append(macroMatch, args)
matchedMacros = append(matchedMacros, macroMatch)
}
return matchedMacros, nil
}

// Interpolate returns an interpolated query string given a backend.DataQuery
Expand All @@ -141,8 +212,10 @@ func Interpolate(driver Driver, query *Query) (string, error) {
}
}
rawSQL := query.RawSQL

for key, macro := range macros {
matches, err := getMatches(key, rawSQL)

if err != nil {
return rawSQL, err
}
Expand All @@ -165,16 +238,17 @@ func Interpolate(driver Driver, query *Query) (string, error) {

rawSQL = strings.Replace(rawSQL, match[0], res, -1)
}

}

return rawSQL, nil
}

func getMatches(macroName, rawSQL string) ([][]string, error) {
rgx, err := regexp.Compile(getMacroRegex(macroName))
parsedInput, err := getMacroMatches(rawSQL, macroName)

if err != nil {
return nil, err
}
return rgx.FindAllStringSubmatch(rawSQL, -1), nil

return parsedInput, err
}
7 changes: 3 additions & 4 deletions macros_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ func TestInterpolate(t *testing.T) {
{input: "select * from $__foo", output: "select * from bar", name: "macro without paranthesis"},
{input: "select * from $__params()", output: "select * from bar", name: "macro without params"},
{input: "select * from $__params(hello)", output: "select * from bar_hello", name: "with param"},
{input: "select * from $__params(h)", output: "select * from bar_h", name: "with short param"},
{input: "select * from $__params(hello) AND $__params(hello)", output: "select * from bar_hello AND bar_hello", name: "same macro multiple times with same param"},
{input: "(select * from $__params(hello) AND $__params(hello))", output: "(select * from bar_hello AND bar_hello)", name: "same macro multiple times with same param and additional parentheses"},
{input: "select * from $__params(hello) AND $__params(world)", output: "select * from bar_hello AND bar_world", name: "same macro multiple times with different param"},
{input: "select * from $__params(world) AND $__foo() AND $__params(hello)", output: "select * from bar_world AND bar AND bar_hello", name: "different macros with different params"},
{input: "select * from foo where $__timeFilter(time)", output: "select * from foo where time >= '0001-01-01T00:00:00Z' AND time <= '0001-01-01T00:00:00Z'", name: "default timeFilter"},
{input: "select * from foo where $__timeFilter(cast(sth as timestamp))", output: "select * from foo where cast(sth as timestamp) >= '0001-01-01T00:00:00Z' AND cast(sth as timestamp) <= '0001-01-01T00:00:00Z'", name: "default timeFilter"},
{input: "select * from foo where $__timeFilter(cast(sth as timestamp) )", output: "select * from foo where cast(sth as timestamp) >= '0001-01-01T00:00:00Z' AND cast(sth as timestamp) <= '0001-01-01T00:00:00Z'", name: "default timeFilter with empty spaces"},
{input: "select * from foo where $__timeTo(time)", output: "select * from foo where time <= '0001-01-01T00:00:00Z'", name: "default timeTo macro"},
{input: "select * from foo where $__timeFrom(time)", output: "select * from foo where time >= '0001-01-01T00:00:00Z'", name: "default timeFrom macro"},
{input: "select * from foo where $__timeFrom(cast(sth as timestamp))", output: "select * from foo where cast(sth as timestamp) >= '0001-01-01T00:00:00Z'", name: "default timeFrom macro"},
Expand All @@ -82,10 +85,6 @@ func TestInterpolate(t *testing.T) {
}
}

func TestGetMacroRegex_returns_composed_regular_expression(t *testing.T) {
assert.Equal(t, `\$__some_string\b(?:\((.*?\)?)\))?`, getMacroRegex("some_string"))
}

func TestGetMatches(t *testing.T) {
t.Run("FindAllStringSubmatch returns DefaultMacros", func(t *testing.T) {
for macroName := range DefaultMacros {
Expand Down

0 comments on commit 53084cb

Please sign in to comment.