Skip to content

Commit

Permalink
Fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 22, 2024
1 parent f9e5d46 commit edfc6ab
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion playground/save_v_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():

visual_domain = cast(
VisualDomainModule,
load_pretrained_module(config.default_root_dir, config.domain_checkpoint)[0],
load_pretrained_module(config.default_root_dir, config.domain_checkpoint),
)
visual_domain.to(device)
visual_domain.freeze()
Expand Down
3 changes: 2 additions & 1 deletion playground/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from lightning.pytorch.loggers.wandb import WandbLogger
from migrate_ckpt.migrate import get_folder_migrations
from shimmer import ContrastiveLossType, LossCoefs
from shimmer import ContrastiveLossType, GlobalWorkspaceBase, LossCoefs
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspaceFusion,
Expand Down Expand Up @@ -95,6 +95,7 @@ def main():
torch.tensor([1 / 0.07]).log(),
)

module: GlobalWorkspaceBase
if config.global_workspace.has_uncertainty:
module = GlobalWorkspaceWithUncertainty(
domain_modules,
Expand Down
16 changes: 8 additions & 8 deletions simple_shapes_dataset/modules/domains/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(
nn.Tanh(),
)

def forward(self, z: torch.Tensor) -> list[torch.Tensor]:
out = self.decoder(z)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
out = self.decoder(x)
return [self.decoder_categories(out), self.decoder_attributes(out)]


Expand Down Expand Up @@ -125,7 +125,7 @@ def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
out.append(torch.zeros_like(z[:, -1]))
return out

def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore
return self.decode(self.encode(x))

def generic_step(
Expand Down Expand Up @@ -168,21 +168,21 @@ def generic_step(
self.log(f"{mode}/loss", total_loss)
return total_loss

def validation_step(
def validation_step( # type: ignore
self, batch: Mapping[str, Sequence[torch.Tensor]], _
) -> torch.Tensor:
x = batch["attr"]
return self.generic_step(x, "val")

def training_step(
def training_step( # type: ignore
self,
batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]],
_,
) -> torch.Tensor:
x = batch[frozenset(["attr"])]["attr"]
return self.generic_step(x, "train")

def configure_optimizers(
def configure_optimizers( # type: ignore
self,
) -> dict[str, Any]:
optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -251,7 +251,7 @@ def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
out.append(unpaired)
return out

def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore
return self.decode(self.encode(x))

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
Expand Down Expand Up @@ -294,5 +294,5 @@ def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
unpaired = torch.zeros_like(z[:, 0])
return [categories, attr, unpaired]

def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore
return self.decode(self.encode(x))
6 changes: 3 additions & 3 deletions simple_shapes_dataset/modules/domains/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,21 @@ def generic_step(
self.log(f"{mode}/loss", total_loss)
return total_loss

def validation_step(
def validation_step( # type: ignore
self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
) -> torch.Tensor:
x = batch["t"]
return self.generic_step(x, "val")

def training_step(
def training_step( # type: ignore
self,
batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
_,
) -> torch.Tensor:
x = batch[frozenset(["t"])]["t"]
return self.generic_step(x, "train")

def configure_optimizers(
def configure_optimizers( # type: ignore
self,
) -> dict[str, Any]:
optimizer = torch.optim.AdamW(
Expand Down

0 comments on commit edfc6ab

Please sign in to comment.