Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the location and filename of schema.pbtxt to .merlin/schema.json #249

Merged
merged 3 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion merlin/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
# flake8: noqa
from merlin.io import dataframe_iter, dataset, shuffle
from merlin.io.dataframe_iter import DataFrameIter
from merlin.io.dataset import Dataset
from merlin.io.dataset import MERLIN_METADATA_DIR_NAME, Dataset
from merlin.io.shuffle import Shuffle, shuffle_df
28 changes: 25 additions & 3 deletions merlin/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
except ImportError:
cudf = None


MERLIN_METADATA_DIR_NAME = ".merlin"
LOG = logging.getLogger("merlin")


Expand Down Expand Up @@ -339,10 +341,28 @@ def __init__(
if schema_path.is_file():
schema_path = schema_path.parent

if (schema_path / "schema.pbtxt").exists():
pbtxt_deprecated_warning = (
"Found schema.pbtxt. Loading schema automatically from "
"schema.pbtxt is deprecated and will be removed in the "
"future. Re-run workflow to generate .merlin/schema.json."
)

if (schema_path / MERLIN_METADATA_DIR_NAME / "schema.json").exists():
schema = TensorflowMetadata.from_json_file(
schema_path / MERLIN_METADATA_DIR_NAME
)
self.schema = schema.to_merlin_schema()
elif (schema_path.parent / MERLIN_METADATA_DIR_NAME / "schema.json").exists():
schema = TensorflowMetadata.from_json_file(
schema_path.parent / MERLIN_METADATA_DIR_NAME
)
self.schema = schema.to_merlin_schema()
elif (schema_path / "schema.pbtxt").exists():
warnings.warn(pbtxt_deprecated_warning, DeprecationWarning)
schema = TensorflowMetadata.from_proto_text_file(schema_path)
self.schema = schema.to_merlin_schema()
elif (schema_path.parent / "schema.pbtxt").exists():
warnings.warn(pbtxt_deprecated_warning, DeprecationWarning)
schema = TensorflowMetadata.from_proto_text_file(schema_path.parent)
self.schema = schema.to_merlin_schema()
else:
Expand Down Expand Up @@ -909,10 +929,12 @@ def to_parquet(
schema[col_name] = schema[col_name].with_dtype(col_dtype)

fs = get_fs_token_paths(output_path)[0]
fs.mkdirs(output_path, exist_ok=True)
fs.mkdirs(str(output_path), exist_ok=True)

tf_metadata = TensorflowMetadata.from_merlin_schema(schema)
tf_metadata.to_proto_text_file(output_path)
metadata_path = fs.sep.join([str(output_path), MERLIN_METADATA_DIR_NAME])
fs.mkdirs(metadata_path, exist_ok=True)
tf_metadata.to_json_file(metadata_path)

# Output dask_cudf DataFrame to dataset
_ddf_to_dataset(
Expand Down
29 changes: 25 additions & 4 deletions merlin/schema/io/tensorflow_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"uint": FeatureType.INT,
"float": FeatureType.FLOAT,
}
SCHEMA_PBTXT_FILE_NAME = "schema.pbtxt"
SCHEMA_JSON_FILE_NAME = "schema.json"


class TensorflowMetadata:
Expand Down Expand Up @@ -64,20 +66,27 @@ def from_json(cls, json: Union[str, bytes]) -> "TensorflowMetadata":
return TensorflowMetadata(schema)

@classmethod
def from_json_file(cls, path: os.PathLike) -> "TensorflowMetadata":
def from_json_file(
cls, path: os.PathLike, file_name=SCHEMA_JSON_FILE_NAME
) -> "TensorflowMetadata":
"""Create a TensorflowMetadata schema object from a JSON file

Parameters
----------
path : str
Path to the JSON file to parse
file_name : str
Name of the schema file. Defaults to "schema.json".

Returns
-------
TensorflowMetadata
Schema object parsed from JSON file

"""
path = pathlib.Path(path)
if path.is_dir():
path = path / file_name
return cls.from_json(_read_file(path))

@classmethod
Expand Down Expand Up @@ -105,7 +114,7 @@ def from_proto_text(cls, proto_text: str) -> "TensorflowMetadata":

@classmethod
def from_proto_text_file(
cls, path: os.PathLike, file_name="schema.pbtxt"
cls, path: os.PathLike, file_name=SCHEMA_PBTXT_FILE_NAME
) -> "TensorflowMetadata":
"""Create a TensorflowMetadata schema object from a Protobuf text file

Expand Down Expand Up @@ -138,7 +147,7 @@ def to_proto_text(self) -> str:

return proto_utils.better_proto_to_proto_text(self.proto_schema, schema_pb2.Schema())

def to_proto_text_file(self, path: str, file_name="schema.pbtxt"):
def to_proto_text_file(self, path: str, file_name=SCHEMA_PBTXT_FILE_NAME):
"""Write this TensorflowMetadata schema object to a file as a Proto text string

Parameters
Expand All @@ -147,7 +156,6 @@ def to_proto_text_file(self, path: str, file_name="schema.pbtxt"):
Path to the directory containing the Protobuf text file
file_name : str
Name of the output file. Defaults to "schema.pbtxt".
path: str :

"""
_write_file(self.to_proto_text(), path, file_name)
Expand Down Expand Up @@ -221,6 +229,19 @@ def to_json(self) -> str:
"""
return self.proto_schema.to_json()

def to_json_file(self, path: str, file_name=SCHEMA_JSON_FILE_NAME):
"""Write this TensorflowMetadata schema object to a file as a JSON

Parameters
----------
path : str
Path to the directory containing the JSON text file
file_name : str
Name of the output file. Defaults to "schema.json".

"""
_write_file(self.to_json(), path, file_name)


def _pb_int_domain(column_schema):
domain = column_schema.properties.get("domain")
Expand Down
17 changes: 10 additions & 7 deletions tests/unit/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_dataset_infer_schema(dataset, engine):

@pytest.mark.parametrize("engine", ["csv", "parquet", "csv-no-header"])
@pytest.mark.parametrize("cpu", [None, True])
def test_string_datatypes(tmpdir, engine, cpu):
@pytest.mark.parametrize("file_format", ["pbtxt", "json"])
def test_string_datatypes(tmpdir, engine, cpu, file_format):
df_lib = dispatch.get_lib()
df = df_lib.DataFrame({"column": [[0.1, 0.2]]})
dataset = merlin.io.Dataset(df)
Expand All @@ -100,10 +101,15 @@ def test_string_datatypes(tmpdir, engine, cpu):
assert not isinstance(column_schema.dtype, str)

tf_metadata = TensorflowMetadata.from_merlin_schema(dataset.schema)
tf_metadata.to_proto_text_file(tmpdir)

pb_schema = TensorflowMetadata.from_proto_text_file(str(tmpdir))
loaded_schema = pb_schema.to_merlin_schema()
if file_format == "pbtxt":
tf_metadata.to_proto_text_file(tmpdir)
schema = TensorflowMetadata.from_proto_text_file(str(tmpdir))
elif file_format == "json":
tf_metadata.to_json_file(tmpdir)
schema = TensorflowMetadata.from_json_file(str(tmpdir))

loaded_schema = schema.to_merlin_schema()

column_schema = loaded_schema.column_schemas["column"]
assert not isinstance(column_schema.dtype, str)
Expand Down Expand Up @@ -648,9 +654,6 @@ def test_hive_partitioned_data(tmpdir, cpu):
assert result_paths
assert all(p.endswith(".parquet") for p in result_paths)

# reading into dask dastaframe cannot have schema in same directory
os.remove(os.path.join(path, "schema.pbtxt"))

# Read back with dask.dataframe and check the data
df_check = dd.read_parquet(path, engine="pyarrow").compute()
df_check["name"] = df_check["name"].astype("object")
Expand Down