Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for limiting number of models in memory #28263

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,13 @@ class _ModelManager:
parameter, if that is set it will only hold that many models in memory at
once before evicting one (using LRU logic).
damccorm marked this conversation as resolved.
Show resolved Hide resolved
"""
def __init__(
self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
def __init__(self, mh_map: Dict[str, ModelHandler]):
"""
Args:
mh_map: A map from keys to model handlers which can be used to load a
model.
max_models: The maximum number of models to load at any given time
before evicting 1 from memory (using LRU logic). Leave as None to
allow unlimited models.
"""
self._max_models = max_models
self._max_models = None
# Map keys to model handlers
self._mh_map: Dict[str, ModelHandler] = mh_map
# Map keys to the last updated model path for that key
Expand Down Expand Up @@ -381,9 +377,7 @@ def increment_max_models(self, increment: int):
increment: the amount by which we are incrementing the number of models.
"""
if self._max_models is None:
raise ValueError(
"Cannot increment max_models if self._max_models is None (unlimited" +
" models mode).")
self._max_models = 0
self._max_models += increment

def update_model_handler(self, key: str, model_path: str, previous_key: str):
Expand Down Expand Up @@ -436,7 +430,8 @@ def __init__(
self,
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
List[KeyModelMapping[KeyT, ExampleT, PredictionT,
ModelT]]]):
ModelT]]],
max_models_per_worker_hint: Optional[int] = None):
"""A ModelHandler that takes keyed examples and returns keyed predictions.

For example, if the original model is used with RunInference to take a
Expand Down Expand Up @@ -494,6 +489,11 @@ def __init__(
unkeyed: Either (a) an implementation of ModelHandler that does not
require keys or (b) a list of KeyModelMappings mapping lists of keys to
unkeyed ModelHandlers.
max_models_per_worker_hint: A hint to the runner indicating how many
models can be held in memory at one time per worker process. For
example, if your worker has 8 GB of memory provisioned and your workers
take up 1 GB each, you should set this to 7 to allow all models to sit
in memory with some buffer.
"""
self._metrics_collectors: Dict[str, _MetricsCollector] = {}
self._default_metrics_collector: _MetricsCollector = None
Expand All @@ -511,6 +511,7 @@ def __init__(
self._unkeyed = unkeyed
return

self._max_models_per_worker_hint = max_models_per_worker_hint
# To maintain an efficient representation, we will map all keys in a given
# KeyModelMapping to a single id (the first key in the KeyModelMapping
# list). We will then map that key to a ModelHandler. This will allow us to
Expand Down Expand Up @@ -587,6 +588,14 @@ def run_inference(
keys,
self._unkeyed.run_inference(unkeyed_batch, model, inference_args))

# The first time a MultiProcessShared ModelManager is used for inference
# from this process, we should increment its max model count
if self._max_models_per_worker_hint is not None:
lock = threading.Lock()
if lock.acquire(blocking=False):
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
model.increment_max_models(self._max_models_per_worker_hint)
self._max_models_per_worker_hint = None

batch_by_key = defaultdict(list)
key_by_id = defaultdict(set)
for key, example in batch:
Expand Down
46 changes: 39 additions & 7 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,41 @@ def test_run_inference_impl_with_keyed_examples_many_model_handlers_metrics(
load_latency_dist_aggregate = metrics['distributions'][0]
self.assertEqual(load_latency_dist_aggregate.committed.count, 2)

def test_run_inference_impl_with_keyed_examples_many_mhs_max_models_hint(
self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10, 2, 4, 6, 8, 9, 7, 1, 5, 3, 10, 2, 4, 6, 8, 9, 7]
metrics_namespace = 'test_namespace'
keyed_examples = [(i, example) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
mhs = [
base.KeyModelMapping([0, 2, 4, 6, 8],
FakeModelHandler(
state=200, multi_process_shared=True)),
base.KeyModelMapping(
[1, 3, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
FakeModelHandler(multi_process_shared=True))
]
_ = pcoll | base.RunInference(
base.KeyedModelHandler(mhs, max_models_per_worker_hint=1),
metrics_namespace=metrics_namespace)
result = pipeline.run()
result.wait_until_finish()

metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace)
metrics = result.metrics().query(metrics_filter)
assert len(metrics['counters']) != 0
assert len(metrics['distributions']) != 0

metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs')
metrics = result.metrics().query(metrics_filter)
load_latency_dist_aggregate = metrics['distributions'][0]
# We should flip back and forth between models a bit since
# max_models_per_worker_hint=1, but we shouldn't thrash forever
# since most examples belong to the second ModelMapping
self.assertGreater(load_latency_dist_aggregate.committed.count, 2)
self.assertLess(load_latency_dist_aggregate.committed.count, 12)

def test_keyed_many_model_handlers_validation(self):
def mult_two(example: str) -> int:
return int(example) * 2
Expand Down Expand Up @@ -1367,7 +1402,8 @@ def test_model_manager_evicts_models(self):
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs, max_models=2)
mm = base._ModelManager(mh_map=mhs)
mm.increment_max_models(2)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
Expand Down Expand Up @@ -1441,7 +1477,8 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs, max_models=1)
mm = base._ModelManager(mh_map=mhs)
mm.increment_max_models(1)
mm.increment_max_models(1)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
Expand Down Expand Up @@ -1477,11 +1514,6 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))

def test_model_manager_fails_if_no_default_initially(self):
mm = base._ModelManager(mh_map={})
with self.assertRaisesRegex(ValueError, r'self._max_models is None'):
mm.increment_max_models(5)


if __name__ == '__main__':
unittest.main()