@torch.no_grad() def sample(self, shape): self.eval() zt = torch.randn(shape).cuda() ## init the noise noise_schedule = NoiseScheduleFlow(schedule='discrete_flow') model_fn_continuous = model_wrapper( self.model, noise_schedule, model_type="flow", model_kwargs={}, guidance_type="uncond" , condition= None, classifier_kwargs={}, interval_guidance = [0, 1.0], ) dpm_solver = DPM_Solver( model_fn_continuous, noise_schedule, algorithm_type="dpmsolver++", ) zt = dpm_solver.sample( zt, steps=20, order=2, skip_type="time_uniform_flow", method="multistep", lower_order_final=False, denoise_to_zero=True, solver_type="dpmsolver", atol=0.0078, rtol=0.05, flow_shift=3.0 ) return zt