Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions graph_net/constraint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ def update_tensor_metas_by_dyn_dim_cstr(
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
):
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
assert len(tensor_metas) == len(input_shapes)
for i, tensor_meta in enumerate(tensor_metas):
# Only update input tensors (first len(input_shapes) tensors), skip weight tensors
for i in range(min(len(input_shapes), len(tensor_metas))):
tensor_meta = tensor_metas[i]
tensor_meta.shape = input_shapes[i]
if tensor_meta.data is not None:
assert isinstance(tensor_meta.data, (list, tuple))
Expand Down
39 changes: 33 additions & 6 deletions graph_net/dimension_generalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,41 @@ def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas):
(to_model_path / "weight_meta.py").write_text(weight_meta_code)

def _get_to_model_path(self, rel_model_path, symbol2example_value):
sym_dim_str = "_".join(
f"{sym_name}_{dim}"
for symbol, dim in symbol2example_value.items()
for sym_name in [symbol.name]
"""
Generates output paths organized by dimension configuration indices rather than
symbolic dimension strings.

Path structure transformation:
Before: model_name__symbolic_dims (e.g., 'model1__symA_8_symB_16')
After: index/model_name (e.g., '0/model1', '1/model1')

The index represents a specific dimension configuration from the reification set,
enabling systematic management of dimension variations.
"""
# Use indices instead of symbol strings
symbols, reified_dims = self._get_symbols_and_reified_dims(
Path(self.config["model_path_prefix"]) / rel_model_path,
DynamicDimConstraints.unserialize_from_py_file(
os.path.join(
self.config["model_path_prefix"],
rel_model_path,
"input_tensor_constraints.py",
)
),
)
sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}"
current_dims = tuple(symbol2example_value[symbol] for symbol in symbols)

# Find corresponding index through dimension value matching
dim_index = 0
for i, dims in enumerate(reified_dims):
if tuple(dims) == current_dims:
dim_index = i
break

# Path structure changed from model/name to index/model
sub_module_name = f"{dim_index}"
to_model_path = (
Path(self.config["output_dir"]) / rel_model_path / sub_module_name
Path(self.config["output_dir"]) / sub_module_name / rel_model_path
)
return to_model_path

Expand Down
11 changes: 9 additions & 2 deletions graph_net/sample_pass/group_ranges_from_subgraph_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ def _save_json(
):
model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path
model_dir.mkdir(parents=True, exist_ok=True)
ranges_json = self._get_ranges_json(subgraph_ranges)
paths_json = self._get_paths_json(subgraph_rel_model_paths)

# Sort ranges by start index, and sort paths accordingly
sorted_data = sorted(
zip(subgraph_ranges, subgraph_rel_model_paths), key=lambda x: x[0][0]
)
sorted_ranges, sorted_paths = zip(*sorted_data) if sorted_data else ([], [])

ranges_json = self._get_ranges_json(list(sorted_ranges))
paths_json = self._get_paths_json(list(sorted_paths))
json_obj = {**ranges_json, **paths_json}
json_str = json.dumps(json_obj, indent=4)
(model_dir / self.config["output_json_file_name"]).write_text(json_str)
Expand Down