diff --git a/.github/workflows/test_suite.yml b/.github/workflows/test_suite.yml index dff7941f..5d4e2c30 100644 --- a/.github/workflows/test_suite.yml +++ b/.github/workflows/test_suite.yml @@ -87,6 +87,30 @@ jobs: GOOGLE_CLOUD_PROJECT: appdev-soda-spanner-staging SQLALCHEMY_SILENCE_UBER_WARNING: 1 + compliance_tests_20: + runs-on: ubuntu-latest + + services: + emulator-0: + image: gcr.io/cloud-spanner-emulator/emulator:latest + ports: + - 9010:9010 + + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install nox + run: python -m pip install nox + - name: Run Compliance Tests + run: nox -s compliance_test_20 + env: + SPANNER_EMULATOR_HOST: localhost:9010 + GOOGLE_CLOUD_PROJECT: appdev-soda-spanner-staging + migration_tests: runs-on: ubuntu-latest diff --git a/google/cloud/sqlalchemy_spanner/requirements.py b/google/cloud/sqlalchemy_spanner/requirements.py index ce5e8d53..791a6b10 100644 --- a/google/cloud/sqlalchemy_spanner/requirements.py +++ b/google/cloud/sqlalchemy_spanner/requirements.py @@ -14,6 +14,7 @@ from sqlalchemy.testing import exclusions from sqlalchemy.testing.requirements import SuiteRequirements +from sqlalchemy.testing.exclusions import against, only_on class Requirements(SuiteRequirements): # pragma: no cover @@ -45,6 +46,15 @@ def foreign_key_constraint_name_reflection(self): def schema_reflection(self): return exclusions.open() + @property + def array_type(self): + return only_on([lambda config: against(config, "postgresql")]) + + @property + def uuid_data_type(self): + """Return databases that support the UUID datatype.""" + return only_on(("postgresql >= 8.3", "mariadb >= 10.7.0")) + @property def implicitly_named_constraints(self): return exclusions.open() @@ -99,3 +109,8 @@ def precision_numerics_enotation_large(self): """target backend supports Decimal() objects using E notation to represent very large values.""" return exclusions.open() + + @property + def views(self): + """Target database must support VIEWs.""" + return exclusions.open() diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 12b0e7bd..e7a2ba17 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -22,7 +22,8 @@ alter_table, format_type, ) -from sqlalchemy import ForeignKeyConstraint, types, util +from sqlalchemy.exc import NoSuchTableError +from sqlalchemy import ForeignKeyConstraint, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.event import listens_for @@ -43,19 +44,29 @@ from google.cloud.spanner_v1.data_types import JsonObject from google.cloud import spanner_dbapi from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call +import sqlalchemy + +USING_SQLACLCHEMY_20 = False +if sqlalchemy.__version__.split(".")[0] == "2": + USING_SQLACLCHEMY_20 = True + +if USING_SQLACLCHEMY_20: + from sqlalchemy.engine.reflection import ObjectKind @listens_for(Pool, "reset") -def reset_connection(dbapi_conn, connection_record): +def reset_connection(dbapi_conn, connection_record, reset_state=None): """An event of returning a connection back to a pool.""" - if isinstance(dbapi_conn.connection, spanner_dbapi.Connection): - if dbapi_conn.connection.inside_transaction: - dbapi_conn.connection.rollback() - - dbapi_conn.connection.staleness = None - dbapi_conn.connection.read_only = False + if hasattr(dbapi_conn, "connection"): + dbapi_conn = dbapi_conn.connection + if isinstance(dbapi_conn, spanner_dbapi.Connection): + if dbapi_conn.inside_transaction: + dbapi_conn.rollback() + + dbapi_conn.staleness = None + dbapi_conn.read_only = False else: - dbapi_conn.connection.rollback() + dbapi_conn.rollback() # register a method to get a single value of a JSON object @@ -186,7 +197,7 @@ def _requires_quotes(self, value): return ( lc_value in self.reserved_words or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) + or not self.legal_characters.match(str(value)) or (lc_value != value) ) @@ -206,7 +217,7 @@ def get_from_hint_text(self, _, text): """ return text - def visit_empty_set_expr(self, type_): + def visit_empty_set_expr(self, type_, **kw): """Return an empty set expression of the given type. Args: @@ -365,7 +376,7 @@ def visit_computed_column(self, generated, **kw): ) return text - def visit_drop_table(self, drop_table): + def visit_drop_table(self, drop_table, **kw): """ Cloud Spanner doesn't drop tables which have indexes or foreign key constraints. This method builds several DDL @@ -396,7 +407,7 @@ def visit_drop_table(self, drop_table): return indexes + constrs + str(drop_table) - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): """Build primary key definition. Primary key in Spanner is defined outside of a table columns definition, see: @@ -406,7 +417,7 @@ def visit_primary_key_constraint(self, constraint): """ return None - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): """Unique constraints in Spanner are defined with indexes: https://cloud.google.com/spanner/docs/secondary-indexes#unique-indexes @@ -510,7 +521,8 @@ class SpannerDialect(DefaultDialect): positional = False paramstyle = "format" encoding = "utf-8" - max_identifier_length = 128 + max_identifier_length = 256 + _legacy_binary_type_literal_encoding = "utf-8" execute_sequence_format = list @@ -540,6 +552,14 @@ def dbapi(cls): """ return spanner_dbapi + @classmethod + def import_dbapi(cls): + """A pointer to the Cloud Spanner DB API package. + + Used to initiate connections to the Cloud Spanner databases. + """ + return spanner_dbapi + @property def default_isolation_level(self): """Default isolation level name. @@ -566,6 +586,58 @@ def _get_default_schema_name(self, _): """ return "" + def _get_table_type_query(self, kind, append_query): + """ + Generates WHERE condition for Kind of Object. + Spanner supports Table and View. + """ + if not USING_SQLACLCHEMY_20: + return "" + + kind = ObjectKind.TABLE if kind is None else kind + if kind == ObjectKind.MATERIALIZED_VIEW: + raise NotImplementedError("Spanner does not support MATERIALIZED VIEWS") + switch_case = { + ObjectKind.ANY: ["BASE TABLE", "VIEW"], + ObjectKind.TABLE: ["BASE TABLE"], + ObjectKind.VIEW: ["VIEW"], + ObjectKind.ANY_VIEW: ["VIEW"], + } + + table_type_query = "" + for table_type in switch_case[kind]: + query = f"t.table_type = '{table_type}'" + if table_type_query != "": + table_type_query = table_type_query + " OR " + query + else: + table_type_query = query + + table_type_query = "(" + table_type_query + ") " + if append_query: + table_type_query = table_type_query + " AND " + return table_type_query + + def _get_table_filter_query( + self, filter_names, info_schema_table, append_query=False + ): + """ + Generates WHERE query for tables or views for which + information is reflected. + """ + table_filter_query = "" + if filter_names is not None: + for table_name in filter_names: + query = f"{info_schema_table}.table_name = '{table_name}'" + if table_filter_query != "": + table_filter_query = table_filter_query + " OR " + query + else: + table_filter_query = query + table_filter_query = "(" + table_filter_query + ") " + if append_query: + table_filter_query = table_filter_query + " AND " + + return table_filter_query + def create_connect_args(self, url): """Parse connection args from the given URL. @@ -589,6 +661,19 @@ def create_connect_args(self, url): @engine_to_connection def get_view_names(self, connection, schema=None, **kw): + """ + Gets a list of view names. + + The method is used by SQLAlchemy introspection systems. + + Args: + connection (sqlalchemy.engine.base.Connection): + SQLAlchemy connection or engine object. + schema (str): Optional. Schema name + + Returns: + list: List of view names. + """ sql = """ SELECT table_name FROM information_schema.views @@ -606,56 +691,137 @@ def get_view_names(self, connection, schema=None, **kw): return all_views @engine_to_connection - def get_columns(self, connection, table_name, schema=None, **kw): - """Get the table columns description. + def get_view_definition(self, connection, view_name, schema=None, **kw): + """ + Gets definition of a particular view. The method is used by SQLAlchemy introspection systems. Args: connection (sqlalchemy.engine.base.Connection): SQLAlchemy connection or engine object. - table_name (str): Name of the table to introspect. + view_name (str): Name of the view. schema (str): Optional. Schema name Returns: - list: The table every column dict-like description. + str: Definition of view. """ sql = """ -SELECT column_name, spanner_type, is_nullable, generation_expression -FROM information_schema.columns -WHERE - table_catalog = '' - AND table_schema = '' - AND table_name = '{}' -ORDER BY - table_catalog, - table_schema, - table_name, - ordinal_position -""".format( - table_name + SELECT view_definition + FROM information_schema.views + WHERE TABLE_SCHEMA='{schema_name}' AND TABLE_NAME='{view_name}' + """.format( + schema_name=schema or "", view_name=view_name + ) + + with connection.connection.database.snapshot() as snap: + rows = list(snap.execute_sql(sql)) + if rows == []: + raise NoSuchTableError(f"{schema if schema else ''}.{view_name}") + result = rows[0][0] + + return result + + @engine_to_connection + def get_multi_columns( + self, connection, schema=None, filter_names=None, scope=None, kind=None, **kw + ): + """ + Return information about columns in all objects in the given + schema. + + The method is used by SQLAlchemy introspection systems. + + Args: + connection (sqlalchemy.engine.base.Connection): + SQLAlchemy connection or engine object. + schema (str): Optional. Schema name + filter_names (Sequence[str]): Optional. Optionally return information + only for the objects listed here. + scope (sqlalchemy.engine.reflection.ObjectScope): Optional. Specifies + if columns of default, temporary or any tables + should be reflected. Spanner does not support temporary. + kind (sqlalchemy.engine.reflection.ObjectKind): Optional. Specifies the + type of objects to reflect. + + Returns: + dictionary: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a database column. + The schema is ``None`` if no schema is provided. + """ + table_filter_query = self._get_table_filter_query(filter_names, "col", True) + schema_filter_query = " col.table_schema = '{schema}' AND ".format( + schema=schema or "" ) + table_type_query = self._get_table_type_query(kind, True) - cols_desc = [] + sql = """ + SELECT col.table_schema, col.table_name, col.column_name, + col.spanner_type, col.is_nullable, col.generation_expression + FROM information_schema.columns as col + JOIN information_schema.tables AS t + ON col.table_name = t.table_name + WHERE + {table_filter_query} + {table_type_query} + {schema_filter_query} + col.table_catalog = '' + ORDER BY + col.table_catalog, + col.table_schema, + col.table_name, + col.ordinal_position + """.format( + table_filter_query=table_filter_query, + table_type_query=table_type_query, + schema_filter_query=schema_filter_query, + ) with connection.connection.database.snapshot() as snap: - columns = snap.execute_sql(sql) + columns = list(snap.execute_sql(sql)) + result_dict = {} for col in columns: - col_desc = { - "name": col[0], - "type": self._designate_type(col[1]), - "nullable": col[2] == "YES", + column_info = { + "name": col[2], + "type": self._designate_type(col[3]), + "nullable": col[4] == "YES", "default": None, } - if col[3] is not None: - col_desc["computed"] = { + if col[5] is not None: + column_info["computed"] = { "persisted": True, - "sqltext": col[3], + "sqltext": col[5], } - cols_desc.append(col_desc) + col[0] = col[0] or None + table_info = result_dict.get((col[0], col[1]), []) + table_info.append(column_info) + result_dict[(col[0], col[1])] = table_info + + return result_dict + + @engine_to_connection + def get_columns(self, connection, table_name, schema=None, **kw): + """Get the table columns description. + + The method is used by SQLAlchemy introspection systems. - return cols_desc + Args: + connection (sqlalchemy.engine.base.Connection): + SQLAlchemy connection or engine object. + table_name (str): Name of the table to introspect. + schema (str): Optional. Schema name + + Returns: + list: The table every column dict-like description. + """ + kind = None if not USING_SQLACLCHEMY_20 else ObjectKind.ANY + dict = self.get_multi_columns( + connection, schema=schema, filter_names=[table_name], kind=kind + ) + schema = schema or None + return dict.get((schema, table_name), []) def _designate_type(self, str_repr): """ @@ -684,6 +850,87 @@ def _designate_type(self, str_repr): else: return _type_map[str_repr] + @engine_to_connection + def get_multi_indexes( + self, connection, schema=None, filter_names=None, scope=None, kind=None, **kw + ): + """ + Return information about indexes in all objects + in the given schema. + + The method is used by SQLAlchemy introspection systems. + + Args: + connection (sqlalchemy.engine.base.Connection): + SQLAlchemy connection or engine object. + schema (str): Optional. Schema name. + filter_names (Sequence[str]): Optional. Optionally return information + only for the objects listed here. + scope (sqlalchemy.engine.reflection.ObjectScope): Optional. Specifies + if columns of default, temporary or any tables + should be reflected. Spanner does not support temporary. + kind (sqlalchemy.engine.reflection.ObjectKind): Optional. Specifies the + type of objects to reflect. + + Returns: + dictionary: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an index. + The schema is ``None`` if no schema is provided. + """ + table_filter_query = self._get_table_filter_query(filter_names, "i", True) + schema_filter_query = " i.table_schema = '{schema}' AND ".format( + schema=schema or "" + ) + table_type_query = self._get_table_type_query(kind, True) + + sql = """ + SELECT + i.table_schema, + i.table_name, + i.index_name, + ARRAY_AGG(ic.column_name), + i.is_unique, + ARRAY_AGG(ic.column_ordering) + FROM information_schema.indexes as i + JOIN information_schema.index_columns AS ic + ON ic.index_name = i.index_name AND ic.table_name = i.table_name + JOIN information_schema.tables AS t + ON i.table_name = t.table_name + WHERE + {table_filter_query} + {table_type_query} + {schema_filter_query} + i.index_type != 'PRIMARY_KEY' + AND i.spanner_is_managed = FALSE + GROUP BY i.table_schema, i.table_name, i.index_name, i.is_unique + ORDER BY i.index_name + """.format( + table_filter_query=table_filter_query, + table_type_query=table_type_query, + schema_filter_query=schema_filter_query, + ) + + with connection.connection.database.snapshot() as snap: + rows = list(snap.execute_sql(sql)) + result_dict = {} + + for row in rows: + index_info = { + "name": row[2], + "column_names": row[3], + "unique": row[4], + "column_sorting": { + col: order for col, order in zip(row[3], row[5]) + }, + } + row[0] = row[0] or None + table_info = result_dict.get((row[0], row[1]), []) + table_info.append(index_info) + result_dict[(row[0], row[1])] = table_info + + return result_dict + @engine_to_connection def get_indexes(self, connection, table_name, schema=None, **kw): """Get the table indexes. @@ -699,41 +946,75 @@ def get_indexes(self, connection, table_name, schema=None, **kw): Returns: list: List with indexes description. """ + kind = None if not USING_SQLACLCHEMY_20 else ObjectKind.ANY + dict = self.get_multi_indexes( + connection, schema=schema, filter_names=[table_name], kind=kind + ) + schema = schema or None + return dict.get((schema, table_name), []) + + @engine_to_connection + def get_multi_pk_constraint( + self, connection, schema=None, filter_names=None, scope=None, kind=None, **kw + ): + """ + Return information about primary key constraints in + all tables in the given schema. + + The method is used by SQLAlchemy introspection systems. + + Args: + connection (sqlalchemy.engine.base.Connection): + SQLAlchemy connection or engine object. + schema (str): Optional. Schema name + filter_names (Sequence[str]): Optional. Optionally return information + only for the objects listed here. + scope (sqlalchemy.engine.reflection.ObjectScope): Optional. Specifies + if columns of default, temporary or any tables + should be reflected. Spanner does not support temporary. + kind (sqlalchemy.engine.reflection.ObjectKind): Optional. Specifies the + type of objects to reflect. + + Returns: + dictionary: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a primary key constraint. + The schema is ``None`` if no schema is provided. + """ + table_filter_query = self._get_table_filter_query(filter_names, "tc", True) + schema_filter_query = " tc.table_schema = '{schema}' AND ".format( + schema=schema or "" + ) + table_type_query = self._get_table_type_query(kind, True) + sql = """ -SELECT - i.index_name, - ARRAY_AGG(ic.column_name), - i.is_unique, - ARRAY_AGG(ic.column_ordering) -FROM information_schema.indexes as i -JOIN information_schema.index_columns AS ic - ON ic.index_name = i.index_name AND ic.table_name = i.table_name -WHERE - i.table_name="{table_name}" - AND i.index_type != 'PRIMARY_KEY' - AND i.spanner_is_managed = FALSE -GROUP BY i.index_name, i.is_unique -ORDER BY i.index_name -""".format( - table_name=table_name + SELECT tc.table_schema, tc.table_name, ccu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc + JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu + ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME + JOIN information_schema.tables AS t + ON tc.table_name = t.table_name + WHERE {table_filter_query} {table_type_query} + {schema_filter_query} tc.CONSTRAINT_TYPE = "PRIMARY KEY" + """.format( + table_filter_query=table_filter_query, + table_type_query=table_type_query, + schema_filter_query=schema_filter_query, ) - ind_desc = [] with connection.connection.database.snapshot() as snap: - rows = snap.execute_sql(sql) + rows = list(snap.execute_sql(sql)) + result_dict = {} for row in rows: - ind_desc.append( - { - "name": row[0], - "column_names": row[1], - "unique": row[2], - "column_sorting": { - col: order for col, order in zip(row[1], row[3]) - }, - } + row[0] = row[0] or None + table_info = result_dict.get( + (row[0], row[1]), {"constrained_columns": []} ) - return ind_desc + table_info["constrained_columns"].append(row[2]) + result_dict[(row[0], row[1])] = table_info + + return result_dict @engine_to_connection def get_pk_constraint(self, connection, table_name, schema=None, **kw): @@ -750,24 +1031,12 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): Returns: dict: Dict with the primary key constraint description. """ - sql = """ -SELECT ccu.COLUMN_NAME -FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc -JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu - ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME -WHERE tc.TABLE_NAME="{table_name}" AND tc.CONSTRAINT_TYPE = "PRIMARY KEY" -""".format( - table_name=table_name + kind = None if not USING_SQLACLCHEMY_20 else ObjectKind.ANY + dict = self.get_multi_pk_constraint( + connection, schema=schema, filter_names=[table_name], kind=kind ) - - cols = [] - with connection.connection.database.snapshot() as snap: - rows = snap.execute_sql(sql) - - for row in rows: - cols.append(row[0]) - - return {"constrained_columns": cols} + schema = schema or None + return dict.get((schema, table_name), []) @engine_to_connection def get_schema_names(self, connection, **kw): @@ -792,51 +1061,79 @@ def get_schema_names(self, connection, **kw): return schemas @engine_to_connection - def get_foreign_keys(self, connection, table_name, schema=None, **kw): - """Get the table foreign key constraints. + def get_multi_foreign_keys( + self, connection, schema=None, filter_names=None, scope=None, kind=None, **kw + ): + """ + Return information about foreign_keys in all tables + in the given schema. The method is used by SQLAlchemy introspection systems. Args: connection (sqlalchemy.engine.base.Connection): SQLAlchemy connection or engine object. - table_name (str): Name of the table to introspect. schema (str): Optional. Schema name + filter_names (Sequence[str]): Optional. Optionally return information + only for the objects listed here. + scope (sqlalchemy.engine.reflection.ObjectScope): Optional. Specifies + if columns of default, temporary or any tables + should be reflected. Spanner does not support temporary. + kind (sqlalchemy.engine.reflection.ObjectKind): Optional. Specifies the + type of objects to reflect. Returns: - list: Dicts, each of which describes a foreign key constraint. + dictionary: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing + a foreign key definition. + The schema is ``None`` if no schema is provided. """ - sql = """ -SELECT - tc.constraint_name, - ctu.table_name, - ctu.table_schema, - ARRAY_AGG(DISTINCT ccu.column_name), - ARRAY_AGG( - DISTINCT CONCAT( - CAST(kcu.ordinal_position AS STRING), - '_____', - kcu.column_name + table_filter_query = self._get_table_filter_query(filter_names, "tc", True) + schema_filter_query = " tc.table_schema = '{schema}' AND".format( + schema=schema or "" ) - ) -FROM information_schema.table_constraints AS tc -JOIN information_schema.constraint_column_usage AS ccu - ON ccu.constraint_name = tc.constraint_name -JOIN information_schema.constraint_table_usage AS ctu - ON ctu.constraint_name = tc.constraint_name -JOIN information_schema.key_column_usage AS kcu - ON kcu.constraint_name = tc.constraint_name -WHERE - tc.table_name="{table_name}" - AND tc.constraint_type = "FOREIGN KEY" -GROUP BY tc.constraint_name, ctu.table_name, ctu.table_schema -""".format( - table_name=table_name + table_type_query = self._get_table_type_query(kind, True) + + sql = """ + SELECT + tc.table_schema, + tc.table_name, + tc.constraint_name, + ctu.table_name, + ctu.table_schema, + ARRAY_AGG(DISTINCT ccu.column_name), + ARRAY_AGG( + DISTINCT CONCAT( + CAST(kcu.ordinal_position AS STRING), + '_____', + kcu.column_name + ) + ) + FROM information_schema.table_constraints AS tc + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + JOIN information_schema.constraint_table_usage AS ctu + ON ctu.constraint_name = tc.constraint_name + JOIN information_schema.key_column_usage AS kcu + ON kcu.constraint_name = tc.constraint_name + JOIN information_schema.tables AS t + ON tc.table_name = t.table_name + WHERE + {table_filter_query} + {table_type_query} + {schema_filter_query} + tc.constraint_type = "FOREIGN KEY" + GROUP BY tc.table_name, tc.table_schema, tc.constraint_name, + ctu.table_name, ctu.table_schema + """.format( + table_filter_query=table_filter_query, + table_type_query=table_type_query, + schema_filter_query=schema_filter_query, ) - keys = [] with connection.connection.database.snapshot() as snap: - rows = snap.execute_sql(sql) + rows = list(snap.execute_sql(sql)) + result_dict = {} for row in rows: # Due to Spanner limitations, arrays order is not guaranteed during @@ -851,20 +1148,45 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): # # The solution seem a bit clumsy, and should be improved as soon as a # better approach found. - for index, value in enumerate(sorted(row[4])): - row[4][index] = value.split("_____")[1] - - keys.append( - { - "name": row[0], - "referred_table": row[1], - "referred_schema": row[2] or None, - "referred_columns": row[3], - "constrained_columns": row[4], - } - ) + row[0] = row[0] or None + table_info = result_dict.get((row[0], row[1]), []) + for index, value in enumerate(sorted(row[6])): + row[6][index] = value.split("_____")[1] + + fk_info = { + "name": row[2], + "referred_table": row[3], + "referred_schema": row[4] or None, + "referred_columns": row[5], + "constrained_columns": row[6], + } + + table_info.append(fk_info) + result_dict[(row[0], row[1])] = table_info + + return result_dict + + @engine_to_connection + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """Get the table foreign key constraints. + + The method is used by SQLAlchemy introspection systems. - return keys + Args: + connection (sqlalchemy.engine.base.Connection): + SQLAlchemy connection or engine object. + table_name (str): Name of the table to introspect. + schema (str): Optional. Schema name + + Returns: + list: Dicts, each of which describes a foreign key constraint. + """ + kind = None if not USING_SQLACLCHEMY_20 else ObjectKind.ANY + dict = self.get_multi_foreign_keys( + connection, schema=schema, filter_names=[table_name], kind=kind + ) + schema = schema or None + return dict.get((schema, table_name), []) @engine_to_connection def get_table_names(self, connection, schema=None, **kw): @@ -883,9 +1205,9 @@ def get_table_names(self, connection, schema=None, **kw): sql = """ SELECT table_name FROM information_schema.tables -WHERE table_schema = '{}' +WHERE table_type = 'BASE TABLE' AND table_schema = '{schema}' """.format( - schema or "" + schema=schema or "" ) table_names = [] @@ -937,7 +1259,7 @@ def get_unique_constraints(self, connection, table_name, schema=None, **kw): return cols @engine_to_connection - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): """Check if the given table exists. The method is used by SQLAlchemy introspection systems. diff --git a/noxfile.py b/noxfile.py index 74059744..394269df 100644 --- a/noxfile.py +++ b/noxfile.py @@ -165,6 +165,7 @@ def compliance_test_13(session): "--cov-fail-under=0", "--asyncio-mode=auto", "test/test_suite_13.py", + *session.posargs, ) @@ -193,7 +194,6 @@ def compliance_test_14(session): session.install("-e", ".[tracing]") session.run("pip", "install", "sqlalchemy>=1.4,<2.0", "--force-reinstall") session.run("python", "create_test_database.py") - session.run( "py.test", "--cov=google.cloud.sqlalchemy_spanner", @@ -204,6 +204,45 @@ def compliance_test_14(session): "--cov-fail-under=0", "--asyncio-mode=auto", "test/test_suite_14.py", + *session.posargs, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def compliance_test_20(session): + """Run SQLAlchemy dialect compliance test suite.""" + + # Sanity check: Only run tests if the environment variable is set. + if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") and not os.environ.get( + "SPANNER_EMULATOR_HOST", "" + ): + session.skip( + "Credentials or emulator host must be set via environment variable" + ) + + session.install( + "pytest", + "pytest-cov", + "pytest-asyncio", + ) + + session.install("mock") + session.install("-e", ".[tracing]") + session.run("python", "create_test_database.py") + + session.install("sqlalchemy>=2.0") + + session.run( + "py.test", + "--cov=google.cloud.sqlalchemy_spanner", + "--cov=test", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + "--asyncio-mode=auto", + "test/test_suite_20.py", + *session.posargs, ) diff --git a/test/test_suite_13.py b/test/test_suite_13.py index e519dd4b..a5b2e3bb 100644 --- a/test/test_suite_13.py +++ b/test/test_suite_13.py @@ -526,6 +526,34 @@ class UnicodeTextTest(UnicodeFixtureTest, _UnicodeTextTest): class ComponentReflectionTest(_ComponentReflectionTest): + @classmethod + def define_views(cls, metadata, schema): + table_info = { + "users": ["user_id", "test1", "test2"], + "email_addresses": ["address_id", "remote_user_id", "email_address"], + } + if testing.requires.self_referential_foreign_keys.enabled: + table_info["users"] = table_info["users"] + ["parent_user_id"] + for table_name in ("users", "email_addresses"): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + "_v" + columns = "" + for column in table_info[table_name]: + stmt = table_name + "." + column + " AS " + column + if columns: + columns = columns + ", " + stmt + else: + columns = stmt + query = f"""CREATE VIEW {view_name} + SQL SECURITY INVOKER + AS SELECT {columns} + FROM {fullname}""" + + event.listen(metadata, "after_create", DDL(query)) + event.listen(metadata, "before_drop", DDL("DROP VIEW %s" % view_name)) + @classmethod def define_reflected_tables(cls, metadata, schema): if schema: @@ -665,11 +693,24 @@ def define_reflected_tables(cls, metadata, schema): sqlalchemy.Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) sqlalchemy.Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) - if testing.requires.view_column_reflection.enabled: + if testing.requires.view_column_reflection.enabled and not bool( + os.environ.get("SPANNER_EMULATOR_HOST") + ): cls.define_views(metadata, schema) if not schema and testing.requires.temp_table_reflection.enabled: cls.define_temp_tables(metadata) + def _test_get_columns(self, schema=None, table_type="table"): + if table_type == "view" and bool(os.environ.get("SPANNER_EMULATOR_HOST")): + pytest.skip("View tables not supported on emulator") + super()._test_get_columns(schema, table_type) + + @testing.provide_metadata + def _test_get_view_definition(self, schema=None): + if bool(os.environ.get("SPANNER_EMULATOR_HOST")): + pytest.skip("View tables not supported on emulator") + super()._test_get_view_definition(schema) + @classmethod def define_temp_tables(cls, metadata): """ @@ -858,7 +899,7 @@ def _test_get_table_names(self, schema=None, table_type="table", order_by=None): insp = inspect(meta.bind) - if table_type == "view": + if table_type == "view" and not bool(os.environ.get("SPANNER_EMULATOR_HOST")): table_names = insp.get_view_names(schema) table_names.sort() answer = ["email_addresses_v", "users_v"] @@ -921,14 +962,14 @@ def test_array_reflection(self): str_array = ARRAY(String(16)) int_array = ARRAY(Integer) - Table( + arrays_test = Table( "arrays_test", orig_meta, Column("id", Integer, primary_key=True), Column("str_array", str_array), Column("int_array", int_array), ) - orig_meta.create_all() + arrays_test.create(create_engine(get_db_url())) # autoload the table and check its columns reflection tab = Table("arrays_test", orig_meta, autoload=True) @@ -969,6 +1010,22 @@ def test_fk_column_order(self): eq_(set(fkey1.get("referred_columns")), {"name", "id", "attr"}) eq_(set(fkey1.get("constrained_columns")), {"pname", "pid", "pattr"}) + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self, connection): + """ + SPANNER OVERRIDE: + Emultor doesn't support returning pk sorted by ordinal value + of columns. + """ + insp = inspect(connection) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + exp = ( + ["id", "name", "attr"] + if bool(os.environ.get("SPANNER_EMULATOR_HOST")) + else ["name", "id", "attr"] + ) + eq_(primary_key.get("constrained_columns"), exp) + class RowFetchTest(_RowFetchTest): def test_row_w_scalar_select(self): diff --git a/test/test_suite_14.py b/test/test_suite_14.py index bd77a63f..a13477b3 100644 --- a/test/test_suite_14.py +++ b/test/test_suite_14.py @@ -60,6 +60,7 @@ from sqlalchemy.types import Text from sqlalchemy.testing import requires from sqlalchemy.testing import is_true +from sqlalchemy import types as sql_types from sqlalchemy.testing.fixtures import ( ComputedReflectionFixtureTest as _ComputedReflectionFixtureTest, ) @@ -229,6 +230,34 @@ def test_binary_reflection(self, connection, metadata): class ComponentReflectionTest(_ComponentReflectionTest): + @classmethod + def define_views(cls, metadata, schema): + table_info = { + "users": ["user_id", "test1", "test2"], + "email_addresses": ["address_id", "remote_user_id", "email_address"], + } + if testing.requires.self_referential_foreign_keys.enabled: + table_info["users"] = table_info["users"] + ["parent_user_id"] + for table_name in ("users", "email_addresses"): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + "_v" + columns = "" + for column in table_info[table_name]: + stmt = table_name + "." + column + " AS " + column + if columns: + columns = columns + ", " + stmt + else: + columns = stmt + query = f"""CREATE VIEW {view_name} + SQL SECURITY INVOKER + AS SELECT {columns} + FROM {fullname}""" + + event.listen(metadata, "after_create", DDL(query)) + event.listen(metadata, "before_drop", DDL("DROP VIEW %s" % view_name)) + @classmethod def define_tables(cls, metadata): cls.define_reflected_tables(metadata, None) @@ -374,9 +403,102 @@ def define_reflected_tables(cls, metadata, schema): sqlalchemy.Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) sqlalchemy.Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) - if testing.requires.view_column_reflection.enabled: + if testing.requires.view_column_reflection.enabled and not bool( + os.environ.get("SPANNER_EMULATOR_HOST") + ): cls.define_views(metadata, schema) + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + @testing.requires.view_reflection + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + view_name1 = "users_v" + view_name2 = "email_addresses_v" + insp = inspect(connection) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + if use_views and bool(os.environ.get("SPANNER_EMULATOR_HOST")): + pytest.skip("Skipped on emulator") + + schema = None + users, addresses = (self.tables.users, self.tables.email_addresses) + if use_views: + table_names = ["users_v", "email_addresses_v"] + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) + for table_name, table in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ + ctype_def = col.type + if isinstance(ctype_def, sqlalchemy.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): + ctype_def = sql_types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + self.assert_( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" % (col.name, col.type, cols[i]["name"], ctype), + ) + + if not col.primary_key: + assert cols[i]["default"] is None + @testing.combinations((False,), argnames="use_schema") @testing.requires.foreign_key_constraint_reflection def test_get_foreign_keys(self, connection, use_schema): @@ -669,7 +791,7 @@ def _test_get_table_names(self, schema=None, table_type="table", order_by=None): insp = inspect(meta.bind) - if table_type == "view": + if table_type == "view" and not bool(os.environ.get("SPANNER_EMULATOR_HOST")): table_names = insp.get_view_names(schema) table_names.sort() answer = ["email_addresses_v", "users_v"] @@ -730,6 +852,22 @@ def test_fk_column_order(self): eq_(set(fkey1.get("referred_columns")), {"name", "id", "attr"}) eq_(set(fkey1.get("constrained_columns")), {"pname", "pid", "pattr"}) + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self, connection): + """ + SPANNER OVERRIDE: + Emultor doesn't support returning pk sorted by ordinal value + of columns. + """ + insp = inspect(connection) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + exp = ( + ["id", "name", "attr"] + if bool(os.environ.get("SPANNER_EMULATOR_HOST")) + else ["name", "id", "attr"] + ) + eq_(primary_key.get("constrained_columns"), exp) + @pytest.mark.skip("Spanner doesn't support quotes in table names.") class QuotedNameArgumentTest(_QuotedNameArgumentTest): @@ -972,6 +1110,9 @@ def test_simple_limit_expr_offset(self, connection): def test_bound_offset(self, connection): pass + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) def test_limit_render_multiple_times(self, connection): table = self.tables.some_table stmt = select(table.c.id).limit(1).scalar_subquery() @@ -981,7 +1122,16 @@ def test_limit_render_multiple_times(self, connection): self._assert_result( connection, u, - [(2,)], + [(1,)], + ) + + @testing.requires.offset + def test_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], ) @@ -1827,6 +1977,14 @@ def define_tables(cls, metadata): def test_has_table_schema(self): pass + @testing.requires.views + def test_has_table_view(self, connection): + pass + + @testing.requires.views + def test_has_table_view_schema(self, connection): + pass + class PostCompileParamsTest(_PostCompileParamsTest): def test_execute(self): @@ -2048,7 +2206,7 @@ class JSONTest(_JSONTest): def test_single_element_round_trip(self, element): pass - def _test_round_trip(self, data_element): + def _test_round_trip(self, data_element, connection): data_table = self.tables.data_table config.db.execute( diff --git a/test/test_suite_20.py b/test/test_suite_20.py new file mode 100644 index 00000000..fb59b725 --- /dev/null +++ b/test/test_suite_20.py @@ -0,0 +1,3002 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timezone +import decimal +import operator +import os +import pkg_resources +import pytest +import random +import time +from unittest import mock + +from google.cloud.spanner_v1 import RequestOptions +import sqlalchemy +from sqlalchemy import create_engine +from sqlalchemy.engine import Inspector +from sqlalchemy import inspect +from sqlalchemy import testing +from sqlalchemy import ForeignKey +from sqlalchemy import MetaData +from sqlalchemy.engine import ObjectKind +from sqlalchemy.engine import ObjectScope +from sqlalchemy.schema import DDL +from sqlalchemy.schema import Computed +from sqlalchemy.testing import config +from sqlalchemy.testing import engines +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import provide_metadata, emits_warning +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.provision import temp_table_keyword_args +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table +from sqlalchemy import literal_column +from sqlalchemy import select +from sqlalchemy import util +from sqlalchemy import union +from sqlalchemy import event +from sqlalchemy import exists +from sqlalchemy import Boolean +from sqlalchemy import Float +from sqlalchemy import LargeBinary +from sqlalchemy import String +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.types import Integer +from sqlalchemy.types import Numeric +from sqlalchemy.types import Text +from sqlalchemy.testing import requires +from sqlalchemy.testing import is_true +from sqlalchemy import Index +from sqlalchemy import types +from sqlalchemy.testing.fixtures import ( + ComputedReflectionFixtureTest as _ComputedReflectionFixtureTest, +) + +from google.api_core.datetime_helpers import DatetimeWithNanoseconds + +from google.cloud import spanner_dbapi + +from sqlalchemy.testing.suite.test_cte import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_ddl import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_dialect import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_insert import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_reflection import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_deprecations import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_results import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_select import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_sequence import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_unicode_ddl import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_update_delete import * # noqa: F401, F403 +from sqlalchemy.testing.suite.test_cte import CTETest as _CTETest +from sqlalchemy.testing.suite.test_ddl import TableDDLTest as _TableDDLTest +from sqlalchemy.testing.suite.test_ddl import ( + FutureTableDDLTest as _FutureTableDDLTest, + LongNameBlowoutTest as _LongNameBlowoutTest, +) +from sqlalchemy.testing.suite.test_update_delete import ( + SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest, +) +from sqlalchemy.testing.suite.test_dialect import ( + DifficultParametersTest as _DifficultParametersTest, + EscapingTest as _EscapingTest, + ReturningGuardsTest as _ReturningGuardsTest, +) +from sqlalchemy.testing.suite.test_insert import ( + InsertBehaviorTest as _InsertBehaviorTest, +) +from sqlalchemy.testing.suite.test_select import ( # noqa: F401, F403 + CompoundSelectTest as _CompoundSelectTest, + ExistsTest as _ExistsTest, + FetchLimitOffsetTest as _FetchLimitOffsetTest, + IdentityAutoincrementTest as _IdentityAutoincrementTest, + IsOrIsNotDistinctFromTest as _IsOrIsNotDistinctFromTest, + LikeFunctionsTest as _LikeFunctionsTest, + OrderByLabelTest as _OrderByLabelTest, + PostCompileParamsTest as _PostCompileParamsTest, + SameNamedSchemaTableTest as _SameNamedSchemaTableTest, +) +from sqlalchemy.testing.suite.test_reflection import ( # noqa: F401, F403 + ComponentReflectionTestExtra as _ComponentReflectionTestExtra, + QuotedNameArgumentTest as _QuotedNameArgumentTest, + ComponentReflectionTest as _ComponentReflectionTest, + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, + ComputedReflectionTest as _ComputedReflectionTest, + HasIndexTest as _HasIndexTest, + HasTableTest as _HasTableTest, +) +from sqlalchemy.testing.suite.test_results import ( + RowFetchTest as _RowFetchTest, +) +from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403 + BooleanTest as _BooleanTest, + DateTest as _DateTest, + _DateFixture as __DateFixture, + DateTimeHistoricTest, + DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest, + DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest, + DateTimeTest as _DateTimeTest, + IntegerTest as _IntegerTest, + JSONTest as _JSONTest, + _LiteralRoundTripFixture, + NumericTest as _NumericTest, + StringTest as _StringTest, + TextTest as _TextTest, + TimeTest as _TimeTest, + TimeMicrosecondsTest as _TimeMicrosecondsTest, + TimestampMicrosecondsTest, + UnicodeVarcharTest as _UnicodeVarcharTest, + UnicodeTextTest as _UnicodeTextTest, + _UnicodeFixture as __UnicodeFixture, +) # noqa: F401, F403 +from test._helpers import get_db_url + +config.test_schema = "" + + +class BooleanTest(_BooleanTest): + @pytest.mark.skip( + "The original test case was split into 2 parts: " + "test_render_literal_bool_true and test_render_literal_bool_false" + ) + def test_render_literal_bool(self): + pass + + def test_render_literal_bool_true(self, literal_round_trip): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + literal_round_trip(Boolean(), [True], [True]) + + def test_render_literal_bool_false(self, literal_round_trip): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + literal_round_trip(Boolean(), [False], [False]) + + @pytest.mark.skip("Not supported by Cloud Spanner") + def test_whereclause(self): + pass + + +class ComponentReflectionTestExtra(_ComponentReflectionTestExtra): + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + connection.connection.commit() + eq_( + dict( + (col["name"], col["nullable"]) + for col in inspect(connection).get_columns("t") + ), + {"a": True, "b": False}, + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", metadata, *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) + t.create(connection) + connection.connection.commit() + + return [c["type"] for c in inspect(connection).get_columns("t")] + + @testing.requires.table_reflection + def test_numeric_reflection(self, connection, metadata): + """ + SPANNER OVERRIDE: + + Spanner defines NUMERIC type with the constant precision=38 + and scale=9. Overriding the test to check if the NUMERIC + column is successfully created and has dimensions + correct for Cloud Spanner. + """ + for typ in self._type_round_trip(connection, metadata, Numeric(18, 5)): + assert isinstance(typ, Numeric) + eq_(typ.precision, 38) + eq_(typ.scale, 9) + + @testing.requires.table_reflection + def test_binary_reflection(self, connection, metadata): + """ + Check that a BYTES column with an explicitly + set size is correctly reflected. + """ + for typ in self._type_round_trip(connection, metadata, LargeBinary(20)): + assert isinstance(typ, LargeBinary) + eq_(typ.length, 20) + + +class ComputedReflectionFixtureTest(_ComputedReflectionFixtureTest): + @classmethod + def define_tables(cls, metadata): + """SPANNER OVERRIDE: + + Avoid using default values for computed columns. + """ + Table( + "computed_default_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_col", Integer, Computed("normal + 42")), + Column("with_default", Integer), + ) + + t = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal + 42")), + ) + + if testing.requires.computed_columns_virtual.enabled: + t.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal + 2", persisted=False), + ) + ) + if testing.requires.computed_columns_stored.enabled: + t.append_column( + Column( + "computed_stored", + Integer, + Computed("normal - 42", persisted=True), + ) + ) + + +class ComputedReflectionTest(_ComputedReflectionTest, ComputedReflectionFixtureTest): + @testing.requires.schemas + def test_get_column_returns_persisted_with_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_column_table", schema=config.test_schema) + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal+42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal/2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal-42", + True, + ) + + @pytest.mark.skip("Default values are not supported.") + def test_computed_col_default_not_set(self): + pass + + def test_get_column_returns_computed(self): + """ + SPANNER OVERRIDE: + + In Spanner all the generated columns are STORED, + meaning there are no persisted and not persisted + (in the terms of the SQLAlchemy) columns. The + method override omits the persistence reflection checks. + """ + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + data = {c["name"]: c for c in cols} + for key in ("id", "normal", "with_default"): + is_true("computed" not in data[key]) + compData = data["computed_col"] + is_true("computed" in compData) + is_true("sqltext" in compData["computed"]) + eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42") + + def test_create_not_null_computed_column(self, connection): + """ + SPANNER TEST: + + Check that on creating a computed column with a NOT NULL + clause the clause is set in front of the computed column + statement definition and doesn't cause failures. + """ + metadata = MetaData() + + Table( + "Singers", + metadata, + Column("SingerId", String(36), primary_key=True, nullable=False), + Column("FirstName", String(200)), + Column("LastName", String(200), nullable=False), + Column( + "FullName", + String(400), + Computed("COALESCE(FirstName || ' ', '') || LastName"), + nullable=False, + ), + ) + + metadata.create_all(connection) + + +class ComponentReflectionTest(_ComponentReflectionTest): + @pytest.mark.skip("Skip") + def test_not_existing_table(self, method, connection): + pass + + @classmethod + def define_tables(cls, metadata): + cls.define_reflected_tables(metadata, None) + + @classmethod + def define_views(cls, metadata, schema): + table_info = { + "dingalings": [ + "dingaling_id", + "address_id", + "data", + "id_user", + ], + "users": ["user_id", "test1", "test2"], + "email_addresses": ["address_id", "remote_user_id", "email_address"], + } + if testing.requires.self_referential_foreign_keys.enabled: + table_info["users"] = table_info["users"] + ["parent_user_id"] + if testing.requires.materialized_views.enabled: + materialized = {"dingalings"} + else: + materialized = set() + for table_name in ("users", "email_addresses", "dingalings"): + fullname = table_name + if schema: + fullname = f"{schema}.{table_name}" + view_name = fullname + "_v" + prefix = "MATERIALIZED " if table_name in materialized else "" + columns = "" + for column in table_info[table_name]: + stmt = table_name + "." + column + " AS " + column + if columns: + columns = columns + ", " + stmt + else: + columns = stmt + query = f"""CREATE {prefix}VIEW {view_name} + SQL SECURITY INVOKER + AS SELECT {columns} + FROM {fullname}""" + + event.listen(metadata, "after_create", DDL(query)) + if table_name in materialized: + index_name = "mat_index" + if schema and testing.against("oracle"): + index_name = f"{schema}.{index_name}" + idx = f"CREATE INDEX {index_name} ON {view_name}(data)" + event.listen(metadata, "after_create", DDL(idx)) + event.listen(metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}")) + + @classmethod + def define_reflected_tables(cls, metadata, schema): + if schema: + schema_prefix = schema + "." + else: + schema_prefix = "" + + if testing.requires.self_referential_foreign_keys.enabled: + users = Table( + "users", + metadata, + Column("user_id", sqlalchemy.INT, primary_key=True), + Column("test1", sqlalchemy.CHAR(5), nullable=False), + Column("test2", sqlalchemy.Float(5), nullable=False), + Column( + "parent_user_id", + sqlalchemy.Integer, + sqlalchemy.ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ), + schema=schema, + test_needs_fk=True, + ) + else: + users = Table( + "users", + metadata, + Column("user_id", sqlalchemy.INT, primary_key=True), + Column("test1", sqlalchemy.CHAR(5), nullable=False), + Column("test2", sqlalchemy.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sqlalchemy.Integer, primary_key=True), + Column( + "address_id", + sqlalchemy.Integer, + sqlalchemy.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column( + "id_user", + sqlalchemy.Integer, + sqlalchemy.ForeignKey("%susers.user_id" % schema_prefix), + ), + Column("data", sqlalchemy.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sqlalchemy.Integer, primary_key=True), + Column( + "remote_user_id", + sqlalchemy.Integer, + sqlalchemy.ForeignKey(users.c.user_id), + ), + Column("email_address", sqlalchemy.String(20)), + sqlalchemy.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sqlalchemy.Integer, primary_key=True, comment="id comment"), + Column("data", sqlalchemy.String(20), comment="data % comment"), + Column( + "d2", + sqlalchemy.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) + Table( + "no_constraints", + metadata, + Column("data", sqlalchemy.String(20)), + schema=schema, + ) + + if testing.requires.cross_schema_fk_reflection.enabled: + if schema is None: + Table( + "local_table", + metadata, + Column("id", sqlalchemy.Integer, primary_key=True), + Column("data", sqlalchemy.String(20)), + Column( + "remote_id", + ForeignKey("%s.remote_table_2.id" % testing.config.test_schema), + ), + test_needs_fk=True, + schema=config.db.dialect.default_schema_name, + ) + else: + Table( + "remote_table", + metadata, + Column("id", sqlalchemy.Integer, primary_key=True), + Column( + "local_id", + ForeignKey( + "%s.local_table.id" % config.db.dialect.default_schema_name + ), + ), + Column("data", sqlalchemy.String(20)), + schema=schema, + test_needs_fk=True, + ) + Table( + "remote_table_2", + metadata, + Column("id", sqlalchemy.Integer, primary_key=True), + Column("data", sqlalchemy.String(20)), + schema=schema, + test_needs_fk=True, + ) + + if testing.requires.index_reflection.enabled: + sqlalchemy.Index("users_t_idx", users.c.test1, users.c.test2, unique=True) + sqlalchemy.Index( + "users_all_idx", users.c.user_id, users.c.test2, users.c.test1 + ) + + if not schema: + # test_needs_fk is at the moment to force MySQL InnoDB + noncol_idx_test_nopk = Table( + "noncol_idx_test_nopk", + metadata, + Column("id", sqlalchemy.Integer, primary_key=True), + Column("q", sqlalchemy.String(5)), + test_needs_fk=True, + extend_existing=True, + ) + + noncol_idx_test_pk = Table( + "noncol_idx_test_pk", + metadata, + Column("id", sqlalchemy.Integer, primary_key=True), + Column("q", sqlalchemy.String(5)), + test_needs_fk=True, + extend_existing=True, + ) + + if testing.requires.indexes_with_ascdesc.enabled: + sqlalchemy.Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + sqlalchemy.Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) + + if testing.requires.view_column_reflection.enabled and not bool( + os.environ.get("SPANNER_EMULATOR_HOST") + ): + cls.define_views(metadata, schema) + + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + if use_views and bool(os.environ.get("SPANNER_EMULATOR_HOST")): + pytest.skip("Skipped on emulator") + + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + if use_views: + table_names = ["users_v", "email_addresses_v", "dingalings_v"] + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) + for table_name, table in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + is_true(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ + ctype_def = col.type + if isinstance(ctype_def, sqlalchemy.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against("oracle") and ctype_def in ( + types.Date, + types.DateTime, + ): + ctype_def = types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + is_true( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + types.Integer, + types.Numeric, + types.DateTime, + types.Date, + types.Time, + types.String, + types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" % (col.name, col.type, cols[i]["name"], ctype), + ) + + if not col.primary_key: + assert cols[i]["default"] is None + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + @testing.requires.view_reflection + def test_get_view_definition( + self, + connection, + ): + schema = None + insp = inspect(connection) + for view in ["users_v", "email_addresses_v", "dingalings_v"]: + v = insp.get_view_definition(view, schema=schema) + is_true(bool(v)) + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + @testing.requires.view_reflection + def test_get_view_definition_does_not_exist(self, connection): + super().test_get_view_definition_does_not_exist(connection) + + def filter_name_values(): + return testing.combinations(True, False, argnames="use_filter") + + @filter_name_values() + @testing.requires.index_reflection + def test_get_multi_indexes( + self, + get_multi_exp, + use_filter, + schema=None, + scope=ObjectScope.DEFAULT, + kind=ObjectKind.TABLE, + ): + """ + SPANNER OVERRIDE: + + Spanner doesn't support indexes on views and + doesn't support temporary tables, so real tables are + used for testing. As the original test expects only real + tables to be read, and in Spanner all the tables are real, + expected results override is required. + """ + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_indexes, + self.exp_indexes, + ) + _ignore_tables = [ + (None, "comment_test"), + (None, "dingalings"), + (None, "email_addresses"), + (None, "no_constraints"), + ] + exp = {k: v for k, v in exp.items() if k not in _ignore_tables} + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_indexes(**kw) + self._check_table_dict(result, exp, self._required_index_keys) + + def exp_pks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def pk(*cols, name=mock.ANY, comment=None): + return { + "constrained_columns": list(cols), + "name": name, + "comment": comment, + } + + empty = pk(name=None) + if testing.requires.materialized_views_reflect_pk.enabled: + materialized = {(schema, "dingalings_v"): pk("dingaling_id")} + else: + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): pk("user_id"), + (schema, "dingalings"): pk("dingaling_id"), + (schema, "email_addresses"): pk( + "address_id", name="email_ad_pk", comment="ea pk comment" + ), + (schema, "comment_test"): pk("id"), + (schema, "no_constraints"): empty, + (schema, "local_table"): pk("id"), + (schema, "remote_table"): pk("id"), + (schema, "remote_table_2"): pk("id"), + (schema, "noncol_idx_test_nopk"): pk("id"), + (schema, "noncol_idx_test_pk"): pk("id"), + (schema, self.temp_table_name()): pk("id"), + } + if not testing.requires.reflects_pk_names.enabled: + for val in tables.values(): + if val["name"] is not None: + val["name"] = mock.ANY + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @filter_name_values() + @testing.requires.primary_key_constraint_reflection + def test_get_multi_pk_constraint( + self, + get_multi_exp, + use_filter, + schema=None, + scope=ObjectScope.DEFAULT, + kind=ObjectKind.TABLE, + ): + """ + SPANNER OVERRIDE: + + Spanner doesn't support temporary tables, so real tables are + used for testing. As the original test expects only real + tables to be read, and in Spanner all the tables are real, + expected results override is required. + """ + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_pk_constraint, + self.exp_pks, + ) + _ignore_tables = [(None, "no_constraints")] + exp = {k: v for k, v in exp.items() if k not in _ignore_tables} + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_pk_constraint(**kw) + self._check_table_dict(result, exp, self._required_pk_keys, make_lists=True) + + def exp_fks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt: + def __eq__(self, other): + return other is None or config.db.dialect.default_schema_name == other + + def fk( + cols, + ref_col, + ref_table, + ref_schema=schema, + name=mock.ANY, + comment=None, + ): + return { + "constrained_columns": cols, + "referred_columns": ref_col, + "name": name, + "options": mock.ANY, + "referred_schema": ref_schema if ref_schema is not None else tt(), + "referred_table": ref_table, + "comment": comment, + } + + materialized = {} + views = {} + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk") + ], + (schema, "dingalings"): [ + fk(["address_id"], ["address_id"], "email_addresses"), + fk(["id_user"], ["user_id"], "users"), + ], + (schema, "email_addresses"): [fk(["remote_user_id"], ["user_id"], "users")], + (schema, "local_table"): [ + fk( + ["remote_id"], + ["id"], + "remote_table_2", + ref_schema=config.test_schema, + ) + ], + (schema, "remote_table"): [ + fk(["local_id"], ["id"], "local_table", ref_schema=None) + ], + } + if not testing.requires.self_referential_foreign_keys.enabled: + tables[(schema, "users")].clear() + if not testing.requires.named_constraints.enabled: + for vals in tables.values(): + for val in vals: + if val["name"] is not mock.ANY: + val["name"] = mock.ANY + + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @filter_name_values() + @testing.requires.foreign_key_constraint_reflection + def test_get_multi_foreign_keys( + self, + get_multi_exp, + use_filter, + schema=None, + scope=ObjectScope.DEFAULT, + kind=ObjectKind.TABLE, + ): + + """ + SPANNER OVERRIDE: + + Spanner doesn't support temporary tables, so real tables are + used for testing. As the original test expects only real + tables to be read, and in Spanner all the tables are real, + expected results override is required. + """ + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_foreign_keys, + self.exp_fks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_foreign_keys(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["constrained_columns"])) + self._check_table_dict( + { + key: sorted(value, key=lambda x: x["constrained_columns"]) + for key, value in result.items() + }, + { + key: sorted(value, key=lambda x: x["constrained_columns"]) + for key, value in exp.items() + }, + self._required_fk_keys, + ) + + def exp_columns( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def col(name, auto=False, default=mock.ANY, comment=None, nullable=True): + res = { + "name": name, + "autoincrement": auto, + "type": mock.ANY, + "default": default, + "comment": comment, + "nullable": nullable, + } + if auto == "omit": + res.pop("autoincrement") + return res + + def pk(name, **kw): + kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw} + return col(name, **kw) + + materialized = { + (schema, "dingalings_v"): [ + col("dingaling_id", auto="omit", nullable=mock.ANY), + col("address_id"), + col("id_user"), + col("data"), + ] + } + views = { + (schema, "email_addresses_v"): [ + col("address_id", auto="omit", nullable=mock.ANY), + col("remote_user_id"), + col("email_address"), + ], + (schema, "users_v"): [ + col("user_id", auto="omit", nullable=mock.ANY), + col("test1", nullable=mock.ANY), + col("test2", nullable=mock.ANY), + col("parent_user_id"), + ], + (schema, "user_tmp_v"): [ + col("id", auto="omit", nullable=mock.ANY), + col("name"), + col("foo"), + ], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + pk("user_id"), + col("test1", nullable=False), + col("test2", nullable=False), + col("parent_user_id"), + ], + (schema, "dingalings"): [ + pk("dingaling_id"), + col("address_id"), + col("id_user"), + col("data"), + ], + (schema, "email_addresses"): [ + pk("address_id"), + col("remote_user_id"), + col("email_address"), + ], + (schema, "comment_test"): [ + pk("id", comment="id comment"), + col("data", comment="data % comment"), + col( + "d2", + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + ], + (schema, "no_constraints"): [col("data")], + (schema, "local_table"): [pk("id"), col("data"), col("remote_id")], + (schema, "remote_table"): [pk("id"), col("local_id"), col("data")], + (schema, "remote_table_2"): [pk("id"), col("data")], + (schema, "noncol_idx_test_nopk"): [pk("id"), col("q")], + (schema, "noncol_idx_test_pk"): [pk("id"), col("q")], + (schema, self.temp_table_name()): [ + pk("id"), + col("name"), + col("foo"), + ], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @filter_name_values() + def test_get_multi_columns( + self, + get_multi_exp, + use_filter, + schema=None, + scope=ObjectScope.DEFAULT, + kind=ObjectKind.TABLE, + ): + """ + SPANNER OVERRIDE: + + Spanner doesn't support temporary tables, so real tables are + used for testing. As the original test expects only real + tables to be read, and in Spanner all the tables are real, + expected results override is required. + """ + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_columns, + self.exp_columns, + ) + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_columns(**kw) + self._check_table_dict(result, exp, self._required_column_keys) + + @pytest.mark.skip( + "Requires an introspection method to be implemented in SQLAlchemy first" + ) + def test_get_multi_unique_constraints(): + pass + + @pytest.mark.skip( + "Requires an introspection method to be implemented in SQLAlchemy first" + ) + def test_get_multi_check_constraints(): + pass + + @testing.combinations((False,), argnames="use_schema") + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + insp = inspect(connection) + expected_schema = schema + # users + + if testing.requires.self_referential_foreign_keys.enabled: + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) + fkey1 = users_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + eq_(fkey1["name"], "user_id_fk") + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + if testing.requires.self_referential_foreign_keys.enabled: + eq_(fkey1["constrained_columns"], ["parent_user_id"]) + + # addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) + fkey1 = addr_fkeys[0] + + with testing.requires.implicitly_named_constraints.fail_if(): + self.assert_(fkey1["name"] is not None) + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) + + @testing.combinations( + None, + ("foreign_key", testing.requires.foreign_key_constraint_reflection), + argnames="order_by", + ) + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_table_names(self, connection, order_by, use_schema): + + schema = None + + _ignore_tables = [ + "account", + "alembic_version", + "bytes_table", + "comment_test", + "date_table", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", + "text_table", + "user_tmp", + "no_constraints", + ] + + insp = inspect(connection) + + if order_by: + tables = [ + rec[0] for rec in insp.get_sorted_table_and_fkc_names(schema) if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] + + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) + + @classmethod + def define_temp_tables(cls, metadata): + """ + SPANNER OVERRIDE: + + In Cloud Spanner unique indexes are used instead of directly + creating unique constraints. Overriding the test to replace + constraints with indexes in testing data. + """ + kw = temp_table_keyword_args(config, config.db) + user_tmp = Table( + "user_tmp", + metadata, + Column("id", sqlalchemy.INT, primary_key=True), + Column("name", sqlalchemy.VARCHAR(50)), + Column("foo", sqlalchemy.INT), + sqlalchemy.Index("user_tmp_uq", "name", unique=True), + sqlalchemy.Index("user_tmp_ix", "foo"), + extend_existing=True, + **kw, + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL("create temporary view user_tmp_v as " "select * from user_tmp"), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + @testing.provide_metadata + def test_reflect_string_column_max_len(self, connection): + """ + SPANNER SPECIFIC TEST: + + In Spanner column of the STRING type can be + created with size defined as MAX. The test + checks that such a column is correctly reflected. + """ + metadata = MetaData() + Table("text_table", metadata, Column("TestColumn", Text, nullable=False)) + metadata.create_all(connection) + + Table("text_table", metadata, autoload=True) + + def test_reflect_bytes_column_max_len(self, connection): + """ + SPANNER SPECIFIC TEST: + + In Spanner column of the BYTES type can be + created with size defined as MAX. The test + checks that such a column is correctly reflected. + """ + metadata = MetaData() + Table( + "bytes_table", + metadata, + Column("TestColumn", LargeBinary, nullable=False), + ) + metadata.create_all(connection) + + Table("bytes_table", metadata, autoload=True) + + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, metadata, connection, use_schema=False): + # SQLite dialect needs to parse the names of the constraints + # separately from what it gets from PRAGMA index_list(), and + # then matches them up. so same set of column_names in two + # constraints will confuse it. Perhaps we should no longer + # bother with index_list() here since we have the whole + # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None + uniques = sorted( + [ + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, + ], + key=operator.itemgetter("name"), + ) + table = Table( + "testtbl", + metadata, + Column("id", sqlalchemy.INT, primary_key=True), + Column("a", String(20)), + Column("b", String(30)), + Column("c", Integer), + # reserved identifiers + Column("asc", String(30)), + Column("key", String(30)), + sqlalchemy.Index("unique_a", "a", unique=True), + sqlalchemy.Index("unique_a_b_c", "a", "b", "c", unique=True), + sqlalchemy.Index("unique_c_a_b", "c", "a", "b", unique=True), + sqlalchemy.Index("unique_asc_key", "asc", "key", unique=True), + schema=schema, + ) + table.create(connection) + connection.connection.commit() + + inspector = inspect(connection) + reflected = sorted( + inspector.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), + ) + + names_that_duplicate_index = set() + + for orig, refl in zip(uniques, reflected): + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + dupe = refl.pop("duplicates_index", None) + if dupe: + names_that_duplicate_index.add(dupe) + eq_(orig, refl) + + reflected_metadata = MetaData() + reflected = Table( + "testtbl", + reflected_metadata, + autoload_with=connection, + schema=schema, + ) + + # test "deduplicates for index" logic. MySQL and Oracle + # "unique constraints" are actually unique indexes (with possible + # exception of a unique that is a dupe of another one in the case + # of Oracle). make sure # they aren't duplicated. + idx_names = set([idx.name for idx in reflected.indexes]) + uq_names = set( + [ + uq.name + for uq in reflected.constraints + if isinstance(uq, sqlalchemy.UniqueConstraint) + ] + ).difference(["unique_c_a_b"]) + + assert not idx_names.intersection(uq_names) + if names_that_duplicate_index: + eq_(names_that_duplicate_index, idx_names) + eq_(uq_names, set()) + + @testing.provide_metadata + def test_unique_constraint_raises(self, connection): + """ + Checking that unique constraint creation + fails due to a ProgrammingError. + """ + metadata = MetaData() + Table( + "user_tmp_failure", + metadata, + Column("id", sqlalchemy.INT, primary_key=True), + Column("name", sqlalchemy.VARCHAR(50)), + sqlalchemy.UniqueConstraint("name", name="user_tmp_uq"), + ) + + with pytest.raises(spanner_dbapi.exceptions.ProgrammingError): + metadata.create_all(connection) + + @testing.provide_metadata + def _test_get_table_names(self, schema=None, table_type="table", order_by=None): + """ + SPANNER OVERRIDE: + + Spanner doesn't support temporary tables, so real tables are + used for testing. As the original test expects only real + tables to be read, and in Spanner all the tables are real, + expected results override is required. + """ + _ignore_tables = [ + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", + "no_constraints", + ] + meta = self.metadata + + insp = inspect(meta.bind) + + if table_type == "view" and not bool(os.environ.get("SPANNER_EMULATOR_HOST")): + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ["email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) + else: + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] + + if order_by == "foreign_key": + answer = {"dingalings", "email_addresses", "user_tmp", "users"} + eq_(set(table_names), answer) + else: + answer = ["dingalings", "email_addresses", "user_tmp", "users"] + eq_(sorted(table_names), answer) + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_get_view_names(self, connection, use_schema=False): + insp = inspect(connection) + schema = None + table_names = insp.get_view_names(schema) + if testing.requires.materialized_views.enabled: + eq_(sorted(table_names), ["email_addresses_v", "users_v"]) + eq_(insp.get_materialized_view_names(schema), ["dingalings_v"]) + else: + answer = ["dingalings_v", "email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) + + @pytest.mark.skip("Spanner doesn't support temporary tables") + def test_get_temp_table_indexes(self): + pass + + @pytest.mark.skip("Spanner doesn't support temporary tables") + def test_get_temp_table_unique_constraints(self): + pass + + @pytest.mark.skip("Spanner doesn't support temporary tables") + def test_get_temp_table_columns(self): + pass + + @pytest.mark.skip("Spanner doesn't support temporary tables") + def test_reflect_table_temp_table(self, connection): + pass + + def exp_indexes( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def idx( + *cols, + name, + unique=False, + column_sorting=None, + duplicates=False, + fk=False, + ): + fk_req = testing.requires.foreign_keys_reflect_as_index + dup_req = testing.requires.unique_constraints_reflect_as_index + if (fk and not fk_req.enabled) or (duplicates and not dup_req.enabled): + return () + res = { + "unique": unique, + "column_names": list(cols), + "name": name, + "dialect_options": mock.ANY, + "include_columns": [], + } + if column_sorting: + res["column_sorting"] = {"q": "DESC"} + if duplicates: + res["duplicates_constraint"] = name + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + if materialized: + materialized[(schema, "dingalings_v")].extend(idx("data", name="mat_index")) + tables = { + (schema, "users"): [ + *idx("parent_user_id", name="user_id_fk", fk=True), + *idx("user_id", "test2", "test1", name="users_all_idx"), + *idx("test1", "test2", name="users_t_idx", unique=True), + ], + (schema, "dingalings"): [ + *idx("data", name=mock.ANY, unique=True, duplicates=True), + *idx("id_user", name=mock.ANY, fk=True), + *idx( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + unique=True, + duplicates=True, + ), + ], + (schema, "email_addresses"): [ + *idx("email_address", name=mock.ANY), + *idx("remote_user_id", name=mock.ANY, fk=True), + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [*idx("remote_id", name=mock.ANY, fk=True)], + (schema, "remote_table"): [*idx("local_id", name=mock.ANY, fk=True)], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [ + *idx( + "q", + name="noncol_idx_nopk", + column_sorting={"q": "DESC"}, + ) + ], + (schema, "noncol_idx_test_pk"): [ + *idx("q", name="noncol_idx_pk", column_sorting={"q": "DESC"}) + ], + (schema, self.temp_table_name()): [ + *idx("foo", name="user_tmp_ix"), + *idx( + "name", + name=f"user_tmp_uq_{config.ident}", + duplicates=True, + unique=True, + ), + ], + } + if ( + not testing.requires.indexes_with_ascdesc.enabled + or not testing.requires.reflect_indexes_with_ascdesc.enabled + ): + tables[(schema, "noncol_idx_test_nopk")].clear() + tables[(schema, "noncol_idx_test_pk")].clear() + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def _check_list(self, result, exp, req_keys=None, msg=None): + if req_keys is None: + eq_(result, exp, msg) + else: + eq_(len(result), len(exp), msg) + for r, e in zip(result, exp): + for k in set(r) | set(e): + if (k in req_keys and (k in r and k in e)) or (k in r and k in e): + if isinstance(r[k], list): + r[k].sort() + e[k].sort() + eq_(r[k], e[k], f"{msg} - {k} - {r}") + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + @testing.combinations(True, False, argnames="use_schema") + @testing.combinations((True, testing.requires.views), False, argnames="views") + def test_metadata(self, connection, use_schema, views): + m = MetaData() + schema = None + m.reflect(connection, schema=schema, views=views, resolve_fks=False) + + insp = inspect(connection) + tables = insp.get_table_names(schema) + if views: + tables += insp.get_view_names(schema) + try: + tables += insp.get_materialized_view_names(schema) + except NotImplementedError: + pass + if schema is not None: + tables = [f"{schema}.{t}" for t in tables] + eq_(sorted(m.tables), sorted(tables)) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @testing.requires.foreign_key_constraint_reflection + def test_fk_column_order(self, connection): + """ + SPANNER OVERRIDE: + + Spanner column usage reflection doesn't support determenistic + ordering. Overriding the test to check that columns are + reflected correctly, without considering their order. + """ + # test for issue #5661 + insp = inspect(connection) + foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) + eq_(len(foreign_keys), 1) + fkey1 = foreign_keys[0] + eq_(set(fkey1.get("referred_columns")), {"name", "id", "attr"}) + eq_(set(fkey1.get("constrained_columns")), {"pname", "pid", "pattr"}) + + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self, connection): + """ + SPANNER OVERRIDE: + Emultor doesn't support returning pk sorted by ordinal value + of columns. + """ + insp = inspect(connection) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + exp = ( + ["id", "name", "attr"] + if bool(os.environ.get("SPANNER_EMULATOR_HOST")) + else ["name", "id", "attr"] + ) + eq_(primary_key.get("constrained_columns"), exp) + + +@pytest.mark.skip("Spanner doesn't support quotes in table names.") +class QuotedNameArgumentTest(_QuotedNameArgumentTest): + pass + + +class _DateFixture(__DateFixture): + compare = None + + @classmethod + def define_tables(cls, metadata): + """ + SPANNER OVERRIDE: + + Cloud Spanner doesn't support auto incrementing ids feature, + which is used by the original test. Overriding the test data + creation method to disable autoincrement and make id column + nullable. + """ + + class Decorated(sqlalchemy.TypeDecorator): + impl = cls.datatype + cache_ok = True + + Table( + "date_table", + metadata, + Column("id", Integer, primary_key=True, nullable=True), + Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), + ) + + +class DateTest(_DateTest): + """ + SPANNER OVERRIDE: + + DateTest tests used same class method to create table, so to avoid those failures + and maintain DRY concept just inherit the class to run tests successfully. + """ + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_null_bound_comparison(self): + super().test_null_bound_comparison() + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_null(self, connection): + super().test_null(connection) + + +class CTETest(_CTETest): + @classmethod + def define_tables(cls, metadata): + """ + The original method creates a foreign key without a name, + which causes troubles on test cleanup. Overriding the + method to explicitly set a foreign key name. + """ + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id", name="fk_some_table")), + ) + + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) + + @pytest.mark.skip("INSERT from WITH subquery is not supported") + def test_insert_from_select_round_trip(self): + """ + The test checks if an INSERT can be done from a cte, like: + + WITH some_cte AS (...) + INSERT INTO some_other_table (... SELECT * FROM some_cte) + + Such queries are not supported by Spanner. + """ + pass + + @pytest.mark.skip("DELETE from WITH subquery is not supported") + def test_delete_scalar_subq_round_trip(self): + """ + The test checks if a DELETE can be done from a cte, like: + + WITH some_cte AS (...) + DELETE FROM some_other_table (... SELECT * FROM some_cte) + + Such queries are not supported by Spanner. + """ + pass + + @pytest.mark.skip("DELETE from WITH subquery is not supported") + def test_delete_from_round_trip(self): + """ + The test checks if a DELETE can be done from a cte, like: + + WITH some_cte AS (...) + DELETE FROM some_other_table (... SELECT * FROM some_cte) + + Such queries are not supported by Spanner. + """ + pass + + @pytest.mark.skip("UPDATE from WITH subquery is not supported") + def test_update_from_round_trip(self): + """ + The test checks if an UPDATE can be done from a cte, like: + + WITH some_cte AS (...) + UPDATE some_other_table + SET (... SELECT * FROM some_cte) + + Such queries are not supported by Spanner. + """ + pass + + @pytest.mark.skip("WITH RECURSIVE subqueries are not supported") + def test_select_recursive_round_trip(self): + pass + + +class DateTimeMicrosecondsTest(_DateTimeMicrosecondsTest, DateTest): + @pytest.mark.skip("Spanner dates are time zone independent") + def test_select_direct(self): + pass + + def test_round_trip(self): + """ + SPANNER OVERRIDE: + + Spanner converts timestamp into `%Y-%m-%dT%H:%M:%S.%fZ` format, so to avoid + assert failures convert datetime input to the desire timestamp format. + """ + date_table = self.tables.date_table + + with config.db.connect() as connection: + connection.execute(date_table.insert(), {"date_data": self.data, "id": 250}) + row = connection.execute(select(date_table.c.date_data)).first() + + compare = self.compare or self.data + compare = compare.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + eq_(row[0].rfc3339(), compare) + assert isinstance(row[0], DatetimeWithNanoseconds) + + def test_round_trip_decorated(self, connection): + """ + SPANNER OVERRIDE: + + Spanner converts timestamp into `%Y-%m-%dT%H:%M:%S.%fZ` format, so to avoid + assert failures convert datetime input to the desire timestamp format. + """ + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "decorated_date_data": self.data} + ) + + row = connection.execute(select(date_table.c.decorated_date_data)).first() + + compare = self.compare or self.data + compare = compare.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + eq_(row[0].rfc3339(), compare) + assert isinstance(row[0], DatetimeWithNanoseconds) + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_null_bound_comparison(self): + super().test_null_bound_comparison() + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_null(self, connection): + super().test_null(connection) + + +class DateTimeTest(_DateTimeTest, DateTimeMicrosecondsTest): + """ + SPANNER OVERRIDE: + + DateTimeTest tests have the same failures same as DateTimeMicrosecondsTest tests, + so to avoid those failures and maintain DRY concept just inherit the class to run + tests successfully. + """ + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_null_bound_comparison(self): + super().test_null_bound_comparison() + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_null(self, connection): + super().test_null(connection) + + @pytest.mark.skip("Spanner dates are time zone independent") + def test_select_direct(self): + pass + + +@pytest.mark.skip("Not supported by Spanner") +class DifficultParametersTest(_DifficultParametersTest): + pass + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip("Spanner doesn't support composite LIMIT and OFFSET clauses") + def test_expr_limit(self, connection): + pass + + @pytest.mark.skip("Spanner doesn't support composite LIMIT and OFFSET clauses") + def test_expr_offset(self, connection): + pass + + @pytest.mark.skip("Spanner doesn't support composite LIMIT and OFFSET clauses") + def test_expr_limit_offset(self, connection): + pass + + @pytest.mark.skip("Spanner doesn't support composite LIMIT and OFFSET clauses") + def test_expr_limit_simple_offset(self, connection): + pass + + @pytest.mark.skip("Spanner doesn't support composite LIMIT and OFFSET clauses") + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip("Spanner doesn't support composite LIMIT and OFFSET clauses") + def test_bound_offset(self, connection): + pass + + @pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" + ) + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).limit(1).scalar_subquery() + + u = union(select(stmt), select(stmt)).subquery().select() + + self._assert_result( + connection, + u, + [(1,)], + ) + + @testing.requires.offset + def test_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + +@pytest.mark.skip("Spanner doesn't support autoincrement") +class IdentityAutoincrementTest(_IdentityAutoincrementTest): + pass + + +@pytest.mark.skip("Spanner doesn't support returning") +class ReturningGuardsTest(_ReturningGuardsTest): + pass + + +@pytest.mark.skip("Spanner doesn't support user made schemas") +class SameNamedSchemaTableTest(_SameNamedSchemaTableTest): + pass + + +class EscapingTest(_EscapingTest): + @provide_metadata + def test_percent_sign_round_trip(self): + """Test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + SPANNER OVERRIDE + Cloud Spanner supports tables with empty primary key, but + only single one row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + m = self.metadata + t = Table("t", m, Column("data", String(50))) + t.create(config.db) + with config.db.begin() as conn: + conn.execute(t.insert(), dict(data="some % value")) + + eq_( + conn.scalar( + select(t.c.data).where(t.c.data == literal_column("'some % value'")) + ), + "some % value", + ) + + conn.execute(t.delete()) + conn.execute(t.insert(), dict(data="some %% other value")) + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +class ExistsTest(_ExistsTest): + def test_select_exists(self, connection): + """ + SPANNER OVERRIDE: + + The original test is trying to execute a query like: + + SELECT ... + WHERE EXISTS (SELECT ...) + + SELECT WHERE without FROM clause is not supported by Spanner. + Rewriting the test to force it to generate a query like: + + SELECT EXISTS (SELECT ...) + """ + stuff = self.tables.stuff + eq_( + connection.execute( + select(exists().where(stuff.c.data == "some data")) + ).fetchall(), + [(True,)], + ) + + def test_select_exists_false(self, connection): + """ + SPANNER OVERRIDE: + + The original test is trying to execute a query like: + + SELECT ... + WHERE EXISTS (SELECT ...) + + SELECT WHERE without FROM clause is not supported by Spanner. + Rewriting the test to force it to generate a query like: + + SELECT EXISTS (SELECT ...) + """ + stuff = self.tables.stuff + eq_( + connection.execute( + select(exists().where(stuff.c.data == "no data")) + ).fetchall(), + [(False,)], + ) + + +class TableDDLTest(_TableDDLTest): + @pytest.mark.skip( + "Spanner table name must start with an uppercase or lowercase letter" + ) + def test_underscore_names(self): + pass + + @pytest.mark.skip("Table names incuding schemas are not supported by Spanner") + def test_create_table_schema(self): + pass + + +class FutureTableDDLTest(_FutureTableDDLTest): + @pytest.mark.skip("Table names incuding schemas are not supported by Spanner") + def test_create_table_schema(self): + pass + + @pytest.mark.skip( + "Spanner table name must start with an uppercase or lowercase letter" + ) + def test_underscore_names(self): + pass + + +@pytest.mark.skip("Max identifier length in Spanner is 128") +class LongNameBlowoutTest(_LongNameBlowoutTest): + pass + + +@pytest.mark.skip("Spanner doesn't support Time data type.") +class TimeTests(_TimeMicrosecondsTest, _TimeTest): + pass + + +@pytest.mark.skip("Spanner doesn't coerce dates from datetime.") +class DateTimeCoercedToDateTimeTest(_DateTimeCoercedToDateTimeTest): + pass + + +class IntegerTest(_IntegerTest): + @provide_metadata + def _round_trip(self, datatype, data): + """ + SPANNER OVERRIDE: + + This is the helper method for integer class tests which creates a table and + performs an insert operation. + Cloud Spanner supports tables with an empty primary key, but only one + row can be inserted into such a table - following insertions will fail with + `400 id must not be NULL in table date_table`. + Overriding the tests and adding a manual primary key value to avoid the same + failures. + """ + metadata = self.metadata + int_table = Table( + "integer_table", + metadata, + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), + Column("integer_data", datatype), + ) + + metadata.create_all(config.db) + + config.db.execute(int_table.insert(), {"id": 1, "integer_data": data}) + + row = config.db.execute(select(int_table.c.integer_data)).first() + + eq_(row, (data,)) + + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + def _huge_ints(): + + return testing.combinations( + 2147483649, # 32 bits + 2147483648, # 32 bits + 2147483647, # 31 bits + 2147483646, # 31 bits + -2147483649, # 32 bits + -2147483648, # 32 interestingly, asyncpg accepts this one as int32 + -2147483647, # 31 + -2147483646, # 31 + 0, + 1376537018368127, + -1376537018368127, + argnames="intvalue", + ) + + @_huge_ints() + def test_huge_int_auto_accommodation(self, connection, intvalue): + """ + Spanner does not allow query to have FROM clause without a WHERE clause + """ + eq_( + connection.scalar(select(intvalue)), + intvalue, + ) + + +class _UnicodeFixture(__UnicodeFixture): + @classmethod + def define_tables(cls, metadata): + """ + SPANNER OVERRIDE: + + Cloud Spanner doesn't support auto incrementing ids feature, + which is used by the original test. Overriding the test data + creation method to disable autoincrement and make id column + nullable. + """ + Table( + "unicode_table", + metadata, + Column("id", Integer, primary_key=True, nullable=True), + Column("unicode_data", cls.datatype), + ) + + def test_round_trip_executemany(self, connection): + """ + SPANNER OVERRIDE + + Cloud Spanner supports tables with empty primary key, but + only single one row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(1, 4)], + ) + + rows = connection.execute(select(unicode_table.c.unicode_data)).fetchall() + eq_(rows, [(self.data,) for i in range(1, 4)]) + for row in rows: + assert isinstance(row[0], str) + + @pytest.mark.skip("Spanner doesn't support non-ascii characters") + def test_literal(self): + pass + + @pytest.mark.skip("Spanner doesn't support non-ascii characters") + def test_literal_non_ascii(self): + pass + + +class UnicodeVarcharTest(_UnicodeFixture, _UnicodeVarcharTest): + """ + SPANNER OVERRIDE: + + UnicodeVarcharTest class inherits the __UnicodeFixture class's tests, + so to avoid those failures and maintain DRY concept just inherit the class to run + tests successfully. + """ + + pass + + +class UnicodeTextTest(_UnicodeFixture, _UnicodeTextTest): + """ + SPANNER OVERRIDE: + + UnicodeTextTest class inherits the __UnicodeFixture class's tests, + so to avoid those failures and maintain DRY concept just inherit the class to run + tests successfully. + """ + + pass + + +class RowFetchTest(_RowFetchTest): + def test_row_w_scalar_select(self, connection): + """ + SPANNER OVERRIDE: + + Cloud Spanner returns a DatetimeWithNanoseconds() for date + data types. Overriding the test to use a DatetimeWithNanoseconds + type value as an expected result. + -------------- + + test that a scalar select as a column is returned as such + and that type conversion works OK. + + (this is half a SQLAlchemy Core test and half to catch database + backends that may have unusual behavior with scalar selects.) + """ + datetable = self.tables.has_dates + s = select(datetable.alias("x").c.today).scalar_subquery() + s2 = select(datetable.c.id, s.label("somelabel")) + row = connection.execute(s2).first() + + eq_( + row.somelabel, + DatetimeWithNanoseconds(2006, 5, 12, 12, 0, 0, tzinfo=timezone.utc), + ) + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip("Spanner doesn't support empty inserts") + def test_empty_insert(self): + pass + + @pytest.mark.skip("Spanner doesn't support empty inserts") + def test_empty_insert_multiple(self): + pass + + @pytest.mark.skip("Spanner doesn't support auto increment") + def test_insert_from_select_autoinc(self): + pass + + def test_autoclose_on_insert(self): + """ + SPANNER OVERRIDE: + + Cloud Spanner doesn't support tables with an auto increment primary key, + following insertions will fail with `400 id must not be NULL in table + autoinc_pk`. + + Overriding the tests and adding a manual primary key value to avoid the same + failures. + """ + if ( + hasattr(config.requirements, "returning") + and config.requirements.returning.enabled + ): + engine = engines.testing_engine(options={"implicit_returning": False}) + else: + engine = config.db + + with engine.begin() as conn: + r = conn.execute( + self.tables.autoinc_pk.insert(), dict(id=1, data="some data") + ) + + assert r._soft_closed + assert not r.closed + assert r.is_insert + assert not r.returns_rows + + +class BytesTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + def test_nolength_binary(self): + metadata = MetaData() + foo = Table("foo", metadata, Column("one", LargeBinary)) + + foo.create(config.db) + foo.drop(config.db) + + +class StringTest(_StringTest): + @pytest.mark.skip("Spanner doesn't support non-ascii characters") + def test_literal_non_ascii(self): + pass + + def test_dont_truncate_rightside( + self, metadata, connection, expr=None, expected=None + ): + t = Table( + "t", + metadata, + Column("x", String(2)), + Column("id", Integer, primary_key=True), + ) + t.create(connection) + connection.connection.commit() + connection.execute( + t.insert(), + [{"x": "AB", "id": 1}, {"x": "BC", "id": 2}, {"x": "AC", "id": 3}], + ) + combinations = [("%B%", ["AB", "BC"]), ("A%C", ["AC"]), ("A%C%Z", [])] + + for args in combinations: + eq_( + connection.scalars(select(t.c.x).where(t.c.x.like(args[0]))).all(), + args[1], + ) + + +class TextTest(_TextTest): + @classmethod + def define_tables(cls, metadata): + """ + SPANNER OVERRIDE: + + Cloud Spanner doesn't support auto incrementing ids feature, + which is used by the original test. Overriding the test data + creation method to disable autoincrement and make id column + nullable. + """ + Table( + "text_table", + metadata, + Column("id", Integer, primary_key=True, nullable=True), + Column("text_data", Text), + ) + + @pytest.mark.skip("Spanner doesn't support non-ascii characters") + def test_literal_non_ascii(self): + pass + + @pytest.mark.skip("Not supported by Spanner") + def test_text_roundtrip(self, connection): + pass + + @pytest.mark.skip("Not supported by Spanner") + def test_text_empty_strings(self, connection): + pass + + @pytest.mark.skip("Not supported by Spanner") + def test_text_null_strings(self, connection): + pass + + +class NumericTest(_NumericTest): + @testing.fixture + def do_numeric_test(self, metadata, connection): + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table( + "t", + metadata, + Column("x", type_), + Column("id", Integer, primary_key=True), + ) + t.create(connection) + connection.connection.commit() + connection.execute( + t.insert(), [{"x": x, "id": i} for i, x in enumerate(input_)] + ) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + return run + + @emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric(self, literal_round_trip): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + literal_round_trip( + Numeric(precision=8, scale=4), + [decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + @emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric_asfloat(self, literal_round_trip): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + literal_round_trip( + Numeric(precision=8, scale=4, asdecimal=False), + [decimal.Decimal("15.7563")], + [15.7563], + ) + + def test_render_literal_float(self, literal_round_trip): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + literal_round_trip( + Float(4), + [decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + @requires.precision_generic_float_type + def test_float_custom_scale(self, do_numeric_test): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + do_numeric_test( + Float(None, decimal_return_scale=7, asdecimal=True), + [decimal.Decimal("15.7563827"), decimal.Decimal("15.7563827")], + [decimal.Decimal("15.7563827")], + check_scale=True, + ) + + def test_numeric_as_decimal(self, do_numeric_test): + """ + SPANNER OVERRIDE: + + Spanner throws an error 400 Value has type FLOAT64 which cannot be + inserted into column x, which has type NUMERIC for value 15.7563. + Overriding the test to remove the failure case. + """ + do_numeric_test( + Numeric(precision=8, scale=4), + [decimal.Decimal("15.7563"), decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_numeric_as_float(self, do_numeric_test): + """ + SPANNER OVERRIDE: + + Spanner throws an error 400 Value has type FLOAT64 which cannot be + inserted into column x, which has type NUMERIC for value 15.7563. + Overriding the test to remove the failure case. + """ + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), + [decimal.Decimal("15.7563"), decimal.Decimal("15.7563")], + [15.7563], + ) + + @requires.floats_to_four_decimals + def test_float_as_decimal(self, do_numeric_test): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + do_numeric_test( + Float(precision=8, asdecimal=True), + [decimal.Decimal("15.7563"), decimal.Decimal("15.7563"), None], + [decimal.Decimal("15.7563"), None], + filter_=lambda n: n is not None and round(n, 4) or None, + ) + + def test_float_as_float(self, do_numeric_test): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + do_numeric_test( + Float(precision=8), + [decimal.Decimal("15.7563"), decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + @requires.precision_numerics_general + def test_precision_decimal(self, do_numeric_test): + """ + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + + Remove an extra digits after decimal point as cloud spanner is + capable of representing an exact numeric value with a precision + of 38 and scale of 9. + """ + numbers = set( + [ + decimal.Decimal("54.246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ] + ) + do_numeric_test(Numeric(precision=18, scale=9), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal_large(self, do_numeric_test): + """test exceedingly large decimals. + + SPANNER OVERRIDE: + + Cloud Spanner supports tables with an empty primary key, but + only a single row can be inserted into such a table - + following insertions will fail with `Row [] already exists". + Overriding the test to avoid the same failure. + """ + numbers = set( + [ + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("000000000.1E+9"), + ] + ) + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal(self, do_numeric_test): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + SPANNER OVERRIDE: + + Remove extra digits after decimal point as Cloud Spanner is + capable of representing an exact numeric value with a precision + of 38 and scale of 9. + """ + numbers = set( + [ + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.105940696"), + decimal.Decimal("0.005940696"), + decimal.Decimal("0.000000696"), + decimal.Decimal("0.700000696"), + decimal.Decimal("696E-9"), + ] + ) + do_numeric_test(Numeric(precision=38, scale=9), numbers, numbers) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_contains_autoescape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_contains_autoescape_escape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_contains_escape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_endswith_autoescape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_endswith_escape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_endswith_autoescape_escape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_startswith_autoescape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_startswith_escape(self): + pass + + @pytest.mark.skip("Spanner doesn't support LIKE ESCAPE clause") + def test_startswith_autoescape_escape(self): + pass + + def test_escape_keyword_raises(self): + """Check that ESCAPE keyword causes an exception when used.""" + with pytest.raises(NotImplementedError): + col = self.tables.some_table.c.data + self._test(col.contains("b##cde", escape="#"), {7}) + + +@pytest.mark.skip("Spanner doesn't support IS DISTINCT FROM clause") +class IsOrIsNotDistinctFromTest(_IsOrIsNotDistinctFromTest): + pass + + +class OrderByLabelTest(_OrderByLabelTest): + @pytest.mark.skip( + "Spanner requires an alias for the GROUP BY list when specifying derived " + "columns also used in SELECT" + ) + def test_group_by_composed(self): + pass + + +class CompoundSelectTest(_CompoundSelectTest): + """ + See: https://github.com/googleapis/python-spanner/issues/347 + """ + + @pytest.mark.skip( + "Spanner DBAPI incorrectly classify the statement starting with brackets." + ) + def test_limit_offset_selectable_in_unions(self): + pass + + @pytest.mark.skip( + "Spanner DBAPI incorrectly classify the statement starting with brackets." + ) + def test_order_by_selectable_in_unions(self): + pass + + +class TestQueryHints(fixtures.TablesTest): + """ + Compile a complex query with JOIN and check that + the table hint was set into the right place. + """ + + __backend__ = True + + def test_complex_query_table_hints(self): + EXPECTED_QUERY = ( + "SELECT users.id, users.name \nFROM users @{FORCE_INDEX=table_1_by_int_idx}" + " JOIN addresses ON users.id = addresses.user_id " + "\nWHERE users.name IN (__[POSTCOMPILE_name_1])" + ) + + Base = declarative_base() + engine = create_engine( + "spanner:///projects/project-id/instances/instance-id/databases/database-id" + ) + + class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relationship("Address", backref="user") + + class Address(Base): + __tablename__ = "addresses" + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey("users.id")) + + session = Session(engine) + + query = session.query(User) + query = query.with_hint( + selectable=User, text="@{FORCE_INDEX=table_1_by_int_idx}" + ) + + query = query.filter(User.name.in_(["val1", "val2"])) + query = query.join(Address) + + assert str(query.statement.compile(session.bind)) == EXPECTED_QUERY + + +class InterleavedTablesTest(fixtures.TestBase): + """ + Check that CREATE TABLE statements for interleaved tables are correctly + generated. + """ + + def setUp(self): + self._engine = create_engine( + "spanner:///projects/appdev-soda-spanner-staging/instances/" + "sqlalchemy-dialect-test/databases/compliance-test" + ) + self._metadata = MetaData() + + def test_interleave(self): + EXP_QUERY = ( + "\nCREATE TABLE client (\n\tteam_id INT64 NOT NULL, " + "\n\tclient_id INT64 NOT NULL, " + "\n\tclient_name STRING(16) NOT NULL" + "\n) PRIMARY KEY (team_id, client_id)," + "\nINTERLEAVE IN PARENT team\n\n" + ) + client = Table( + "client", + self._metadata, + Column("team_id", Integer, primary_key=True), + Column("client_id", Integer, primary_key=True), + Column("client_name", String(16), nullable=False), + spanner_interleave_in="team", + ) + with mock.patch("google.cloud.spanner_dbapi.cursor.Cursor.execute") as execute: + client.create(self._engine) + execute.assert_called_once_with(EXP_QUERY, []) + + def test_interleave_on_delete_cascade(self): + EXP_QUERY = ( + "\nCREATE TABLE client (\n\tteam_id INT64 NOT NULL, " + "\n\tclient_id INT64 NOT NULL, " + "\n\tclient_name STRING(16) NOT NULL" + "\n) PRIMARY KEY (team_id, client_id)," + "\nINTERLEAVE IN PARENT team ON DELETE CASCADE\n\n" + ) + client = Table( + "client", + self._metadata, + Column("team_id", Integer, primary_key=True), + Column("client_id", Integer, primary_key=True), + Column("client_name", String(16), nullable=False), + spanner_interleave_in="team", + spanner_interleave_on_delete_cascade=True, + ) + with mock.patch("google.cloud.spanner_dbapi.cursor.Cursor.execute") as execute: + client.create(self._engine) + execute.assert_called_once_with(EXP_QUERY, []) + + +class UserAgentTest(fixtures.TestBase): + """Check that SQLAlchemy dialect uses correct user agent.""" + + def setUp(self): + self._engine = create_engine( + "spanner:///projects/appdev-soda-spanner-staging/instances/" + "sqlalchemy-dialect-test/databases/compliance-test" + ) + self._metadata = MetaData() + + def test_user_agent(self): + dist = pkg_resources.get_distribution("sqlalchemy-spanner") + + with self._engine.connect() as connection: + assert ( + connection.connection.instance._client._client_info.user_agent + == "gl-" + dist.project_name + "/" + dist.version + ) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + """ + SPANNER OVERRIDE: + + Spanner doesn't support `rowcount` property. These + test cases overrides omit `rowcount` checks. + """ + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + assert not r.returns_rows + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + assert not r.returns_rows + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + +class HasIndexTest(_HasIndexTest): + __backend__ = True + kind = testing.combinations("dialect", "inspector", argnames="kind") + + @classmethod + def define_tables(cls, metadata): + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("data2", String(50)), + ) + sqlalchemy.Index("my_idx", tt.c.data) + + @kind + def test_has_index(self, kind, connection, metadata): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx") + assert not meth("test_table", "my_idx_s") + assert not meth("nonexistent_table", "my_idx") + assert not meth("test_table", "nonexistent_idx") + + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + idx = Index("my_idx_2", self.tables.test_table.c.data2) + tbl = Table( + "test_table_2", + metadata, + Column("foo", Integer, primary_key=True), + Index("my_idx_3", "foo"), + ) + idx.create(connection) + tbl.create(connection) + + try: + if kind == "inspector": + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + meth.__self__.clear_cache() + connection.connection.commit() + assert meth("test_table", "my_idx_2") is True + assert meth("test_table_2", "my_idx_3") is True + finally: + tbl.drop(connection) + idx.drop(connection) + connection.connection.commit() + self.tables["test_table"].indexes.remove(idx) + + @pytest.mark.skip("Not supported by Cloud Spanner") + @kind + def test_has_index_schema(self, kind, connection): + pass + + +class HasTableTest(_HasTableTest): + @classmethod + def define_tables(cls, metadata): + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @pytest.mark.skip("Not supported by Cloud Spanner") + def test_has_table_nonexistent_schema(self): + pass + + @pytest.mark.skip("Not supported by Cloud Spanner") + def test_has_table_schema(self): + pass + + @pytest.mark.skip("Not supported by Cloud Spanner") + def test_has_table_cache(self): + pass + + @testing.requires.views + def test_has_table_view(self, connection): + pass + + @testing.requires.views + def test_has_table_view_schema(self, connection): + pass + + +class PostCompileParamsTest(_PostCompileParamsTest): + def test_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == sqlalchemy.bindparam("q", literal_execute=True) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=10)) + + asserter.assert_( + sqlalchemy.testing.assertsql.CursorSQL( + "SELECT some_table.id \nFROM some_table " "\nWHERE some_table.x = 10", + [] if config.db.dialect.positional else {}, + ) + ) + + def test_execute_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x.in_( + sqlalchemy.bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[5, 6, 7])) + + asserter.assert_( + sqlalchemy.testing.assertsql.CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x IN (5, 6, 7)", + [] if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + sqlalchemy.tuple_(table.c.x, table.c.y).in_( + sqlalchemy.bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, 10), (12, 18)])) + + asserter.assert_( + sqlalchemy.testing.assertsql.CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.y) " + "IN (%s(5, 10), (12, 18))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + sqlalchemy.tuple_(table.c.x, table.c.z).in_( + sqlalchemy.bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")])) + + asserter.assert_( + sqlalchemy.testing.assertsql.CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.z) " + "IN (%s(5, 'z1'), (12, 'z3'))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + +@pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" +) +class JSONTest(_JSONTest): + @pytest.mark.skip("Values without keys are not supported.") + def test_single_element_round_trip(self, element): + pass + + def _test_round_trip(self, data_element, connection): + data_table = self.tables.data_table + + connection.execute( + data_table.insert(), + {"id": random.randint(1, 100000000), "name": "row1", "data": data_element}, + ) + + row = connection.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.connect() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "r1", + "data": { + "réve🐍 illé": "réve🐍 illé", + "data": {"k1": "drôl🐍e"}, + }, + }, + ) + + eq_( + conn.scalar(select(self.tables.data_table.c.data)), + { + "réve🐍 illé": "réve🐍 illé", + "data": {"k1": "drôl🐍e"}, + }, + ) + + @pytest.mark.skip("Parameterized types are not supported.") + def test_eval_none_flag_orm(self): + pass + + @pytest.mark.skip( + "Spanner JSON_VALUE() always returns STRING," + "thus, this test case can't be executed." + ) + def test_index_typed_comparison(self): + pass + + @pytest.mark.skip( + "Spanner JSON_VALUE() always returns STRING," + "thus, this test case can't be executed." + ) + def test_path_typed_comparison(self): + pass + + @pytest.mark.skip("Custom JSON de-/serializers are not supported.") + def test_round_trip_custom_json(self): + pass + + def _index_fixtures(fn): + fn = testing.combinations( + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + id_="sa", + )(fn) + return fn + + @_index_fixtures + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + with config.db.connect() as conn: + conn.execute( + data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + expr = data_table.c.data["key1"] + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select(expr)) + if roundtrip in ("true", "false", None): + roundtrip = str(roundtrip).capitalize() + + eq_(str(roundtrip), str(value)) + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_json_null_as_json_null(self): + pass + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_none_as_json_null(self): + pass + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_none_as_sql_null(self): + pass + + +class ExecutionOptionsRequestPriorotyTest(fixtures.TestBase): + def setUp(self): + self._engine = create_engine(get_db_url(), pool_size=1) + metadata = MetaData() + + self._table = Table( + "execution_options2", + metadata, + Column("opt_id", Integer, primary_key=True), + Column("opt_name", String(16), nullable=False), + ) + + metadata.create_all(self._engine) + time.sleep(1) + + def test_request_priority(self): + PRIORITY = RequestOptions.Priority.PRIORITY_MEDIUM + with self._engine.connect().execution_options( + request_priority=PRIORITY + ) as connection: + connection.execute(select(self._table)).fetchall() + + with self._engine.connect() as connection: + assert connection.connection.request_priority is None + + engine = create_engine("sqlite:///database") + with engine.connect() as connection: + pass