Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support PyArrow timestamptz with Etc/UTC #910

Merged
merged 9 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
MAP_KEY_NAME = "key"
MAP_VALUE_NAME = "value"
DOC = "doc"
UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}

T = TypeVar("T")

Expand Down Expand Up @@ -937,7 +938,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
else:
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")

if primitive.tz == "UTC" or primitive.tz == "+00:00":
if primitive.tz in UTC_ALIASES:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
Expand Down Expand Up @@ -1320,7 +1321,16 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
return values.cast(target_type, safe=False)
if target_type.tz == "UTC" and values.type.tz in UTC_ALIASES or not target_type.tz and not values.type.tz:
return values.cast(target_type, safe=False)
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit in {"s", "ms", "us"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we're requesting usand the file provide a us, I don't think we need to cast?

Suggested change
and values.type.unit in {"s", "ms", "us"}
and values.type.unit in {"s", "ms"}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this is a bit confusing, I agree 😓 . I included it here because I wanted to support casting pa.timestamp('us', tz='Etc/UTC') to pa.timestamp('us', tz='UTC') within the same condition.

I think we won't hit this condition if both the input and requested types are pa.timestamp('us') because we enter this block only if target_type and values.type are not equal:
https://github.com/apache/iceberg-python/blob/main/pyiceberg/io/pyarrow.py#L1315

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might need to do some more work here. In Iceberg we're rather strict on the distinction between Timestamp and TimestampTZ. A good way of showing this can be found here:

@to.register(TimestampType)
def _(self, _: TimestampType) -> Literal[int]:
return TimestampLiteral(timestamp_to_micros(self.value))
@to.register(TimestamptzType)
def _(self, _: TimestamptzType) -> Literal[int]:
return TimestampLiteral(timestamptz_to_micros(self.value))

This is when we parse a string from a literal, which often comes an expression: dt >= '1925-05-22T00:00:00'. If the value has a timezone, even UTC, we reject it as Timestamp. When it has a timestamp, we normalize it to UTC and then store the integer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good @Fokko- thank you for the review.
I've made these checks stricter and also clearer for users to follow.

):
if target_type.tz == "UTC" and values.type.tz in UTC_ALIASES or not target_type.tz and not values.type.tz:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add parentheses?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

return values.cast(target_type)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this look correct that we are casting the types in order to follow the Iceberg Spec for Parquet Physical and Logical Types? https://iceberg.apache.org/spec/#parquet

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks good. Arrow writes INT64 by default for timestamps: https://arrow.apache.org/docs/cpp/parquet.html#logical-types

return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down
8 changes: 0 additions & 8 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,10 +528,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this cast function - to_requested_schema should be responsible for casting the types to their desired schema, instead of casting it here


manifest_merge_enabled = PropertyUtil.property_as_bool(
self.table_metadata.properties,
Expand Down Expand Up @@ -587,10 +583,6 @@ def overwrite(
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)

Expand Down
44 changes: 38 additions & 6 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Dict
from urllib.parse import urlparse

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -977,7 +978,9 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null

@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None:
def test_write_all_timestamp_precision(
mocker: MockerFixture, spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = "default.table_all_timestamp_precision"
arrow_table_schema_with_all_timestamp_precisions = pa.schema([
("timestamp_s", pa.timestamp(unit="s")),
Expand All @@ -988,8 +991,10 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="ns")),
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
("timestamptz_us_z", pa.timestamp(unit="us", tz="Z")),
])
TEST_DATA_WITH_NULL = {
TEST_DATA_WITH_NULL = pd.DataFrame({
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_s": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
Expand All @@ -1008,14 +1013,28 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamp_ns": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
],
"timestamptz_ns": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
}
input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions)
"timestamptz_us_etc_utc": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_us_z": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
})
input_arrow_table = pa.Table.from_pandas(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions)
mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})

tbl = _create_table(
Expand All @@ -1037,9 +1056,22 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="us")),
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_z", pa.timestamp(unit="us", tz="UTC")),
])
assert written_arrow_table.schema == expected_schema_in_all_us
assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us)
assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us, safe=False)
lhs = spark.table(f"{identifier}").toPandas()
rhs = written_arrow_table.to_pandas()

for column in written_arrow_table.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if pd.isnull(left):
assert pd.isnull(right)
else:
# Check only upto microsecond precision since Spark loaded dtype is timezone unaware
# and supports upto microsecond precision
assert left.timestamp() == right.timestamp(), f"Difference in column {column}: {left} != {right}"


@pytest.mark.integration
Expand Down