Skip to content

Commit

Permalink
0.0.159
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed May 13, 2024
1 parent f8d841d commit 1dcc122
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 1 deletion.
70 changes: 70 additions & 0 deletions orso/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from warnings import warn

import numpy
import orjson
from data_expectations import Expectation

from orso.exceptions import ColumnDefinitionError
Expand Down Expand Up @@ -260,6 +261,54 @@ def all_names(self):
return self.aliases + [self.name]
return [self.name]

@property
def arrow_field(self):
import pyarrow

TYPE_MAP: dict = {
OrsoTypes.BOOLEAN: pyarrow.bool_(),
OrsoTypes.BLOB: pyarrow.binary(),
OrsoTypes.DATE: pyarrow.date64(),
OrsoTypes.TIMESTAMP: pyarrow.timestamp("us"),
OrsoTypes.TIME: pyarrow.time32("ms"),
OrsoTypes.INTERVAL: pyarrow.month_day_nano_interval(),
OrsoTypes.STRUCT: pyarrow.struct([]),
OrsoTypes.DECIMAL: lambda col: pyarrow.decimal128(col.precision, col.scale),
OrsoTypes.DOUBLE: pyarrow.float64(),
OrsoTypes.INTEGER: pyarrow.int64(),
OrsoTypes.ARRAY: pyarrow.list_(pyarrow.string()),
OrsoTypes.VARCHAR: pyarrow.string(),
OrsoTypes.BSON: pyarrow.binary(),
OrsoTypes.NULL: pyarrow.null(),
}

return pyarrow.field(name=self.name, type=TYPE_MAP.get(self.type, pyarrow.string()))

def to_json(self) -> str:
def default_serializer(o):
if isinstance(o, OrsoTypes):
return str(o)
if isinstance(o, Expectation):
return o.__dict__
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")

return orjson.dumps(asdict(self), default=default_serializer)

@classmethod
def from_json(cls, json_str: str) -> "FlatColumn":
def custom_decoder(dct):
if "type" in dct:
dct["type"] = OrsoTypes.__members__.get(dct["type"], OrsoTypes._MISSING_TYPE)
if "expectations" in dct and isinstance(dct["expectations"], list):
dct["expectations"] = [
SchemaExpectation.load(v) if isinstance(v, dict) else v
for v in dct["expectations"]
]
return dct

data = orjson.loads(json_str)
return cls(**data)


@dataclass(init=False)
class FunctionColumn(FlatColumn):
Expand Down Expand Up @@ -605,3 +654,24 @@ def validate(self, data: MutableMapping) -> bool:
if errors:
raise DataValidationError(errors=errors)
return True

def to_json(self):
return {
"name": self.name,
"aliases": self.aliases,
"primary_key": self.primary_key,
"columns": [col.to_json() for col in self.columns],
}


def convert_arrow_schema_to_orso_schema(arrow_schema):
return RelationSchema(
name="arrow",
columns=[FlatColumn.from_arrow(field) for field in arrow_schema],
)


def convert_orso_schema_to_arrow_schema(orso_schema):
from pyarrow import schema

return schema([col.arrow_field for col in orso_schema.columns])
2 changes: 1 addition & 1 deletion orso/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__: str = "0.0.158"
__version__: str = "0.0.159"
__author__: str = "@joocer"
33 changes: 33 additions & 0 deletions tests/test_schema_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,29 @@ def test_to_flatcolumn_preserve_attributes():
assert getattr(new_column, field) == getattr(flat_column, field), field


def test_to_json_and_back():
"""
Test that to_flatcolumn preserves the attributes.
"""
flat_column = FlatColumn(
name="id",
type=OrsoTypes.INTEGER,
description="An ID column",
aliases=["ID"],
nullable=False,
precision=5,
)
as_json = flat_column.to_json()
as_column = FlatColumn.from_json(as_json)

for field in [
f
for f in dir(FlatColumn)
if (f[0] != "_" and isinstance(getattr(flat_column, f), (int, str, float, list, OrsoTypes)))
]:
assert getattr(as_column, field) == getattr(flat_column, field), field


def test_aliasing():
col = FlatColumn(name="alpha", type=OrsoTypes.VARCHAR)

Expand All @@ -375,7 +398,17 @@ def test_minimum_definition():
col = FlatColumn(name="a")


def test_arrow_conversion():
from tests.cities import schema as city_schema
from pyarrow import schema as arrow_schema

_arrow_schema = arrow_schema([col.arrow_field for col in city_schema.columns])

print(_arrow_schema)


if __name__ == "__main__": # prgama: nocover
from tests import run_tests

test_arrow_conversion()
run_tests()

0 comments on commit 1dcc122

Please sign in to comment.