-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: SysVI for cycle consistency loss and VampPrior. #3195
Conversation
for more information, see https://pre-commit.ci
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3195 +/- ##
==========================================
- Coverage 89.27% 82.65% -6.62%
==========================================
Files 185 190 +5
Lines 16265 16551 +286
==========================================
- Hits 14520 13680 -840
- Misses 1745 2871 +1126
|
Thank you very much. I will re-run the tutorial on this branch in the next few days. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the changes broke something numerically. Any ideas why?
|
||
outputs = {"y_m": y_m, "y_v": y_v} | ||
outputs = {"q_dist": Normal(q_m, q_v.sqrt())} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is causing err in tutorial. This did not happen before these latest changes and not in any of our numerous benchmarks.
However, would expect that something else is causing this
Please try running https://github.com/Hrovatin/scvi-tutorials/blob/stable/scrna/sysVI.ipynb as it has non-synthetic data.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[5], line 5
3 # Train
4 max_epochs = 200
----> 5 model.train(
6 max_epochs=max_epochs,
7 check_val_every_n_epoch=1,
8 )
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_model.py:145, in SysVI.train(self, plan_kwargs, **train_kwargs)
143 train_kwargs = train_kwargs or {}
144 train_kwargs["plan_kwargs"] = plan_kwargs
--> 145 super().train(**train_kwargs)
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/model/base/_training_mixin.py:145, in UnsupervisedTrainingMixin.train(self, max_epochs, accelerator, devices, train_size, validation_size, shuffle_set_split, load_sparse_tensor, batch_size, early_stopping, datasplitter_kwargs, plan_kwargs, datamodule, **trainer_kwargs)
133 trainer_kwargs[es] = (
134 early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
135 )
136 runner = self._train_runner_cls(
137 self,
138 training_plan=training_plan,
(...)
143 **trainer_kwargs,
144 )
--> 145 return runner()
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainrunner.py:96, in TrainRunner.__call__(self)
93 if hasattr(self.data_splitter, "n_val"):
94 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 96 self.trainer.fit(self.training_plan, self.data_splitter)
97 self._update_history()
99 # data splitter only gets these attrs after fit
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainer.py:201, in Trainer.fit(self, *args, **kwargs)
195 if isinstance(args[0], PyroTrainingPlan):
196 warnings.filterwarnings(
197 action="ignore",
198 category=UserWarning,
199 message="`LightningModule.configure_optimizers` returned `None`",
200 )
--> 201 super().fit(*args, **kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
984 self._signal_connector.register_signal_handlers()
986 # ----------------------------
987 # RUN THE TRAINER
988 # ----------------------------
--> 989 results = self._run_stage()
991 # ----------------------------
992 # POST-Training CLEAN UP
993 # ----------------------------
994 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
1033 self._run_sanity_check()
1034 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035 self.fit_loop.run()
1036 return None
1037 raise RuntimeError(f"Unexpected state {self.state}")
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
200 try:
201 self.on_advance_start()
--> 202 self.advance()
203 self.on_advance_end()
204 self._restarting = False
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
357 with self.trainer.profiler.profile("run_training_epoch"):
358 assert self._data_fetcher is not None
--> 359 self.epoch_loop.run(self._data_fetcher)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
134 while not self.done:
135 try:
--> 136 self.advance(data_fetcher)
137 self.on_advance_end(data_fetcher)
138 self._restarting = False
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
237 with trainer.profiler.profile("run_training_batch"):
238 if trainer.lightning_module.automatic_optimization:
239 # in automatic optimization, there can only be one optimizer
--> 240 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
241 else:
242 batch_output = self.manual_optimization.run(kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
180 closure()
182 # ------------------------------
183 # BACKWARD PASS
184 # ------------------------------
185 # gradient update with accumulated gradients
186 else:
--> 187 self._optimizer_step(batch_idx, closure)
189 result = closure.consume_result()
190 if result.loss is None:
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
262 self.optim_progress.optimizer.step.increment_ready()
264 # model hook
--> 265 call._call_lightning_module_hook(
266 trainer,
267 "optimizer_step",
268 trainer.current_epoch,
269 batch_idx,
270 optimizer,
271 train_step_and_backward_closure,
272 )
274 if not should_accumulate:
275 self.optim_progress.optimizer.step.increment_completed()
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
154 pl_module._current_fx_name = hook_name
156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157 output = fn(*args, **kwargs)
159 # restore current_fx when nested context
160 pl_module._current_fx_name = prev_fx_name
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
1252 def optimizer_step(
1253 self,
1254 epoch: int,
(...)
1257 optimizer_closure: Optional[Callable[[], Any]] = None,
1258 ) -> None:
1259 r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1260 the optimizer.
1261
(...)
1289
1290 """
-> 1291 optimizer.step(closure=optimizer_closure)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
148 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
153 self._on_after_step()
155 return step_output
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
115 """Hook to run the optimizer step."""
116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
482 else:
483 raise RuntimeError(
484 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
485 )
--> 487 out = func(*args, **kwargs)
488 self._optimizer_step_code()
490 # call optimizer step post hooks
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/optim/optimizer.py:91, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
89 torch.set_grad_enabled(self.defaults["differentiable"])
90 torch._dynamo.graph_break()
---> 91 ret = func(self, *args, **kwargs)
92 finally:
93 torch._dynamo.graph_break()
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/optim/adam.py:202, in Adam.step(self, closure)
200 if closure is not None:
201 with torch.enable_grad():
--> 202 loss = closure()
204 for group in self.param_groups:
205 params_with_grad: List[Tensor] = []
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
91 def _wrap_closure(
92 self,
93 model: "pl.LightningModule",
94 optimizer: Optimizer,
95 closure: Callable[[], Any],
96 ) -> Any:
97 """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
98 hook is called.
99
(...)
102
103 """
--> 104 closure_result = closure()
105 self._after_closure(model, optimizer)
106 return closure_result
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140 self._result = self.closure(*args, **kwargs)
141 return self._result.loss
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
124 @torch.enable_grad()
125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126 step_output = self._step_fn()
128 if step_output.closure_loss is None:
129 self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
312 trainer = self.trainer
314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
316 self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
306 return None
308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309 output = fn(*args, **kwargs)
311 # restore current_fx when nested context
312 pl_module._current_fx_name = prev_fx_name
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
380 if self.model != self.lightning_module:
381 return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainingplans.py:364, in TrainingPlan.training_step(self, batch, batch_idx)
362 self.loss_kwargs.update({"kl_weight": kl_weight})
363 self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 364 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
365 self.log(
366 "train_loss",
367 scvi_loss.loss,
(...)
370 sync_dist=self.use_sync_dist,
371 )
372 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/train/_trainingplans.py:294, in TrainingPlan.forward(self, *args, **kwargs)
292 def forward(self, *args, **kwargs):
293 """Passthrough to the module's forward method."""
--> 294 return self.module(
295 *args,
296 **kwargs,
297 get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder},
298 )
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_module.py:351, in SysVAE.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
349 # Inference
350 inference_inputs = self._get_inference_input(tensors, **get_inference_input_kwargs)
--> 351 inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
352 # Generative
353 cycle_batch = self.random_select_batch(tensors[REGISTRY_KEYS.BATCH_KEY])
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_module.py:296, in SysVAE.inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
286 @auto_move_data
287 def inference(
288 self,
(...)
293 n_samples: int = 1,
294 ) -> dict[str, torch.Tensor | Distribution | None]:
295 """Inference: expression & covariates -> latent representation."""
--> 296 result = self.encoder(x, batch_index=batch_index, cat_list=cat_covs, cont=cont_covs)
297 z, qz = result["q"], result["q_dist"]
298 return {
299 MODULE_KEYS.Z_KEY: z,
300 MODULE_KEYS.QZ_KEY: qz,
301 }
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/Documents/GitHub/scvi-tools-Hrovatin/src/scvi/external/sysvi/_base_components.py:129, in EncoderDecoder.forward(self, x, batch_index, cont, cat_list)
126 q_m = torch.nan_to_num(q_m)
127 q_v = self.var_encoder(q_)
--> 129 outputs = {"q_dist": Normal(q_m, q_v.sqrt())}
131 if self.sample:
132 outputs["q"] = outputs["q_dist"].rsample()
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/distributions/normal.py:59, in Normal.__init__(self, loc, scale, validate_args)
57 else:
58 batch_shape = self.loc.size()
---> 59 super().__init__(batch_shape, validate_args=validate_args)
File ~/miniconda3/envs/scvi/lib/python3.10/site-packages/torch/distributions/distribution.py:71, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
69 valid = constraint.check(value)
70 if not valid.all():
---> 71 raise ValueError(
72 f"Expected parameter {param} "
73 f"({type(value).__name__} of shape {tuple(value.shape)}) "
74 f"of distribution {repr(self)} "
75 f"to satisfy the constraint {repr(constraint)}, "
76 f"but found invalid values:\n{value}"
77 )
78 super().__init__()
ValueError: Expected parameter scale (Tensor of shape (128, 15)) of distribution Normal(loc: torch.Size([128, 15]), scale: torch.Size([128, 15])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], grad_fn=<SqrtBackward0>)
@Hrovatin merging the encoder dictionaries was wrong after changing it (see commit) the performance looks fine in the tutorial. |
@canergen thank you for resolving. I also updated tutorial (see scverse/scvi-tutorials#212). - Would be great if the tutorial could now be added to scvi tutorials |
@Hrovatin @moinfar superseding #2421. I harmonized the code more to other scvi-tools models. I don't think it should change anything in terms of output but happy if you run a small test case once. This is ready to be merged otherwise and I'll go ahead and merge it next week.