From 80b334f867499c55cd721ec72639c07a4fc16984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Sat, 3 Apr 2021 18:32:27 +0700 Subject: [PATCH] test: add unit tests for trino.sqlalchemy --- setup.py | 2 +- tests/__init__.py | 0 tests/sqlalchemy/__init__.py | 0 tests/sqlalchemy/conftest.py | 34 ++++++++ tests/sqlalchemy/test_datatype_parse.py | 111 ++++++++++++++++++++++++ tests/sqlalchemy/test_datatype_split.py | 53 +++++++++++ 6 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 tests/__init__.py create mode 100644 tests/sqlalchemy/__init__.py create mode 100644 tests/sqlalchemy/conftest.py create mode 100644 tests/sqlalchemy/test_datatype_parse.py create mode 100644 tests/sqlalchemy/test_datatype_split.py diff --git a/setup.py b/setup.py index f2fce5a2..9176b303 100755 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ all_require = kerberos_require + sqlalchemy_require -tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "mock", "pytz", "flake8"] +tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "mock", "pytz", "flake8", "assertpy"] py27_require = ["ipaddress", "typing"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..0bba9362 --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,34 @@ +from assertpy import add_extension, assert_that +from sqlalchemy.sql.sqltypes import ARRAY + +from trino.sqlalchemy.datatype import SQLType, MAP, ROW + + +def assert_sqltype(this: SQLType, that: SQLType): + if isinstance(this, type): + this = this() + if isinstance(that, type): + that = that() + assert_that(type(this)).is_same_as(type(that)) + if isinstance(this, ARRAY): + assert_sqltype(this.item_type, that.item_type) + if this.dimensions is None or this.dimensions == 1: + assert_that(that.dimensions).is_in(None, 1) + else: + assert_that(this.dimensions).is_equal_to(this.dimensions) + elif isinstance(this, MAP): + assert_sqltype(this.key_type, that.key_type) + assert_sqltype(this.value_type, that.value_type) + elif isinstance(this, ROW): + assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types)) + for name, this_attr in this.attr_types.items(): + that_attr = this.attr_types[name] + assert_sqltype(this_attr, that_attr) + else: + assert_that(str(this)).is_equal_to(str(that)) + + +@add_extension +def is_sqltype(self, that): + this = self.val + assert_sqltype(this, that) diff --git a/tests/sqlalchemy/test_datatype_parse.py b/tests/sqlalchemy/test_datatype_parse.py new file mode 100644 index 00000000..154cb44a --- /dev/null +++ b/tests/sqlalchemy/test_datatype_parse.py @@ -0,0 +1,111 @@ +import pytest +from assertpy import assert_that +from sqlalchemy.sql.sqltypes import * +from sqlalchemy.sql.type_api import TypeEngine + +from trino.sqlalchemy import datatype +from trino.sqlalchemy.datatype import MAP, ROW + + +@pytest.mark.parametrize( + 'type_str, sql_type', + datatype._type_map.items(), + ids=datatype._type_map.keys() +) +def test_parse_simple_type(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + if not isinstance(actual_type, type): + actual_type = type(actual_type) + assert_that(actual_type).is_equal_to(sql_type) + + +parse_type_options_testcases = { + 'VARCHAR(10)': VARCHAR(10), + 'DECIMAL(20)': DECIMAL(20), + 'DECIMAL(20, 3)': DECIMAL(20, 3), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_type_options_testcases.items(), + ids=parse_type_options_testcases.keys() +) +def test_parse_type_options(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_array_testcases = { + 'array(integer)': ARRAY(INTEGER()), + 'array(varchar(10))': ARRAY(VARCHAR(10)), + 'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)), + 'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_array_testcases.items(), + ids=parse_array_testcases.keys() +) +def test_parse_array(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_map_testcases = { + 'map(char, integer)': MAP(CHAR(), INTEGER()), + 'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)), + 'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)), + 'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))), + 'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))), + 'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_map_testcases.items(), + ids=parse_map_testcases.keys() +) +def test_parse_map(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_row_testcases = { + 'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())), + 'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))), + 'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))': + ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_row_testcases.items(), + ids=parse_row_testcases.keys() +) +def test_parse_row(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_datetime_testcases = { + 'date': DATE(), + 'time': TIME(), + 'time with time zone': TIME(timezone=True), + 'timestamp': TIMESTAMP(), + 'timestamp with time zone': TIMESTAMP(timezone=True), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_datetime_testcases.items(), + ids=parse_datetime_testcases.keys() +) +def test_parse_datetime(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) diff --git a/tests/sqlalchemy/test_datatype_split.py b/tests/sqlalchemy/test_datatype_split.py new file mode 100644 index 00000000..e1ebdd08 --- /dev/null +++ b/tests/sqlalchemy/test_datatype_split.py @@ -0,0 +1,53 @@ +from typing import * + +import pytest +from assertpy import assert_that + +from trino.sqlalchemy import datatype + +split_string_testcases = { + '10': ['10'], + '10,3': ['10', '3'], + 'varchar': ['varchar'], + 'varchar,int': ['varchar', 'int'], + 'varchar,int,float': ['varchar', 'int', 'float'], + 'array(varchar)': ['array(varchar)'], + 'array(varchar),int': ['array(varchar)', 'int'], + 'array(varchar(20))': ['array(varchar(20))'], + 'array(varchar(20)),int': ['array(varchar(20))', 'int'], + 'array(varchar(20)),array(varchar(20))': ['array(varchar(20))', 'array(varchar(20))'], + 'map(varchar, integer),int': ['map(varchar, integer)', 'int'], + 'map(varchar(20), integer),int': ['map(varchar(20), integer)', 'int'], + 'map(varchar(20), varchar(20)),int': ['map(varchar(20), varchar(20))', 'int'], + 'map(varchar(20), varchar(20)),array(varchar)': ['map(varchar(20), varchar(20))', 'array(varchar)'], + 'row(first_name varchar(20), last_name varchar(20)),int': + ['row(first_name varchar(20), last_name varchar(20))', 'int'], +} + + +@pytest.mark.parametrize( + 'input_string, output_strings', + split_string_testcases.items(), + ids=split_string_testcases.keys() +) +def test_split_string(input_string: str, output_strings: List[str]): + actual = list(datatype.split(input_string)) + assert_that(actual).is_equal_to(output_strings) + + +split_delimiter_testcases = [ + ('first,second', ',', ['first', 'second']), + ('first second', ' ', ['first', 'second']), + ('first|second', '|', ['first', 'second']), + ('first,second third', ',', ['first', 'second third']), + ('first,second third', ' ', ['first,second', 'third']), +] + + +@pytest.mark.parametrize( + 'input_string, delimiter, output_strings', + split_delimiter_testcases, +) +def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]): + actual = list(datatype.split(input_string, delimiter=delimiter)) + assert_that(actual).is_equal_to(output_strings)