Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update shimmer to latest version (WIP) #8

Merged
merged 9 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added migrations/gw/.gitignore
Empty file.
21 changes: 0 additions & 21 deletions migrations/gw/1_add_gw_interfaces.py

This file was deleted.

8 changes: 0 additions & 8 deletions migrations/gw/2_del_gw_interfaces_hparam.py

This file was deleted.

22 changes: 0 additions & 22 deletions migrations/gw/3_del_gw_interfaces.py

This file was deleted.

15 changes: 0 additions & 15 deletions migrations/gw/4_del_buffers_put_models_in_gw_mod.py

This file was deleted.

11 changes: 0 additions & 11 deletions migrations/gw/5_del_coef_buffers.py

This file was deleted.

3 changes: 3 additions & 0 deletions playground/migrate_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from shimmer import migrate_model as migrate_shimmer_model

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import migrate_model
from simple_shapes_dataset.config import load_config
Expand All @@ -10,6 +12,7 @@ def main():
)

if config.global_workspace.checkpoint is not None:
migrate_shimmer_model(config.global_workspace.checkpoint)
migrate_model(
config.global_workspace.checkpoint, PROJECT_DIR / "migrations" / "gw"
)
Expand Down
2 changes: 1 addition & 1 deletion playground/save_v_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():

visual_domain = cast(
VisualDomainModule,
load_pretrained_module(config.default_root_dir, config.domain_checkpoint)[0],
load_pretrained_module(config.default_root_dir, config.domain_checkpoint),
)
visual_domain.to(device)
visual_domain.freeze()
Expand Down
15 changes: 6 additions & 9 deletions playground/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
RichProgressBar,
)
from lightning.pytorch.loggers.wandb import WandbLogger
from migrate_ckpt.migrate import get_folder_migrations
from shimmer import ContrastiveLossType, LossCoefs
from shimmer import ContrastiveLossType, GlobalWorkspaceBase, LossCoefs, SaveMigrations
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspaceFusion,
Expand All @@ -21,7 +20,6 @@
from torch import set_float32_matmul_precision

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import SaveMigrations
from simple_shapes_dataset.config import load_config
from simple_shapes_dataset.dataset import SimpleShapesDataModule
from simple_shapes_dataset.dataset.pre_process import (
Expand Down Expand Up @@ -79,7 +77,7 @@ def main():
bias=config.global_workspace.linear_domains_use_bias,
)

loss_coefs: dict[str, float] = {
loss_coefs: LossCoefs = {
"demi_cycles": config.global_workspace.loss_coefficients.demi_cycles,
"cycles": config.global_workspace.loss_coefficients.cycles,
"translations": config.global_workspace.loss_coefficients.translations,
Expand All @@ -95,13 +93,14 @@ def main():
torch.tensor([1 / 0.07]).log(),
)

module: GlobalWorkspaceBase
if config.global_workspace.has_uncertainty:
module = GlobalWorkspaceWithUncertainty(
domain_modules,
gw_encoders,
gw_decoders,
config.global_workspace.latent_dim,
LossCoefs(**loss_coefs),
loss_coefs,
config.global_workspace.cont_loss_with_uncertainty,
config.training.optim.lr,
config.training.optim.weight_decay,
Expand Down Expand Up @@ -133,7 +132,7 @@ def main():
gw_encoders,
gw_decoders,
config.global_workspace.latent_dim,
LossCoefs(**loss_coefs),
loss_coefs,
config.training.optim.lr,
config.training.optim.weight_decay,
scheduler_args=SchedulerArgs(
Expand Down Expand Up @@ -235,9 +234,7 @@ def main():
)
callbacks.extend(
[
SaveMigrations(
get_folder_migrations(PROJECT_DIR / "migrations" / "gw")
),
SaveMigrations(),
ModelCheckpoint(
dirpath=checkpoint_dir,
filename="{epoch}",
Expand Down
49 changes: 26 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions simple_shapes_dataset/ckpt_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
ckpt_migration_key,
migrate_from_folder,
)
from shimmer import migrate_model as migrate_shimmer_model

from simple_shapes_dataset import LOGGER


def migrate_model(ckpt_path: str | PathLike, migration_path: str | PathLike, **kwargs):
if Path(migration_path).name == "gw":
migrate_shimmer_model(ckpt_path, **kwargs)

ckpt_path = Path(ckpt_path)
ckpt = torch.load(ckpt_path, **kwargs)
new_ckpt, done_migrations = migrate_from_folder(ckpt, migration_path)
Expand Down
9 changes: 6 additions & 3 deletions simple_shapes_dataset/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from matplotlib import gridspec
from matplotlib.figure import Figure
from PIL import Image
from shimmer import batch_cycles, batch_demi_cycles, batch_translations
from shimmer.modules.global_workspace import GlobalWorkspaceBase
from torchvision.utils import make_grid

Expand Down Expand Up @@ -384,9 +385,11 @@ def on_callback(

with torch.no_grad():
pl_module.eval()
prediction_demi_cycles = pl_module.batch_demi_cycles(latents)
prediction_cycles = pl_module.batch_cycles(latents)
prediction_translations = pl_module.batch_translations(latents)
prediction_demi_cycles = batch_demi_cycles(pl_module.gw_mod, latents)
prediction_cycles = batch_cycles(
pl_module.gw_mod, latents, pl_module.domain_mods.keys()
)
prediction_translations = batch_translations(pl_module.gw_mod, latents)
pl_module.train()

for logger in loggers:
Expand Down
Loading
Loading