From 6865f13b241a88bc51bad1fd3fe18d59b1fe4636 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Tue, 29 Aug 2023 16:54:35 -0400 Subject: [PATCH 1/3] Update KeyMhMapping to KeyModelMapping --- ...ytorch_model_per_key_image_segmentation.py | 4 +- sdks/python/apache_beam/ml/inference/base.py | 17 ++--- .../apache_beam/ml/inference/base_test.py | 62 ++++++++++--------- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py index e09a348511b2..f0b5462d5335 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py @@ -33,7 +33,7 @@ import torch from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.base import KeyedModelHandler -from apache_beam.ml.inference.base import KeyMhMapping +from apache_beam.ml.inference.base import KeyModelMapping from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor @@ -257,7 +257,7 @@ def run( # Note that multiple keys can also point to a single model handler, # unlike this example. model_handler = KeyedModelHandler( - [KeyMhMapping(['v1'], mh1), KeyMhMapping(['v2'], mh2)]) + [KeyModelMapping(['v1'], mh1), KeyModelMapping(['v2'], mh2)]) pipeline = test_pipeline if not test_pipeline: diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 10e8981d8bdf..b5aa4f352fa6 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -396,10 +396,10 @@ def update_model_handler(self, key: str, model_path: str, previous_key: str): # Use a dataclass instead of named tuple because NamedTuples and generics don't # mix well across the board for all versions: # https://github.com/python/typing/issues/653 -class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]): +class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]): """ - Dataclass for mapping 1 or more keys to 1 model handler. - Given `KeyMhMapping(['key1', 'key2'], myMh)`, all examples with keys `key1` + Dataclass for mapping 1 or more keys to 1 model handler. Given + `KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1` or `key2` will be run against the model defined by the `myMh` ModelHandler. """ def __init__( @@ -415,7 +415,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], def __init__( self, unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], - List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]): + List[KeyModelMapping[KeyT, ExampleT, PredictionT, + ModelT]]]): """A ModelHandler that takes keyed examples and returns keyed predictions. For example, if the original model is used with RunInference to take a @@ -429,7 +430,7 @@ def __init__( k1 = ['k1', 'k2', 'k3'] k2 = ['k4', 'k5'] - KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)]) + KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)]) Note that a single copy of each of these models may all be held in memory at the same time; be careful not to load too many large models or your @@ -462,7 +463,7 @@ def __init__( Args: unkeyed: Either (a) an implementation of ModelHandler that does not - require keys or (b) a list of KeyMhMappings mapping lists of keys to + require keys or (b) a list of KeyModelMappings mapping lists of keys to unkeyed ModelHandlers. """ self._single_model = not isinstance(unkeyed, list) @@ -479,8 +480,8 @@ def __init__( return # To maintain an efficient representation, we will map all keys in a given - # KeyMhMapping to a single id (the first key in the KeyMhMapping list). - # We will then map that key to a ModelHandler. This will allow us to + # KeyModelMapping to a single id (the first key in the KeyModelMapping + # list). We will then map that key to a ModelHandler. This will allow us to # quickly look up the appropriate ModelHandler for any given key. self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT, ModelT]] = {} diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 4b551ce55847..af6168c80aff 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -277,11 +277,11 @@ def test_run_inference_impl_with_keyed_examples_many_model_handlers(self): expected[0] = (0, 200) pcoll = pipeline | 'start' >> beam.Create(keyed_examples) mhs = [ - base.KeyMhMapping([0], - FakeModelHandler( - state=200, multi_process_shared=True)), - base.KeyMhMapping([1, 2, 3], - FakeModelHandler(multi_process_shared=True)) + base.KeyModelMapping([0], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyModelMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) ] actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs)) assert_that(actual, equal_to(expected), label='assert:inferences') @@ -291,45 +291,45 @@ def mult_two(example: str) -> int: return int(example) * 2 mhs = [ - base.KeyMhMapping( + base.KeyModelMapping( [0], FakeModelHandler( state=200, multi_process_shared=True).with_preprocess_fn(mult_two)), - base.KeyMhMapping([1, 2, 3], - FakeModelHandler(multi_process_shared=True)) + base.KeyModelMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) ] with self.assertRaises(ValueError): base.KeyedModelHandler(mhs) mhs = [ - base.KeyMhMapping( + base.KeyModelMapping( [0], FakeModelHandler( state=200, multi_process_shared=True).with_postprocess_fn(mult_two)), - base.KeyMhMapping([1, 2, 3], - FakeModelHandler(multi_process_shared=True)) + base.KeyModelMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) ] with self.assertRaises(ValueError): base.KeyedModelHandler(mhs) mhs = [ - base.KeyMhMapping([0], - FakeModelHandler( - state=200, multi_process_shared=True)), - base.KeyMhMapping([0, 1, 2, 3], - FakeModelHandler(multi_process_shared=True)) + base.KeyModelMapping([0], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyModelMapping([0, 1, 2, 3], + FakeModelHandler(multi_process_shared=True)) ] with self.assertRaises(ValueError): base.KeyedModelHandler(mhs) mhs = [ - base.KeyMhMapping([], - FakeModelHandler( - state=200, multi_process_shared=True)), - base.KeyMhMapping([0, 1, 2, 3], - FakeModelHandler(multi_process_shared=True)) + base.KeyModelMapping([], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyModelMapping([0, 1, 2, 3], + FakeModelHandler(multi_process_shared=True)) ] with self.assertRaises(ValueError): base.KeyedModelHandler(mhs) @@ -343,8 +343,10 @@ def test_keyed_model_handler_get_num_bytes(self): def test_keyed_model_handler_multiple_models_get_num_bytes(self): mhs = [ - base.KeyMhMapping(['key1'], FakeModelHandler(num_bytes_per_element=10)), - base.KeyMhMapping(['key2'], FakeModelHandler(num_bytes_per_element=20)) + base.KeyModelMapping(['key1'], + FakeModelHandler(num_bytes_per_element=10)), + base.KeyModelMapping(['key2'], + FakeModelHandler(num_bytes_per_element=20)) ] mh = base.KeyedModelHandler(mhs) batch = [('key1', 1), ('key2', 2), ('key1', 3)] @@ -1010,12 +1012,12 @@ def test_run_inference_side_input_in_batch_per_key_models(self): ] model_handler = base.KeyedModelHandler([ - base.KeyMhMapping(['key1'], - FakeModelHandlerReturnsPredictionResult( - multi_process_shared=True, state=True)), - base.KeyMhMapping(['key2'], - FakeModelHandlerReturnsPredictionResult( - multi_process_shared=True, state=True)) + base.KeyModelMapping(['key1'], + FakeModelHandlerReturnsPredictionResult( + multi_process_shared=True, state=True)), + base.KeyModelMapping(['key2'], + FakeModelHandlerReturnsPredictionResult( + multi_process_shared=True, state=True)) ]) class _EmitElement(beam.DoFn): @@ -1114,7 +1116,7 @@ def test_run_inference_side_input_in_batch_per_key_models_split_cohort(self): ] model_handler = base.KeyedModelHandler([ - base.KeyMhMapping( + base.KeyModelMapping( ['key1', 'key2'], FakeModelHandlerReturnsPredictionResult(multi_process_shared=True)) ]) From 1c00f1aafde97ce612fa7c519f4a2f5df30e99e5 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 30 Aug 2023 12:29:22 -0400 Subject: [PATCH 2/3] Keep track of metrics at the KeyedModelHandler level --- sdks/python/apache_beam/ml/inference/base.py | 123 +++++++++++++++--- .../apache_beam/ml/inference/base_test.py | 84 ++++++++++-- 2 files changed, 175 insertions(+), 32 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index b5aa4f352fa6..09823866189c 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -108,6 +108,12 @@ class RunInferenceDLQ(NamedTuple): failed_postprocessing: Sequence[beam.PCollection] +class _ModelLoadStats(NamedTuple): + model_tag: str + load_latency: Optional[int] + byte_size: Optional[int] + + ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be a file path or a URL where the model can be accessed. It is used to load the model for inference.""" @@ -130,17 +136,18 @@ class KeyModelPathMapping(Generic[KeyT]): Dataclass for mapping 1 or more keys to 1 model path. This is used in conjunction with a KeyedModelHandler with many model handlers to update a set of keys' model handlers with the new path. Given - `KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path')`, - all examples with keys `key1` or `key2` will have their corresponding model - handler's update_model function called with 'updated/path'. For more - information see the - KeyedModelHandler documentation + `KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path', + model_id: 'id1')`, all examples with keys `key1` or `key2` will have their + corresponding model handler's update_model function called with + 'updated/path' and their metrics will correspond with 'id1'. For more + information see the KeyedModelHandler documentation https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler documentation and the website section on model updates https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh """ keys: List[KeyT] update_path: str + model_id: str = '' class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): @@ -286,6 +293,12 @@ def share_model_across_processes(self) -> bool: https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""" return False + def override_metrics(self, metrics_namespace: str = '') -> bool: + """Returns a boolean representing whether or not a model handler will + override metrics reporting. If True, RunInference will not report any + metrics.""" + return False + class _ModelManager: """ @@ -318,7 +331,7 @@ def __init__( # of this map should last as long as the corresponding entry in _tag_map. self._proxy_map: Dict[str, multi_process_shared.MultiProcessShared] = {} - def load(self, key: str) -> str: + def load(self, key: str) -> _ModelLoadStats: """ Loads the appropriate model for the given key into memory. Args: @@ -332,6 +345,7 @@ def load(self, key: str) -> str: # has been released and deleted. if key in self._tag_map: self._tag_map.move_to_end(key) + return _ModelLoadStats(self._tag_map[key], None, None) else: self._tag_map[key] = uuid.uuid4().hex @@ -346,12 +360,17 @@ def load(self, key: str) -> str: del self._proxy_map[tag_to_remove] # Load the new model + memory_before = _get_current_process_memory_in_bytes() + start_time = _to_milliseconds(time.time_ns()) shared_handle = multi_process_shared.MultiProcessShared( mh.load_model, tag=tag) model_reference = shared_handle.acquire() self._proxy_map[tag] = (shared_handle, model_reference) + memory_after = _get_current_process_memory_in_bytes() + end_time = _to_milliseconds(time.time_ns()) - return tag + return _ModelLoadStats( + tag, end_time - start_time, memory_after - memory_before) def increment_max_models(self, increment: int): """ @@ -460,12 +479,24 @@ def __init__( must appear in your list of KeyModelPathMappings exactly once. No additional keys can be added. + When using many models defined per key, metrics about inference and model + loading will be gathered on an aggregate basis for all keys. These will be + reported with no prefix. Metrics will also be gathered on a per key basis. + Since some keys can share the same model, only one set of metrics will be + reported per key 'cohort'. These will be reported in the form: + `-`, where `` can be any key selected + from the cohort. When model updates occur, the metrics will be reported in + the form `--`. + Args: unkeyed: Either (a) an implementation of ModelHandler that does not require keys or (b) a list of KeyModelMappings mapping lists of keys to unkeyed ModelHandlers. """ + self._metrics_collectors: Dict[str, _MetricsCollector] = {} + self._default_metrics_collector: _MetricsCollector = None + self._metrics_namespace = '' self._single_model = not isinstance(unkeyed, list) if self._single_model: if len(unkeyed.get_preprocess_fns()) or len( @@ -564,17 +595,41 @@ def run_inference( predictions = [] for id, keys in key_by_id.items(): mh = self._id_to_mh_map[id] - keyed_model_tag = model.load(id) + loaded_model = model.load(id) + keyed_model_tag = loaded_model.model_tag + if loaded_model.byte_size is not None: + self._metrics_collectors[id].update_load_model_metrics( + loaded_model.load_latency, loaded_model.byte_size) + self._default_metrics_collector.update_load_model_metrics( + loaded_model.load_latency, loaded_model.byte_size) keyed_model_shared_handle = multi_process_shared.MultiProcessShared( mh.load_model, tag=keyed_model_tag) keyed_model = keyed_model_shared_handle.acquire() - for key in keys: - unkeyed_batches = batch_by_key[key] - for inf in mh.run_inference(unkeyed_batches, - keyed_model, - inference_args): - predictions.append((key, inf)) - keyed_model_shared_handle.release(keyed_model) + start_time = _to_microseconds(time.time_ns()) + num_bytes = 0 + num_elements = 0 + try: + for key in keys: + unkeyed_batches = batch_by_key[key] + try: + for inf in mh.run_inference(unkeyed_batches, + keyed_model, + inference_args): + predictions.append((key, inf)) + except BaseException as e: + self._metrics_collectors[id].failed_batches_counter.inc() + self._default_metrics_collector.failed_batches_counter.inc() + raise e + num_bytes += mh.get_num_bytes(unkeyed_batches) + num_elements += len(unkeyed_batches) + finally: + keyed_model_shared_handle.release(keyed_model) + end_time = _to_microseconds(time.time_ns()) + inference_latency = end_time - start_time + self._metrics_collectors[id].update( + num_elements, num_bytes, inference_latency) + self._default_metrics_collector.update( + num_elements, num_bytes, inference_latency) return predictions @@ -641,10 +696,12 @@ def update_model_paths( # } # } cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {} + key_modelid_mapping: Dict[KeyT, str] = {} seen_keys = set() for mp in model_paths: keys = mp.keys update_path = mp.update_path + model_id = mp.model_id if len(update_path) == 0: raise ValueError(f'Invalid model update, path for {keys} is empty') for key in keys: @@ -658,6 +715,7 @@ def update_model_paths( raise ValueError( f'Invalid model update: {key} appears in ' 'update, but not in the original configuration.') + key_modelid_mapping[key] = model_id cohort_id = self._key_to_id_map[key] if cohort_id not in cohort_path_mapping: cohort_path_mapping[cohort_id] = defaultdict(list) @@ -682,6 +740,9 @@ def update_model_paths( self._id_to_mh_map[cohort_id] = deepcopy(mh) self._id_to_mh_map[cohort_id].update_model_path(updated_path) model.update_model_handler(cohort_id, updated_path, old_cohort_id) + model_id = key_modelid_mapping[cohort_id] + self._metrics_collectors[cohort_id] = _MetricsCollector( + self._metrics_namespace, f'{cohort_id}-{model_id}-') def update_model_path(self, model_path: Optional[str] = None): if self._single_model: @@ -697,6 +758,18 @@ def share_model_across_processes(self) -> bool: return self._unkeyed.share_model_across_processes() return True + def override_metrics(self, metrics_namespace: str = '') -> bool: + if self._single_model: + return self._unkeyed.override_metrics(metrics_namespace) + + self._metrics_namespace = metrics_namespace + self._default_metrics_collector = _MetricsCollector(metrics_namespace) + for cohort_id in self._id_to_mh_map: + self._metrics_collectors[cohort_id] = _MetricsCollector( + metrics_namespace, f'{cohort_id}-') + + return True + class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Union[ExampleT, Tuple[KeyT, @@ -1179,6 +1252,10 @@ def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size): self._load_model_latency_milli_secs_cache = load_model_latency_ms self._model_byte_size_cache = model_byte_size + def update_load_model_metrics(self, load_model_latency_ms, model_byte_size): + self._load_model_latency_milli_secs.update(load_model_latency_ms) + self._model_byte_size.update(model_byte_size) + def update( self, examples_count: int, @@ -1235,8 +1312,9 @@ def load(): memory_after = _get_current_process_memory_in_bytes() load_model_latency_ms = end_time - start_time model_byte_size = memory_after - memory_before - self._metrics_collector.cache_load_model_metrics( - load_model_latency_ms, model_byte_size) + if self._metrics_collector: + self._metrics_collector.cache_load_model_metrics( + load_model_latency_ms, model_byte_size) return model # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing @@ -1268,6 +1346,8 @@ def get_metrics_collector(self, prefix: str = ''): metrics_namespace = ( self._metrics_namespace) if self._metrics_namespace else ( self._model_handler.get_metrics_namespace()) + if self._model_handler.override_metrics(metrics_namespace): + return None return _MetricsCollector(metrics_namespace, prefix=prefix) def setup(self): @@ -1288,7 +1368,8 @@ def _run_inference(self, batch, inference_args): result_generator = self._model_handler.run_inference( batch, self._model, inference_args) except BaseException as e: - self._metrics_collector.failed_batches_counter.inc() + if self._metrics_collector: + self._metrics_collector.failed_batches_counter.inc() raise e predictions = list(result_generator) @@ -1296,7 +1377,8 @@ def _run_inference(self, batch, inference_args): inference_latency = end_time - start_time num_bytes = self._model_handler.get_num_bytes(batch) num_elements = len(batch) - self._metrics_collector.update(num_elements, num_bytes, inference_latency) + if self._metrics_collector: + self._metrics_collector.update(num_elements, num_bytes, inference_latency) return predictions @@ -1336,7 +1418,8 @@ def process( def finish_bundle(self): # TODO(https://github.com/apache/beam/issues/21435): Figure out why there # is a cache. - self._metrics_collector.update_metrics_with_cache() + if self._metrics_collector: + self._metrics_collector.update_metrics_with_cache() def _is_darwin() -> bool: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index af6168c80aff..f2146cdd1e56 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -286,6 +286,66 @@ def test_run_inference_impl_with_keyed_examples_many_model_handlers(self): actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs)) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_keyed_examples_many_model_handlers_metrics( + self): + pipeline = TestPipeline() + examples = [1, 5, 3, 10] + metrics_namespace = 'test_namespace' + keyed_examples = [(i, example) for i, example in enumerate(examples)] + pcoll = pipeline | 'start' >> beam.Create(keyed_examples) + mhs = [ + base.KeyModelMapping([0], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyModelMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + _ = pcoll | base.RunInference( + base.KeyedModelHandler(mhs), metrics_namespace=metrics_namespace) + result = pipeline.run() + result.wait_until_finish() + + metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace) + metrics = result.metrics().query(metrics_filter) + assert len(metrics['counters']) != 0 + assert len(metrics['distributions']) != 0 + + metrics_filter = MetricsFilter().with_name('0-_num_inferences') + metrics = result.metrics().query(metrics_filter) + num_inferences_counter_key_0 = metrics['counters'][0] + self.assertEqual(num_inferences_counter_key_0.committed, 1) + + metrics_filter = MetricsFilter().with_name('1-_num_inferences') + metrics = result.metrics().query(metrics_filter) + num_inferences_counter_key_1 = metrics['counters'][0] + self.assertEqual(num_inferences_counter_key_1.committed, 3) + + metrics_filter = MetricsFilter().with_name('num_inferences') + metrics = result.metrics().query(metrics_filter) + num_inferences_counter_aggregate = metrics['counters'][0] + self.assertEqual(num_inferences_counter_aggregate.committed, 4) + + metrics_filter = MetricsFilter().with_name('0-_failed_batches_counter') + metrics = result.metrics().query(metrics_filter) + failed_batches_counter_key_0 = metrics['counters'] + self.assertEqual(len(failed_batches_counter_key_0), 0) + + metrics_filter = MetricsFilter().with_name('failed_batches_counter') + metrics = result.metrics().query(metrics_filter) + failed_batches_counter_aggregate = metrics['counters'] + self.assertEqual(len(failed_batches_counter_aggregate), 0) + + metrics_filter = MetricsFilter().with_name( + '0-_load_model_latency_milli_secs') + metrics = result.metrics().query(metrics_filter) + load_latency_dist_key_0 = metrics['distributions'][0] + self.assertEqual(load_latency_dist_key_0.committed.count, 1) + + metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs') + metrics = result.metrics().query(metrics_filter) + load_latency_dist_aggregate = metrics['distributions'][0] + self.assertEqual(load_latency_dist_aggregate.committed.count, 2) + def test_keyed_many_model_handlers_validation(self): def mult_two(example: str) -> int: return int(example) * 2 @@ -1285,7 +1345,7 @@ def test_model_manager_loads_shared_model(self): 'key3': FakeModelHandler(state=3) } mm = base._ModelManager(mh_map=mhs) - tag1 = mm.load('key1') + tag1 = mm.load('key1').model_tag # Use bad_mh's load function to make sure we're actually loading the # version already stored bad_mh = FakeModelHandler(state=100) @@ -1293,8 +1353,8 @@ def test_model_manager_loads_shared_model(self): bad_mh.load_model, tag=tag1).acquire() self.assertEqual(1, model1.predict(10)) - tag2 = mm.load('key2') - tag3 = mm.load('key3') + tag2 = mm.load('key2').model_tag + tag3 = mm.load('key3').model_tag model2 = multi_process_shared.MultiProcessShared( bad_mh.load_model, tag=tag2).acquire() model3 = multi_process_shared.MultiProcessShared( @@ -1308,14 +1368,14 @@ def test_model_manager_evicts_models(self): mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} mm = base._ModelManager(mh_map=mhs, max_models=2) - tag1 = mm.load('key1') + tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() self.assertEqual(1, model1.predict(10)) model1.increment_state(5) - tag2 = mm.load('key2') - tag3 = mm.load('key3') + tag2 = mm.load('key2').model_tag + tag3 = mm.load('key3').model_tag sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2) model2 = sh2.acquire() sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3) @@ -1349,7 +1409,7 @@ def test_model_manager_evicts_models_after_update(self): mh1 = FakeModelHandler(state=1) mhs = {'key1': mh1} mm = base._ModelManager(mh_map=mhs) - tag1 = mm.load('key1') + tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() self.assertEqual(1, model1.predict(10)) @@ -1359,7 +1419,7 @@ def test_model_manager_evicts_models_after_update(self): self.assertEqual(6, model1.predict(10)) sh1.release(model1) - tag1 = mm.load('key1') + tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() self.assertEqual(1, model1.predict(10)) @@ -1369,7 +1429,7 @@ def test_model_manager_evicts_models_after_update(self): # Shouldn't evict if path is the same as last update mm.update_model_handler('key1', 'fake/path', 'key1') - tag1 = mm.load('key1') + tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() self.assertEqual(6, model1.predict(10)) @@ -1383,7 +1443,7 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented( mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} mm = base._ModelManager(mh_map=mhs, max_models=1) mm.increment_max_models(1) - tag1 = mm.load('key1') + tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() self.assertEqual(1, model1.predict(10)) @@ -1391,8 +1451,8 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented( self.assertEqual(6, model1.predict(10)) sh1.release(model1) - tag2 = mm.load('key2') - tag3 = mm.load('key3') + tag2 = mm.load('key2').model_tag + tag3 = mm.load('key3').model_tag sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2) model2 = sh2.acquire() sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3) From 001d49d79cde76178f94c8e79c06429fe9bcf413 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 31 Aug 2023 10:25:24 -0400 Subject: [PATCH 3/3] Update returns doc --- sdks/python/apache_beam/ml/inference/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 09823866189c..9bd59f99e3da 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -337,7 +337,8 @@ def load(self, key: str) -> _ModelLoadStats: Args: key: the key associated with the model we'd like to load. Returns: - the tag we can use to access the model using multi_process_shared.py. + _ModelLoadStats with tag, byte size, and latency to load the model. If + the model was already loaded, byte size/latency will be None. """ # Map the key for a model to a unique tag that will persist until the model # is released. This needs to be unique between releasing/reacquiring th