Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ItemRetrievalTask can’t be saved #188

Closed
marcromeyn opened this issue Feb 16, 2022 · 3 comments
Closed

[BUG] ItemRetrievalTask can’t be saved #188

marcromeyn opened this issue Feb 16, 2022 · 3 comments
Assignees
Labels
bug Something isn't working P1

Comments

@marcromeyn
Copy link
Contributor

Bug description

With this test case:

def test_two_tower_block_saving(ecommerce_data: SyntheticData):
    two_tower = ml.TwoTowerBlock(ecommerce_data.schema, query_tower=ml.MLPBlock([64, 128]))

    model = two_tower.connect(
        ml.ItemRetrievalTask(ecommerce_data.schema, target_name="click", metrics=[])
    )

    dataset = ecommerce_data.tf_dataloader(batch_size=50)
    copy_two_tower = testing_utils.assert_model_is_retrainable(model, dataset)

    outputs = copy_two_tower(ecommerce_data.tf_tensor_dict)
    assert list(outputs.shape) == [100, 1]

Expected behavior

@viswa-nvidia
Copy link

Sara could reproduce this issue. @sararb , please provide an update on this.

@sararb
Copy link
Contributor

sararb commented Mar 14, 2022

The issue is related to being able to save to disk the custom classes related to ItemRetrievalTask and which are: ItemRetrievalScorer, InBatchSampler, CachedCrossBatchSampler, CachedUniformSampler and PopularityBasedSampler

@oliverholworthy
Copy link
Member

The saving of ItemRetrievalTask and related layers is now resolved. By #594 and #615

Loading is still a problem and tracked in #498

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P1
Projects
None yet
Development

No branches or pull requests

4 participants