Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ckpt_export #6965

Merged
merged 14 commits into from
Sep 11, 2023
39 changes: 29 additions & 10 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,10 +1191,10 @@


def ckpt_export(
net_id: str | None = None,
filepath: PathLike | None = None,
ckpt_file: str | None = None,
meta_file: str | Sequence[str] | None = None,
net_id: str | None = "network_def",
wyli marked this conversation as resolved.
Show resolved Hide resolved
filepath: PathLike | None = "models/model.ts",
wyli marked this conversation as resolved.
Show resolved Hide resolved
ckpt_file: str | None = "models/model.pt",
meta_file: str | Sequence[str] | None = "configs/metadata.json",
config_file: str | Sequence[str] | None = None,
key_in_ckpt: str | None = None,
use_trace: bool | None = None,
Expand All @@ -1214,9 +1214,13 @@

Args:
net_id: ID name of the network component in the config, it must be `torch.nn.Module`.
Default to "network_def".
filepath: filepath to export, if filename has no extension it becomes `.ts`.
Default to "models/model.ts" under "os.getcwd()" if `bundle_root` is not specified.
ckpt_file: filepath of the model checkpoint to load.
Default to "models/model.pt" under "os.getcwd()" if `bundle_root` is not specified.
meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
Default to "configs/metadata.json" under "os.getcwd()" if `bundle_root` is not specified.
config_file: filepath of the config file to save in TorchScript model and extract network information,
the saved key in the TorchScript model is the config filename without extension, and the saved config
value is always serialized in JSON format no matter the original file format is JSON or YAML.
Expand Down Expand Up @@ -1250,9 +1254,10 @@
)
_log_input_summary(tag="ckpt_export", args=_args)
(
config_file_,
filepath_,
ckpt_file_,
config_file_,
bundle_root_,
net_id_,
meta_file_,
key_in_ckpt_,
Expand All @@ -1261,11 +1266,12 @@
converter_kwargs_,
) = _pop_args(
_args,
"filepath",
"ckpt_file",
"config_file",
net_id="",
meta_file=None,
filepath="models/model.ts",
ckpt_file="models/model.pt",
bundle_root=os.getcwd(),
net_id="network_def",
meta_file="configs/metadata.json",
key_in_ckpt="",
use_trace=False,
input_shape=None,
Expand All @@ -1275,9 +1281,22 @@
parser = ConfigParser()

parser.read_config(f=config_file_)
if meta_file_ is not None:
meta_file_ = (
os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ == "configs/metadata.json" else meta_file_
)
filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ == "models/model.ts" else filepath_
ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ == "models/model.pt" else ckpt_file_
if not os.path.exists(ckpt_file_):
raise FileNotFoundError(f"ckpt_file in {ckpt_file_} does not exist, please specify it.")

Check warning on line 1290 in monai/bundle/scripts.py

View check run for this annotation

Codecov / codecov/patch

monai/bundle/scripts.py#L1290

Added line #L1290 was not covered by tests
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if os.path.exists(meta_file_):
parser.read_meta(f=meta_file_)

if net_id_ == "network_def":
try:
parser.get_parsed_content(net_id_)
except ValueError as e:
raise ValueError(f"Default net_id: network_def in {config_file_} does not exist.") from e

Check warning on line 1298 in monai/bundle/scripts.py

View check run for this annotation

Codecov / codecov/patch

monai/bundle/scripts.py#L1297-L1298

Added lines #L1297 - L1298 were not covered by tests

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v
Expand Down
23 changes: 23 additions & 0 deletions tests/test_bundle_ckpt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@
self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_default_value(self, key_in_ckpt, use_trace):
config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
def_args_file = os.path.join(tempdir, "def_args.yaml")
ckpt_file = os.path.join(tempdir, "models/model.pt")
ts_file = os.path.join(tempdir, "models/model.ts")

Check warning on line 85 in tests/test_bundle_ckpt_export.py

View check run for this annotation

Codecov / codecov/patch

tests/test_bundle_ckpt_export.py#L80-L85

Added lines #L80 - L85 were not covered by tests

parser = ConfigParser()
parser.export_config_file(config=def_args, filepath=def_args_file)
parser.read_config(config_file)
net = parser.get_parsed_content("network_def")
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

Check warning on line 91 in tests/test_bundle_ckpt_export.py

View check run for this annotation

Codecov / codecov/patch

tests/test_bundle_ckpt_export.py#L87-L91

Added lines #L87 - L91 were not covered by tests

# check with default value
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
cmd += ["--config_file", config_file, "--bundle_root", tempdir]
if use_trace == "True":
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
command_line_tests(cmd)
self.assertTrue(os.path.exists(ts_file))

Check warning on line 99 in tests/test_bundle_ckpt_export.py

View check run for this annotation

Codecov / codecov/patch

tests/test_bundle_ckpt_export.py#L94-L99

Added lines #L94 - L99 were not covered by tests


if __name__ == "__main__":
unittest.main()
Loading