From d3a541f0e8b23c9548c6d52061fd40f7aaa242c6 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 16 Mar 2021 11:49:49 -0700 Subject: [PATCH] compat: sqlalchemy deprecations --- pandas/io/sql.py | 42 ++++++++++++++++++++++++++++++------- pandas/tests/io/test_sql.py | 24 ++++++++++++++++----- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index fb08abb6fea45..e3347468828d1 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -9,6 +9,7 @@ datetime, time, ) +from distutils.version import LooseVersion from functools import partial import re from typing import ( @@ -77,6 +78,16 @@ def _is_sqlalchemy_connectable(con): return False +def _gt14() -> bool: + """ + Check if sqlalchemy.__version__ is at least 1.4.0, when several + deprecations were made. + """ + import sqlalchemy + + return LooseVersion(sqlalchemy.__version__) >= LooseVersion("1.4.0") + + def _convert_params(sql, params): """Convert SQL and params args to DBAPI2.0 compliant format.""" args = [sql] @@ -823,7 +834,10 @@ def sql_schema(self): def _execute_create(self): # Inserting table into database, add to MetaData object - self.table = self.table.tometadata(self.pd_sql.meta) + if _gt14(): + self.table = self.table.to_metadata(self.pd_sql.meta) + else: + self.table = self.table.tometadata(self.pd_sql.meta) self.table.create() def create(self): @@ -1596,9 +1610,17 @@ def to_sql( # Only check when name is not a number and name is not lower case engine = self.connectable.engine with self.connectable.connect() as conn: - table_names = engine.table_names( - schema=schema or self.meta.schema, connection=conn - ) + if _gt14(): + from sqlalchemy import inspect + + insp = inspect(conn) + table_names = insp.get_table_names( + schema=schema or self.meta.schema + ) + else: + table_names = engine.table_names( + schema=schema or self.meta.schema, connection=conn + ) if name not in table_names: msg = ( f"The provided table name '{name}' is not found exactly as " @@ -1613,9 +1635,15 @@ def tables(self): return self.meta.tables def has_table(self, name: str, schema: Optional[str] = None): - return self.connectable.run_callable( - self.connectable.dialect.has_table, name, schema or self.meta.schema - ) + if _gt14(): + import sqlalchemy as sa + + insp = sa.inspect(self.connectable) + return insp.has_table(name, schema or self.meta.schema) + else: + return self.connectable.run_callable( + self.connectable.dialect.has_table, name, schema or self.meta.schema + ) def get_table(self, table_name: str, schema: Optional[str] = None): schema = schema or self.meta.schema diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index e57030a4bf125..7d923e57834ea 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -52,12 +52,14 @@ import pandas.io.sql as sql from pandas.io.sql import ( + _gt14, read_sql_query, read_sql_table, ) try: import sqlalchemy + from sqlalchemy import inspect from sqlalchemy.ext import declarative from sqlalchemy.orm import session as sa_session import sqlalchemy.schema @@ -1487,7 +1489,11 @@ def test_create_table(self): pandasSQL = sql.SQLDatabase(temp_conn) pandasSQL.to_sql(temp_frame, "temp_frame") - assert temp_conn.has_table("temp_frame") + if _gt14(): + insp = inspect(temp_conn) + assert insp.has_table("temp_frame") + else: + assert temp_conn.has_table("temp_frame") def test_drop_table(self): temp_conn = self.connect() @@ -1499,11 +1505,18 @@ def test_drop_table(self): pandasSQL = sql.SQLDatabase(temp_conn) pandasSQL.to_sql(temp_frame, "temp_frame") - assert temp_conn.has_table("temp_frame") + if _gt14(): + insp = inspect(temp_conn) + assert insp.has_table("temp_frame") + else: + assert temp_conn.has_table("temp_frame") pandasSQL.drop_table("temp_frame") - assert not temp_conn.has_table("temp_frame") + if _gt14(): + assert not insp.has_table("temp_frame") + else: + assert not temp_conn.has_table("temp_frame") def test_roundtrip(self): self._roundtrip() @@ -1843,9 +1856,10 @@ def test_nan_string(self): tm.assert_frame_equal(result, df) def _get_index_columns(self, tbl_name): - from sqlalchemy.engine import reflection + from sqlalchemy import inspect + + insp = inspect(self.conn) - insp = reflection.Inspector.from_engine(self.conn) ixs = insp.get_indexes(tbl_name) ixs = [i["column_names"] for i in ixs] return ixs