Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scheduler trait/enum #36

Open
rockerBOO opened this issue Jan 9, 2023 · 2 comments
Open

Add Scheduler trait/enum #36

rockerBOO opened this issue Jan 9, 2023 · 2 comments

Comments

@rockerBOO
Copy link

rockerBOO commented Jan 9, 2023

Right now we are adding the schedulers, but it is difficult to work with since swapping the scheduler doesn't work well. This also slows down testing and evaluation of the schedulers, as a separate script needs to be made each time to test the samplers. I also was implementing these into an application and swapping the schedulers wasn't working (due to different types at runtime).

I experimented some in adding a trait, so we can use impl Scheduler. Came up with the following, but causes some points of contention.

  • steps need a &mut self for some schedulers
  • steps needs a timestep of different number types (f64 and usize) currently
  • add_noise doesn't have a noise input on each one
  • timesteps returns a slice of usize or f64 (or maybe other ones?)
pub trait Scheduler {
    fn step<T: SomeTraitThatWouldTakef64AndUsize>(&mut self, model_output: &Tensor, timestep: T, sample: &Tensor) -> Tensor;
    fn timesteps(&self) -> &[usize];
    fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
    fn init_noise_sigma(&self) -> f64;
    fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;
}

And then I could do the following (I'm still learning traits):

sample<T: Scheduler>(
    ...,
    mut scheduler: T
)

And/or we could also do a Scheduler enum.

enum SamplerScheduler {
    Dpmpp2m(dpmsolver_multistep::DPMSolverMultistepScheduler),
    Dpmpp2s(dpmsolver_singlestep::DPMSolverSinglestepScheduler),
    Ddim(ddim::DDIMScheduler),
    Ddpm(ddpm::DDPMScheduler),
    EulerDiscrete(euler_discrete::EulerDiscreteScheduler),
}

I'm not 100% sure what's the best approach.

@mspronesti
Copy link
Contributor

mspronesti commented Jan 9, 2023

A possibility to solve the timestep type issue Is to have inside the trait a type, say timestep_t which requires to be cloned, copied and allowed to be converted to a primitive type

type timestep_t = Copy + Clone + toPrimitive;

and then use it in the trait as Self::timestep_t. Therefore, when implementing the trait for a particular scheduler one only needs to set it appropriately

impl Scheduler for MyScheduler {
    type timestep_t = usize;
}

I did all of this already, but I'm waiting to open a PR because I have 2 more pending and one more soon to be opened implementing the Heun Discrete scheduler 😅

@mspronesti
Copy link
Contributor

On second thought, I'm not so sure this is a good idea. Not if we want to port other diffusion models, at least. In fact, some of the missing schedulers have a different implementation logic: some don't implement add_noise, some others have two different kind of steps (the predictive step and the correct step, with different return types), e.g. the stochastic differential equation (SDE) scheduler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants