Skip to content

Commit

Permalink
Keep track of metrics at the KeyedModelHandler level (#28228)
Browse files Browse the repository at this point in the history
* Update KeyMhMapping to KeyModelMapping

* Keep track of metrics at the KeyedModelHandler level

* Update returns doc
  • Loading branch information
damccorm authored Aug 31, 2023
1 parent a65f55b commit 205083d
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 33 deletions.
126 changes: 105 additions & 21 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ class RunInferenceDLQ(NamedTuple):
failed_postprocessing: Sequence[beam.PCollection]


class _ModelLoadStats(NamedTuple):
model_tag: str
load_latency: Optional[int]
byte_size: Optional[int]


ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be
a file path or a URL where the model can be accessed. It is used to load
the model for inference."""
Expand All @@ -130,17 +136,18 @@ class KeyModelPathMapping(Generic[KeyT]):
Dataclass for mapping 1 or more keys to 1 model path. This is used in
conjunction with a KeyedModelHandler with many model handlers to update
a set of keys' model handlers with the new path. Given
`KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path')`,
all examples with keys `key1` or `key2` will have their corresponding model
handler's update_model function called with 'updated/path'. For more
information see the
KeyedModelHandler documentation
`KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path',
model_id: 'id1')`, all examples with keys `key1` or `key2` will have their
corresponding model handler's update_model function called with
'updated/path' and their metrics will correspond with 'id1'. For more
information see the KeyedModelHandler documentation
https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
documentation and the website section on model updates
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
"""
keys: List[KeyT]
update_path: str
model_id: str = ''


class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
Expand Down Expand Up @@ -286,6 +293,12 @@ def share_model_across_processes(self) -> bool:
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
return False

def override_metrics(self, metrics_namespace: str = '') -> bool:
"""Returns a boolean representing whether or not a model handler will
override metrics reporting. If True, RunInference will not report any
metrics."""
return False


class _ModelManager:
"""
Expand Down Expand Up @@ -318,20 +331,22 @@ def __init__(
# of this map should last as long as the corresponding entry in _tag_map.
self._proxy_map: Dict[str, multi_process_shared.MultiProcessShared] = {}

def load(self, key: str) -> str:
def load(self, key: str) -> _ModelLoadStats:
"""
Loads the appropriate model for the given key into memory.
Args:
key: the key associated with the model we'd like to load.
Returns:
the tag we can use to access the model using multi_process_shared.py.
_ModelLoadStats with tag, byte size, and latency to load the model. If
the model was already loaded, byte size/latency will be None.
"""
# Map the key for a model to a unique tag that will persist until the model
# is released. This needs to be unique between releasing/reacquiring th
# model because otherwise the ProxyManager will try to reuse the model that
# has been released and deleted.
if key in self._tag_map:
self._tag_map.move_to_end(key)
return _ModelLoadStats(self._tag_map[key], None, None)
else:
self._tag_map[key] = uuid.uuid4().hex

Expand All @@ -346,12 +361,17 @@ def load(self, key: str) -> str:
del self._proxy_map[tag_to_remove]

# Load the new model
memory_before = _get_current_process_memory_in_bytes()
start_time = _to_milliseconds(time.time_ns())
shared_handle = multi_process_shared.MultiProcessShared(
mh.load_model, tag=tag)
model_reference = shared_handle.acquire()
self._proxy_map[tag] = (shared_handle, model_reference)
memory_after = _get_current_process_memory_in_bytes()
end_time = _to_milliseconds(time.time_ns())

return tag
return _ModelLoadStats(
tag, end_time - start_time, memory_after - memory_before)

def increment_max_models(self, increment: int):
"""
Expand Down Expand Up @@ -460,12 +480,24 @@ def __init__(
must appear in your list of KeyModelPathMappings exactly once. No
additional keys can be added.
When using many models defined per key, metrics about inference and model
loading will be gathered on an aggregate basis for all keys. These will be
reported with no prefix. Metrics will also be gathered on a per key basis.
Since some keys can share the same model, only one set of metrics will be
reported per key 'cohort'. These will be reported in the form:
`<cohort_key>-<metric_name>`, where `<cohort_key>` can be any key selected
from the cohort. When model updates occur, the metrics will be reported in
the form `<cohort_key>-<model id>-<metric_name>`.
Args:
unkeyed: Either (a) an implementation of ModelHandler that does not
require keys or (b) a list of KeyModelMappings mapping lists of keys to
unkeyed ModelHandlers.
"""
self._metrics_collectors: Dict[str, _MetricsCollector] = {}
self._default_metrics_collector: _MetricsCollector = None
self._metrics_namespace = ''
self._single_model = not isinstance(unkeyed, list)
if self._single_model:
if len(unkeyed.get_preprocess_fns()) or len(
Expand Down Expand Up @@ -564,17 +596,41 @@ def run_inference(
predictions = []
for id, keys in key_by_id.items():
mh = self._id_to_mh_map[id]
keyed_model_tag = model.load(id)
loaded_model = model.load(id)
keyed_model_tag = loaded_model.model_tag
if loaded_model.byte_size is not None:
self._metrics_collectors[id].update_load_model_metrics(
loaded_model.load_latency, loaded_model.byte_size)
self._default_metrics_collector.update_load_model_metrics(
loaded_model.load_latency, loaded_model.byte_size)
keyed_model_shared_handle = multi_process_shared.MultiProcessShared(
mh.load_model, tag=keyed_model_tag)
keyed_model = keyed_model_shared_handle.acquire()
for key in keys:
unkeyed_batches = batch_by_key[key]
for inf in mh.run_inference(unkeyed_batches,
keyed_model,
inference_args):
predictions.append((key, inf))
keyed_model_shared_handle.release(keyed_model)
start_time = _to_microseconds(time.time_ns())
num_bytes = 0
num_elements = 0
try:
for key in keys:
unkeyed_batches = batch_by_key[key]
try:
for inf in mh.run_inference(unkeyed_batches,
keyed_model,
inference_args):
predictions.append((key, inf))
except BaseException as e:
self._metrics_collectors[id].failed_batches_counter.inc()
self._default_metrics_collector.failed_batches_counter.inc()
raise e
num_bytes += mh.get_num_bytes(unkeyed_batches)
num_elements += len(unkeyed_batches)
finally:
keyed_model_shared_handle.release(keyed_model)
end_time = _to_microseconds(time.time_ns())
inference_latency = end_time - start_time
self._metrics_collectors[id].update(
num_elements, num_bytes, inference_latency)
self._default_metrics_collector.update(
num_elements, num_bytes, inference_latency)

return predictions

Expand Down Expand Up @@ -641,10 +697,12 @@ def update_model_paths(
# }
# }
cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
key_modelid_mapping: Dict[KeyT, str] = {}
seen_keys = set()
for mp in model_paths:
keys = mp.keys
update_path = mp.update_path
model_id = mp.model_id
if len(update_path) == 0:
raise ValueError(f'Invalid model update, path for {keys} is empty')
for key in keys:
Expand All @@ -658,6 +716,7 @@ def update_model_paths(
raise ValueError(
f'Invalid model update: {key} appears in '
'update, but not in the original configuration.')
key_modelid_mapping[key] = model_id
cohort_id = self._key_to_id_map[key]
if cohort_id not in cohort_path_mapping:
cohort_path_mapping[cohort_id] = defaultdict(list)
Expand All @@ -682,6 +741,9 @@ def update_model_paths(
self._id_to_mh_map[cohort_id] = deepcopy(mh)
self._id_to_mh_map[cohort_id].update_model_path(updated_path)
model.update_model_handler(cohort_id, updated_path, old_cohort_id)
model_id = key_modelid_mapping[cohort_id]
self._metrics_collectors[cohort_id] = _MetricsCollector(
self._metrics_namespace, f'{cohort_id}-{model_id}-')

def update_model_path(self, model_path: Optional[str] = None):
if self._single_model:
Expand All @@ -697,6 +759,18 @@ def share_model_across_processes(self) -> bool:
return self._unkeyed.share_model_across_processes()
return True

def override_metrics(self, metrics_namespace: str = '') -> bool:
if self._single_model:
return self._unkeyed.override_metrics(metrics_namespace)

self._metrics_namespace = metrics_namespace
self._default_metrics_collector = _MetricsCollector(metrics_namespace)
for cohort_id in self._id_to_mh_map:
self._metrics_collectors[cohort_id] = _MetricsCollector(
metrics_namespace, f'{cohort_id}-')

return True


class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Union[ExampleT, Tuple[KeyT,
Expand Down Expand Up @@ -1188,6 +1262,10 @@ def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size):
self._load_model_latency_milli_secs_cache = load_model_latency_ms
self._model_byte_size_cache = model_byte_size

def update_load_model_metrics(self, load_model_latency_ms, model_byte_size):
self._load_model_latency_milli_secs.update(load_model_latency_ms)
self._model_byte_size.update(model_byte_size)

def update(
self,
examples_count: int,
Expand Down Expand Up @@ -1244,8 +1322,9 @@ def load():
memory_after = _get_current_process_memory_in_bytes()
load_model_latency_ms = end_time - start_time
model_byte_size = memory_after - memory_before
self._metrics_collector.cache_load_model_metrics(
load_model_latency_ms, model_byte_size)
if self._metrics_collector:
self._metrics_collector.cache_load_model_metrics(
load_model_latency_ms, model_byte_size)
return model

# TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
Expand Down Expand Up @@ -1277,6 +1356,8 @@ def get_metrics_collector(self, prefix: str = ''):
metrics_namespace = (
self._metrics_namespace) if self._metrics_namespace else (
self._model_handler.get_metrics_namespace())
if self._model_handler.override_metrics(metrics_namespace):
return None
return _MetricsCollector(metrics_namespace, prefix=prefix)

def setup(self):
Expand All @@ -1297,15 +1378,17 @@ def _run_inference(self, batch, inference_args):
result_generator = self._model_handler.run_inference(
batch, self._model, inference_args)
except BaseException as e:
self._metrics_collector.failed_batches_counter.inc()
if self._metrics_collector:
self._metrics_collector.failed_batches_counter.inc()
raise e
predictions = list(result_generator)

end_time = _to_microseconds(self._clock.time_ns())
inference_latency = end_time - start_time
num_bytes = self._model_handler.get_num_bytes(batch)
num_elements = len(batch)
self._metrics_collector.update(num_elements, num_bytes, inference_latency)
if self._metrics_collector:
self._metrics_collector.update(num_elements, num_bytes, inference_latency)

return predictions

Expand Down Expand Up @@ -1345,7 +1428,8 @@ def process(
def finish_bundle(self):
# TODO(https://github.com/apache/beam/issues/21435): Figure out why there
# is a cache.
self._metrics_collector.update_metrics_with_cache()
if self._metrics_collector:
self._metrics_collector.update_metrics_with_cache()


def _is_darwin() -> bool:
Expand Down
Loading

0 comments on commit 205083d

Please sign in to comment.