Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SysVI for cycle consistency loss and VampPrior. #3195

Merged
merged 9 commits into from
Feb 23, 2025
Merged

Conversation

canergen
Copy link
Member

@canergen canergen commented Feb 19, 2025

@Hrovatin @moinfar superseding #2421. I harmonized the code more to other scvi-tools models. I don't think it should change anything in terms of output but happy if you run a small test case once. This is ready to be merged otherwise and I'll go ahead and merge it next week.

@canergen canergen added the on-merge: backport to 1.3.x on-merge: backport to 1.3.x label Feb 19, 2025
@canergen canergen changed the title Merge branch 'main' of https://github.com/Hrovatin/scvi-tools feat: SysVI for cycle consistency loss and VampPrior. Feb 19, 2025
Copy link

codecov bot commented Feb 19, 2025

Codecov Report

Attention: Patch coverage is 92.80822% with 21 lines in your changes missing coverage. Please review.

Project coverage is 82.65%. Comparing base (c3926eb) to head (b4624b3).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/scvi/external/sysvi/_base_components.py 87.27% 7 Missing ⚠️
src/scvi/external/sysvi/_module.py 94.40% 7 Missing ⚠️
src/scvi/external/sysvi/_model.py 88.67% 6 Missing ⚠️
src/scvi/external/sysvi/_priors.py 97.82% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (c3926eb) and HEAD (b4624b3). Click for more details.

HEAD has 18 uploads less than BASE
Flag BASE (c3926eb) HEAD (b4624b3)
21 3
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3195      +/-   ##
==========================================
- Coverage   89.27%   82.65%   -6.62%     
==========================================
  Files         185      190       +5     
  Lines       16265    16551     +286     
==========================================
- Hits        14520    13680     -840     
- Misses       1745     2871    +1126     
Files with missing lines Coverage Δ
src/scvi/external/__init__.py 100.00% <100.00%> (ø)
src/scvi/external/sysvi/__init__.py 100.00% <100.00%> (ø)
src/scvi/nn/_base_components.py 94.52% <100.00%> (+0.06%) ⬆️
src/scvi/external/sysvi/_priors.py 97.82% <97.82%> (ø)
src/scvi/external/sysvi/_model.py 88.67% <88.67%> (ø)
src/scvi/external/sysvi/_base_components.py 87.27% <87.27%> (ø)
src/scvi/external/sysvi/_module.py 94.40% <94.40%> (ø)

... and 27 files with indirect coverage changes

@Hrovatin
Copy link

Thank you very much. I will re-run the tutorial on this branch in the next few days.

@canergen canergen requested a review from Hrovatin February 20, 2025 15:45
Copy link

@Hrovatin Hrovatin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the changes broke something numerically. Any ideas why?


outputs = {"y_m": y_m, "y_v": y_v}
outputs = {"q_dist": Normal(q_m, q_v.sqrt())}
Copy link

@Hrovatin Hrovatin Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is causing err in tutorial. This did not happen before these latest changes and not in any of our numerous benchmarks.

However, would expect that something else is causing this

Please try running https://github.com/Hrovatin/scvi-tutorials/blob/stable/scrna/sysVI.ipynb as it has non-synthetic data.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 5
      3 # Train
      4 max_epochs = 200
----> 5 model.train(
      6     max_epochs=max_epochs,
      7     check_val_every_n_epoch=1,
      8 )

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_model.py:145, in SysVI.train(self, plan_kwargs, **train_kwargs)
    143 train_kwargs = train_kwargs or {}
    144 train_kwargs["plan_kwargs"] = plan_kwargs
--> 145 super().train(**train_kwargs)

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/model/base/_training_mixin.py:145, in UnsupervisedTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, load_sparse_tensor, batch_size, early_stopping, datasplitter_kwargs, plan_kwargs, datamodule, **trainer_kwargs)
    133 trainer_kwargs[es] = (
    134     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
    135 )
    136 runner = self._train_runner_cls(
    137     self,
    138     training_plan=training_plan,
   (...)
    143     **trainer_kwargs,
    144 )
--> 145 return runner()

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainrunner.py:96, in TrainRunner.__call__(self)
     93 if hasattr(self.data_splitter, "n_val"):
     94     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 96 self.trainer.fit(self.training_plan, self.data_splitter)
     97 self._update_history()
     99 # data splitter only gets these attrs after fit

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainer.py:201, in Trainer.fit(self, *args, **kwargs)
    195 if isinstance(args[0], PyroTrainingPlan):
    196     warnings.filterwarnings(
    197         action="ignore",
    198         category=UserWarning,
    199         message="`LightningModule.configure_optimizers` returned `None`",
    200     )
--> 201 super().fit(*args, **kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
    984 self._signal_connector.register_signal_handlers()
    986 # ----------------------------
    987 # RUN THE TRAINER
    988 # ----------------------------
--> 989 results = self._run_stage()
    991 # ----------------------------
    992 # POST-Training CLEAN UP
    993 # ----------------------------
    994 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
   1033         self._run_sanity_check()
   1034     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035         self.fit_loop.run()
   1036     return None
   1037 raise RuntimeError(f"Unexpected state {self.state}")

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
    357 with self.trainer.profiler.profile("run_training_epoch"):
    358     assert self._data_fetcher is not None
--> 359     self.epoch_loop.run(self._data_fetcher)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
    134 while not self.done:
    135     try:
--> 136         self.advance(data_fetcher)
    137         self.on_advance_end(data_fetcher)
    138         self._restarting = False

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
    237 with trainer.profiler.profile("run_training_batch"):
    238     if trainer.lightning_module.automatic_optimization:
    239         # in automatic optimization, there can only be one optimizer
--> 240         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    241     else:
    242         batch_output = self.manual_optimization.run(kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    180         closure()
    182 # ------------------------------
    183 # BACKWARD PASS
    184 # ------------------------------
    185 # gradient update with accumulated gradients
    186 else:
--> 187     self._optimizer_step(batch_idx, closure)
    189 result = closure.consume_result()
    190 if result.loss is None:

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    262     self.optim_progress.optimizer.step.increment_ready()
    264 # model hook
--> 265 call._call_lightning_module_hook(
    266     trainer,
    267     "optimizer_step",
    268     trainer.current_epoch,
    269     batch_idx,
    270     optimizer,
    271     train_step_and_backward_closure,
    272 )
    274 if not should_accumulate:
    275     self.optim_progress.optimizer.step.increment_completed()

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    154 pl_module._current_fx_name = hook_name
    156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157     output = fn(*args, **kwargs)
    159 # restore current_fx when nested context
    160 pl_module._current_fx_name = prev_fx_name

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1252 def optimizer_step(
   1253     self,
   1254     epoch: int,
   (...)
   1257     optimizer_closure: Optional[Callable[[], Any]] = None,
   1258 ) -> None:
   1259     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1260     the optimizer.
   1261 
   (...)
   1289 
   1290     """
-> 1291     optimizer.step(closure=optimizer_closure)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
    148     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    153 self._on_after_step()
    155 return step_output

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    115 """Hook to run the optimizer step."""
    116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/optim/optimizer.py:91, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     89     torch.set_grad_enabled(self.defaults["differentiable"])
     90     torch._dynamo.graph_break()
---> 91     ret = func(self, *args, **kwargs)
     92 finally:
     93     torch._dynamo.graph_break()

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/optim/adam.py:202, in Adam.step(self, closure)
    200 if closure is not None:
    201     with torch.enable_grad():
--> 202         loss = closure()
    204 for group in self.param_groups:
    205     params_with_grad: List[Tensor] = []

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
     91 def _wrap_closure(
     92     self,
     93     model: "pl.LightningModule",
     94     optimizer: Optimizer,
     95     closure: Callable[[], Any],
     96 ) -> Any:
     97     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
     98     hook is called.
     99 
   (...)
    102 
    103     """
--> 104     closure_result = closure()
    105     self._after_closure(model, optimizer)
    106     return closure_result

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    124 @torch.enable_grad()
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
    312 trainer = self.trainer
    314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    316 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    306     return None
    308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309     output = fn(*args, **kwargs)
    311 # restore current_fx when nested context
    312 pl_module._current_fx_name = prev_fx_name

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
    380 if self.model != self.lightning_module:
    381     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainingplans.py:364, in TrainingPlan.training_step(self, batch, batch_idx)
    362     self.loss_kwargs.update({"kl_weight": kl_weight})
    363     self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 364 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
    365 self.log(
    366     "train_loss",
    367     scvi_loss.loss,
   (...)
    370     sync_dist=self.use_sync_dist,
    371 )
    372 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainingplans.py:294, in TrainingPlan.forward(self, *args, **kwargs)
    292 def forward(self, *args, **kwargs):
    293     """Passthrough to the module's forward method."""
--> 294     return self.module(
    295         *args,
    296         **kwargs,
    297         get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder},
    298     )

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_module.py:351, in SysVAE.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    349 # Inference
    350 inference_inputs = self._get_inference_input(tensors, **get_inference_input_kwargs)
--> 351 inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
    352 # Generative
    353 cycle_batch = self.random_select_batch(tensors[REGISTRY_KEYS.BATCH_KEY])

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_module.py:296, in SysVAE.inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    286 @auto_move_data
    287 def inference(
    288     self,
   (...)
    293     n_samples: int = 1,
    294 ) -> dict[str, torch.Tensor | Distribution | None]:
    295     """Inference: expression & covariates -> latent representation."""
--> 296     result = self.encoder(x, batch_index=batch_index, cat_list=cat_covs, cont=cont_covs)
    297     z, qz = result["q"], result["q_dist"]
    298     return {
    299         MODULE_KEYS.Z_KEY: z,
    300         MODULE_KEYS.QZ_KEY: qz,
    301     }

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_base_components.py:129, in EncoderDecoder.forward(self, x, batch_index, cont, cat_list)
    126     q_m = torch.nan_to_num(q_m)
    127 q_v = self.var_encoder(q_)
--> 129 outputs = {"q_dist": Normal(q_m, q_v.sqrt())}
    131 if self.sample:
    132     outputs["q"] = outputs["q_dist"].rsample()

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/distributions/normal.py:59, in Normal.__init__(self, loc, scale, validate_args)
     57 else:
     58     batch_shape = self.loc.size()
---> 59 super().__init__(batch_shape, validate_args=validate_args)

File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/distributions/distribution.py:71, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     69         valid = constraint.check(value)
     70         if not valid.all():
---> 71             raise ValueError(
     72                 f"Expected parameter {param} "
     73                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     74                 f"of distribution {repr(self)} "
     75                 f"to satisfy the constraint {repr(constraint)}, "
     76                 f"but found invalid values:\n{value}"
     77             )
     78 super().__init__()

ValueError: Expected parameter scale (Tensor of shape (128, 15)) of distribution Normal(loc: torch.Size([128, 15]), scale: torch.Size([128, 15])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SqrtBackward0>)

@canergen
Copy link
Member Author

@Hrovatin merging the encoder dictionaries was wrong after changing it (see commit) the performance looks fine in the tutorial.
Screenshot 2025-02-22 at 9 56 01 PM

@Hrovatin
Copy link

Hrovatin commented Feb 23, 2025

@canergen thank you for resolving.

I also updated tutorial (see scverse/scvi-tutorials#212). - Would be great if the tutorial could now be added to scvi tutorials

@canergen canergen merged commit a9e45bf into main Feb 23, 2025
14 of 15 checks passed
meeseeksmachine pushed a commit to meeseeksmachine/scvi-tools that referenced this pull request Feb 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
on-merge: backport to 1.3.x on-merge: backport to 1.3.x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants