diff --git a/pkg/workloads/cortex/lib/model/validation.py b/pkg/workloads/cortex/lib/model/validation.py index 1069f0af46..baf0666c43 100644 --- a/pkg/workloads/cortex/lib/model/validation.py +++ b/pkg/workloads/cortex/lib/model/validation.py @@ -307,7 +307,13 @@ def validate_model_paths( f"{predictor_type} predictor at '{common_prefix}'", "model path can't be empty" ) + paths_by_prefix_cache = {} + def _validate_model_paths(pattern: Any, paths: List[str], common_prefix: str) -> None: + if common_prefix not in paths_by_prefix_cache: + paths_by_prefix_cache[common_prefix] = util.get_paths_with_prefix(paths, common_prefix) + paths = paths_by_prefix_cache[common_prefix] + rel_paths = [os.path.relpath(path, common_prefix) for path in paths] rel_paths = [path for path in rel_paths if not path.startswith("../")] @@ -392,18 +398,24 @@ def _validate_model_paths(pattern: Any, paths: List[str], common_prefix: str) -> "unexpected path(s) for " + str(unvisited_paths), ) - aggregated_ooa_valid_key_ids = [] + new_common_prefixes = [] + sub_patterns = [] + paths_by_prefix = {} for obj_id, key_id in enumerate(visited_objects): obj = objects[obj_id] key = keys[key_id] + if key != AnyPlaceholder: + new_common_prefixes.append(os.path.join(common_prefix, obj)) + sub_patterns.append(pattern[key]) - new_common_prefix = os.path.join(common_prefix, obj) - sub_pattern = pattern[key] + if len(new_common_prefixes) > 0: + paths_by_prefix = util.get_paths_by_prefixes(paths, new_common_prefixes) - if key != AnyPlaceholder: - aggregated_ooa_valid_key_ids += _validate_model_paths( - sub_pattern, paths, new_common_prefix - ) + aggregated_ooa_valid_key_ids = [] + for sub_pattern, new_common_prefix in zip(sub_patterns, new_common_prefixes): + aggregated_ooa_valid_key_ids += _validate_model_paths( + sub_pattern, paths_by_prefix[new_common_prefix], new_common_prefix + ) return aggregated_ooa_valid_key_ids diff --git a/pkg/workloads/cortex/lib/util.py b/pkg/workloads/cortex/lib/util.py index ebc98f5df8..cc19d91d78 100644 --- a/pkg/workloads/cortex/lib/util.py +++ b/pkg/workloads/cortex/lib/util.py @@ -21,7 +21,7 @@ import inspect from inspect import Parameter from copy import deepcopy -from typing import List, Any +from typing import List, Dict, Any def has_method(object, method: str): @@ -82,26 +82,36 @@ def ensure_suffix(string, suffix): return string + suffix +def get_paths_with_prefix(paths: List[str], prefix: str) -> List[str]: + return list(filter(lambda path: path.startswith(prefix), paths)) + + +def get_paths_by_prefixes(paths: List[str], prefixes: List[str]) -> Dict[str, List[str]]: + paths_by_prefix = {} + for path in paths: + for prefix in prefixes: + if not path.startswith(prefix): + continue + if prefix not in paths_by_prefix: + paths_by_prefix[prefix] = [path] + else: + paths_by_prefix[prefix].append(path) + return paths_by_prefix + + def get_leftmost_part_of_path(path: str) -> str: """ Gets the leftmost part of a path. If a path looks like - /models/tensorflow/iris/15559399 + models/tensorflow/iris/15559399 Then this function will return - /models/ + models """ - has_leading_slash = False - if path.startswith("/"): - path = path[1:] - has_leading_slash = True - - basename = "" - while path: - path, basename = os.path.split(path) - - return "/" * has_leading_slash + basename + if path == "." or path == "./": + return "." + return pathlib.PurePath(path).parts[0] def remove_non_empty_directory_paths(paths: List[str]) -> List[str]: @@ -116,23 +126,37 @@ def remove_non_empty_directory_paths(paths: List[str]) -> List[str]: Then after calling this function, it will look like: models/tensorflow/iris/1569001258/saved_model.pb """ - new_paths = [] - split_paths = [list(filter(lambda x: x != "", path.split("/"))) for path in paths] - create_set_from_list = lambda l: set([(idx, split) for idx, split in enumerate(l)]) - split_set_paths = [create_set_from_list(split_path) for split_path in split_paths] + leading_slash_paths_mask = [path.startswith("/") for path in paths] + all_paths_start_with_leading_slash = all(leading_slash_paths_mask) + some_paths_start_with_leading_slash = any(leading_slash_paths_mask) - for id_a, a in enumerate(split_set_paths): - matches = 0 - for id_b, b in enumerate(split_set_paths): - if id_a == id_b: - continue - if a.issubset(b): - matches += 1 - if matches == 0: - new_paths.append(paths[id_a]) + if not all_paths_start_with_leading_slash and some_paths_start_with_leading_slash: + raise ValueError("can only either pass in absolute paths or relative paths") + + path_map = {} + split_paths = [list(filter(lambda x: x != "", path.split("/"))) for path in paths] - return new_paths + for split_path in split_paths: + composed_path = "" + split_path_length = len(split_path) + for depth, path_level in enumerate(split_path): + if composed_path != "": + composed_path += "/" + composed_path += path_level + if composed_path not in path_map: + path_map[composed_path] = 1 + if depth < split_path_length - 1: + path_map[composed_path] += 1 + else: + path_map[composed_path] += 1 + + file_paths = [] + for file_path, appearances in path_map.items(): + if appearances == 1: + file_paths.append(all_paths_start_with_leading_slash * "/" + file_path) + + return file_paths def merge_dicts_in_place_overwrite(*dicts):