diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 09823866189c..9bd59f99e3da 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -337,7 +337,8 @@ def load(self, key: str) -> _ModelLoadStats: Args: key: the key associated with the model we'd like to load. Returns: - the tag we can use to access the model using multi_process_shared.py. + _ModelLoadStats with tag, byte size, and latency to load the model. If + the model was already loaded, byte size/latency will be None. """ # Map the key for a model to a unique tag that will persist until the model # is released. This needs to be unique between releasing/reacquiring th