diff --git a/graph_net/constraint_util.py b/graph_net/constraint_util.py index d58bb8fc7..c1fa443eb 100644 --- a/graph_net/constraint_util.py +++ b/graph_net/constraint_util.py @@ -231,8 +231,9 @@ def update_tensor_metas_by_dyn_dim_cstr( tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints ): input_shapes = dyn_dim_cstr.get_reified_input_shapes() - assert len(tensor_metas) == len(input_shapes) - for i, tensor_meta in enumerate(tensor_metas): + # Only update input tensors (first len(input_shapes) tensors), skip weight tensors + for i in range(min(len(input_shapes), len(tensor_metas))): + tensor_meta = tensor_metas[i] tensor_meta.shape = input_shapes[i] if tensor_meta.data is not None: assert isinstance(tensor_meta.data, (list, tuple)) diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py index da754d6d8..e8ac50731 100644 --- a/graph_net/dimension_generalizer.py +++ b/graph_net/dimension_generalizer.py @@ -137,14 +137,41 @@ def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas): (to_model_path / "weight_meta.py").write_text(weight_meta_code) def _get_to_model_path(self, rel_model_path, symbol2example_value): - sym_dim_str = "_".join( - f"{sym_name}_{dim}" - for symbol, dim in symbol2example_value.items() - for sym_name in [symbol.name] + """ + Generates output paths organized by dimension configuration indices rather than + symbolic dimension strings. + + Path structure transformation: + Before: model_name__symbolic_dims (e.g., 'model1__symA_8_symB_16') + After: index/model_name (e.g., '0/model1', '1/model1') + + The index represents a specific dimension configuration from the reification set, + enabling systematic management of dimension variations. + """ + # Use indices instead of symbol strings + symbols, reified_dims = self._get_symbols_and_reified_dims( + Path(self.config["model_path_prefix"]) / rel_model_path, + DynamicDimConstraints.unserialize_from_py_file( + os.path.join( + self.config["model_path_prefix"], + rel_model_path, + "input_tensor_constraints.py", + ) + ), ) - sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}" + current_dims = tuple(symbol2example_value[symbol] for symbol in symbols) + + # Find corresponding index through dimension value matching + dim_index = 0 + for i, dims in enumerate(reified_dims): + if tuple(dims) == current_dims: + dim_index = i + break + + # Path structure changed from model/name to index/model + sub_module_name = f"{dim_index}" to_model_path = ( - Path(self.config["output_dir"]) / rel_model_path / sub_module_name + Path(self.config["output_dir"]) / sub_module_name / rel_model_path ) return to_model_path diff --git a/graph_net/sample_pass/group_ranges_from_subgraph_sources.py b/graph_net/sample_pass/group_ranges_from_subgraph_sources.py index eaf5be128..0724c59b8 100644 --- a/graph_net/sample_pass/group_ranges_from_subgraph_sources.py +++ b/graph_net/sample_pass/group_ranges_from_subgraph_sources.py @@ -83,8 +83,15 @@ def _save_json( ): model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path model_dir.mkdir(parents=True, exist_ok=True) - ranges_json = self._get_ranges_json(subgraph_ranges) - paths_json = self._get_paths_json(subgraph_rel_model_paths) + + # Sort ranges by start index, and sort paths accordingly + sorted_data = sorted( + zip(subgraph_ranges, subgraph_rel_model_paths), key=lambda x: x[0][0] + ) + sorted_ranges, sorted_paths = zip(*sorted_data) if sorted_data else ([], []) + + ranges_json = self._get_ranges_json(list(sorted_ranges)) + paths_json = self._get_paths_json(list(sorted_paths)) json_obj = {**ranges_json, **paths_json} json_str = json.dumps(json_obj, indent=4) (model_dir / self.config["output_json_file_name"]).write_text(json_str)