Skip to content

Commit

Permalink
Merge branch 'main' into enh_athena2pyarrow
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidisido committed Jul 15, 2024
2 parents ce73bd5 + c6eb592 commit c7033b8
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 71 deletions.
4 changes: 4 additions & 0 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def athena2pandas(dtype: str, dtype_backend: str | None = None) -> str: # noqa:
return "datetime64" if dtype_backend != "pyarrow" else "timestamp[ns][pyarrow]"
if dtype == "date":
return "date" if dtype_backend != "pyarrow" else "date32[pyarrow]"
if dtype == "time":
# Pandas does not have a type for time of day, so we are returning a string.
# However, if the backend is pyarrow, we can return time32[ms]
return "string" if dtype_backend != "pyarrow" else "time32[ms][pyarrow]"
if dtype.startswith("decimal"):
return "decimal" if dtype_backend != "pyarrow" else "double[pyarrow]"
if dtype in ("binary", "varbinary"):
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/s3/_write_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _get_bucketing_series(df: pd.DataFrame, bucketing_info: typing.BucketingInfo
axis="columns",
)
)
return bucket_number_series.astype(pd.CategoricalDtype(range(bucketing_info[1])))
return bucket_number_series.astype(np.array([pd.CategoricalDtype(range(bucketing_info[1]))]))


def _simulate_overflow(value: int, bits: int = 31, signed: bool = False) -> int:
Expand Down
121 changes: 62 additions & 59 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ msgpack = "*"
poetry = "^1.8.3"

# Lint
boto3-stubs = {version = "^1.34.136", extras = ["athena", "cleanrooms", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
boto3-stubs = {version = "^1.34.144", extras = ["athena", "cleanrooms", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
doc8 = "^1.0"
mypy = "^1.10"
ruff = "^0.5.0"
ruff = "^0.5.2"

# Test
moto = "^5.0"
Expand Down
12 changes: 6 additions & 6 deletions test_infra/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,17 @@ def test_athena_time_zone(glue_database):
assert df["value"][0].year == datetime.datetime.utcnow().year


@pytest.mark.parametrize("dtype_backend", ["numpy_nullable", "pyarrow"])
def test_athena_time_type(glue_database: str, dtype_backend: str) -> None:
df = wr.athena.read_sql_query(
"SELECT time '13:24:11' as col", glue_database, ctas_approach=False, dtype_backend=dtype_backend
)
if dtype_backend == "pyarrow":
assert df["col"].iloc[0] == datetime.time(13, 24, 11)
else:
assert df["col"].iloc[0] == "13:24:11"


@pytest.mark.parametrize(
"ctas_approach",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_athena_geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_athena_geospatial(path, glue_table, glue_database):
ctas_approach=False,
)

assert type(df) == geopandas.GeoDataFrame
assert isinstance(df, geopandas.GeoDataFrame)

assert isinstance(df["value"], pd.Series)
assert isinstance(df["point"], geopandas.GeoSeries)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_moto.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,13 +613,13 @@ def mock_data_api_connector(connector, has_result_set=True):
data = [[col["stringValue"] for col in record] for record in statement_response["Records"]]
response_dataframe = pd.DataFrame(data, columns=column_names)

if type(connector) == wr.data_api.redshift.RedshiftDataApi:
if isinstance(connector, wr.data_api.redshift.RedshiftDataApi):
connector.client.execute_statement = mock.MagicMock(return_value={"Id": request_id})
connector.client.describe_statement = mock.MagicMock(
return_value={"Status": "FINISHED", "HasResultSet": has_result_set}
)
connector.client.get_statement_result = mock.MagicMock(return_value=statement_response)
elif type(connector) == wr.data_api.rds.RdsDataApi:
elif isinstance(connector, wr.data_api.rds.RdsDataApi):
records = statement_response["Records"]
metadata = statement_response["ColumnMetadata"]
del statement_response["Records"]
Expand Down

0 comments on commit c7033b8

Please sign in to comment.