Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: NaNs detected in the generator loss #3

Open
zoezhang106 opened this issue Jun 10, 2024 · 0 comments
Open

RuntimeError: NaNs detected in the generator loss #3

zoezhang106 opened this issue Jun 10, 2024 · 0 comments

Comments

@zoezhang106
Copy link

Thank you very much for the excellent work.

I am facing a problem of "NaNs detected in the generator loss", when I try to reproduce experiments_02_sources_of_gain_parametric.ipynb

I add print for real_output, fake_output, errG_noncen, errG_cen, errG_discr to better understand where is come from.

Would be very appreciate if you can help me point out why it happens and how to avoid it.

Error message:
File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/models/time_to_event/tte_date.py:266, in TimeEventGAN._train_epoch_generator(self, X, T, E)
264 print("errG_cen",errG_cen)
265 print("errG_discr",errG_discr)
--> 266 raise RuntimeError("NaNs detected in the generator loss")
268 # Calculate gradients for G
269 errG.backward()

RuntimeError: NaNs detected in the generator loss

System Information

  • OS: [iOS]
  • Language Version: [Python 3.10]

Complete error message

....after hundreds of iterations working fine ...
real_output tensor(8.1511, grad_fn=)
fake_output tensor(-7.9771, grad_fn=)
real_output tensor(5.5393, grad_fn=)
fake_output tensor(-7.9846, grad_fn=)
real_output tensor(5.3318, grad_fn=)
fake_output tensor(-8.4363, grad_fn=)
real_output tensor(5.8340, grad_fn=)
fake_output tensor(-8.7527, grad_fn=)
real_output tensor(4.9789, grad_fn=)
fake_output tensor(-8.6458, grad_fn=)
real_output tensor(nan, grad_fn=)
fake_output tensor(nan, grad_fn=)
errG_noncen tensor(0.1083, grad_fn=)
errG_cen tensor(0.0708, grad_fn=)
errG_discr tensor(nan, grad_fn=)

RuntimeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 base_score = evaluate_dataset("aids", gain_scenarios)

Cell In[1], line 77, in evaluate_dataset(dataset, scenarios)
72 for scenario_name, scenario_args in scenarios:
73 bkp = (
74 out_dir / f"experiment_{experiment}{dataset}{scenario_name}_{repeats}.bkp"
75 )
---> 77 score = Benchmarks.evaluate(
78 [(scenario_name, "survival_gan", scenario_args)],
79 SurvivalAnalysisDataLoader(
80 df,
81 target_column=event_col,
82 time_to_event_column=duration_col,
83 time_horizons=time_horizons,
84 ),
85 task_type="survival_analysis",
86 synthetic_size=len(df),
87 repeats=repeats,
88 )
89 save_to_file(bkp, score)
91 print("Scenario", scenario_name, scenario_args)

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/benchmark/init.py:194, in Benchmarks.evaluate(tests, X, X_test, metrics, repeats, synthetic_size, synthetic_constraints, synthetic_cache, synthetic_reuse_if_exists, augmented_reuse_if_exists, task_type, workspace, augmentation_rule, strict_augmentation, ad_hoc_augment_vals, use_metric_cache, **generate_kwargs)
188 else:
189 generator = Plugins(categories=plugin_cats).get(
190 plugin,
191 **kwargs,
192 )
--> 194 generator.fit(X.train())
196 if synthetic_cache:
197 save_to_file(generator_file, generator)

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/plugin.py:256, in Plugin.fit(self, X, *args, **kwargs)
248 X, self.compress_context = load_from_file(bkp_file)
250 self._training_schema = Schema(
251 data=X,
252 sampling_strategy=self.sampling_strategy,
253 random_state=self.random_state,
254 )
--> 256 output = self._fit(X, *args, **kwargs)
257 self.fitted = True
259 return output

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/survival_analysis/plugin_survival_gan.py:203, in SurvivalGANPlugin._fit(self, X, cond, *args, **kwargs)
188 train_conditional = BinEncoder().fit_transform(precond)
190 self.model = SurvivalPipeline(
191 "adsgan",
192 strategy=self.tte_strategy,
(...)
201 **self.kwargs,
202 )
--> 203 self.model.fit(X, cond=train_conditional, *args, **kwargs)
205 return self

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/plugin.py:256, in Plugin.fit(self, X, *args, **kwargs)
248 X, self.compress_context = load_from_file(bkp_file)
250 self._training_schema = Schema(
251 data=X,
252 sampling_strategy=self.sampling_strategy,
253 random_state=self.random_state,
254 )
--> 256 output = self._fit(X, *args, **kwargs)
257 self.fitted = True
259 return output

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/survival_analysis/_survival_pipeline.py:133, in SurvivalPipeline._fit(self, X, cond, *args, **kwargs)
131 if self.uncensoring_model is not None:
132 log.info("Train the uncensoring model")
--> 133 self.uncensoring_model.fit(Xcov, T, E)
135 log.info("Train the synthetic generator")
136 if self.strategy == "uncensoring":

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/pydantic/decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/models/time_to_event/tte_date.py:424, in DATETimeToEvent.fit(self, X, T, Y)
421 self.scaler_T = MinMaxScaler()
422 enc_T = self.scaler_T.fit_transform(T.values.reshape(-1, 1)).squeeze()
--> 424 self.model.fit(enc_X, enc_T, Y)
426 return self

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/models/time_to_event/tte_date.py:143, in TimeEventGAN.fit(self, X, T, E)
140 Tt = self._check_tensor(T)
141 Et = self._check_tensor(E)
--> 143 self._train(
144 Xt,
145 Tt,
146 Et,
147 )
149 return self

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/models/time_to_event/tte_date.py:361, in TimeEventGAN._train(self, X, T, E)
359 # Train loop
360 for i in range(self.generator_n_iter):
--> 361 g_loss, d_loss = self._train_epoch(loader)
362 # Check how the generator is doing by saving G's output on fixed_noise
363 if (i + 1) % self.n_iter_print == 0:

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/models/time_to_event/tte_date.py:331, in TimeEventGAN._train_epoch(self, loader)
327 D_losses = []
329 for i, data in enumerate(loader):
330 G_losses.append(
--> 331 self._train_epoch_generator(
332 *data,
333 )
334 )
335 D_losses.append(
336 self._train_epoch_discriminator(
337 *data,
338 )
339 )
341 return np.mean(G_losses), np.mean(D_losses)

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/synthcity/plugins/core/models/time_to_event/tte_date.py:266, in TimeEventGAN._train_epoch_generator(self, X, T, E)
264 print("errG_cen",errG_cen)
265 print("errG_discr",errG_discr)
--> 266 raise RuntimeError("NaNs detected in the generator loss")
268 # Calculate gradients for G
269 errG.backward()

RuntimeError: NaNs detected in the generator loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant