-
Notifications
You must be signed in to change notification settings - Fork 838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for DDIMScheduler in Diffusion Policy #146
Support for DDIMScheduler in Diffusion Policy #146
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Thanks so much for this PR. I have one request. A couple of things to discuss:
- It's not passing style requirements. There's a section in CONTRIBUTING.md "6. Follow our style." that you can follow to set up
ruff
and a pre-commit hook. - This is great. There's just one slight hiccup that makes it less accessible. You can't use it with pretrained models without patching the code. That's because the
make_policy
callsPytorchModelHubMixin.from_pretrained
therefore ignoring any command line overrides you add. For example I didpython lerobot/scripts/eval.py -p lerobot/diffusion_pusht eval.n_episodes=10 eval.batch_size=10 policy.num_inference_steps=10 +policy.noise_scheduler_type=DDIM
and my policy arguments were ignored. We can handle this in another PR. (btw just left this withhuggingface_hub
as it should help Don't override 'config' in model_kwargs huggingface_hub#2274) (btw I have a temporary fix in the works! Override pretrained model config #147). - Just FYI, it's possible we will remove the dependency on diffusers in which case we'll port the noise scheduler code into
modeling_diffusion.py
. We'll keep the DDIM feature in place though :)
@@ -110,6 +110,7 @@ class DiffusionConfig: | |||
diffusion_step_embed_dim: int = 128 | |||
use_film_scale_modulation: bool = True | |||
# Noise scheduler. | |||
noise_scheduler_type: str = 'DDPM' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a dosctring argument for this (preserving ordering)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Thanks @alexander-soare!
|
@kashyapakshay There are still style issues somehow. Curious, did you follow the instructions in CONTRIBUTING? If so, we may need to fix the instructions! cc @aliberts I noticed that the style issue triggers on a yaml config file. Does the pre-commit handle that? Since we are doing another round I have a nitty nit for you please: can you please move the noise scheduler factory so that the main policy class stays as the first definition in the file? |
@alexander-soare I did follow the instructions and seems like ruff passed? Anyway, fixed that style issue and also moved the factory func (I like that placement too but wasnt sure if you had were trying to ordering methods before classes) |
Since pre-commits hooks only check/affect staged files, if commits were pushed before installing the pre-commits then they can stay there since subsequent changes may not stage the same files. To solve this, run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved!
What this does
This PR adds DDIMScheduler as a choice for Diffusion Policy. DDPM can be too slow for real-world rollouts (takes ~0.7s on a 3080 with 100 timesteps, n_obs=2 and 216x288 images). This aims to get closer to real-time control (as is also mentioned in the DP paper).
How it was tested
I ran the standard test suite + diffusion-pushT training for some steps (screenshot attached).
I have not done a full training run and do not have performance numbers for this.
How to checkout & try? (for the reviewer)
Add/update the
noise_scheduler_type
field in your diffusion config yaml.