From 1738e00668312277e43f8cb1dc08441cbca2f50d Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 22 Mar 2024 15:31:23 +0000 Subject: [PATCH 1/9] Update shimmer version. Use shimmer fn to compute cycles and demi_cycles --- poetry.lock | 8 ++++---- simple_shapes_dataset/logging.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index e8178ea..df4e1da 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2427,7 +2427,7 @@ torch = "^2.0.1" type = "git" url = "git@github.com:bdvllrs/shimmer.git" reference = "main" -resolved_reference = "5b210058800e8717f7cddce9099cd57b4a3083ac" +resolved_reference = "9fa9217493173b1fa42b747fc5ba9ac96ad4d3d6" [[package]] name = "six" @@ -2747,13 +2747,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.38.2" +version = "4.39.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.38.2-py3-none-any.whl", hash = "sha256:c4029cb9f01b3dd335e52f364c52d2b37c65b4c78e02e6a08b1919c5c928573e"}, - {file = "transformers-4.38.2.tar.gz", hash = "sha256:c5fc7ad682b8a50a48b2a4c05d4ea2de5567adb1bdd00053619dbe5960857dd5"}, + {file = "transformers-4.39.0-py3-none-any.whl", hash = "sha256:7801785b1f016d667467e8c372c1c3653c18fe32ba97952059e3bea79ba22b08"}, + {file = "transformers-4.39.0.tar.gz", hash = "sha256:517a13cd633b10bea01c92ab0b3059762872c7c29da3d223db9d28e926fe330d"}, ] [package.dependencies] diff --git a/simple_shapes_dataset/logging.py b/simple_shapes_dataset/logging.py index fa287a7..3447976 100644 --- a/simple_shapes_dataset/logging.py +++ b/simple_shapes_dataset/logging.py @@ -13,6 +13,7 @@ from matplotlib import gridspec from matplotlib.figure import Figure from PIL import Image +from shimmer import batch_cycles, batch_demi_cycles, batch_translations from shimmer.modules.global_workspace import GlobalWorkspaceBase from torchvision.utils import make_grid @@ -384,9 +385,11 @@ def on_callback( with torch.no_grad(): pl_module.eval() - prediction_demi_cycles = pl_module.batch_demi_cycles(latents) - prediction_cycles = pl_module.batch_cycles(latents) - prediction_translations = pl_module.batch_translations(latents) + prediction_demi_cycles = batch_demi_cycles(pl_module.gw_mod, latents) + prediction_cycles = batch_cycles( + pl_module.gw_mod, latents, pl_module.domain_mods.keys() + ) + prediction_translations = batch_translations(pl_module.gw_mod, latents) pl_module.train() for logger in loggers: From f9e5d46c59f19161434fd49978a47c080a8f9f28 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 22 Mar 2024 15:40:43 +0000 Subject: [PATCH 2/9] Fix LossCoefs type issue --- playground/train_gw.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/playground/train_gw.py b/playground/train_gw.py index 5200d65..d71d3a2 100644 --- a/playground/train_gw.py +++ b/playground/train_gw.py @@ -79,7 +79,7 @@ def main(): bias=config.global_workspace.linear_domains_use_bias, ) - loss_coefs: dict[str, float] = { + loss_coefs: LossCoefs = { "demi_cycles": config.global_workspace.loss_coefficients.demi_cycles, "cycles": config.global_workspace.loss_coefficients.cycles, "translations": config.global_workspace.loss_coefficients.translations, @@ -101,7 +101,7 @@ def main(): gw_encoders, gw_decoders, config.global_workspace.latent_dim, - LossCoefs(**loss_coefs), + loss_coefs, config.global_workspace.cont_loss_with_uncertainty, config.training.optim.lr, config.training.optim.weight_decay, @@ -133,7 +133,7 @@ def main(): gw_encoders, gw_decoders, config.global_workspace.latent_dim, - LossCoefs(**loss_coefs), + loss_coefs, config.training.optim.lr, config.training.optim.weight_decay, scheduler_args=SchedulerArgs( From edfc6ab18581a87b5fc6817dc8a2a14fc0bb9748 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 22 Mar 2024 15:45:50 +0000 Subject: [PATCH 3/9] Fix typing issues --- playground/save_v_latents.py | 2 +- playground/train_gw.py | 3 ++- .../modules/domains/attribute.py | 16 ++++++++-------- simple_shapes_dataset/modules/domains/text.py | 6 +++--- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/playground/save_v_latents.py b/playground/save_v_latents.py index 030326a..69ed85f 100644 --- a/playground/save_v_latents.py +++ b/playground/save_v_latents.py @@ -43,7 +43,7 @@ def main(): visual_domain = cast( VisualDomainModule, - load_pretrained_module(config.default_root_dir, config.domain_checkpoint)[0], + load_pretrained_module(config.default_root_dir, config.domain_checkpoint), ) visual_domain.to(device) visual_domain.freeze() diff --git a/playground/train_gw.py b/playground/train_gw.py index d71d3a2..719e28b 100644 --- a/playground/train_gw.py +++ b/playground/train_gw.py @@ -11,7 +11,7 @@ ) from lightning.pytorch.loggers.wandb import WandbLogger from migrate_ckpt.migrate import get_folder_migrations -from shimmer import ContrastiveLossType, LossCoefs +from shimmer import ContrastiveLossType, GlobalWorkspaceBase, LossCoefs from shimmer.modules.global_workspace import ( GlobalWorkspace, GlobalWorkspaceFusion, @@ -95,6 +95,7 @@ def main(): torch.tensor([1 / 0.07]).log(), ) + module: GlobalWorkspaceBase if config.global_workspace.has_uncertainty: module = GlobalWorkspaceWithUncertainty( domain_modules, diff --git a/simple_shapes_dataset/modules/domains/attribute.py b/simple_shapes_dataset/modules/domains/attribute.py index b5d826e..2471e0f 100644 --- a/simple_shapes_dataset/modules/domains/attribute.py +++ b/simple_shapes_dataset/modules/domains/attribute.py @@ -73,8 +73,8 @@ def __init__( nn.Tanh(), ) - def forward(self, z: torch.Tensor) -> list[torch.Tensor]: - out = self.decoder(z) + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + out = self.decoder(x) return [self.decoder_categories(out), self.decoder_attributes(out)] @@ -125,7 +125,7 @@ def decode(self, z: torch.Tensor) -> list[torch.Tensor]: out.append(torch.zeros_like(z[:, -1])) return out - def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: + def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore return self.decode(self.encode(x)) def generic_step( @@ -168,13 +168,13 @@ def generic_step( self.log(f"{mode}/loss", total_loss) return total_loss - def validation_step( + def validation_step( # type: ignore self, batch: Mapping[str, Sequence[torch.Tensor]], _ ) -> torch.Tensor: x = batch["attr"] return self.generic_step(x, "val") - def training_step( + def training_step( # type: ignore self, batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]], _, @@ -182,7 +182,7 @@ def training_step( x = batch[frozenset(["attr"])]["attr"] return self.generic_step(x, "train") - def configure_optimizers( + def configure_optimizers( # type: ignore self, ) -> dict[str, Any]: optimizer = torch.optim.AdamW( @@ -251,7 +251,7 @@ def decode(self, z: torch.Tensor) -> list[torch.Tensor]: out.append(unpaired) return out - def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: + def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore return self.decode(self.encode(x)) def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: @@ -294,5 +294,5 @@ def decode(self, z: torch.Tensor) -> list[torch.Tensor]: unpaired = torch.zeros_like(z[:, 0]) return [categories, attr, unpaired] - def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: + def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore return self.decode(self.encode(x)) diff --git a/simple_shapes_dataset/modules/domains/text.py b/simple_shapes_dataset/modules/domains/text.py index a43e71c..843843b 100644 --- a/simple_shapes_dataset/modules/domains/text.py +++ b/simple_shapes_dataset/modules/domains/text.py @@ -207,13 +207,13 @@ def generic_step( self.log(f"{mode}/loss", total_loss) return total_loss - def validation_step( + def validation_step( # type: ignore self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ ) -> torch.Tensor: x = batch["t"] return self.generic_step(x, "val") - def training_step( + def training_step( # type: ignore self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _, @@ -221,7 +221,7 @@ def training_step( x = batch[frozenset(["t"])]["t"] return self.generic_step(x, "train") - def configure_optimizers( + def configure_optimizers( # type: ignore self, ) -> dict[str, Any]: optimizer = torch.optim.AdamW( From 44d623502cc5a5b79f369c6a1097ed00a471b041 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Mon, 25 Mar 2024 15:07:24 +0000 Subject: [PATCH 4/9] update dependencies --- poetry.lock | 49 ++++++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/poetry.lock b/poetry.lock index df4e1da..e6ab1f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -453,18 +453,18 @@ six = ">=1.4.0" [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.2" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.2-py3-none-any.whl", hash = "sha256:e4c33bc026ace328551af557d4d34f59566c98acd4ed66c13b4335f114f04f7a"}, + {file = "filelock-3.13.2.tar.gz", hash = "sha256:9e2106260b5f65600a31bc503721e3db7e64598bb406ebc5921aeaafe441ba34"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -689,13 +689,13 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre [[package]] name = "huggingface-hub" -version = "0.21.4" +version = "0.22.0" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"}, - {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"}, + {file = "huggingface_hub-0.22.0-py3-none-any.whl", hash = "sha256:72dea96299751699180184c06a4689e54cbfacecb1a3d08ac7a269c884bb17c3"}, + {file = "huggingface_hub-0.22.0.tar.gz", hash = "sha256:304f1e235c68c0a9f58bced47f13d6df241a5b4e3678f4981aa1e4f4bce63f6d"}, ] [package.dependencies] @@ -708,15 +708,16 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] hf-transfer = ["hf-transfer (>=0.1.4)"] -inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] -quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] +inference = ["aiohttp", "minijinja (>=1.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] @@ -947,13 +948,13 @@ test = ["click (==8.1.7)", "cloudpickle (>=1.3,<3.0)", "coverage (==7.3.1)", "fa [[package]] name = "lightning-utilities" -version = "0.11.0" +version = "0.11.1" description = "Lightning toolbox for across the our ecosystem." optional = false python-versions = ">=3.8" files = [ - {file = "lightning-utilities-0.11.0.tar.gz", hash = "sha256:dd704795785ceba1e0cd60ba3a9b0553c7902ec9efc1578a74e893a291416e62"}, - {file = "lightning_utilities-0.11.0-py3-none-any.whl", hash = "sha256:bf576a421027fdbaf48e80cbc2fdf900a3316a469748a953c33a8ca2b2718a20"}, + {file = "lightning-utilities-0.11.1.tar.gz", hash = "sha256:5002eaadfb8caa2cd2bb7b845748d4e8e77ac275137b66b10182db3ecb0e4867"}, + {file = "lightning_utilities-0.11.1-py3-none-any.whl", hash = "sha256:89381384c79efcc958682d0a618acde2d9ae737ad3c673353839ba8f12cd0085"}, ] [package.dependencies] @@ -1643,13 +1644,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pre-commit" -version = "3.6.2" +version = "3.7.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.9" files = [ - {file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"}, - {file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"}, + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, ] [package.dependencies] @@ -2417,8 +2418,10 @@ files = [] develop = false [package.dependencies] +click = "^8.1.7" lightning = "^2.1.0" matplotlib = "^3.7.1" +migrate-ckpt = {git = "https://github.com/bdvllrs/migrate-ckpt.git", rev = "main"} numpy = "^1.25" pandas = "^2.0.2" torch = "^2.0.1" @@ -2427,7 +2430,7 @@ torch = "^2.0.1" type = "git" url = "git@github.com:bdvllrs/shimmer.git" reference = "main" -resolved_reference = "9fa9217493173b1fa42b747fc5ba9ac96ad4d3d6" +resolved_reference = "54b057f08a1bcad93b48aababd8eb51dd34cfaa1" [[package]] name = "six" @@ -2747,13 +2750,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.39.0" +version = "4.39.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.39.0-py3-none-any.whl", hash = "sha256:7801785b1f016d667467e8c372c1c3653c18fe32ba97952059e3bea79ba22b08"}, - {file = "transformers-4.39.0.tar.gz", hash = "sha256:517a13cd633b10bea01c92ab0b3059762872c7c29da3d223db9d28e926fe330d"}, + {file = "transformers-4.39.1-py3-none-any.whl", hash = "sha256:df167e08b27ab254044a38bb7c439461cd3916332205416e9b6b1592b517a1a5"}, + {file = "transformers-4.39.1.tar.gz", hash = "sha256:ab9c1e1912843b9976e6cc62b27cd5434284fc0dab465e1b660333acfa81c6bc"}, ] [package.dependencies] From 57d04e7be8b8dc3f9895e3d13d7669bdda41255d Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Mon, 25 Mar 2024 15:11:34 +0000 Subject: [PATCH 5/9] Use SaveMigration from shimmer for the GW --- playground/train_gw.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/playground/train_gw.py b/playground/train_gw.py index 719e28b..7e0dfc1 100644 --- a/playground/train_gw.py +++ b/playground/train_gw.py @@ -10,8 +10,7 @@ RichProgressBar, ) from lightning.pytorch.loggers.wandb import WandbLogger -from migrate_ckpt.migrate import get_folder_migrations -from shimmer import ContrastiveLossType, GlobalWorkspaceBase, LossCoefs +from shimmer import ContrastiveLossType, GlobalWorkspaceBase, LossCoefs, SaveMigrations from shimmer.modules.global_workspace import ( GlobalWorkspace, GlobalWorkspaceFusion, @@ -21,7 +20,6 @@ from torch import set_float32_matmul_precision from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR -from simple_shapes_dataset.ckpt_migrations import SaveMigrations from simple_shapes_dataset.config import load_config from simple_shapes_dataset.dataset import SimpleShapesDataModule from simple_shapes_dataset.dataset.pre_process import ( @@ -236,9 +234,7 @@ def main(): ) callbacks.extend( [ - SaveMigrations( - get_folder_migrations(PROJECT_DIR / "migrations" / "gw") - ), + SaveMigrations(), ModelCheckpoint( dirpath=checkpoint_dir, filename="{epoch}", From 5006441004c5d4bbeade6d71b1cda1f75878f737 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Mon, 25 Mar 2024 15:14:32 +0000 Subject: [PATCH 6/9] Use migrations in shimmer for GW --- migrations/gw/.gitignore | 0 migrations/gw/1_add_gw_interfaces.py | 21 ------------------ migrations/gw/2_del_gw_interfaces_hparam.py | 8 ------- migrations/gw/3_del_gw_interfaces.py | 22 ------------------- .../gw/4_del_buffers_put_models_in_gw_mod.py | 15 ------------- migrations/gw/5_del_coef_buffers.py | 11 ---------- simple_shapes_dataset/ckpt_migrations.py | 3 +++ 7 files changed, 3 insertions(+), 77 deletions(-) create mode 100644 migrations/gw/.gitignore delete mode 100644 migrations/gw/1_add_gw_interfaces.py delete mode 100644 migrations/gw/2_del_gw_interfaces_hparam.py delete mode 100644 migrations/gw/3_del_gw_interfaces.py delete mode 100644 migrations/gw/4_del_buffers_put_models_in_gw_mod.py delete mode 100644 migrations/gw/5_del_coef_buffers.py diff --git a/migrations/gw/.gitignore b/migrations/gw/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/migrations/gw/1_add_gw_interfaces.py b/migrations/gw/1_add_gw_interfaces.py deleted file mode 100644 index 3f2d1b6..0000000 --- a/migrations/gw/1_add_gw_interfaces.py +++ /dev/null @@ -1,21 +0,0 @@ -from migrate_ckpt import CkptType - - -def handle(ckpt: CkptType) -> CkptType: - new_state_dict = {} - for name, val in ckpt["state_dict"].items(): - new_name = name.replace( - "gw_mod.encoders.resnet", "gw_mod.gw_interfaces.resnet.encoder" - ) - new_name = new_name.replace( - "gw_mod.encoders.bge", "gw_mod.gw_interfaces.bge.encoder" - ) - new_name = new_name.replace( - "gw_mod.decoders.resnet", "gw_mod.gw_interfaces.resnet.decoder" - ) - new_name = new_name.replace( - "gw_mod.decoders.bge", "gw_mod.gw_interfaces.bge.decoder" - ) - new_state_dict[new_name] = val - ckpt["state_dict"] = new_state_dict - return ckpt diff --git a/migrations/gw/2_del_gw_interfaces_hparam.py b/migrations/gw/2_del_gw_interfaces_hparam.py deleted file mode 100644 index 7b3b844..0000000 --- a/migrations/gw/2_del_gw_interfaces_hparam.py +++ /dev/null @@ -1,8 +0,0 @@ -from migrate_ckpt import CkptType - - -def handle(ckpt: CkptType) -> CkptType: - if "hyper_parameters" in ckpt.keys(): - if "gw_interfaces" in ckpt["hyper_parameters"].keys(): - del ckpt["hyper_parameters"]["gw_interfaces"] - return ckpt diff --git a/migrations/gw/3_del_gw_interfaces.py b/migrations/gw/3_del_gw_interfaces.py deleted file mode 100644 index 52ccf75..0000000 --- a/migrations/gw/3_del_gw_interfaces.py +++ /dev/null @@ -1,22 +0,0 @@ -from migrate_ckpt import CkptType - - -def handle(ckpt: CkptType) -> CkptType: - new_state_dict = {} - for name, val in ckpt["state_dict"].items(): - if "gw_mod.gw_interfaces" in name and "domain_module" in name: - continue - elif "gw_mod.gw_interfaces" in name and "encoder" in name: - new_name = name.replace(".gw_interfaces", ".gw_encoders") - new_name = new_name.replace(".encoder", "") - new_state_dict[new_name] = val - elif "gw_mod.gw_interfaces" in name and "decoder" in name: - new_name = name.replace(".gw_interfaces", ".gw_decoders") - new_name = new_name.replace(".decoder", "") - new_state_dict[new_name] = val - elif "gw_interfaces" in name: - print(name) - else: - new_state_dict[name] = val - ckpt["state_dict"] = new_state_dict - return ckpt diff --git a/migrations/gw/4_del_buffers_put_models_in_gw_mod.py b/migrations/gw/4_del_buffers_put_models_in_gw_mod.py deleted file mode 100644 index bdc294e..0000000 --- a/migrations/gw/4_del_buffers_put_models_in_gw_mod.py +++ /dev/null @@ -1,15 +0,0 @@ -from migrate_ckpt import CkptType - - -def handle(ckpt: CkptType) -> CkptType: - new_state_dict = {} - for name, val in ckpt["state_dict"].items(): - if "loss_coefs.buffer" in name: - continue - if name[:12] == "domain_mods.": - name = "gw_mod." + name - if name[:18] == "gw_mod.domain_mods": - new_state_dict["loss_mod." + name] = val - new_state_dict[name] = val - ckpt["state_dict"] = new_state_dict - return ckpt diff --git a/migrations/gw/5_del_coef_buffers.py b/migrations/gw/5_del_coef_buffers.py deleted file mode 100644 index ecbd508..0000000 --- a/migrations/gw/5_del_coef_buffers.py +++ /dev/null @@ -1,11 +0,0 @@ -from migrate_ckpt import CkptType - - -def handle(ckpt: CkptType) -> CkptType: - new_state_dict = {} - for name, val in ckpt["state_dict"].items(): - if "coef_buffers." in name: - continue - new_state_dict[name] = val - ckpt["state_dict"] = new_state_dict - return ckpt diff --git a/simple_shapes_dataset/ckpt_migrations.py b/simple_shapes_dataset/ckpt_migrations.py index 4a691b5..0c90cfe 100644 --- a/simple_shapes_dataset/ckpt_migrations.py +++ b/simple_shapes_dataset/ckpt_migrations.py @@ -10,11 +10,14 @@ ckpt_migration_key, migrate_from_folder, ) +from shimmer import migrate_model as migrate_shimmer_model from simple_shapes_dataset import LOGGER def migrate_model(ckpt_path: str | PathLike, migration_path: str | PathLike, **kwargs): + migrate_shimmer_model(ckpt_path, **kwargs) + ckpt_path = Path(ckpt_path) ckpt = torch.load(ckpt_path, **kwargs) new_ckpt, done_migrations = migrate_from_folder(ckpt, migration_path) From 97658e7f4222bb7424b78a3f061bdccbdb283d68 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Mon, 25 Mar 2024 15:15:35 +0000 Subject: [PATCH 7/9] Use shimmer migrations in migrate script --- playground/migrate_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/playground/migrate_model.py b/playground/migrate_model.py index 643b5b3..695a6de 100644 --- a/playground/migrate_model.py +++ b/playground/migrate_model.py @@ -1,3 +1,5 @@ +from shimmer import migrate_model as migrate_shimmer_model + from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR from simple_shapes_dataset.ckpt_migrations import migrate_model from simple_shapes_dataset.config import load_config @@ -10,6 +12,7 @@ def main(): ) if config.global_workspace.checkpoint is not None: + migrate_shimmer_model(config.global_workspace.checkpoint) migrate_model( config.global_workspace.checkpoint, PROJECT_DIR / "migrations" / "gw" ) From c0b896aea8be753a93a6f92ea48d1f5c718bd8f6 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 26 Mar 2024 10:34:14 +0000 Subject: [PATCH 8/9] update dependencies --- poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index e6ab1f4..73c26f5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -453,13 +453,13 @@ six = ">=1.4.0" [[package]] name = "filelock" -version = "3.13.2" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.2-py3-none-any.whl", hash = "sha256:e4c33bc026ace328551af557d4d34f59566c98acd4ed66c13b4335f114f04f7a"}, - {file = "filelock-3.13.2.tar.gz", hash = "sha256:9e2106260b5f65600a31bc503721e3db7e64598bb406ebc5921aeaafe441ba34"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] @@ -2430,7 +2430,7 @@ torch = "^2.0.1" type = "git" url = "git@github.com:bdvllrs/shimmer.git" reference = "main" -resolved_reference = "54b057f08a1bcad93b48aababd8eb51dd34cfaa1" +resolved_reference = "b4fbb8d18f5102196d9a806fb0a894c64418df66" [[package]] name = "six" From 677b4060fa9b313585e9dbb3249e17ac9db60c6e Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 26 Mar 2024 10:39:58 +0000 Subject: [PATCH 9/9] Only migrate from shimmer if it's a global workspace --- simple_shapes_dataset/ckpt_migrations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/simple_shapes_dataset/ckpt_migrations.py b/simple_shapes_dataset/ckpt_migrations.py index 0c90cfe..b9f1420 100644 --- a/simple_shapes_dataset/ckpt_migrations.py +++ b/simple_shapes_dataset/ckpt_migrations.py @@ -16,7 +16,8 @@ def migrate_model(ckpt_path: str | PathLike, migration_path: str | PathLike, **kwargs): - migrate_shimmer_model(ckpt_path, **kwargs) + if Path(migration_path).name == "gw": + migrate_shimmer_model(ckpt_path, **kwargs) ckpt_path = Path(ckpt_path) ckpt = torch.load(ckpt_path, **kwargs)