Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more translations for Snowflake #860

Merged
merged 11 commits into from
Aug 12, 2022
227 changes: 220 additions & 7 deletions R/backend-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,199 @@ NULL

#' @export
sql_translation.Snowflake <- function(con) {
sql_variant(
sql_translator(
.parent = base_odbc_scalar,
log10 = function(x) sql_expr(log(10, !!x))
),
base_agg,
base_win
sql_variant(
sql_translator(
.parent = base_odbc_scalar,
log10 = function(x) sql_expr(log(10, !!x)),
grepl = snowflake_grepl,
round = snowflake_round,
paste = snowflake_paste(" "),
paste0 = snowflake_paste(""),
str_c = function(..., sep = "", collapse = NULL) {
if (!is.null(collapse)) {
cli_abort(c(
"{.arg collapse} not supported in DB translation of {.fn str_c}.",
i = "Please use {.fn str_flatten} instead."
))
}
sql_call2("CONCAT_WS", sep, ...)
},
str_locate = function(string, pattern) {
sql_expr(POSITION(!!pattern, !!string))
},
# REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which
# str_detect does not. Left- and right-pad `pattern` with .* to get
# str_detect-like behavior
str_detect = function(string, pattern, negate = FALSE) {
if (isTRUE(negate)) {
sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")))
} else {
sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))
}
},
# On Snowflake, REGEXP_REPLACE is used like this:
# REGEXP_REPLACE( <subject> , <pattern> [ , <replacement> ,
# <position> , <occurrence> , <parameters> ] )
# so we must set <occurrence> to 1 if not replacing all. See:
# https://docs.snowflake.com/en/sql-reference/functions/regexp_replace.html
# Also, Snowflake needs backslashes escaped, so we must increase the
# level of escaping by 1
str_replace = function(string, pattern, replacement) {
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
sql_expr(regexp_replace(!!string, !!pattern, !!replacement, 1, 1))
},
str_replace_all = function(string, pattern, replacement) {
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
sql_expr(regexp_replace(!!string, !!pattern, !!replacement))
},
str_remove = function(string, pattern) {
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
sql_expr(regexp_replace(!!string, !!pattern, "", 1, 1))
},
str_remove_all = function(string, pattern) {
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
sql_expr(regexp_replace(!!string, !!pattern))
},
str_trim = function(string) {
sql_expr(trim(!!string))
},
str_squish = function(string) {
sql_expr(regexp_replace(trim(!!string), "\\\\s+", " "))
},


# lubridate functions
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html
day = function(x) {
sql_expr(EXTRACT(DAY %FROM% !!x))
},
mday = function(x) {
sql_expr(EXTRACT(DAY %FROM% !!x))
},
wday = function(x, label = FALSE, abbr = TRUE, week_start = NULL) {
if (!label) {
week_start <- week_start %||% getOption("lubridate.week.start", 7)
offset <- as.integer(7 - week_start)
sql_expr(EXTRACT("dayofweek", DATE(!!x) + !!offset) + 1)
} else if (label && !abbr) {
sql_expr(
DECODE(
EXTRACT("dayofweek", !!x),
1, "Monday",
2, "Tuesday",
3, "Wednesday",
4, "Thursday",
5, "Friday",
6, "Saturday",
0, "Sunday"
)
)
} else if (label && abbr) {
sql_expr(DAYNAME(!!x))
} else {
abort("Unrecognized arguments to `wday`")
}
},
yday = function(x) sql_expr(EXTRACT("dayofyear", !!x)),
week = function(x) {
sql_expr(FLOOR((EXTRACT("dayofyear", !!x) - 1L) / 7L) + 1L)
hadley marked this conversation as resolved.
Show resolved Hide resolved
},
isoweek = function(x) sql_expr(EXTRACT("weekiso", !!x)),
month = function(x, label = FALSE, abbr = TRUE) {
if (!label) {
sql_expr(EXTRACT("month", !!x))
} else {
if (abbr) {
sql_expr(MONTHNAME(!!x))
} else {
sql_expr(
DECODE(
EXTRACT("month", !!x),
1, "January",
2, "February",
3, "March",
4, "April",
5, "May",
6, "June",
7, "July",
8, "August",
9, "September",
10, "October",
11, "November",
12, "December"
)
)
}
}
},
quarter = function(x, with_year = FALSE, fiscal_start = 1) {
if (fiscal_start != 1) {
abort("`fiscal_start` is not supported in Snowflake translation. Must be 1.")
}

if (with_year) {
sql_expr((EXTRACT("year", !!x) || "." || EXTRACT("quarter", !!x)))
} else {
sql_expr(EXTRACT("quarter", !!x))
}
},
isoyear = function(x) {
sql_expr(EXTRACT("year", !!x))
},
seconds = function(x) {
build_sql("INTERVAL '", x, " second'")
},
minutes = function(x) {
build_sql("INTERVAL '", x, " minute'")
},
hours = function(x) {
build_sql("INTERVAL '", x, " hour'")
},
days = function(x) {
build_sql("INTERVAL '", x, " day'")
},
weeks = function(x) {
build_sql("INTERVAL '", x, " week'")
},
months = function(x) {
build_sql("INTERVAL '", x, " month'")
},
years = function(x) {
build_sql("INTERVAL '", x, " year'")
},
# https://docs.snowflake.com/en/sql-reference/functions/date_trunc.html
floor_date = function(x, unit = "seconds") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default is plural, below is always singular. I see this is also the case for the postgres backend...
Either you simply fix the default to "seconds" or if you want to be more user friendly you could also allow the plural version. Up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, maybe I'll follow up with an update to the postgres backend. Since the singular and plural versions are all supported on both Snowflake (here) and lubridate, I'm expanding the arg list to include both.

unit <- arg_match(unit,
c("second", "minute", "hour", "day", "week", "month", "quarter", "year",
"seconds", "minutes", "hours", "days", "weeks", "months", "quarters", "years")
)
sql_expr(DATE_TRUNC(!!unit, !!x))
}
),
sql_translator(
.parent = base_agg,
cor = sql_aggregate_2("CORR"),
cov = sql_aggregate_2("COVAR_SAMP"),
all = sql_aggregate("BOOLAND_AGG", "all"),
any = sql_aggregate("BOOLOR_AGG", "any"),
sd = sql_aggregate("STDDEV", "sd"),
str_flatten = function(x, collapse) sql_expr(LISTAGG(!!x, !!collapse))
),
sql_translator(
.parent = base_win,
cor = win_aggregate_2("CORR"),
cov = win_aggregate_2("COVAR_SAMP"),
all = win_aggregate("BOOLAND_AGG"),
any = win_aggregate("BOOLOR_AGG"),
sd = win_aggregate("STDDEV"),
str_flatten = function(x, collapse) {
win_over(
sql_expr(LISTAGG(!!x, !!collapse)),
partition = win_current_group(),
order = win_current_order()
)
}
)
)
}

