Skip to content

Commit

Permalink
Update T4R for ColumnSchema API changes from Merlin Core
Browse files Browse the repository at this point in the history
Since `is_list` and `is_ragged` have become derived properties computed from the shape, it's no longer possible to directly set them from the constructor. They can be smuggled in through the properties, after which they'll be used to determine an appropriate shape that results in the same `is_list` and `is_ragged` values on the other side.

(This is a first step toward capturing and using more comprehensive shape information, with the goal of putting `Shape` in place while breaking as little as possible. There will be subsequent changes to directly capture more shape information, but this gets us part-way there.)

Depends on NVIDIA-Merlin/core#195
  • Loading branch information
karlhigley committed Jan 20, 2023
1 parent 04148b8 commit 3a2effa
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
4 changes: 2 additions & 2 deletions merlin_standard_lib/utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,14 @@ def _augment_schema(
for col in sparse_names or []:
cs = schema[col]
properties = cs.properties
properties["is_list"] = True
properties["is_ragged"] = True
if sparse_max and col in sparse_max:
properties["value_count"] = {"max": sparse_max[col]}
schema[col] = ColumnSchema(
name=cs.name,
tags=cs.tags,
dtype=cs.dtype,
is_list=True,
is_ragged=not sparse_as_dense,
properties=properties,
)

Expand Down
8 changes: 2 additions & 6 deletions transformers4rec/torch/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,19 +739,15 @@ def input_schema(self):

dtype = {0: np.float32, 2: np.int64, 3: np.float32}[column.type]
tags = column.tags
is_list = column.value_count.max > 0
value_counts = {"min": column.value_count.min, "max": column.value_count.max}
int_domain = {"min": column.int_domain.min, "max": column.int_domain.max}
properties = {
"int_domain": int_domain,
}
properties = {"int_domain": int_domain, "value_counts": value_counts}

col_schema = ColumnSchema(
name,
dtype=dtype,
tags=tags,
properties=properties,
is_list=is_list,
is_ragged=False,
)
core_schema[name] = col_schema
return core_schema
Expand Down

0 comments on commit 3a2effa

Please sign in to comment.