diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 80b73c006752..90d43cfddb94 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -304,21 +304,15 @@ class _ModelManager: """ A class for efficiently managing copies of multiple models. Will load a single copy of each model into a multi_process_shared object and then - return a lookup key for that object. Optionally takes in a max_models - parameter, if that is set it will only hold that many models in memory at - once before evicting one (using LRU logic). + return a lookup key for that object. """ - def __init__( - self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None): + def __init__(self, mh_map: Dict[str, ModelHandler]): """ Args: mh_map: A map from keys to model handlers which can be used to load a model. - max_models: The maximum number of models to load at any given time - before evicting 1 from memory (using LRU logic). Leave as None to - allow unlimited models. """ - self._max_models = max_models + self._max_models = None # Map keys to model handlers self._mh_map: Dict[str, ModelHandler] = mh_map # Map keys to the last updated model path for that key @@ -376,14 +370,12 @@ def load(self, key: str) -> _ModelLoadStats: def increment_max_models(self, increment: int): """ Increments the number of models that this instance of a _ModelManager is - able to hold. + able to hold. If it is never called, no limit is imposed. Args: increment: the amount by which we are incrementing the number of models. """ if self._max_models is None: - raise ValueError( - "Cannot increment max_models if self._max_models is None (unlimited" + - " models mode).") + self._max_models = 0 self._max_models += increment def update_model_handler(self, key: str, model_path: str, previous_key: str): @@ -436,7 +428,8 @@ def __init__( self, unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], List[KeyModelMapping[KeyT, ExampleT, PredictionT, - ModelT]]]): + ModelT]]], + max_models_per_worker_hint: Optional[int] = None): """A ModelHandler that takes keyed examples and returns keyed predictions. For example, if the original model is used with RunInference to take a @@ -494,6 +487,11 @@ def __init__( unkeyed: Either (a) an implementation of ModelHandler that does not require keys or (b) a list of KeyModelMappings mapping lists of keys to unkeyed ModelHandlers. + max_models_per_worker_hint: A hint to the runner indicating how many + models can be held in memory at one time per worker process. For + example, if your worker has 8 GB of memory provisioned and your workers + take up 1 GB each, you should set this to 7 to allow all models to sit + in memory with some buffer. """ self._metrics_collectors: Dict[str, _MetricsCollector] = {} self._default_metrics_collector: _MetricsCollector = None @@ -511,6 +509,7 @@ def __init__( self._unkeyed = unkeyed return + self._max_models_per_worker_hint = max_models_per_worker_hint # To maintain an efficient representation, we will map all keys in a given # KeyModelMapping to a single id (the first key in the KeyModelMapping # list). We will then map that key to a ModelHandler. This will allow us to @@ -587,6 +586,14 @@ def run_inference( keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) + # The first time a MultiProcessShared ModelManager is used for inference + # from this process, we should increment its max model count + if self._max_models_per_worker_hint is not None: + lock = threading.Lock() + if lock.acquire(blocking=False): + model.increment_max_models(self._max_models_per_worker_hint) + self._max_models_per_worker_hint = None + batch_by_key = defaultdict(list) key_by_id = defaultdict(set) for key, example in batch: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index f2146cdd1e56..b93f1e185b9d 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -346,6 +346,41 @@ def test_run_inference_impl_with_keyed_examples_many_model_handlers_metrics( load_latency_dist_aggregate = metrics['distributions'][0] self.assertEqual(load_latency_dist_aggregate.committed.count, 2) + def test_run_inference_impl_with_keyed_examples_many_mhs_max_models_hint( + self): + pipeline = TestPipeline() + examples = [1, 5, 3, 10, 2, 4, 6, 8, 9, 7, 1, 5, 3, 10, 2, 4, 6, 8, 9, 7] + metrics_namespace = 'test_namespace' + keyed_examples = [(i, example) for i, example in enumerate(examples)] + pcoll = pipeline | 'start' >> beam.Create(keyed_examples) + mhs = [ + base.KeyModelMapping([0, 2, 4, 6, 8], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyModelMapping( + [1, 3, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + FakeModelHandler(multi_process_shared=True)) + ] + _ = pcoll | base.RunInference( + base.KeyedModelHandler(mhs, max_models_per_worker_hint=1), + metrics_namespace=metrics_namespace) + result = pipeline.run() + result.wait_until_finish() + + metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace) + metrics = result.metrics().query(metrics_filter) + assert len(metrics['counters']) != 0 + assert len(metrics['distributions']) != 0 + + metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs') + metrics = result.metrics().query(metrics_filter) + load_latency_dist_aggregate = metrics['distributions'][0] + # We should flip back and forth between models a bit since + # max_models_per_worker_hint=1, but we shouldn't thrash forever + # since most examples belong to the second ModelMapping + self.assertGreater(load_latency_dist_aggregate.committed.count, 2) + self.assertLess(load_latency_dist_aggregate.committed.count, 12) + def test_keyed_many_model_handlers_validation(self): def mult_two(example: str) -> int: return int(example) * 2 @@ -1367,7 +1402,8 @@ def test_model_manager_evicts_models(self): mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs, max_models=2) + mm = base._ModelManager(mh_map=mhs) + mm.increment_max_models(2) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() @@ -1441,7 +1477,8 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented( mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs, max_models=1) + mm = base._ModelManager(mh_map=mhs) + mm.increment_max_models(1) mm.increment_max_models(1) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) @@ -1477,11 +1514,6 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented( mh3.load_model, tag=tag3).acquire() self.assertEqual(8, model3.predict(10)) - def test_model_manager_fails_if_no_default_initially(self): - mm = base._ModelManager(mh_map={}) - with self.assertRaisesRegex(ValueError, r'self._max_models is None'): - mm.increment_max_models(5) - if __name__ == '__main__': unittest.main()