From 11aae676a9ee4dab7091424ef0ba7d9f7344f730 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Sat, 18 Nov 2023 22:10:39 +0100 Subject: [PATCH] Add `dump_dummy` script (#25) * Add `dump_dummy` script * type error --- scripts/dump_dummy.py | 57 ++++++++++++++++++++++++++++++++++++++ scripts/dump_state_dict.py | 24 ++++++++++------ 2 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 scripts/dump_dummy.py diff --git a/scripts/dump_dummy.py b/scripts/dump_dummy.py new file mode 100644 index 00000000..9ab0eeb4 --- /dev/null +++ b/scripts/dump_dummy.py @@ -0,0 +1,57 @@ +# ruff: noqa: E402 +""" +This is a tool for dumping the state dict of a dummy model. + +Purpose: +When adding/testing model detection or model parameter detection code, +it is useful to see the effects a single parameter has on the state dict of a +model. Since there aren't pretrained models for every possible parameter +configuration, this script can be used to generate a dummy model with the given +parameters. + +Usage: +To use this script, you need to edit the `create_dummy` function below. Edit +the function to make it return a model with your desired parameters. As always, +VSCode is the recommended IDE for this task. + +After you edited the function, run this script, and it will dump the state dict +of the dummy model to `dump.yml`. + + python scripts/dump_dummy.py + +For more detail on the dump itself, see the docs of `dump_state_dict.py`. +""" + + +import inspect +import os +import sys +from textwrap import dedent + +import torch + +# This hack is necessary to make our module import +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) + +from dump_state_dict import dump + +from spandrel.architectures import SCUNet + + +def create_dummy() -> torch.nn.Module: + """Edit this function""" + return SCUNet.SCUNet() + + +if __name__ == "__main__": + net = create_dummy() + state = net.state_dict() + + # get source code expression of network + source = inspect.getsource(create_dummy) + source = "\n".join(source.split("\n")[1:]) # remove "def create_dummy(): + source = dedent(source) + if source.startswith("return "): + source = source[7:] + + dump(state, source) diff --git a/scripts/dump_state_dict.py b/scripts/dump_state_dict.py index cc6d1aee..b505731e 100644 --- a/scripts/dump_state_dict.py +++ b/scripts/dump_state_dict.py @@ -44,11 +44,11 @@ import os import sys from dataclasses import dataclass -from typing import Dict, Generic, Iterable, TypeVar +from typing import Any, Dict, Generic, Iterable, TypeVar from torch import Tensor -# I fucking hate python. This hack is necessary to make our module import +# This hack is necessary to make our module import sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from spandrel import ModelLoader # noqa: E402 @@ -181,12 +181,18 @@ def dump(s: Fork[State] | State, level: int = 0): return lines -file = sys.argv[1] -print(f"Input file: {file}") -state = load_state(file) +def dump(state: dict[str, Any], comment: str, file: str = "dump.yml"): + with open(file, "w") as f: + comment = "\n".join("# " + s for s in comment.splitlines()) + f.write(f"{comment}\n") + f.write("\n".join(dump_lines(state))) -with open("dump.yml", "w") as f: - f.write(f"# {file}\n") - f.write("\n".join(dump_lines(state))) + print(f"Dumped {len(state)} keys to {file}") -print(f"Dumped {len(state)} keys to dump.yml") + +if __name__ == "__main__": + file = sys.argv[1] + print(f"Input file: {file}") + state = load_state(file) + + dump(state, file)