Skip to content

Commit

Permalink
Merge branch 'main' into 2956-athena-read_sql_query-provides-complete…
Browse files Browse the repository at this point in the history
…ly-wrong-results-for-qmark-style-parametrized-queries-with-cache-enabled
  • Loading branch information
jaidisido committed Sep 11, 2024
2 parents 202abb2 + 5cb7a4d commit 64c4e13
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 44 deletions.
175 changes: 132 additions & 43 deletions awswrangler/redshift/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def _get_primary_keys(cursor: "redshift_connector.Cursor", schema: str, table: s
return fields


def _get_table_columns(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
sql = f"SELECT column_name FROM svv_columns\n WHERE table_schema = '{schema}' AND table_name = '{table}'"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result: tuple[list[str]] = cursor.fetchall()
columns = ["".join(lst) for lst in result]
return columns


def _add_table_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str]
) -> None:
for column_name, column_type in new_columns.items():
sql = (
f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}"
f"\nADD COLUMN {_identifier(column_name)} {column_type};"
)
_logger.debug("Executing alter query:\n%s", sql)
cursor.execute(sql)


def _does_table_exist(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> bool:
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
sql = (
Expand All @@ -128,6 +149,16 @@ def _get_paths_from_manifest(path: str, boto3_session: boto3.Session | None = No
return paths


def _get_parameter_setting(cursor: "redshift_connector.Cursor", parameter_name: str) -> str:
sql = f"SHOW {parameter_name}"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result = cursor.fetchall()
status = str(result[0][0])
_logger.debug(f"{parameter_name}='{status}'")
return status


def _lock(
cursor: "redshift_connector.Cursor",
table_names: list[str],
Expand Down Expand Up @@ -267,7 +298,90 @@ def _redshift_types_from_path(
return redshift_types


def _create_table( # noqa: PLR0912,PLR0913,PLR0915
def _get_rsh_columns_types(
df: pd.DataFrame | None,
path: str | list[str] | None,
index: bool,
dtype: dict[str, str] | None,
varchar_lengths_default: int,
varchar_lengths: dict[str, int] | None,
data_format: Literal["parquet", "orc", "csv"] = "parquet",
redshift_column_types: dict[str, str] | None = None,
parquet_infer_sampling: float = 1.0,
path_suffix: str | None = None,
path_ignore_suffix: str | list[str] | None = None,
manifest: bool | None = False,
use_threads: bool | int = True,
boto3_session: boto3.Session | None = None,
s3_additional_kwargs: dict[str, str] | None = None,
) -> dict[str, str]:
if df is not None:
redshift_types: dict[str, str] = _data_types.database_types_from_pandas(
df=df,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
converter_func=_data_types.pyarrow2redshift,
)
_logger.debug("Converted redshift types from pandas: %s", redshift_types)
elif path is not None:
if manifest:
if not isinstance(path, str):
raise TypeError(
f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True;
must be a string"""
)
path = _get_paths_from_manifest(
path=path,
boto3_session=boto3_session,
)

if data_format in ["parquet", "orc"]:
redshift_types = _redshift_types_from_path(
path=path,
data_format=data_format, # type: ignore[arg-type]
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
else:
if redshift_column_types is None:
raise ValueError(
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
)
redshift_types = redshift_column_types
else:
raise ValueError("df and path are None. You MUST pass at least one.")
return redshift_types


def _add_new_table_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, redshift_columns_types: dict[str, str]
) -> None:
# Check if Redshift is configured as case sensitive or not
is_case_sensitive = False
if _get_parameter_setting(cursor=cursor, parameter_name="enable_case_sensitive_identifier").lower() in [
"on",
"true",
]:
is_case_sensitive = True

# If it is case-insensitive, convert all the DataFrame columns names to lowercase before performing the comparison
if is_case_sensitive is False:
redshift_columns_types = {key.lower(): value for key, value in redshift_columns_types.items()}
actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table))
new_df_columns = {key: value for key, value in redshift_columns_types.items() if key not in actual_table_columns}

_add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns)


def _create_table( # noqa: PLR0913
df: pd.DataFrame | None,
path: str | list[str] | None,
con: "redshift_connector.Connection",
Expand Down Expand Up @@ -336,49 +450,24 @@ def _create_table( # noqa: PLR0912,PLR0913,PLR0915
return table, schema
diststyle = diststyle.upper() if diststyle else "AUTO"
sortstyle = sortstyle.upper() if sortstyle else "COMPOUND"
if df is not None:
redshift_types: dict[str, str] = _data_types.database_types_from_pandas(
df=df,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
converter_func=_data_types.pyarrow2redshift,
)
_logger.debug("Converted redshift types from pandas: %s", redshift_types)
elif path is not None:
if manifest:
if not isinstance(path, str):
raise TypeError(
f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True;
must be a string"""
)
path = _get_paths_from_manifest(
path=path,
boto3_session=boto3_session,
)

if data_format in ["parquet", "orc"]:
redshift_types = _redshift_types_from_path(
path=path,
data_format=data_format, # type: ignore[arg-type]
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
else:
if redshift_column_types is None:
raise ValueError(
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
)
redshift_types = redshift_column_types
else:
raise ValueError("df and path are None. You MUST pass at least one.")
redshift_types = _get_rsh_columns_types(
df=df,
path=path,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
data_format=data_format,
redshift_column_types=redshift_column_types,
manifest=manifest,
)
_validate_parameters(
redshift_types=redshift_types,
diststyle=diststyle,
Expand Down
53 changes: 52 additions & 1 deletion awswrangler/redshift/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
from awswrangler._config import apply_configs

from ._connect import _validate_connection
from ._utils import _create_table, _make_s3_auth_string, _upsert
from ._utils import (
_add_new_table_columns,
_create_table,
_does_table_exist,
_get_rsh_columns_types,
_make_s3_auth_string,
_upsert,
)

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -102,6 +109,7 @@ def to_sql(
chunksize: int = 200,
commit_transaction: bool = True,
precombine_key: str | None = None,
add_new_columns: bool = False,
) -> None:
"""Write records stored in a DataFrame into Redshift.
Expand Down Expand Up @@ -169,6 +177,8 @@ def to_sql(
When there is a primary_key match during upsert, this column will change the upsert method,
comparing the values of the specified column from source and target, and keeping the
larger of the two. Will only work when mode = upsert.
add_new_columns
If True, it automatically adds the new DataFrame columns into the target table.
Examples
--------
Expand All @@ -191,6 +201,19 @@ def to_sql(
con.autocommit = False
try:
with con.cursor() as cursor:
if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table):
redshift_columns_types = _get_rsh_columns_types(
df=df,
path=None,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
)
_add_new_table_columns(
cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types
)

created_table, created_schema = _create_table(
df=df,
path=None,
Expand Down Expand Up @@ -280,6 +303,7 @@ def copy_from_files( # noqa: PLR0913
s3_additional_kwargs: dict[str, str] | None = None,
precombine_key: str | None = None,
column_names: list[str] | None = None,
add_new_columns: bool = False,
) -> None:
"""Load files from S3 to a Table on Amazon Redshift (Through COPY command).
Expand Down Expand Up @@ -396,6 +420,8 @@ def copy_from_files( # noqa: PLR0913
larger of the two. Will only work when mode = upsert.
column_names
List of column names to map source data fields to the target columns.
add_new_columns
If True, it automatically adds the new DataFrame columns into the target table.
Examples
--------
Expand All @@ -420,6 +446,27 @@ def copy_from_files( # noqa: PLR0913
con.autocommit = False
try:
with con.cursor() as cursor:
if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table):
redshift_columns_types = _get_rsh_columns_types(
df=None,
path=path,
index=False,
dtype=None,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
data_format=data_format,
redshift_column_types=redshift_column_types,
manifest=manifest,
)
_add_new_table_columns(
cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types
)
created_table, created_schema = _create_table(
df=None,
path=path,
Expand Down Expand Up @@ -521,6 +568,7 @@ def copy( # noqa: PLR0913
max_rows_by_file: int | None = 10_000_000,
precombine_key: str | None = None,
use_column_names: bool = False,
add_new_columns: bool = False,
) -> None:
"""Load Pandas DataFrame as a Table on Amazon Redshift using parquet files on S3 as stage.
Expand Down Expand Up @@ -628,6 +676,8 @@ def copy( # noqa: PLR0913
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
inserted into the database columns `col1` and `col3`.
add_new_columns
If True, it automatically adds the new DataFrame columns into the target table.
Examples
--------
Expand Down Expand Up @@ -692,6 +742,7 @@ def copy( # noqa: PLR0913
sql_copy_extra_params=sql_copy_extra_params,
precombine_key=precombine_key,
column_names=column_names,
add_new_columns=add_new_columns,
)
finally:
if keep_files is False:
Expand Down
Loading

0 comments on commit 64c4e13

Please sign in to comment.