diff --git a/Cargo.lock b/Cargo.lock index d378e36b1e..6083db9af4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -794,7 +794,9 @@ dependencies = [ "chrono", "deltalake", "env_logger 0.9.0", + "lazy_static", "pyo3", + "regex", "reqwest", "serde_json", "tokio", diff --git a/python/Cargo.toml b/python/Cargo.toml index 525a532189..3371772606 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -21,6 +21,8 @@ env_logger = "0" reqwest = { version = "*", features = ["native-tls-vendored"] } serde_json = "1" chrono = "0" +regex = "1" +lazy_static = "1" [dependencies.pyo3] version = "0.16" diff --git a/python/deltalake/deltalake.pyi b/python/deltalake/deltalake.pyi new file mode 100644 index 0000000000..2dac677a9b --- /dev/null +++ b/python/deltalake/deltalake.pyi @@ -0,0 +1,110 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import pyarrow as pa +from typing_extensions import Literal + +from deltalake.writer import AddAction + +RawDeltaTable: Any +rust_core_version: Callable[[], str] +DeltaStorageFsBackend: Any + +class PyDeltaTableError(BaseException): ... + +write_new_deltalake: Callable[[str, pa.Schema, List[AddAction], str, List[str]], None] + +# Can't implement inheritance (see note in src/schema.rs), so this is next +# best thing. +DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] + +class PrimitiveType: + def __init__(self, data_type: str) -> None: ... + type: str + + def to_json(self) -> str: ... + @staticmethod + def from_json(json: str) -> "PrimitiveType": ... + def to_pyarrow(self) -> pa.DataType: ... + @staticmethod + def from_pyarrow(type: pa.DataType) -> "PrimitiveType": ... + +class ArrayType: + def __init__( + self, element_type: DataType, *, contains_null: bool = True + ) -> None: ... + type: Literal["array"] + element_type: DataType + contains_null: bool + + def to_json(self) -> str: ... + @staticmethod + def from_json(json: str) -> "ArrayType": ... + def to_pyarrow( + self, + ) -> pa.ListType: ... + @staticmethod + def from_pyarrow(type: pa.ListType) -> "ArrayType": ... + +class MapType: + def __init__( + self, + key_type: DataType, + value_type: DataType, + *, + value_contains_null: bool = True + ) -> None: ... + type: Literal["map"] + key_type: DataType + value_type: DataType + value_contains_null: bool + + def to_json(self) -> str: ... + @staticmethod + def from_json(json: str) -> "MapType": ... + def to_pyarrow(self) -> pa.MapType: ... + @staticmethod + def from_pyarrow(type: pa.MapType) -> "MapType": ... + +class Field: + def __init__( + self, + name: str, + type: DataType, + *, + nullable: bool = True, + metadata: Optional[Dict[str, Any]] = None + ) -> None: ... + name: str + type: DataType + nullable: bool + metadata: Dict[str, Any] + + def to_json(self) -> str: ... + @staticmethod + def from_json(json: str) -> "Field": ... + def to_pyarrow(self) -> pa.Field: ... + @staticmethod + def from_pyarrow(type: pa.Field) -> "Field": ... + +class StructType: + def __init__(self, fields: List[Field]) -> None: ... + type: Literal["struct"] + fields: List[Field] + + def to_json(self) -> str: ... + @staticmethod + def from_json(json: str) -> "StructType": ... + def to_pyarrow(self) -> pa.StructType: ... + @staticmethod + def from_pyarrow(type: pa.StructType) -> "StructType": ... + +class Schema: + def __init__(self, fields: List[Field]) -> None: ... + fields: List[Field] + + def to_json(self) -> str: ... + @staticmethod + def from_json(json: str) -> "Schema": ... + def to_pyarrow(self) -> pa.Schema: ... + @staticmethod + def from_pyarrow(type: pa.Schema) -> "Schema": ... diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index 9363b6d097..233795583e 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -1,298 +1,7 @@ -import json -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Union -import pyarrow +from .deltalake import ArrayType, Field, MapType, PrimitiveType, Schema, StructType -# TODO: implement this module in Rust land to avoid JSON serialization -# https://github.com/delta-io/delta-rs/issues/95 - - -@dataclass -class DataType: - """ - Base class of all Delta data types. - """ - - type: str - - def __str__(self) -> str: - return f"DataType({self.type})" - - @classmethod - def from_dict(cls, json_dict: Dict[str, Any]) -> "DataType": - """ - Generate a DataType from a DataType in json format. - - :param json_dict: the data type in json format - :return: the Delta DataType - """ - type_class = json_dict["type"] - if type_class == "map": - key_type_dict = {"type": json_dict["keyType"]} - value_type_dict = {"type": json_dict["valueType"]} - value_contains_null = json_dict["valueContainsNull"] - key_type = cls.from_dict(json_dict=key_type_dict) - value_type = cls.from_dict(json_dict=value_type_dict) - return MapType( - key_type=key_type, - value_type=value_type, - value_contains_null=value_contains_null, - ) - if type_class == "array": - field = json_dict["elementType"] - if isinstance(field, str): - element_type = cls(field) - else: - element_type = cls.from_dict(json_dict=field) - return ArrayType( - element_type=element_type, - contains_null=json_dict["containsNull"], - ) - if type_class == "struct": - fields = [] - for json_field in json_dict["fields"]: - if isinstance(json_field["type"], str): - data_type = cls(json_field["type"]) - else: - data_type = cls.from_dict(json_field["type"]) - field = Field( - name=json_field["name"], - type=data_type, - nullable=json_field["nullable"], - metadata=json_field.get("metadata"), - ) - fields.append(field) - return StructType(fields=fields) - - return DataType(type_class) - - -@dataclass(init=False) -class MapType(DataType): - """Concrete class for map data types.""" - - key_type: DataType - value_type: DataType - value_contains_null: bool - type: str - - def __init__( - self, key_type: "DataType", value_type: "DataType", value_contains_null: bool - ): - super().__init__("map") - self.key_type = key_type - self.value_type = value_type - self.value_contains_null = value_contains_null - - def __str__(self) -> str: - return f"DataType(map<{self.key_type}, {self.value_type}, {self.value_contains_null}>)" - - -@dataclass(init=False) -class ArrayType(DataType): - """Concrete class for array data types.""" - - element_type: DataType - contains_null: bool - type: str - - def __init__(self, element_type: DataType, contains_null: bool): - super().__init__("array") - self.element_type = element_type - self.contains_null = contains_null - - def __str__(self) -> str: - return f"DataType(array<{self.element_type}> {self.contains_null})" - - -@dataclass(init=False) -class StructType(DataType): - """Concrete class for struct data types.""" - - fields: List["Field"] - type: str - - def __init__(self, fields: List["Field"]): - super().__init__("struct") - self.fields = fields - - def __str__(self) -> str: - field_strs = [str(f) for f in self.fields] - return f"DataType(struct<{', '.join(field_strs)}>)" - - -@dataclass -class Field: - """Create a DeltaTable Field instance.""" - - name: str - type: DataType - nullable: bool - metadata: Optional[Dict[str, str]] = None - - def __str__(self) -> str: - return f"Field({self.name}: {self.type} nullable({self.nullable}) metadata({self.metadata}))" - - -@dataclass -class Schema: - """Create a DeltaTable Schema instance.""" - - fields: List[Field] - json_value: Dict[str, Any] - - def __str__(self) -> str: - field_strs = [str(f) for f in self.fields] - return f"Schema({', '.join(field_strs)})" - - def json(self) -> Dict[str, Any]: - return self.json_value - - @classmethod - def from_json(cls, json_data: str) -> "Schema": - """ - Generate a DeltaTable Schema from a json format. - - :param json_data: the schema in json format - :return: the DeltaTable schema - """ - json_value = json.loads(json_data) - fields = [] - for json_field in json_value["fields"]: - if isinstance(json_field["type"], str): - data_type = DataType(json_field["type"]) - else: - data_type = DataType.from_dict(json_field["type"]) - field = Field( - name=json_field["name"], - type=data_type, - nullable=json_field["nullable"], - metadata=json_field.get("metadata"), - ) - fields.append(field) - return cls(fields=fields, json_value=json_value) - - -def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType: - """ - Create a DataType in PyArrow format from a Schema json format. - - :param json_dict: the DataType in json format - :return: the DataType in PyArrow format - """ - type_class = json_dict["type"]["name"] - if type_class == "dictionary": - key_type = json_dict["dictionary"]["indexType"] - value_type = json_dict["children"][0] - key_type = pyarrow_datatype_from_dict(key_type) - value_type = pyarrow_datatype_from_dict(value_type) - return pyarrow.map_(key_type, value_type) - elif "dictionary" in json_dict: - key_type = { - "name": "key", - "type": json_dict["dictionary"]["indexType"], - "nullable": json_dict["nullable"], - } - key = pyarrow_datatype_from_dict(key_type) - if type_class == "list": - value_type = { - "name": "val", - "type": json_dict["dictionary"]["indexType"], - "nullable": json_dict["nullable"], - } - return pyarrow.map_( - key, - pyarrow.list_( - pyarrow.field( - "entries", pyarrow.struct([pyarrow_field_from_dict(value_type)]) - ) - ), - ) - value_type = { - "name": "value", - "type": json_dict["type"], - "nullable": json_dict["nullable"], - } - return pyarrow.map_(key, pyarrow_datatype_from_dict(value_type)) - elif type_class == "list": - field = json_dict["children"][0] - element_type = pyarrow_datatype_from_dict(field) - return pyarrow.list_(pyarrow.field("item", element_type)) - elif type_class == "struct": - fields = [pyarrow_field_from_dict(field) for field in json_dict["children"]] - return pyarrow.struct(fields) - elif type_class == "int": - return pyarrow.type_for_alias(f'{type_class}{json_dict["type"]["bitWidth"]}') - elif type_class == "date": - type_info = json_dict["type"] - if type_info["unit"] == "DAY": - return pyarrow.date32() - else: - return pyarrow.date64() - elif type_class == "time": - type_info = json_dict["type"] - if type_info["unit"] == "MICROSECOND": - unit = "us" - elif type_info["unit"] == "NANOSECOND": - unit = "ns" - elif type_info["unit"] == "MILLISECOND": - unit = "ms" - else: - unit = "s" - return pyarrow.type_for_alias(f'{type_class}{type_info["bitWidth"]}[{unit}]') - elif type_class == "timestamp": - type_info = json_dict["type"] - if "unit" in type_info: - if type_info["unit"] == "MICROSECOND": - unit = "us" - elif type_info["unit"] == "NANOSECOND": - unit = "ns" - elif type_info["unit"] == "MILLISECOND": - unit = "ms" - elif type_info["unit"] == "SECOND": - unit = "s" - else: - unit = "ns" - return pyarrow.type_for_alias(f"{type_class}[{unit}]") - elif type_class.startswith("decimal"): - type_info = json_dict["type"] - return pyarrow.decimal128( - precision=type_info["precision"], scale=type_info["scale"] - ) - elif type_class.startswith("floatingpoint"): - type_info = json_dict["type"] - if type_info["precision"] == "HALF": - return pyarrow.float16() - elif type_info["precision"] == "SINGLE": - return pyarrow.float32() - elif type_info["precision"] == "DOUBLE": - return pyarrow.float64() - else: - return pyarrow.type_for_alias(type_class) - - -def pyarrow_field_from_dict(field: Dict[str, Any]) -> pyarrow.Field: - """ - Create a Field in PyArrow format from a Field in json format. - :param field: the field in json format - :return: the Field in PyArrow format - """ - return pyarrow.field( - field["name"], - pyarrow_datatype_from_dict(field), - field["nullable"], - field.get("metadata"), - ) - - -def pyarrow_schema_from_json(json_data: str) -> pyarrow.Schema: - """ - Create a Schema in PyArrow format from a Schema in json format. - - :param json_data: the field in json format - :return: the Schema in PyArrow format - """ - schema_json = json.loads(json_data) - arrow_fields = [pyarrow_field_from_dict(field) for field in schema_json["fields"]] - return pyarrow.schema(arrow_fields) +# Can't implement inheritance (see note in src/schema.rs), so this is next +# best thing. +DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 62e1b91e3f..256e4ab794 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -13,7 +13,7 @@ from .data_catalog import DataCatalog from .deltalake import RawDeltaTable from .fs import DeltaStorageHandler -from .schema import Schema, pyarrow_schema_from_json +from .schema import Schema @dataclass(init=False) @@ -214,7 +214,7 @@ def schema(self) -> Schema: :return: the current Schema registered in the transaction log """ - return Schema.from_json(self._table.schema_json()) + return self._table.schema def metadata(self) -> Metadata: """ @@ -264,9 +264,16 @@ def pyarrow_schema(self) -> pyarrow.Schema: """ Get the current schema of the DeltaTable with the Parquet PyArrow format. + DEPRECATED: use DeltaTable.schema().to_pyarrow() instead. + :return: the current Schema with the Parquet PyArrow format """ - return pyarrow_schema_from_json(self._table.arrow_schema_json()) + warnings.warn( + "DeltaTable.pyarrow_schema() is deprecated. Use DeltaTable.schema().to_pyarrow() instead.", + category=DeprecationWarning, + stacklevel=2, + ) + return self.schema().to_pyarrow() def to_pyarrow_dataset( self, @@ -297,11 +304,13 @@ def to_pyarrow_dataset( partition_expression=part_expression, ) for file, part_expression in self._table.dataset_partitions( - partitions, self.pyarrow_schema() + partitions, self.schema().to_pyarrow() ) ] - return FileSystemDataset(fragments, self.pyarrow_schema(), format, filesystem) + return FileSystemDataset( + fragments, self.schema().to_pyarrow(), format, filesystem + ) def to_pyarrow_table( self, diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 76fa493c1b..0ceca587b5 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -151,12 +151,12 @@ def write_deltalake( fs = DeltaStorageHandler(table_uri) if table: # already exists - if schema != table.pyarrow_schema() and not ( + if schema != table.schema().to_pyarrow() and not ( mode == "overwrite" and overwrite_schema ): raise ValueError( "Schema of data does not match table schema\n" - f"Table schema:\n{schema}\nData Schema:\n{table.pyarrow_schema()}" + f"Table schema:\n{schema}\nData Schema:\n{table.schema().to_pyarrow()}" ) if mode == "error": diff --git a/python/docs/source/api_reference.rst b/python/docs/source/api_reference.rst index 09659ebc10..70efd9e2ae 100644 --- a/python/docs/source/api_reference.rst +++ b/python/docs/source/api_reference.rst @@ -12,10 +12,27 @@ Writing DeltaTables .. autofunction:: deltalake.write_deltalake -DeltaSchema ------------ +Delta Lake Schemas +------------------ + +Schemas, fields, and data types are provided in the ``deltalake.schema`` submodule. + +.. autoclass:: deltalake.schema.Schema + :members: + +.. autoclass:: deltalake.schema.PrimitiveType + :members: + +.. autoclass:: deltalake.schema.ArrayType + :members: + +.. autoclass:: deltalake.schema.MapType + :members: + +.. autoclass:: deltalake.schema.Field + :members: -.. automodule:: deltalake.schema +.. autoclass:: deltalake.schema.StructType :members: DataCatalog diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index 99d8c25624..9d1563fe3c 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -126,21 +126,21 @@ Use :meth:`DeltaTable.schema` to retrieve the delta lake schema: >>> from deltalake import DeltaTable >>> dt = DeltaTable("../rust/tests/data/simple_table") >>> dt.schema() - Schema(Field(id: DataType(long) nullable(True) metadata({}))) + Schema([Field(id, PrimitiveType("long"), nullable=True)]) These schemas have a JSON representation that can be retrieved. To reconstruct -from json, use :meth:`deltalake.schema.Schema.from_json`. +from json, use :meth:`deltalake.schema.Schema.from_json()`. .. code-block:: python >>> dt.schema().json() - {'type': 'struct', 'fields': [{'name': 'id', 'type': 'long', 'nullable': True, 'metadata': {}}]} + '{"type":"struct","fields":[{"name":"id","type":"long","nullable":true,"metadata":{}}]}' -Use :meth:`DeltaTable.pyarrow_schema` to retrieve the PyArrow schema: +Use :meth:`deltalake.schema.Schema.to_pyarrow()` to retrieve the PyArrow schema: .. code-block:: python - >>> dt.pyarrow_schema() + >>> dt.schema().to_pyarrow() id: int64 @@ -194,7 +194,7 @@ support filtering partitions and selecting particular columns. >>> from deltalake import DeltaTable >>> dt = DeltaTable("../rust/tests/data/delta-0.8.0-partitioned") - >>> dt.dt.pyarrow_schema() + >>> dt.schema().to_pyarrow() value: string year: string month: string diff --git a/python/pyproject.toml b/python/pyproject.toml index 16e0b1dc0c..f26ed4d936 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -63,7 +63,7 @@ no_implicit_optional = true warn_redundant_casts = true warn_unused_ignores = true warn_return_any = false -implicit_reexport = false +implicit_reexport = true strict_equality = true [tool.isort] diff --git a/python/src/lib.rs b/python/src/lib.rs index a5987c387d..ef011238da 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -2,6 +2,8 @@ extern crate pyo3; +pub mod schema; + use chrono::{DateTime, FixedOffset, Utc}; use deltalake::action; use deltalake::action::Action; @@ -28,6 +30,8 @@ use std::sync::Arc; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use crate::schema::schema_to_pyobject; + create_exception!(deltalake, PyDeltaTableError, PyException); impl PyDeltaTableError { @@ -220,13 +224,13 @@ impl RawDeltaTable { Ok(self._table.get_file_uris().collect()) } - pub fn schema_json(&self) -> PyResult { - let schema = self + #[getter] + pub fn schema(&self, py: Python) -> PyResult { + let schema: &Schema = self ._table .get_schema() .map_err(PyDeltaTableError::from_raw)?; - serde_json::to_string(&schema) - .map_err(|_| PyDeltaTableError::new_err("Got invalid table schema")) + schema_to_pyobject(schema, py) } /// Run the Vacuum command on the Delta Table: list and delete files no longer referenced by the Delta table and are older than the retention threshold. @@ -616,5 +620,13 @@ fn deltalake(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add("PyDeltaTableError", py.get_type::())?; + // There are issues with submodules, so we will expose them flat for now + // See also: https://github.com/PyO3/pyo3/issues/759 + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/python/src/schema.rs b/python/src/schema.rs new file mode 100644 index 0000000000..580f52d03f --- /dev/null +++ b/python/src/schema.rs @@ -0,0 +1,1067 @@ +extern crate pyo3; + +use crate::pyo3::types::IntoPyDict; +use deltalake::arrow::datatypes::{ + DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema, +}; +use deltalake::arrow::error::ArrowError; +use deltalake::schema::{ + Schema, SchemaDataType, SchemaField, SchemaTypeArray, SchemaTypeMap, SchemaTypeStruct, +}; +use lazy_static::lazy_static; +use pyo3::exceptions::{PyException, PyNotImplementedError, PyTypeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::{PyRef, PyResult}; +use regex::Regex; +use std::collections::HashMap; + +// PyO3 doesn't yet support converting classes with inheritance with Python +// objects within Rust code, which we need here. So for now, we implement +// the types with no inheritance. Later, we may add inheritance. +// See: https://github.com/PyO3/pyo3/issues/1836 + +// Decimal is separate special case, since it has parameters +const VALID_PRIMITIVE_TYPES: [&str; 11] = [ + "string", + "long", + "integer", + "short", + "byte", + "float", + "double", + "boolean", + "binary", + "date", + "timestamp", +]; + +fn try_parse_decimal_type(data_type: &str) -> Option<(usize, usize)> { + lazy_static! { + static ref DECIMAL_REGEX: Regex = Regex::new(r"\((\d{1,2}),(\d{1,2})\)").unwrap(); + } + let extract = DECIMAL_REGEX.captures(data_type)?; + let precision = extract + .get(1) + .and_then(|v| v.as_str().parse::().ok())?; + let scale = extract + .get(2) + .and_then(|v| v.as_str().parse::().ok())?; + Some((precision, scale)) +} + +fn schema_type_to_python(schema_type: SchemaDataType, py: Python) -> PyResult { + match schema_type { + SchemaDataType::primitive(data_type) => Ok((PrimitiveType::new(data_type)?).into_py(py)), + SchemaDataType::array(array_type) => { + let array_type: ArrayType = array_type.into(); + Ok(array_type.into_py(py)) + } + SchemaDataType::map(map_type) => { + let map_type: MapType = map_type.into(); + Ok(map_type.into_py(py)) + } + SchemaDataType::r#struct(struct_type) => { + let struct_type: StructType = struct_type.into(); + Ok(struct_type.into_py(py)) + } + } +} + +fn python_type_to_schema(ob: PyObject, py: Python) -> PyResult { + if let Ok(data_type) = ob.extract::(py) { + return Ok(SchemaDataType::primitive(data_type.inner_type)); + } + if let Ok(array_type) = ob.extract::(py) { + return Ok(array_type.into()); + } + if let Ok(map_type) = ob.extract::(py) { + return Ok(map_type.into()); + } + if let Ok(struct_type) = ob.extract::(py) { + return Ok(struct_type.into()); + } + if let Ok(raw_primitive) = ob.extract::(py) { + // Pass through PrimitiveType::new() to do validation + return PrimitiveType::new(raw_primitive) + .map(|data_type| SchemaDataType::primitive(data_type.inner_type)); + } + Err(PyValueError::new_err("Invalid data type")) +} + +/// A primitive datatype, such as a string or number. +/// +/// Can be initialized with a string value: +/// +/// >>> PrimitiveType("integer") +/// PrimitiveType("integer") +/// +/// Valid primitive data types include: +/// +/// * "string", +/// * "long", +/// * "integer", +/// * "short", +/// * "byte", +/// * "float", +/// * "double", +/// * "boolean", +/// * "binary", +/// * "date", +/// * "timestamp", +/// * "decimal(, )" +/// +/// :param data_type: string representation of the data type +#[pyclass(module = "deltalake.schema", text_signature = "(data_type)")] +#[derive(Clone)] +pub struct PrimitiveType { + inner_type: String, +} + +impl TryFrom for PrimitiveType { + type Error = PyErr; + fn try_from(value: SchemaDataType) -> PyResult { + match value { + SchemaDataType::primitive(type_name) => Self::new(type_name), + _ => Err(PyTypeError::new_err("Type is not primitive")), + } + } +} + +#[pymethods] +impl PrimitiveType { + #[new] + fn new(data_type: String) -> PyResult { + if data_type.starts_with("decimal") { + if try_parse_decimal_type(&data_type).is_none() { + Err(PyValueError::new_err(format!( + "invalid decimal type: {}", + data_type + ))) + } else { + Ok(Self { + inner_type: data_type, + }) + } + } else if !VALID_PRIMITIVE_TYPES + .iter() + .any(|&valid| data_type == valid) + { + Err(PyValueError::new_err(format!( + "data_type must be one of decimal(, ), {}.", + VALID_PRIMITIVE_TYPES.join(", ") + ))) + } else { + Ok(Self { + inner_type: data_type, + }) + } + } + + /// The inner type + /// + /// :rtype: str + #[getter] + fn get_type(&self) -> PyResult { + Ok(self.inner_type.clone()) + } + + fn __richcmp__(&self, other: PrimitiveType, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.inner_type == other.inner_type), + pyo3::basic::CompareOp::Ne => Ok(self.inner_type != other.inner_type), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported.", + )), + } + } + + fn __repr__(&self) -> PyResult { + Ok(format!("PrimitiveType(\"{}\")", &self.inner_type)) + } + + /// Get the JSON string representation of the type. + /// + /// :rtype: str + #[pyo3(text_signature = "($self)")] + fn to_json(&self) -> PyResult { + let inner_type = SchemaDataType::primitive(self.inner_type.clone()); + serde_json::to_string(&inner_type).map_err(|err| PyException::new_err(err.to_string())) + } + + /// Create a PrimitiveType from a JSON string + /// + /// The JSON representation for a primitive type is just a quoted string: + /// + /// >>> PrimitiveType.from_json('"integer"') + /// PrimitiveType("integer") + /// + /// :param type_json: A JSON string + /// :type type_json: str + /// :rtype: PrimitiveType + #[staticmethod] + #[pyo3(text_signature = "(type_json)")] + fn from_json(type_json: String) -> PyResult { + let data_type: SchemaDataType = serde_json::from_str(&type_json) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + data_type.try_into() + } + + /// Get the equivalent PyArrow type. + /// + /// :rtype: pyarrow.DataType + #[pyo3(text_signature = "($self)")] + fn to_pyarrow(&self) -> PyResult { + let inner_type = SchemaDataType::primitive(self.inner_type.clone()); + (&inner_type) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string())) + } + + /// Create a PrimitiveType from a PyArrow type + /// + /// Will raise ``TypeError`` if the PyArrow type is not a primitive type. + /// + /// :param data_type: A PyArrow DataType + /// :type data_type: pyarrow.DataType + /// :rtype: PrimitiveType + #[pyo3(text_signature = "(data_type)")] + #[staticmethod] + fn from_pyarrow(data_type: ArrowDataType) -> PyResult { + let inner_type: SchemaDataType = (&data_type) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; + + inner_type.try_into() + } +} + +/// An Array (List) DataType +/// +/// Can either pass the element type explicitly or can pass a string +/// if it is a primitive type: +/// +/// >>> ArrayType(PrimitiveType("integer")) +/// ArrayType(PrimitiveType("integer"), contains_null=True) +/// >>> ArrayType("integer", contains_null=False) +/// ArrayType(PrimitiveType("integer"), contains_null=False) +#[pyclass( + module = "deltalake.schema", + text_signature = "(element_type, contains_null=True)" +)] +#[derive(Clone)] +pub struct ArrayType { + inner_type: SchemaTypeArray, +} + +impl From for ArrayType { + fn from(inner_type: SchemaTypeArray) -> Self { + Self { inner_type } + } +} + +impl From for SchemaDataType { + fn from(arr: ArrayType) -> SchemaDataType { + SchemaDataType::array(arr.inner_type) + } +} + +impl TryFrom for ArrayType { + type Error = PyErr; + fn try_from(value: SchemaDataType) -> PyResult { + match value { + SchemaDataType::array(inner_type) => Ok(Self { inner_type }), + _ => Err(PyTypeError::new_err("Type is not an array")), + } + } +} + +#[pymethods] +impl ArrayType { + #[new] + #[args(contains_null = true)] + fn new(element_type: PyObject, contains_null: bool, py: Python) -> PyResult { + let inner_type = SchemaTypeArray::new( + Box::new(python_type_to_schema(element_type, py)?), + contains_null, + ); + Ok(Self { inner_type }) + } + + fn __repr__(&self, py: Python) -> PyResult { + let type_repr: String = + schema_type_to_python(self.inner_type.get_element_type().clone(), py)? + .call_method0(py, "__repr__")? + .extract(py)?; + Ok(format!( + "ArrayType({}, contains_null={})", + type_repr, + if self.inner_type.contains_null() { + "True" + } else { + "False" + }, + )) + } + + fn __richcmp__(&self, other: ArrayType, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.inner_type == other.inner_type), + pyo3::basic::CompareOp::Ne => Ok(self.inner_type != other.inner_type), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported.", + )), + } + } + + /// The string "array" + /// + /// :rtype: str + #[getter] + fn get_type(&self) -> String { + "array".to_string() + } + + /// The type of the element + /// + /// :rtype: Union[PrimitiveType, ArrayType, MapType, StructType] + #[getter] + fn element_type(&self, py: Python) -> PyResult { + schema_type_to_python(self.inner_type.get_element_type().to_owned(), py) + } + + /// Whether the arrays may contain null values + /// + /// :rtype: bool + #[getter] + fn contains_null(&self, py: Python) -> PyResult { + Ok(self.inner_type.contains_null().into_py(py)) + } + + /// Get the JSON string representation of the type. + /// + /// :rtype: str + #[pyo3(text_signature = "($self)")] + fn to_json(&self) -> PyResult { + serde_json::to_string(&self.inner_type).map_err(|err| PyException::new_err(err.to_string())) + } + + /// Create an ArrayType from a JSON string + /// + /// The JSON representation for an array type is an object with ``type`` (set to + /// ``"array"``), ``elementType``, and ``containsNull``: + /// + /// >>> ArrayType.from_json("""{ + /// ... "type": "array", + /// ... "elementType": "integer", + /// ... "containsNull": false + /// ... }""") + /// ArrayType(PrimitiveType("integer"), contains_null=False) + /// + /// :param type_json: A JSON string + /// :type type_json: str + /// :rtype: ArrayType + #[staticmethod] + #[pyo3(text_signature = "(type_json)")] + fn from_json(type_json: String) -> PyResult { + let data_type: SchemaDataType = serde_json::from_str(&type_json) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + data_type.try_into() + } + + /// Get the equivalent PyArrow type. + /// + /// :rtype: pyarrow.DataType + #[pyo3(text_signature = "($self)")] + fn to_pyarrow(&self) -> PyResult { + (&SchemaDataType::array(self.inner_type.clone())) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string())) + } + + /// Create an ArrayType from a pyarrow.ListType. + /// + /// Will raise ``TypeError`` if a different PyArrow DataType is provided. + /// + /// :param data_type: The PyArrow datatype + /// :type data_type: pyarrow.ListType + /// :rtype: ArrayType + #[staticmethod] + #[pyo3(text_signature = "(data_type)")] + fn from_pyarrow(data_type: ArrowDataType) -> PyResult { + let inner_type: SchemaDataType = (&data_type) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; + + inner_type.try_into() + } +} + +/// A map data type +/// +/// ``key_type`` and ``value_type`` should be :class PrimitiveType:, :class ArrayType:, +/// :class ListType:, or :class StructType:. A string can also be passed, which will be +/// parsed as a primitive type: +/// +/// >>> MapType(PrimitiveType("integer"), PrimitiveType("string")) +/// MapType(PrimitiveType("integer"), PrimitiveType("string"), value_contains_null=True) +/// >>> MapType("integer", "string", value_contains_null=False) +/// MapType(PrimitiveType("integer"), PrimitiveType("string"), value_contains_null=False) +#[pyclass( + module = "deltalake.schema", + text_signature = "(key_type, value_type, value_contains_null=True)" +)] +#[derive(Clone)] +pub struct MapType { + inner_type: SchemaTypeMap, +} + +impl From for MapType { + fn from(inner_type: SchemaTypeMap) -> Self { + Self { inner_type } + } +} + +impl From for SchemaDataType { + fn from(map: MapType) -> SchemaDataType { + SchemaDataType::map(map.inner_type) + } +} + +impl TryFrom for MapType { + type Error = PyErr; + fn try_from(value: SchemaDataType) -> PyResult { + match value { + SchemaDataType::map(inner_type) => Ok(Self { inner_type }), + _ => Err(PyTypeError::new_err("Type is not a map")), + } + } +} + +#[pymethods] +impl MapType { + #[new] + #[args(value_contains_null = true)] + fn new( + key_type: PyObject, + value_type: PyObject, + value_contains_null: bool, + py: Python, + ) -> PyResult { + let inner_type = SchemaTypeMap::new( + Box::new(python_type_to_schema(key_type, py)?), + Box::new(python_type_to_schema(value_type, py)?), + value_contains_null, + ); + Ok(Self { inner_type }) + } + + fn __repr__(&self, py: Python) -> PyResult { + let key_repr: String = schema_type_to_python(self.inner_type.get_key_type().clone(), py)? + .call_method0(py, "__repr__")? + .extract(py)?; + let value_repr: String = + schema_type_to_python(self.inner_type.get_value_type().clone(), py)? + .call_method0(py, "__repr__")? + .extract(py)?; + Ok(format!( + "MapType({}, {}, value_contains_null={})", + key_repr, + value_repr, + if self.inner_type.get_value_contains_null() { + "True" + } else { + "False" + } + )) + } + + fn __richcmp__(&self, other: MapType, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.inner_type == other.inner_type), + pyo3::basic::CompareOp::Ne => Ok(self.inner_type != other.inner_type), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported.", + )), + } + } + + /// The string "map" + /// + /// :rtype: str + #[getter] + fn get_type(&self) -> String { + "map".to_string() + } + + /// The type of the keys + /// + /// :rtype: Union[PrimitiveType, ArrayType, MapType, StructType] + #[getter] + fn key_type(&self, py: Python) -> PyResult { + schema_type_to_python(self.inner_type.get_key_type().to_owned(), py) + } + + /// The type of the values + /// + /// :rtype: Union[PrimitiveType, ArrayType, MapType, StructType] + #[getter] + fn value_type(&self, py: Python) -> PyResult { + schema_type_to_python(self.inner_type.get_value_type().to_owned(), py) + } + + /// Whether the values in a map may be null + /// + /// :rtype: bool + #[getter] + fn value_contains_null(&self, py: Python) -> PyResult { + Ok(self.inner_type.get_value_contains_null().into_py(py)) + } + + /// Get JSON string representation of map type. + /// + /// :rtype: str + #[pyo3(text_signature = "($self)")] + fn to_json(&self) -> PyResult { + serde_json::to_string(&self.inner_type).map_err(|err| PyException::new_err(err.to_string())) + } + + /// Create a MapType from a JSON string + /// + /// The JSON representation for a map type is an object with ``type`` (set to ``map``), + /// ``keyType``, ``valueType``, and ``valueContainsNull``: + /// + /// >>> MapType.from_json("""{ + /// ... "type": "map", + /// ... "keyType": "integer", + /// ... "valueType": "string", + /// ... "valueContainsNull": true + /// ... }""") + /// MapType(PrimitiveType("integer"), PrimitiveType("string"), value_contains_null=True) + /// + /// :param type_json: A JSON string + /// :type type_json: str + /// :rtype: MapType + #[staticmethod] + #[pyo3(text_signature = "(type_json)")] + fn from_json(type_json: String) -> PyResult { + let data_type: SchemaDataType = serde_json::from_str(&type_json) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + data_type.try_into() + } + + /// Get the equivalent PyArrow data type. + /// + /// :rtype: pyarrow.MapType + #[pyo3(text_signature = "($self)")] + fn to_pyarrow(&self) -> PyResult { + (&SchemaDataType::map(self.inner_type.clone())) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string())) + } + + /// Create a MapType from a PyArrow MapType. + /// + /// Will raise ``TypeError`` if passed a different type. + /// + /// :param data_type: the PyArrow MapType + /// :type data_type: pyarrow.MapType + /// :rtype: MapType + #[staticmethod] + #[pyo3(text_signature = "(data_type)")] + fn from_pyarrow(data_type: ArrowDataType) -> PyResult { + let inner_type: SchemaDataType = (&data_type) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; + + inner_type.try_into() + } +} + +/// A field in a Delta StructType or Schema +/// +/// Can create with just a name and a type: +/// +/// >>> Field("my_int_col", "integer") +/// Field("my_int_col", PrimitiveType("integer"), nullable=True, metadata=None) +/// +/// Can also attach metadata to the field. Metadata should be a dictionary with +/// string keys and JSON-serializable values (str, list, int, float, dict): +/// +/// >>> Field("my_col", "integer", metadata={"custom_metadata": {"test": 2}}) +/// Field("my_col", PrimitiveType("integer"), nullable=True, metadata={"custom_metadata": {"test": 2}}) +#[pyclass( + module = "deltalake.schema", + text_signature = "(name, ty, nullable=True, metadata=None)" +)] +#[derive(Clone)] +pub struct Field { + inner: SchemaField, +} + +#[pymethods] +impl Field { + #[new] + #[args(nullable = true)] + fn new( + name: String, + ty: PyObject, + nullable: bool, + metadata: Option, + py: Python, + ) -> PyResult { + let ty = python_type_to_schema(ty, py)?; + + // Serialize and de-serialize JSON (it needs to be valid JSON anyways) + let metadata: HashMap = if let Some(ref json) = metadata { + let json_dumps = PyModule::import(py, "json")?.getattr("dumps")?; + let metadata_json: String = json_dumps.call1((json,))?.extract()?; + let metadata_json = Some(metadata_json) + .filter(|x| x != "null") + .unwrap_or_else(|| "{}".to_string()); + serde_json::from_str(&metadata_json) + .map_err(|err| PyValueError::new_err(err.to_string()))? + } else { + HashMap::new() + }; + + Ok(Self { + inner: SchemaField::new(name, ty, nullable, metadata), + }) + } + + /// The name of the field + /// + /// :rtype: str + #[getter] + fn name(&self) -> String { + self.inner.get_name().to_string() + } + + /// The type of the field + /// + /// :rtype: Union[PrimitiveType, ArrayType, MapType, StructType] + #[getter] + fn get_type(&self, py: Python) -> PyResult { + schema_type_to_python(self.inner.get_type().clone(), py) + } + + /// Whether there may be null values in the field + /// + /// :rtype: bool + #[getter] + fn nullable(&self) -> bool { + self.inner.is_nullable() + } + + /// The metadata of the field + /// + /// :rtype: dict + #[getter] + fn metadata(&self, py: Python) -> PyResult { + let json_loads = PyModule::import(py, "json")?.getattr("loads")?; + let metadata_json: String = serde_json::to_string(self.inner.get_metadata()) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + Ok(json_loads.call1((metadata_json,))?.to_object(py)) + } + + fn __repr__(&self, py: Python) -> PyResult { + let type_repr: String = schema_type_to_python(self.inner.get_type().clone(), py)? + .call_method0(py, "__repr__")? + .extract(py)?; + + let metadata = self.inner.get_metadata(); + let maybe_metadata = if metadata.is_empty() { + "".to_string() + } else { + let metadata_repr: String = self + .metadata(py)? + .call_method0(py, "__repr__")? + .extract(py)?; + format!(", metadata={}", metadata_repr) + }; + Ok(format!( + "Field({}, {}, nullable={}{})", + self.inner.get_name(), + type_repr, + if self.inner.is_nullable() { + "True" + } else { + "False" + }, + maybe_metadata, + )) + } + + fn __richcmp__(&self, other: Field, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.inner == other.inner), + pyo3::basic::CompareOp::Ne => Ok(self.inner != other.inner), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported.", + )), + } + } + + /// Get the field as JSON string. + /// + /// >>> Field("col", "integer").to_json() + /// '{"name":"col","type":"integer","nullable":true,"metadata":{}}' + /// + /// :rtype: str + #[pyo3(text_signature = "($self)")] + fn to_json(&self) -> PyResult { + serde_json::to_string(&self.inner).map_err(|err| PyException::new_err(err.to_string())) + } + + /// Create a Field from a JSON string. + /// + /// >>> Field.from_json("""{ + /// ... "name": "col", + /// ... "type": "integer", + /// ... "nullable": true, + /// ... "metadata": {} + /// ... }""") + /// Field(col, PrimitiveType("integer"), nullable=True) + /// + /// :param field_json: the JSON string. + /// :type field_json: str + /// :rtype: Field + #[staticmethod] + #[pyo3(text_signature = "(field_json)")] + fn from_json(field_json: String) -> PyResult { + let field: SchemaField = serde_json::from_str(&field_json) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + Ok(Self { inner: field }) + } + + /// Convert to an equivalent PyArrow field + /// + /// Note: This currently doesn't preserve field metadata. + /// + /// :rtype: pyarrow.Field + #[pyo3(text_signature = "($self)")] + fn to_pyarrow(&self) -> PyResult { + (&self.inner) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string())) + } + + /// Create a Field from a PyArrow field + /// + /// Note: This currently doesn't preserve field metadata. + /// + /// :param field: a field + /// :type: pyarrow.Field + /// :rtype: Field + #[staticmethod] + #[pyo3(text_signature = "(field)")] + fn from_pyarrow(field: ArrowField) -> PyResult { + Ok(Self { + inner: SchemaField::try_from(&field) + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?, + }) + } +} + +/// A struct datatype, containing one or more subfields +/// +/// Create with a list of :class:`Field`: +/// +/// >>> StructType([Field("x", "integer"), Field("y", "string")]) +/// StructType([Field(x, PrimitiveType("integer"), nullable=True), Field(y, PrimitiveType("string"), nullable=True)]) +#[pyclass(subclass, module = "deltalake.schema", text_signature = "(fields)")] +#[derive(Clone)] +pub struct StructType { + inner_type: SchemaTypeStruct, +} + +impl From for StructType { + fn from(inner_type: SchemaTypeStruct) -> Self { + Self { inner_type } + } +} + +impl From for SchemaDataType { + fn from(str: StructType) -> SchemaDataType { + SchemaDataType::r#struct(str.inner_type) + } +} + +impl TryFrom for StructType { + type Error = PyErr; + fn try_from(value: SchemaDataType) -> PyResult { + match value { + SchemaDataType::r#struct(inner_type) => Ok(Self { inner_type }), + _ => Err(PyTypeError::new_err("Type is not a struct")), + } + } +} +#[pymethods] +impl StructType { + #[new] + fn new(fields: Vec>) -> Self { + let fields: Vec = fields + .into_iter() + .map(|field| field.inner.clone()) + .collect(); + let inner_type = SchemaTypeStruct::new(fields); + Self { inner_type } + } + + fn __repr__(&self, py: Python) -> PyResult { + let inner_data: Vec = self + .inner_type + .get_fields() + .iter() + .map(|field| { + let field = Field { + inner: field.clone(), + }; + field.__repr__(py) + }) + .collect::>()?; + Ok(format!("StructType([{}])", inner_data.join(", "))) + } + + fn __richcmp__(&self, other: StructType, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.inner_type == other.inner_type), + pyo3::basic::CompareOp::Ne => Ok(self.inner_type != other.inner_type), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported.", + )), + } + } + + /// The string "struct" + #[getter] + fn get_type(&self) -> String { + "struct".to_string() + } + + /// The fields within the struct + /// + /// :rtype: List[Field] + #[getter] + fn fields(&self) -> Vec { + self.inner_type + .get_fields() + .iter() + .map(|field| Field { + inner: field.clone(), + }) + .collect::>() + } + + /// Get the JSON representation of the type. + /// + /// >>> StructType([Field("x", "integer")]).to_json() + /// '{"type":"struct","fields":[{"name":"x","type":"integer","nullable":true,"metadata":{}}]}' + /// + /// :rtype: str + #[pyo3(text_signature = "($self)")] + fn to_json(&self) -> PyResult { + serde_json::to_string(&self.inner_type).map_err(|err| PyException::new_err(err.to_string())) + } + + /// Create a new StructType from a JSON string. + /// + /// >>> StructType.from_json("""{ + /// ... "type": "struct", + /// ... "fields": [{"name": "x", "type": "integer", "nullable": true, "metadata": {}}] + /// ... }""") + /// StructType([Field(x, PrimitiveType("integer"), nullable=True)]) + /// + /// :param type_json: a JSON string + /// :type type_json: str + /// :rtype: StructType + #[staticmethod] + #[pyo3(text_signature = "(type_json)")] + fn from_json(type_json: String) -> PyResult { + let data_type: SchemaDataType = serde_json::from_str(&type_json) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + data_type.try_into() + } + + /// Get the equivalent PyArrow StructType + /// + /// :rtype: pyarrow.StructType + #[pyo3(text_signature = "($self)")] + fn to_pyarrow(&self) -> PyResult { + (&SchemaDataType::r#struct(self.inner_type.clone())) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string())) + } + + /// Create a new StructType from a PyArrow struct type. + /// + /// Will raise ``TypeError`` if a different data type is provided. + /// + /// :param data_type: a PyArrow struct type. + /// :type data_type: pyarrow.StructType + /// :rtype: StructType + #[staticmethod] + #[pyo3(text_signature = "(data_type)")] + fn from_pyarrow(data_type: ArrowDataType) -> PyResult { + let inner_type: SchemaDataType = (&data_type) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; + + inner_type.try_into() + } +} + +pub fn schema_to_pyobject(schema: &Schema, py: Python) -> PyResult { + let fields: Vec = schema + .get_fields() + .iter() + .map(|field| Field { + inner: field.clone(), + }) + .collect(); + + let py_schema = PyModule::import(py, "deltalake.schema")?.getattr("Schema")?; + + py_schema + .call1((fields,)) + .map(|schema| schema.to_object(py)) +} + +/// A Delta Lake schema +/// +/// Create using a list of :class:`Field`: +/// +/// >>> Schema([Field("x", "integer"), Field("y", "string")]) +/// Schema([Field(x, PrimitiveType("integer"), nullable=True), Field(y, PrimitiveType("string"), nullable=True)]) +/// +/// Or create from a PyArrow schema: +/// +/// >>> import pyarrow as pa +/// >>> Schema.from_pyarrow(pa.schema({"x": pa.int32(), "y": pa.string()})) +/// Schema([Field(x, PrimitiveType("integer"), nullable=True), Field(y, PrimitiveType("string"), nullable=True)]) +#[pyclass(extends=StructType, name="Schema", module="deltalake.schema", +text_signature = "(fields)")] +pub struct PySchema; + +#[pymethods] +impl PySchema { + #[new] + fn new(fields: Vec>) -> PyResult<(Self, StructType)> { + let fields: Vec = fields + .into_iter() + .map(|field| field.inner.clone()) + .collect(); + let inner_type = SchemaTypeStruct::new(fields); + Ok((Self {}, StructType { inner_type })) + } + + fn __repr__(self_: PyRef<'_, Self>, py: Python) -> PyResult { + let super_ = self_.as_ref(); + let inner_data: Vec = super_ + .inner_type + .get_fields() + .iter() + .map(|field| { + let field = Field { + inner: field.clone(), + }; + field.__repr__(py) + }) + .collect::>()?; + Ok(format!("Schema([{}])", inner_data.join(", "))) + } + + /// DEPRECATED: Convert to JSON dictionary representation + fn json(self_: PyRef<'_, Self>, py: Python) -> PyResult { + let warnings_warn = PyModule::import(py, "warnings")?.getattr("warn")?; + let deprecation_warning = PyModule::import(py, "builtins")? + .getattr("DeprecationWarning")? + .to_object(py); + let kwargs: [(&str, PyObject); 2] = [ + ("category", deprecation_warning), + ("stacklevel", 2.to_object(py)), + ]; + warnings_warn.call( + ("Schema.json() is deprecated. Use json.loads(Schema.to_json()) instead.",), + Some(kwargs.into_py_dict(py)), + )?; + + let super_ = self_.as_ref(); + let json = super_.to_json()?; + let json_loads = PyModule::import(py, "json")?.getattr("loads")?; + json_loads + .call1((json.into_py(py),)) + .map(|obj| obj.to_object(py)) + } + + /// Return equivalent PyArrow schema + /// + /// :rtype: pyarrow.Schema + #[pyo3(text_signature = "($self)")] + fn to_pyarrow(self_: PyRef<'_, Self>) -> PyResult { + let super_ = self_.as_ref(); + (&super_.inner_type.clone()) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string())) + } + + /// Create from a PyArrow schema + /// + /// :param data_type: a PyArrow schema + /// :type data_type: pyarrow.Schema + /// :rtype: Schema + #[staticmethod] + #[pyo3(text_signature = "(data_type)")] + fn from_pyarrow(data_type: ArrowSchema, py: Python) -> PyResult { + let inner_type: SchemaTypeStruct = (&data_type) + .try_into() + .map_err(|err: ArrowError| PyException::new_err(err.to_string()))?; + + schema_to_pyobject(&inner_type, py) + } + + /// Get the JSON representation of the schema. + /// + /// A schema has the same JSON format as a StructType. + /// + /// >>> Schema([Field("x", "integer")]).to_json() + /// '{"type":"struct","fields":[{"name":"x","type":"integer","nullable":true,"metadata":{}}]}' + /// + /// :rtype: str + #[pyo3(text_signature = "($self)")] + fn to_json(self_: PyRef<'_, Self>) -> PyResult { + let super_ = self_.as_ref(); + super_.to_json() + } + + /// Create a new Schema from a JSON string. + /// + /// A schema has the same JSON format as a StructType. + /// + /// >>> Schema.from_json("""{ + /// ... "type": "struct", + /// ... "fields": [{"name": "x", "type": "integer", "nullable": true, "metadata": {}}] + /// ... }""") + /// Schema([Field(x, PrimitiveType("integer"), nullable=True)]) + /// + /// :param schema_json: a JSON string + /// :type schema_json: str + /// :rtype: Schema + #[staticmethod] + #[pyo3(text_signature = "(schema_json)")] + fn from_json(schema_json: String, py: Python) -> PyResult> { + let data_type: SchemaDataType = serde_json::from_str(&schema_json) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + if let SchemaDataType::r#struct(inner_type) = data_type { + Py::new(py, (Self {}, StructType { inner_type })) + } else { + Err(PyTypeError::new_err("Type is not a struct")) + } + } +} diff --git a/python/stubs/deltalake/__init__.pyi b/python/stubs/deltalake/__init__.pyi deleted file mode 100644 index e994c5f3a6..0000000000 --- a/python/stubs/deltalake/__init__.pyi +++ /dev/null @@ -1,3 +0,0 @@ -from typing import Any - -RawDeltaTableMetadata: Any diff --git a/python/stubs/deltalake/deltalake.pyi b/python/stubs/deltalake/deltalake.pyi deleted file mode 100644 index b872e61bd3..0000000000 --- a/python/stubs/deltalake/deltalake.pyi +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, Callable, List - -import pyarrow as pa - -from deltalake.writer import AddAction - -RawDeltaTable: Any -rust_core_version: Callable[[], str] -DeltaStorageFsBackend: Any - -write_new_deltalake: Callable[[str, pa.Schema, List[AddAction], str, List[str]], None] - -class PyDeltaTableError(BaseException): ... diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index 330db25ab6..574ff878b6 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -6,6 +6,9 @@ Table: Any RecordBatch: Any Field: Any DataType: Any +ListType: Any +StructType: Any +MapType: Any schema: Any map_: Any list_: Any diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 296a5d8c5f..c55c4e625b 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -1,13 +1,16 @@ +from array import array + import pyarrow +import pytest from deltalake import DeltaTable, Field from deltalake.schema import ( ArrayType, DataType, MapType, + PrimitiveType, Schema, StructType, - pyarrow_field_from_dict, ) @@ -22,19 +25,21 @@ def test_table_schema(): assert len(schema.fields) == 1 field = schema.fields[0] assert field.name == "id" - assert field.type == DataType("long") + assert field.type == PrimitiveType("long") assert field.nullable is True assert field.metadata == {} json = '{"type":"struct","fields":[{"name":"x","type":{"type":"array","elementType":"long","containsNull":true},"nullable":true,"metadata":{}}]}' schema = Schema.from_json(json) - assert schema.fields[0] == Field("x", ArrayType(DataType("long"), True), True, {}) + assert schema.fields[0] == Field( + "x", ArrayType(PrimitiveType("long"), True), True, {} + ) def test_table_schema_pyarrow_simple(): table_path = "../rust/tests/data/simple_table" dt = DeltaTable(table_path) - schema = dt.pyarrow_schema() + schema = dt.schema().to_pyarrow() field = schema.field(0) assert len(schema.types) == 1 assert field.name == "id" @@ -46,7 +51,7 @@ def test_table_schema_pyarrow_simple(): def test_table_schema_pyarrow_020(): table_path = "../rust/tests/data/delta-0.2.0" dt = DeltaTable(table_path) - schema = dt.pyarrow_schema() + schema = dt.schema().to_pyarrow() field = schema.field(0) assert len(schema.types) == 1 assert field.name == "value" @@ -55,270 +60,160 @@ def test_table_schema_pyarrow_020(): assert field.metadata is None -def test_schema_pyarrow_from_decimal_and_floating_types(): - field_name = "decimal_test" - metadata = {b"metadata_k": b"metadata_v"} - precision = 20 - scale = 2 - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "decimal", "precision": precision, "scale": scale}, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.decimal128(precision=precision, scale=scale) - assert dict(pyarrow_field.metadata) == metadata - assert pyarrow_field.nullable is False - - field_name = "floating_test" - metadata = {b"metadata_k": b"metadata_v"} - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "floatingpoint", "precision": "HALF"}, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.float16() - assert dict(pyarrow_field.metadata) == metadata - assert pyarrow_field.nullable is False - - -def test_schema_delta_types(): - field_name = "column1" - metadata = {"metadata_k": "metadata_v"} - delta_field = Field( - name=field_name, - type=DataType.from_dict({"type": "integer"}), - metadata={"metadata_k": "metadata_v"}, - nullable=False, - ) - assert delta_field.name == field_name - assert delta_field.type == DataType("integer") - assert delta_field.metadata == metadata - assert delta_field.nullable is False - - delta_field = Field( - name=field_name, - type=DataType.from_dict( - {"type": "array", "elementType": {"type": "integer"}, "containsNull": True} - ), - metadata={"metadata_k": "metadata_v"}, - nullable=False, - ) - assert delta_field.name == field_name - assert delta_field.type == ArrayType(DataType("integer"), True) - assert delta_field.metadata == metadata - assert delta_field.nullable is False - - delta_field = Field( - name=field_name, - type=DataType.from_dict( - { - "type": "map", - "keyType": "integer", - "valueType": "integer", - "valueContainsNull": True, - } - ), - metadata={"metadata_k": "metadata_v"}, - nullable=False, - ) - assert delta_field.name == field_name - key_type = DataType("integer") - value_type = DataType("integer") - assert delta_field.type == MapType(key_type, value_type, True) - assert delta_field.metadata == metadata - assert delta_field.nullable is False - - delta_field = Field( - name=field_name, - type=DataType.from_dict( - { - "type": "struct", - "fields": [ - { - "name": "x", - "type": {"type": "integer"}, - "nullable": True, - "metadata": {}, - } - ], - } - ), - metadata={"metadata_k": "metadata_v"}, - nullable=False, - ) - assert delta_field.name == field_name - assert delta_field.type == StructType([Field("x", DataType("integer"), True, {})]) - assert delta_field.metadata == metadata - assert delta_field.nullable is False - - -def test_schema_pyarrow_types(): - field_name = "column1" - metadata = {b"metadata_k": b"metadata_v"} - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "int", "bitWidth": 8, "isSigned": True}, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.int8() - assert dict(pyarrow_field.metadata) == metadata - assert pyarrow_field.nullable is False - - field_name = "column_timestamp_no_unit" - metadata = {b"metadata_k": b"metadata_v"} - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "timestamp"}, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.timestamp("ns") - assert dict(pyarrow_field.metadata) == metadata - assert pyarrow_field.nullable is False - - field_name = "column_timestamp_with_unit" - metadata = {b"metadata_k": b"metadata_v"} - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "timestamp", "unit": "MICROSECOND"}, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.timestamp("us") - assert dict(pyarrow_field.metadata) == metadata - assert pyarrow_field.nullable is False - - field_name = "date_with_day_unit" - metadata = {b"metadata_k": b"metadata_v"} - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "date", "unit": "DAY"}, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.date32() - assert dict(pyarrow_field.metadata) == metadata - assert pyarrow_field.nullable is False - - field_name = "simple_list" - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "list"}, - "children": [{"type": {"name": "int", "bitWidth": 32, "isSigned": True}}], - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.list_(pyarrow.field("item", pyarrow.int32())) - assert pyarrow_field.metadata == metadata - assert pyarrow_field.nullable is False - - field_name = "dictionary" - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "int", "bitWidth": 32, "isSigned": True}, - "children": [], - "dictionary": { - "id": 0, - "indexType": {"name": "int", "bitWidth": 16, "isSigned": True}, - }, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.map_(pyarrow.int16(), pyarrow.int32()) - assert pyarrow_field.metadata == metadata - assert pyarrow_field.nullable is False - - field_name = "struct_array" - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "nullable": False, - "metadata": metadata, - "type": {"name": "list"}, - "children": [], - "dictionary": { - "id": 0, - "indexType": {"name": "int", "bitWidth": 32, "isSigned": True}, - }, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.map_( - pyarrow.int32(), - pyarrow.list_( - pyarrow.field( - "entries", - pyarrow.struct( - [pyarrow.field("val", pyarrow.int32(), False, metadata)] - ), - ) +def test_primitive_delta_types(): + valid_types = [ + "string", + "long", + "integer", + "short", + "byte", + "float", + "double", + "boolean", + "binary", + "date", + "timestamp", + "decimal(10,2)", + ] + + invalid_types = ["int", "decimal", "decimal()"] + + for data_type in valid_types: + delta_type = PrimitiveType(data_type) + assert delta_type.type == data_type + assert data_type in str(delta_type) + assert data_type in repr(delta_type) + + pa_type = delta_type.to_pyarrow() + assert delta_type == PrimitiveType.from_pyarrow(pa_type) + + json_type = delta_type.to_json() + assert delta_type == PrimitiveType.from_json(json_type) + + for data_type in invalid_types: + with pytest.raises(ValueError): + PrimitiveType(data_type) + + +def test_array_delta_types(): + init_values = [ + (PrimitiveType("string"), False), + (ArrayType(PrimitiveType("string"), True), True), + ] + + for element_type, contains_null in init_values: + array_type = ArrayType(element_type, contains_null) + + assert array_type.type == "array" + assert array_type.element_type == element_type + assert array_type.contains_null == contains_null + + pa_type = array_type.to_pyarrow() + assert array_type == ArrayType.from_pyarrow(pa_type) + + json_type = array_type.to_json() + assert array_type == ArrayType.from_json(json_type) + + +def test_map_delta_types(): + init_values = [ + (PrimitiveType("string"), PrimitiveType("decimal(20,9)"), False), + (PrimitiveType("float"), PrimitiveType("string"), True), + ( + PrimitiveType("string"), + MapType(PrimitiveType("date"), PrimitiveType("date"), True), + False, ), - ) - assert pyarrow_field.metadata == metadata - assert pyarrow_field.nullable is False - - field_name = "simple_dictionary" - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "metadata": {"metadata_k": "metadata_v"}, - "nullable": False, - "type": {"name": "dictionary"}, - "dictionary": {"indexType": {"type": {"name": "int", "bitWidth": 8}}}, - "children": [{"type": {"name": "int", "bitWidth": 32}}], - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.map_(pyarrow.int8(), pyarrow.int32()) - assert pyarrow_field.metadata == metadata - assert pyarrow_field.nullable is False - - pyarrow_field = pyarrow_field_from_dict( - { - "name": field_name, - "type": {"name": "struct"}, - "children": [ - { - "name": "x", - "type": {"name": "int", "bitWidth": 64}, - "nullable": True, - "metadata": {}, - } - ], - "metadata": {"metadata_k": "metadata_v"}, - "nullable": False, - } - ) - assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.struct( - [pyarrow.field("x", pyarrow.int64(), True, {})] - ) - assert pyarrow_field.metadata == metadata - assert pyarrow_field.nullable is False + ] + for key_type, value_type, value_contains_null in init_values: + map_type = MapType(key_type, value_type, value_contains_null) + + assert map_type.type == "map" + assert map_type.key_type == key_type + assert map_type.value_type == value_type + assert map_type.value_contains_null == value_contains_null + + # Map type is not yet supported in C Data Interface + # https://github.com/apache/arrow-rs/issues/2037 + # pa_type = map_type.to_pyarrow() + # assert map_type == PrimitiveType.from_pyarrow(pa_type) + + json_type = map_type.to_json() + assert map_type == MapType.from_json(json_type) + + +def test_struct_delta_types(): + fields = [ + Field("x", "integer", nullable=True, metadata={"x": {"y": 3}}), + Field("y", PrimitiveType("string"), nullable=False), + ] + + struct_type = StructType(fields) + + assert struct_type.type == "struct" + assert struct_type.fields == fields + + json_type = struct_type.to_json() + assert struct_type == StructType.from_json(json_type) + + # Field metadata doesn't roundtrip currently + # See: https://github.com/apache/arrow-rs/issues/478 + fields = [ + Field("x", "integer", nullable=True), + Field("y", PrimitiveType("string"), nullable=False), + ] + struct_type = StructType(fields) + pa_type = struct_type.to_pyarrow() + assert struct_type == StructType.from_pyarrow(pa_type) + + +def test_delta_field(): + args = [ + ("x", PrimitiveType("string"), True, {}), + ("y", "float", False, {"x": {"y": 3}}), + ("z", ArrayType(StructType([Field("x", "integer", True)]), True), True, None), + ] + + # TODO: are there field names we should reject? + + for name, ty, nullable, metadata in args: + field = Field(name, ty, nullable=nullable, metadata=metadata) + + assert field.name == name + assert field.type == (PrimitiveType(ty) if isinstance(ty, str) else ty) + assert field.nullable == nullable + assert field.metadata == (metadata or {}) + + # Field metadata doesn't roundtrip currently + # See: https://github.com/apache/arrow-rs/issues/478 + if len(field.metadata) == 0: + pa_field = field.to_pyarrow() + assert field == Field.from_pyarrow(pa_field) + + json_field = field.to_json() + assert field == Field.from_json(json_field) + + +def test_delta_schema(): + fields = [ + Field("x", "integer", nullable=True, metadata={"x": {"y": 3}}), + Field("y", PrimitiveType("string"), nullable=False), + ] + + schema = Schema(fields) + + assert schema.fields == fields + + empty_schema = Schema([]) + pa_schema = empty_schema.to_pyarrow() + assert empty_schema == Schema.from_pyarrow(pa_schema) + + # Field metadata doesn't roundtrip currently + # See: https://github.com/apache/arrow-rs/issues/478 + fields = [ + Field("x", "integer", nullable=True), + Field("y", ArrayType("string", contains_null=True), nullable=False), + ] + schema_without_metadata = schema = Schema(fields) + pa_schema = schema_without_metadata.to_pyarrow() + assert schema_without_metadata == Schema.from_pyarrow(pa_schema) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 808db8bd1b..31e7486880 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -46,7 +46,7 @@ def test_roundtrip_basic(tmp_path: pathlib.Path, sample_data: pa.Table): assert ("0" * 20 + ".json") in os.listdir(tmp_path / "_delta_log") delta_table = DeltaTable(str(tmp_path)) - assert delta_table.pyarrow_schema() == sample_data.schema + assert delta_table.schema().to_pyarrow() == sample_data.schema table = delta_table.to_pyarrow_table() assert table == sample_data @@ -85,7 +85,7 @@ def test_update_schema(existing_table: DeltaTable): read_data = existing_table.to_pyarrow_table() assert new_data == read_data - assert existing_table.pyarrow_schema() == new_data.schema + assert existing_table.schema().to_pyarrow() == new_data.schema def test_local_path(tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch): @@ -95,7 +95,7 @@ def test_local_path(tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch): local_path = "./path/to/table" write_deltalake(local_path, sample_data) delta_table = DeltaTable(local_path) - assert delta_table.pyarrow_schema() == sample_data.schema + assert delta_table.schema().to_pyarrow() == sample_data.schema table = delta_table.to_pyarrow_table() assert table == sample_data @@ -140,7 +140,7 @@ def test_roundtrip_partitioned( write_deltalake(str(tmp_path), sample_data, partition_by=[column]) delta_table = DeltaTable(str(tmp_path)) - assert delta_table.pyarrow_schema() == sample_data.schema + assert delta_table.schema().to_pyarrow() == sample_data.schema table = delta_table.to_pyarrow_table() table = table.take(pc.sort_indices(table["int64"])) @@ -158,7 +158,7 @@ def test_roundtrip_null_partition(tmp_path: pathlib.Path, sample_data: pa.Table) write_deltalake(str(tmp_path), sample_data, partition_by=["utf8_with_nulls"]) delta_table = DeltaTable(str(tmp_path)) - assert delta_table.pyarrow_schema() == sample_data.schema + assert delta_table.schema().to_pyarrow() == sample_data.schema table = delta_table.to_pyarrow_table() table = table.take(pc.sort_indices(table["int64"])) @@ -169,7 +169,7 @@ def test_roundtrip_multi_partitioned(tmp_path: pathlib.Path, sample_data: pa.Tab write_deltalake(str(tmp_path), sample_data, partition_by=["int32", "bool"]) delta_table = DeltaTable(str(tmp_path)) - assert delta_table.pyarrow_schema() == sample_data.schema + assert delta_table.schema().to_pyarrow() == sample_data.schema table = delta_table.to_pyarrow_table() table = table.take(pc.sort_indices(table["int64"])) diff --git a/rust/src/delta_arrow.rs b/rust/src/delta_arrow.rs index 772a6a9dc2..e572160959 100644 --- a/rust/src/delta_arrow.rs +++ b/rust/src/delta_arrow.rs @@ -8,6 +8,7 @@ use arrow::datatypes::{ use arrow::error::ArrowError; use lazy_static::lazy_static; use regex::Regex; +use std::collections::BTreeMap; use std::collections::HashMap; use std::convert::TryFrom; @@ -29,11 +30,25 @@ impl TryFrom<&schema::SchemaField> for ArrowField { type Error = ArrowError; fn try_from(f: &schema::SchemaField) -> Result { - Ok(ArrowField::new( + let mut field = ArrowField::new( f.get_name(), ArrowDataType::try_from(f.get_type())?, f.is_nullable(), - )) + ); + + let metadata: Option> = Some(f.get_metadata()) + .filter(|metadata| metadata.is_empty()) + .map(|metadata| { + metadata + .iter() + .map(|(key, val)| Ok((key.clone(), serde_json::to_string(val)?))) + .collect::>() + .map_err(|err| ArrowError::JsonError(err.to_string())) + }) + .transpose()?; + + field.set_metadata(metadata); + Ok(field) } } @@ -42,7 +57,7 @@ impl TryFrom<&schema::SchemaTypeArray> for ArrowField { fn try_from(a: &schema::SchemaTypeArray) -> Result { Ok(ArrowField::new( - "element", + "item", ArrowDataType::try_from(a.get_element_type())?, a.contains_null(), )) @@ -54,7 +69,7 @@ impl TryFrom<&schema::SchemaTypeMap> for ArrowField { fn try_from(a: &schema::SchemaTypeMap) -> Result { Ok(ArrowField::new( - "key_value", + "entries", ArrowDataType::Struct(vec![ ArrowField::new("key", ArrowDataType::try_from(a.get_key_type())?, false), ArrowField::new(