Skip to content

Commit 5f38e2f

Browse files
Copilotnjzjz
andcommitted
refactor(finetune): consolidate duplicated warning functions into shared utilities
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 4b7d82d commit 5f38e2f

File tree

5 files changed

+190
-359
lines changed

5 files changed

+190
-359
lines changed

deepmd/pd/train/training.py

Lines changed: 4 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -72,109 +72,19 @@
7272
nvprof_context,
7373
to_numpy_array,
7474
)
75-
from deepmd.utils.argcheck import (
76-
normalize,
77-
)
7875
from deepmd.utils.data import (
7976
DataRequirementItem,
8077
)
78+
from deepmd.utils.finetune import (
79+
warn_configuration_mismatch_during_finetune,
80+
)
8181
from deepmd.utils.path import (
8282
DPH5Path,
8383
)
8484

8585
log = logging.getLogger(__name__)
8686

8787

88-
def _warn_configuration_mismatch_during_finetune(
89-
input_descriptor: dict,
90-
pretrained_descriptor: dict,
91-
model_branch: str = "Default",
92-
) -> None:
93-
"""
94-
Warn about configuration mismatches between input descriptor and pretrained model
95-
when fine-tuning without --use-pretrain-script option.
96-
97-
This function warns when configurations differ and state_dict initialization
98-
will only pick relevant keys from the pretrained model (e.g., first 6 layers
99-
from a 16-layer model).
100-
101-
Parameters
102-
----------
103-
input_descriptor : dict
104-
Descriptor configuration from input.json
105-
pretrained_descriptor : dict
106-
Descriptor configuration from pretrained model
107-
model_branch : str
108-
Model branch name for logging context
109-
"""
110-
# Normalize both configurations to ensure consistent comparison
111-
# This avoids warnings for parameters that only differ due to default values
112-
try:
113-
# Create minimal configs for normalization with required fields
114-
base_config = {
115-
"model": {
116-
"fitting_net": {"neuron": [240, 240, 240]},
117-
"type_map": ["H", "O"],
118-
},
119-
"training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100},
120-
}
121-
122-
input_config = base_config.copy()
123-
input_config["model"]["descriptor"] = input_descriptor.copy()
124-
125-
pretrained_config = base_config.copy()
126-
pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy()
127-
128-
# Normalize both configurations
129-
normalized_input = normalize(input_config, multi_task=False)["model"][
130-
"descriptor"
131-
]
132-
normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][
133-
"descriptor"
134-
]
135-
136-
if normalized_input == normalized_pretrained:
137-
return
138-
139-
# Use normalized configs for comparison to show only meaningful differences
140-
input_descriptor = normalized_input
141-
pretrained_descriptor = normalized_pretrained
142-
except Exception:
143-
# If normalization fails, fall back to original comparison
144-
pass
145-
146-
if input_descriptor == pretrained_descriptor:
147-
return
148-
149-
# Collect differences
150-
differences = []
151-
152-
# Check for keys that differ in values
153-
for key in input_descriptor:
154-
if key in pretrained_descriptor:
155-
if input_descriptor[key] != pretrained_descriptor[key]:
156-
differences.append(
157-
f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)"
158-
)
159-
else:
160-
differences.append(f" {key}: {input_descriptor[key]} (input only)")
161-
162-
# Check for keys only in pretrained model
163-
for key in pretrained_descriptor:
164-
if key not in input_descriptor:
165-
differences.append(
166-
f" {key}: {pretrained_descriptor[key]} (pretrained only)"
167-
)
168-
169-
if differences:
170-
log.warning(
171-
f"Descriptor configuration mismatch detected between input.json and pretrained model "
172-
f"(branch '{model_branch}'). State dict initialization will only use compatible parameters "
173-
f"from the pretrained model. Mismatched configuration:\n"
174-
+ "\n".join(differences)
175-
)
176-
177-
17888
class Trainer:
17989
def __init__(
18090
self,
@@ -632,7 +542,7 @@ def collect_single_finetune_params(
632542
)
633543

634544
# Warn about configuration mismatches
635-
_warn_configuration_mismatch_during_finetune(
545+
warn_configuration_mismatch_during_finetune(
636546
current_descriptor,
637547
pretrained_descriptor,
638548
_model_key_from,

deepmd/pd/utils/finetune.py

Lines changed: 2 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -6,98 +6,14 @@
66

77
import paddle
88

9-
from deepmd.utils.argcheck import (
10-
normalize,
11-
)
129
from deepmd.utils.finetune import (
1310
FinetuneRuleItem,
11+
warn_descriptor_config_differences,
1412
)
1513

1614
log = logging.getLogger(__name__)
1715

1816

19-
def _warn_descriptor_config_differences(
20-
input_descriptor: dict,
21-
pretrained_descriptor: dict,
22-
model_branch: str = "Default",
23-
) -> None:
24-
"""
25-
Warn about differences between input descriptor config and pretrained model's descriptor config.
26-
27-
Parameters
28-
----------
29-
input_descriptor : dict
30-
Descriptor configuration from input.json
31-
pretrained_descriptor : dict
32-
Descriptor configuration from pretrained model
33-
model_branch : str
34-
Model branch name for logging context
35-
"""
36-
# Normalize both configurations to ensure consistent comparison
37-
# This avoids warnings for parameters that only differ due to default values
38-
try:
39-
# Create minimal configs for normalization with required fields
40-
base_config = {
41-
"model": {
42-
"fitting_net": {"neuron": [240, 240, 240]},
43-
"type_map": ["H", "O"],
44-
},
45-
"training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100},
46-
}
47-
48-
input_config = base_config.copy()
49-
input_config["model"]["descriptor"] = input_descriptor.copy()
50-
51-
pretrained_config = base_config.copy()
52-
pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy()
53-
54-
# Normalize both configurations
55-
normalized_input = normalize(input_config, multi_task=False)["model"][
56-
"descriptor"
57-
]
58-
normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][
59-
"descriptor"
60-
]
61-
62-
if normalized_input == normalized_pretrained:
63-
return
64-
65-
# Use normalized configs for comparison to show only meaningful differences
66-
input_descriptor = normalized_input
67-
pretrained_descriptor = normalized_pretrained
68-
except Exception:
69-
# If normalization fails, fall back to original comparison
70-
pass
71-
72-
if input_descriptor == pretrained_descriptor:
73-
return
74-
75-
# Collect differences
76-
differences = []
77-
78-
# Check for keys that differ in values
79-
for key in input_descriptor:
80-
if key in pretrained_descriptor:
81-
if input_descriptor[key] != pretrained_descriptor[key]:
82-
differences.append(
83-
f" {key}: {input_descriptor[key]} -> {pretrained_descriptor[key]}"
84-
)
85-
else:
86-
differences.append(f" {key}: {input_descriptor[key]} -> (removed)")
87-
88-
# Check for keys only in pretrained model
89-
for key in pretrained_descriptor:
90-
if key not in input_descriptor:
91-
differences.append(f" {key}: (added) -> {pretrained_descriptor[key]}")
92-
93-
if differences:
94-
log.warning(
95-
f"Descriptor configuration in input.json differs from pretrained model "
96-
f"(branch '{model_branch}'). The input configuration will be overwritten "
97-
f"with the pretrained model's configuration:\n" + "\n".join(differences)
98-
)
99-
100-
10117
def get_finetune_rule_single(
10218
_single_param_target,
10319
_model_param_pretrained,
@@ -149,7 +65,7 @@ def get_finetune_rule_single(
14965

15066
# Warn about descriptor configuration differences before overwriting
15167
if "descriptor" in single_config and "descriptor" in single_config_chosen:
152-
_warn_descriptor_config_differences(
68+
warn_descriptor_config_differences(
15369
single_config["descriptor"],
15470
single_config_chosen["descriptor"],
15571
model_branch_chosen,

deepmd/pt/train/training.py

Lines changed: 3 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@
8181
DataLoader,
8282
)
8383

84-
from deepmd.utils.argcheck import (
85-
normalize,
84+
from deepmd.utils.finetune import (
85+
warn_configuration_mismatch_during_finetune,
8686
)
8787
from deepmd.utils.path import (
8888
DPH5Path,
@@ -91,96 +91,6 @@
9191
log = logging.getLogger(__name__)
9292

9393

94-
def _warn_configuration_mismatch_during_finetune(
95-
input_descriptor: dict,
96-
pretrained_descriptor: dict,
97-
model_branch: str = "Default",
98-
) -> None:
99-
"""
100-
Warn about configuration mismatches between input descriptor and pretrained model
101-
when fine-tuning without --use-pretrain-script option.
102-
103-
This function warns when configurations differ and state_dict initialization
104-
will only pick relevant keys from the pretrained model (e.g., first 6 layers
105-
from a 16-layer model).
106-
107-
Parameters
108-
----------
109-
input_descriptor : dict
110-
Descriptor configuration from input.json
111-
pretrained_descriptor : dict
112-
Descriptor configuration from pretrained model
113-
model_branch : str
114-
Model branch name for logging context
115-
"""
116-
# Normalize both configurations to ensure consistent comparison
117-
# This avoids warnings for parameters that only differ due to default values
118-
try:
119-
# Create minimal configs for normalization with required fields
120-
base_config = {
121-
"model": {
122-
"fitting_net": {"neuron": [240, 240, 240]},
123-
"type_map": ["H", "O"],
124-
},
125-
"training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100},
126-
}
127-
128-
input_config = base_config.copy()
129-
input_config["model"]["descriptor"] = input_descriptor.copy()
130-
131-
pretrained_config = base_config.copy()
132-
pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy()
133-
134-
# Normalize both configurations
135-
normalized_input = normalize(input_config, multi_task=False)["model"][
136-
"descriptor"
137-
]
138-
normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][
139-
"descriptor"
140-
]
141-
142-
if normalized_input == normalized_pretrained:
143-
return
144-
145-
# Use normalized configs for comparison to show only meaningful differences
146-
input_descriptor = normalized_input
147-
pretrained_descriptor = normalized_pretrained
148-
except Exception:
149-
# If normalization fails, fall back to original comparison
150-
pass
151-
152-
if input_descriptor == pretrained_descriptor:
153-
return
154-
155-
# Collect differences
156-
differences = []
157-
158-
# Check for keys that differ in values
159-
for key in input_descriptor:
160-
if key in pretrained_descriptor:
161-
if input_descriptor[key] != pretrained_descriptor[key]:
162-
differences.append(
163-
f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)"
164-
)
165-
else:
166-
differences.append(f" {key}: {input_descriptor[key]} (input only)")
167-
168-
# Check for keys only in pretrained model
169-
for key in pretrained_descriptor:
170-
if key not in input_descriptor:
171-
differences.append(
172-
f" {key}: {pretrained_descriptor[key]} (pretrained only)"
173-
)
174-
175-
if differences:
176-
log.warning(
177-
f"Descriptor configuration mismatch detected between input.json and pretrained model "
178-
f"(branch '{model_branch}'). State dict initialization will only use compatible parameters "
179-
f"from the pretrained model. Mismatched configuration:\n"
180-
+ "\n".join(differences)
181-
)
182-
183-
18494
class Trainer:
18595
def __init__(
18696
self,
@@ -661,7 +571,7 @@ def collect_single_finetune_params(
661571
)
662572

663573
# Warn about configuration mismatches
664-
_warn_configuration_mismatch_during_finetune(
574+
warn_configuration_mismatch_during_finetune(
665575
current_descriptor,
666576
pretrained_descriptor,
667577
_model_key_from,

0 commit comments

Comments
 (0)