Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BI-5526: Add regexp_extract_all function #485

Merged
merged 12 commits into from
Jun 26, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@
),
]
),
# regexp_extract_all
base.FuncRegexpExtractAll(
variants=[
V(D.CLICKHOUSE, lambda text, pattern: sa.func.extractAll(sa.func.assumeNotNull(text), pattern)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://clickhouse.com/docs/en/sql-reference/functions/functions-for-nulls#assumenotnull
kinda scared of "Returned values: An arbitrary value, if the input value is NULL". can we wrap the case of text being NULL?
also, definitely need a test for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this functions wraps null text to ""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test

]
),
# regexp_extract_nth
base.FuncRegexpExtractNth(
variants=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@


class TestStringFunctionClickHouse_21_8(ClickHouse_21_8TestBase, StringFunctionClickHouseTestSuite):
pass
supports_regex_extract_all = True
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as sa_postgresql
from sqlalchemy.ext.compiler import compiles
import sqlalchemy.sql.functions

from dl_formula.definitions.base import (
TranslationVariant,
Expand All @@ -16,6 +20,33 @@
VW = TranslationVariantWrapped.make


class RegexpMatchesInBrackets(sqlalchemy.sql.functions.GenericFunction):
inherit_cache = True


@compiles(RegexpMatchesInBrackets)
def compile_regexp_matches_in_brackets(element: Any, compiler: sa.sql.compiler.SQLCompiler, **kw: Any) -> str:
# Need this to perform get_item (array[i]) after
return "(REGEXP_MATCHES(%s))" % compiler.process(element.clauses, **kw)


def regexp_matches_in_brackets(text: str, pattern: str) -> sa.sql.expression.TypeCoerce:
regexp_matches_subquery = sa.select(
sa.type_coerce(
sa.func.RegexpMatchesInBrackets(text, pattern, "g"),
sa_postgresql.ARRAY(sa.String),
)[
1
].label("strs")
)
return sa.type_coerce(
sa.select(
[sa.func.array_agg(regexp_matches_subquery.c.strs)],
).select_from(regexp_matches_subquery),
sa_postgresql.ARRAY(sa.String),
)


DEFINITIONS_STRING = [
# ascii
base.FuncAscii.for_dialect(D.POSTGRESQL),
Expand Down Expand Up @@ -119,6 +150,15 @@
),
]
),
# regexp_extract_all
base.FuncRegexpExtractAll(
variants=[
V(
D.POSTGRESQL,
regexp_matches_in_brackets,
)
]
),
# regexp_extract_nth
base.FuncRegexpExtractNth(
variants=[
Expand All @@ -140,7 +180,10 @@
# regexp_replace
base.FuncRegexpReplace(
variants=[
V(D.POSTGRESQL, lambda text, patt, repl: sa.func.REGEXP_REPLACE(text, patt, repl, "g")),
V(
D.POSTGRESQL,
lambda text, patt, repl: sa.func.REGEXP_REPLACE(text, patt, repl, "g"),
),
]
),
# replace
Expand Down Expand Up @@ -170,7 +213,8 @@
base.FuncSplit3(
variants=[
V(
D.POSTGRESQL, lambda text, delim, ind: sa.func.SPLIT_PART(text, delim, sa.cast(ind, sa.INTEGER))
D.POSTGRESQL,
lambda text, delim, ind: sa.func.SPLIT_PART(text, delim, sa.cast(ind, sa.INTEGER)),
), # FIXME: does not work with negative indices
]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class TestStringFunctionPostgreSQL_9_3(PostgreSQL_9_3TestBase, DefaultStringFunctionFormulaConnectorTestSuite):
pass
supports_regex_extract_all = True


class TestStringFunctionPostgreSQL_9_4(PostgreSQL_9_4TestBase, DefaultStringFunctionFormulaConnectorTestSuite):
pass
supports_regex_extract_all = True
12 changes: 12 additions & 0 deletions lib/dl_formula/dl_formula/definitions/functions_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,16 @@ class FuncRegexpExtract(StringFunction):
return_type = Fixed(DataType.STRING)


class FuncRegexpExtractAll(StringFunction):
name = "regexp_extract_all"
arg_cnt = 2
arg_names = ["string", "pattern"]
argument_types = [
ArgTypeSequence([DataType.STRING, DataType.CONST_STRING]),
]
return_type = Fixed(DataType.ARRAY_STR)


# In MS SQL Server regular expressions can only be used in user-defined functions


Expand Down Expand Up @@ -751,6 +761,8 @@ class FuncUtf8(StringFunction):
FuncLtrim,
# regexp_extract
FuncRegexpExtract,
# regexp_extract_all
FuncRegexpExtractAll,
# regexp_extract_nth
FuncRegexpExtractNth,
# regexp_match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DefaultStringFunctionFormulaConnectorTestSuite(FormulaConnectorTestBase):
supports_trimming_funcs: ClassVar[bool] = True
supports_regex_extract: ClassVar[bool] = True
supports_regex_extract_nth: ClassVar[bool] = True
supports_regex_extract_all: ClassVar[bool] = False
supports_regex_replace: ClassVar[bool] = True
supports_regex_match: ClassVar[bool] = True
supports_split_3: ClassVar[bool] = True
Expand Down Expand Up @@ -126,6 +127,14 @@ def test_regexp_extract(self, dbe: DbEvaluator, data_table: sa.Table) -> None:
assert to_str(dbe.eval('REGEXP_EXTRACT("Карл у Клары украл кораллы", ".лары")')) == "Клары"
assert to_str(dbe.eval('REGEXP_EXTRACT([str_null_value], "or..")', from_=data_table)) in (None, "")

def test_regexp_extract_all(self, dbe: DbEvaluator, data_table: sa.Table) -> None:
if not self.supports_regex_extract_all:
pytest.skip()
assert dbe.eval("REGEXP_EXTRACT_ALL('100-200, 300-400', '(\\d+)-(\\d+)')") in ("['100','300']", ["100", "300"])
assert dbe.eval("REGEXP_EXTRACT_ALL('нет_цифр', '(\\d+)-(\\d+)')") in ("[]", None)
assert to_str(dbe.eval('REGEXP_EXTRACT_ALL([str_null_value], "or..")', from_=data_table)) in (None, "[]")
assert dbe.eval("REGEXP_EXTRACT_ALL('1а2б3в4', '\\d+([а-я]*)\\d+')") in ("['а','в']", ["а", "в"])

def test_regexp_extract_nth(self, dbe: DbEvaluator, data_table: sa.Table) -> None:
if not self.supports_regex_extract_nth:
pytest.skip()
Expand Down
Loading