From f9f39fb45a490fa49c8542ed271e9a173cb9405e Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 28 Dec 2023 15:59:13 +0900 Subject: [PATCH] fix #95 issue * check validity of keys --- scripts/rebasin/weight_matching.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/rebasin/weight_matching.py b/scripts/rebasin/weight_matching.py index c5db200..ebcb56c 100644 --- a/scripts/rebasin/weight_matching.py +++ b/scripts/rebasin/weight_matching.py @@ -828,9 +828,15 @@ def scipylap(cost): return lapfunc +def _valid_key(key): + if "cond_stage_model.transformer.text_model." in key: + return True + return "model_" not in key + + def apply_permutation(ps: PermutationSpec, perm, params): """Apply a `perm` to `params`.""" - return {k: get_permuted_param(ps, perm, k, params) for k in params.keys() if "model_" not in k} + return {k: get_permuted_param(ps, perm, k, params) for k in params.keys() if _valid_key(k)} def weight_matching(ps: PermutationSpec, params_a, params_b, special_layers=None, device="cpu", max_iter=3, init_perm=None, usefp16=False, usetqdm=True, full=False, lap="lap"): """Find a permutation of `params_b` to make them match `params_a`."""