Expand All @@ -37,3 +223,30 @@ simulate_snowflake <- function() simulate_dbi("Snowflake")
# Link to full list: https://docs.snowflake.com/en/sql-reference/sql-all.html
#' @export
sql_table_analyze.Snowflake <- function(con, table, ...) {}

snowflake_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed = FALSE, useBytes = FALSE) {
# https://docs.snowflake.com/en/sql-reference/functions/regexp.html
if (perl || fixed || useBytes || ignore.case) {
cli_abort("{.arg {c('perl', 'fixed', 'useBytes', 'ignore.case')}} parameters are unsupported.")
}
# REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which
# grepl does not. Left- and right-pad `pattern` with .* to get grepl-like
# behavior
sql_expr(((!!x)) %REGEXP% (".*" || (!!pattern) || ".*"))
}
snowflake_round <- function(x, digits = 0L) {
digits <- as.integer(digits)
sql_expr(round(((!!x)) %::% FLOAT, !!digits))
}

# On Snowflake, CONCAT_WS is null if any of its arguments are null. Paste
fh-afrachioni marked this conversation as resolved.
Show resolved Hide resolved
# is implemented here to avoid this behavior.
snowflake_paste <- function(default_sep) {
function(..., sep = default_sep, collapse = NULL) {
check_collapse(collapse)
sql_call2(
"ARRAY_TO_STRING",
sql_call2("ARRAY_CONSTRUCT_COMPACT", ...), sep
)
}
}
96 changes: 96 additions & 0 deletions tests/testthat/test-test-backend-snowflake.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,100 @@
test_that("custom scalar translated correctly", {
local_con(simulate_snowflake())
expect_equal(translate_sql(log10(x)), sql("LOG(10.0, `x`)"))
expect_equal(translate_sql(round(x, digits = 1.1)), sql("ROUND((`x`) :: FLOAT, 1)"))
expect_equal(translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || 'exp' || '.*')"))
expect_error(translate_sql(grepl("exp", x, ignore.case = TRUE)),
"`perl`, `fixed`, `useBytes`, and `ignore.case` parameters are unsupported.")
})

test_that("pasting translated correctly", {
local_con(simulate_snowflake())

expect_equal(translate_sql(paste(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), ' ')"))
expect_equal(translate_sql(paste0(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), '')"))
expect_equal(translate_sql(str_c(x, y)), sql("CONCAT_WS('', `x`, `y`)"))
expect_equal(translate_sql(str_c(x, y, sep = '|')), sql("CONCAT_WS('|', `x`, `y`)"))

expect_error(translate_sql(paste0(x, collapse = "")), "`collapse` not supported")

expect_error(translate_sql(str_flatten(x)), 'argument "collapse" is missing, with no default')
expect_equal(translate_sql(str_flatten(x, collapse = "|"), window = TRUE), sql("LISTAGG(`x`, '|') OVER ()"))
expect_equal(translate_sql(str_flatten(x, collapse = "|"), window = FALSE), sql("LISTAGG(`x`, '|')"))
})

test_that("custom stringr functions translated correctly", {
local_con(simulate_snowflake())

expect_equal(translate_sql(str_locate(x, y)), sql("POSITION(`y`, `x`)"))
expect_equal(translate_sql(str_detect(x, y)), sql("(`x`) REGEXP ('.*' || `y` || '.*')"))
expect_equal(translate_sql(str_detect(x, y, negate = TRUE)), sql("!((`x`) REGEXP ('.*' || `y` || '.*'))"))
expect_equal(translate_sql(str_replace(x, y, z)), sql("REGEXP_REPLACE(`x`, `y`, `z`, 1.0, 1.0)"))
expect_equal(translate_sql(str_replace(x, "\\d", z)), sql("REGEXP_REPLACE(`x`, '\\\\d', `z`, 1.0, 1.0)"))
expect_equal(translate_sql(str_replace_all(x, y, z)), sql("REGEXP_REPLACE(`x`, `y`, `z`)"))
expect_equal(translate_sql(str_squish(x)), sql("REGEXP_REPLACE(TRIM(`x`), '\\\\s+', ' ')"))
expect_equal(translate_sql(str_remove(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '', 1.0, 1.0)"))
expect_equal(translate_sql(str_remove_all(x, y)), sql("REGEXP_REPLACE(`x`, `y`)"))
expect_equal(translate_sql(str_trim(x)), sql("TRIM(`x`)"))
})

test_that("aggregates are translated correctly", {
local_con(simulate_snowflake())

expect_equal(translate_sql(cor(x, y), window = FALSE), sql("CORR(`x`, `y`)"))
expect_equal(translate_sql(cor(x, y), window = TRUE), sql("CORR(`x`, `y`) OVER ()"))

expect_equal(translate_sql(cov(x, y), window = FALSE), sql("COVAR_SAMP(`x`, `y`)"))
expect_equal(translate_sql(cov(x, y), window = TRUE), sql("COVAR_SAMP(`x`, `y`) OVER ()"))

expect_equal(translate_sql(all(x, na.rm = TRUE), window = FALSE), sql("BOOLAND_AGG(`x`)"))
expect_equal(translate_sql(all(x, na.rm = TRUE), window = TRUE), sql("BOOLAND_AGG(`x`) OVER ()"))

expect_equal(translate_sql(any(x, na.rm = TRUE), window = FALSE), sql("BOOLOR_AGG(`x`)"))
expect_equal(translate_sql(any(x, na.rm = TRUE), window = TRUE), sql("BOOLOR_AGG(`x`) OVER ()"))

expect_equal(translate_sql(sd(x, na.rm = TRUE), window = FALSE), sql("STDDEV(`x`)"))
expect_equal(translate_sql(sd(x, na.rm = TRUE), window = TRUE), sql("STDDEV(`x`) OVER ()"))
})

test_that("snowflake mimics two argument log", {
local_con(simulate_snowflake())

expect_equal(translate_sql(log(x)), sql('LN(`x`)'))
expect_equal(translate_sql(log(x, 10)), sql('LOG(10.0, `x`)'))
expect_equal(translate_sql(log(x, 10L)), sql('LOG(10, `x`)'))
})

test_that("custom lubridate functions translated correctly", {
local_con(simulate_snowflake())

expect_equal(translate_sql(day(x)), sql("EXTRACT(DAY FROM `x`)"))
expect_equal(translate_sql(mday(x)), sql("EXTRACT(DAY FROM `x`)"))
expect_equal(translate_sql(yday(x)), sql("EXTRACT('dayofyear', `x`)"))
expect_equal(translate_sql(wday(x)), sql("EXTRACT('dayofweek', DATE(`x`) + 0) + 1.0"))
expect_equal(translate_sql(wday(x, label = TRUE)), sql("DAYNAME(`x`)"))
expect_equal(translate_sql(wday(x, label = TRUE, abbr = FALSE)), sql(
"DECODE(EXTRACT('dayofweek', `x`), 1.0, 'Monday', 2.0, 'Tuesday', 3.0, 'Wednesday', 4.0, 'Thursday', 5.0, 'Friday', 6.0, 'Saturday', 0.0, 'Sunday')"
))
expect_equal(translate_sql(week(x)), sql("FLOOR((EXTRACT('dayofyear', `x`) - 1) / 7) + 1"))
expect_equal(translate_sql(isoweek(x)), sql("EXTRACT('weekiso', `x`)"))
expect_equal(translate_sql(month(x)), sql("EXTRACT('month', `x`)"))
expect_equal(translate_sql(month(x, label = TRUE)), sql("MONTHNAME(`x`)"))
expect_equal(translate_sql(month(x, label = TRUE, abbr = FALSE)), sql(
"DECODE(EXTRACT('month', `x`), 1.0, 'January', 2.0, 'February', 3.0, 'March', 4.0, 'April', 5.0, 'May', 6.0, 'June', 7.0, 'July', 8.0, 'August', 9.0, 'September', 10.0, 'October', 11.0, 'November', 12.0, 'December')"
))
expect_equal(translate_sql(quarter(x)), sql("EXTRACT('quarter', `x`)"))
expect_equal(translate_sql(quarter(x, with_year = TRUE)), sql("(EXTRACT('year', `x`) || '.' || EXTRACT('quarter', `x`))"))
expect_error(translate_sql(quarter(x, fiscal_start = 2)))
expect_equal(translate_sql(isoyear(x)), sql("EXTRACT('year', `x`)"))

expect_equal(translate_sql(seconds(x)), sql("INTERVAL '`x` second'"))
expect_equal(translate_sql(minutes(x)), sql("INTERVAL '`x` minute'"))
expect_equal(translate_sql(hours(x)), sql("INTERVAL '`x` hour'"))
expect_equal(translate_sql(days(x)), sql("INTERVAL '`x` day'"))
expect_equal(translate_sql(weeks(x)), sql("INTERVAL '`x` week'"))
expect_equal(translate_sql(months(x)), sql("INTERVAL '`x` month'"))
expect_equal(translate_sql(years(x)), sql("INTERVAL '`x` year'"))

expect_equal(translate_sql(floor_date(x, 'month')), sql("DATE_TRUNC('month', `x`)"))
expect_equal(translate_sql(floor_date(x, 'week')), sql("DATE_TRUNC('week', `x`)"))
})