From afd8caa1153d5ed64b262137476300a2de9b9bf5 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 22 Mar 2024 08:48:36 -0400 Subject: [PATCH 1/8] Handle arrow table with date32 columns --- altair/utils/__init__.py | 2 ++ altair/utils/core.py | 8 +++---- altair/vegalite/v5/api.py | 2 +- tests/utils/test_utils.py | 49 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index 0bd8ec5e3..dba1e1f81 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -2,6 +2,7 @@ infer_vegalite_type, infer_encoding_types, sanitize_dataframe, + sanitize_arrow_table, parse_shorthand, use_signature, update_nested, @@ -18,6 +19,7 @@ "infer_vegalite_type", "infer_encoding_types", "sanitize_dataframe", + "sanitize_arrow_table", "spec_to_html", "parse_shorthand", "use_signature", diff --git a/altair/utils/core.py b/altair/utils/core.py index a382ac787..0f3e480bb 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -429,15 +429,15 @@ def sanitize_arrow_table(pa_table): schema = pa_table.schema for name in schema.names: array = pa_table[name] - dtype = schema.field(name).type - if str(dtype).startswith("timestamp"): + dtype_name = str(schema.field(name).type) + if dtype_name.startswith("timestamp") or dtype_name.startswith("date32"): arrays.append(pc.strftime(array)) - elif str(dtype).startswith("duration"): + elif dtype_name.startswith("duration"): raise ValueError( 'Field "{col_name}" has type "{dtype}" which is ' "not supported by Altair. Please convert to " "either a timestamp or a numerical value." - "".format(col_name=name, dtype=dtype) + "".format(col_name=name, dtype=dtype_name) ) else: arrays.append(array) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 4202fd9a8..363a86e84 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -56,7 +56,7 @@ def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str: values = values.to_dict() if values == [{}]: return "empty" - values_json = json.dumps(values, sort_keys=True) + values_json = json.dumps(values, sort_keys=True, default=str) hsh = hashlib.sha256(values_json.encode()).hexdigest()[:32] return "data-" + hsh diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index c0334533a..9cf5bda37 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from altair.utils import infer_vegalite_type, sanitize_dataframe +from altair.utils import infer_vegalite_type, sanitize_dataframe, sanitize_arrow_table try: import pyarrow as pa @@ -120,6 +120,53 @@ def test_sanitize_dataframe_arrow_columns(): json.dumps(records) +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +def test_sanitize_pyarrow_table_columns(): + # create a dataframe with various types + df = pd.DataFrame( + { + "s": list("abcde"), + "f": np.arange(5, dtype=float), + "i": np.arange(5, dtype=int), + "b": np.array([True, False, True, True, False]), + "d": pd.date_range("2012-01-01", periods=5, freq="H"), + "c": pd.Series(list("ababc"), dtype="category"), + "p": pd.date_range("2012-01-01", periods=5, freq="H").tz_localize("UTC"), + } + ) + + # Create pyarrow table with explicit schema so that date32 type is preserved + pa_table = pa.Table.from_pandas( + df, + pa.schema( + [ + pa.field("s", pa.string()), + pa.field("f", pa.float64()), + pa.field("i", pa.int64()), + pa.field("b", pa.bool_()), + pa.field("d", pa.date32()), + pa.field("c", pa.dictionary(pa.int8(), pa.string())), + pa.field("p", pa.timestamp("ns", tz="UTC")), + ] + ), + ) + sanitized = sanitize_arrow_table(pa_table) + values = sanitized.to_pylist() + + assert values[0] == { + "s": "a", + "f": 0.0, + "i": 0, + "b": True, + "d": "2012-01-01T00:00:00", + "c": "a", + "p": "2012-01-01T00:00:00.000000000", + } + + # Make sure we can serialize to JSON without error + json.dumps(values) + + def test_sanitize_dataframe_colnames(): df = pd.DataFrame(np.arange(12).reshape(4, 3)) From 49536d41079944621f2a4ff270fbf56b02d4c48b Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 22 Mar 2024 08:53:36 -0400 Subject: [PATCH 2/8] Handle all date types --- altair/utils/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 0f3e480bb..c2b0634cc 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -430,7 +430,7 @@ def sanitize_arrow_table(pa_table): for name in schema.names: array = pa_table[name] dtype_name = str(schema.field(name).type) - if dtype_name.startswith("timestamp") or dtype_name.startswith("date32"): + if dtype_name.startswith("timestamp") or dtype_name.startswith("date"): arrays.append(pc.strftime(array)) elif dtype_name.startswith("duration"): raise ValueError( From 9c6454dc9073dab793f885ecf781fe2c14529ba2 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 22 Mar 2024 09:45:06 -0400 Subject: [PATCH 3/8] Add changelog entry --- doc/releases/changes.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changes.rst b/doc/releases/changes.rst index be354bb36..1ca9eb7be 100644 --- a/doc/releases/changes.rst +++ b/doc/releases/changes.rst @@ -28,6 +28,7 @@ Bug Fixes ~~~~~~~~~ - Fix type hints for libraries such as Polars where Altair uses the dataframe interchange protocol (#3297) - Fix anywidget deprecation warning (#3364) +- Fix handling of Date32 columns in arrow tables (#3377) Backward-Incompatible Changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From f225b5508b657f0325fde2eb218f362f3a2bfd7c Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 23 Mar 2024 08:09:31 -0400 Subject: [PATCH 4/8] Use direct arrow conversion methods if available --- altair/utils/data.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index 7fba7adaa..0e9071209 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -106,12 +106,11 @@ def raise_max_rows_error(): # as equivalent to TDataType return data # type: ignore[return-value] elif hasattr(data, "__dataframe__"): - pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + pa_table = arrow_table_from_dfi_dataframe(data) if max_rows is not None and pa_table.num_rows > max_rows: raise_max_rows_error() # Return pyarrow Table instead of input since the - # `from_dataframe` call may be expensive + # `arrow_table_from_dfi_dataframe` call above may be expensive return pa_table if max_rows is not None and len(values) > max_rows: @@ -143,9 +142,7 @@ def sample( # Maybe this should raise an error or return something useful? return None elif hasattr(data, "__dataframe__"): - # experimental interchange dataframe support - pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + pa_table = arrow_table_from_dfi_dataframe(data) if not n: if frac is None: raise ValueError( @@ -233,9 +230,7 @@ def to_values(data: DataType) -> ToValuesReturnType: raise KeyError("values expected in data dict, but not present.") return data elif hasattr(data, "__dataframe__"): - # experimental interchange dataframe support - pi = import_pyarrow_interchange() - pa_table = sanitize_arrow_table(pi.from_dataframe(data)) + pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data)) return {"values": pa_table.to_pylist()} else: # Should never reach this state as tested by check_data_type @@ -278,9 +273,7 @@ def _data_to_json_string(data: DataType) -> str: raise KeyError("values expected in data dict, but not present.") return json.dumps(data["values"], sort_keys=True) elif hasattr(data, "__dataframe__"): - # experimental interchange dataframe support - pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + pa_table = arrow_table_from_dfi_dataframe(data) return json.dumps(pa_table.to_pylist()) else: raise NotImplementedError( @@ -305,11 +298,10 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str: return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) elif hasattr(data, "__dataframe__"): # experimental interchange dataframe support - pi = import_pyarrow_interchange() import pyarrow as pa import pyarrow.csv as pa_csv - pa_table = pi.from_dataframe(data) + pa_table = arrow_table_from_dfi_dataframe(data) csv_buffer = pa.BufferOutputStream() pa_csv.write_csv(pa_table, csv_buffer) return csv_buffer.getvalue().to_pybytes().decode() @@ -346,3 +338,23 @@ def curry(*args, **kwargs): stacklevel=1, ) return curried.curry(*args, **kwargs) + + +def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> "pyarrow.lib.Table": + """Convert a DataFrame Interchange Protocol compatible object to an Arrow Table""" + import pyarrow as pa + + # First check if the dataframe object has a method to convert to arrow. + # Give this preference over the pyarrow from_dataframe function since the object + # has more control over the conversion, and may have broader compatibility. + # This is the case for Polars, which supports Date32 columns in direct conversion + # while pyarrow does not yet support this type in from_dataframe + for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"): + convert_method = getattr(dfi_df, convert_method_name, None) + if callable(convert_method): + result = convert_method() + if isinstance(result, pa.Table): + return result + + pi = import_pyarrow_interchange() + return pi.from_dataframe(dfi_df) From 1432353c6bfe14f7df95da9547a71657f4113053 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 23 Mar 2024 08:16:32 -0400 Subject: [PATCH 5/8] Make mypy happy by making DataFrameLike protocol runtime checkable and using isinstance --- altair/utils/core.py | 3 ++- altair/utils/data.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index c2b0634cc..ea275d589 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -37,7 +37,7 @@ else: from typing_extensions import ParamSpec -from typing import Literal, Protocol, TYPE_CHECKING +from typing import Literal, Protocol, TYPE_CHECKING, runtime_checkable if TYPE_CHECKING: from pandas.core.interchange.dataframe_protocol import Column as PandasColumn @@ -46,6 +46,7 @@ P = ParamSpec("P") +@runtime_checkable class DataFrameLike(Protocol): def __dataframe__( self, nan_as_null: bool = False, allow_copy: bool = True diff --git a/altair/utils/data.py b/altair/utils/data.py index 0e9071209..e4b135d38 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -105,7 +105,7 @@ def raise_max_rows_error(): # mypy gets confused as it doesn't see Dict[Any, Any] # as equivalent to TDataType return data # type: ignore[return-value] - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = arrow_table_from_dfi_dataframe(data) if max_rows is not None and pa_table.num_rows > max_rows: raise_max_rows_error() @@ -141,7 +141,7 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = arrow_table_from_dfi_dataframe(data) if not n: if frac is None: @@ -229,7 +229,7 @@ def to_values(data: DataType) -> ToValuesReturnType: if "values" not in data: raise KeyError("values expected in data dict, but not present.") return data - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data)) return {"values": pa_table.to_pylist()} else: @@ -272,7 +272,7 @@ def _data_to_json_string(data: DataType) -> str: if "values" not in data: raise KeyError("values expected in data dict, but not present.") return json.dumps(data["values"], sort_keys=True) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = arrow_table_from_dfi_dataframe(data) return json.dumps(pa_table.to_pylist()) else: @@ -296,7 +296,7 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str: if "values" not in data: raise KeyError("values expected in data dict, but not present") return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): # experimental interchange dataframe support import pyarrow as pa import pyarrow.csv as pa_csv From d843a1a886b12c126bc6a0aca44fd042fa424bcc Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 23 Mar 2024 08:45:20 -0400 Subject: [PATCH 6/8] Update changelog --- doc/releases/changes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changes.rst b/doc/releases/changes.rst index 1ca9eb7be..c630b99f4 100644 --- a/doc/releases/changes.rst +++ b/doc/releases/changes.rst @@ -28,7 +28,7 @@ Bug Fixes ~~~~~~~~~ - Fix type hints for libraries such as Polars where Altair uses the dataframe interchange protocol (#3297) - Fix anywidget deprecation warning (#3364) -- Fix handling of Date32 columns in arrow tables (#3377) +- Fix handling of Date32 columns in arrow tables and Polars DataFrames (#3377) Backward-Incompatible Changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 7feefd5eef883b77c592f949e67a3f12b5d2570f Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 23 Mar 2024 08:49:05 -0400 Subject: [PATCH 7/8] Fix vegafusion test and update VegaFusion constraint --- pyproject.toml | 2 +- tests/utils/test_mimebundle.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7cc710894..f03b1b721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ all = [ "vega_datasets>=0.9.0", "vl-convert-python>=1.3.0", "pyarrow>=11", - "vegafusion[embed]>=1.5.0", + "vegafusion[embed]>=1.6.6", "anywidget>=0.9.0", "altair_tiles>=0.3.0" ] diff --git a/tests/utils/test_mimebundle.py b/tests/utils/test_mimebundle.py index 541ac483f..97c353c56 100644 --- a/tests/utils/test_mimebundle.py +++ b/tests/utils/test_mimebundle.py @@ -241,7 +241,7 @@ def check_pre_transformed_vega_spec(vega_spec): # Check that the bin transform has been applied row0 = data_0["values"][0] - assert row0 == {"a": "A", "b": 28, "b_end": 28.0, "b_start": 0.0} + assert row0 == {"a": "A", "b_end": 28.0, "b_start": 0.0} # And no transforms remain assert len(data_0.get("transform", [])) == 0 From c119d1e7e6b55b23d3166cadd2b54b0c1697c341 Mon Sep 17 00:00:00 2001 From: mattijn Date: Sun, 24 Mar 2024 16:12:06 +0100 Subject: [PATCH 8/8] check for instance DataFrameLike instead of __dataframe__ attribute --- altair/utils/_vegafusion_data.py | 2 +- altair/utils/core.py | 2 +- altair/utils/data.py | 4 ++-- altair/vegalite/v5/api.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 8b46bab78..ce30e8d6d 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -45,7 +45,7 @@ def vegafusion_data_transformer( # Use default transformer for geo interface objects # # (e.g. a geopandas GeoDataFrame) return default_data_transformer(data) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): table_name = f"table_{uuid.uuid4()}".replace("-", "_") extracted_inline_tables[table_name] = data return {"url": VEGAFUSION_PREFIX + table_name} diff --git a/altair/utils/core.py b/altair/utils/core.py index ea275d589..baf1013f7 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -589,7 +589,7 @@ def parse_shorthand( # if data is specified and type is not, infer type from data if "type" not in attrs: - if pyarrow_available() and data is not None and hasattr(data, "__dataframe__"): + if pyarrow_available() and data is not None and isinstance(data, DataFrameLike): dfi = data.__dataframe__() if "field" in attrs: unescaped_field = attrs["field"].replace("\\", "") diff --git a/altair/utils/data.py b/altair/utils/data.py index e4b135d38..871b43092 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -238,8 +238,8 @@ def to_values(data: DataType) -> ToValuesReturnType: def check_data_type(data: DataType) -> None: - if not isinstance(data, (dict, pd.DataFrame)) and not any( - hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] + if not isinstance(data, (dict, pd.DataFrame, DataFrameLike)) and not any( + hasattr(data, attr) for attr in ["__geo_interface__"] ): raise TypeError( "Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format( diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 363a86e84..dfde5ee7e 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -114,7 +114,7 @@ def _prepare_data(data, context=None): elif isinstance(data, str): data = core.UrlData(data) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): data = _pipe(data, data_transformers.get()) # consolidate inline data to top-level datasets