Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing batch-execution compatibility in Dipole, Rotation matrix, etc. #154

Closed
cr-xu opened this issue May 14, 2024 · 4 comments · Fixed by #157
Closed

Missing batch-execution compatibility in Dipole, Rotation matrix, etc. #154

cr-xu opened this issue May 14, 2024 · 4 comments · Fixed by #157
Assignees
Labels
bug Something isn't working

Comments

@cr-xu
Copy link
Member

cr-xu commented May 14, 2024

Rotation matrix is currently not batch-compatible. This is causing problem when elements are rotated (tilt!=0)

import torch
from cheetah import ParticleBeam, Drift, Segment, Quadrupole

beam_in = ParticleBeam.from_parameters(num_particles=torch.tensor(1000), energy=torch.tensor([1e9,1e9]), mu_x=torch.tensor([1e-5, 2e-5]))
batch_shape = beam_in.particles.shape[:-2]
segment = Segment([Quadrupole(length=torch.tensor([0.5]), tilt=torch.tensor(torch.pi/4)), Drift(length=torch.tensor([0.5]))]).broadcast(batch_shape)
segment(beam_in)
@cr-xu cr-xu added the bug Something isn't working label May 14, 2024
@cr-xu cr-xu self-assigned this May 14, 2024
@jank324
Copy link
Member

jank324 commented May 14, 2024

Hmm ... you are only giving one tilt. Does it work if you pass a batch of tilts?

This raises the question what we want the "correct" interface to look like.

@cr-xu
Copy link
Member Author

cr-xu commented May 14, 2024

Hmm ... you are only giving one tilt. Does it work if you pass a batch of tilts?

This raises the question what we want the "correct" interface to look like.

No the problem is actually due to the rotation matrix not properly broadcasted. Putting the tilts in the correct dimension doesn't help.

import torch
from cheetah import ParticleBeam, Drift, Segment, Quadrupole

beam_in = ParticleBeam.from_parameters(num_particles=torch.tensor(1000), energy=torch.tensor([1e9,1e9]), mu_x=torch.tensor([1e-5, 2e-5]))
batch_shape = beam_in.particles.shape[:-2]
segment = Segment([Quadrupole(length=torch.tensor([0.5, 0.5]), tilt=torch.tensor([torch.pi/4, torch.pi/4])), Drift(length=torch.tensor([0.5, 0.5]))])
segment(beam_in)

@cr-xu
Copy link
Member Author

cr-xu commented May 14, 2024

Another thing is I found out quite a lot of extended indexing was like [:, i, j] instead of [..., i, j] which would probably breakdown for multiple batch-dimensions, right?

@jank324
Copy link
Member

jank324 commented May 14, 2024

Another thing is I found out quite a lot of extended indexing was like [:, i, j] instead of [..., i, j] which would probably breakdown for multiple batch-dimensions, right?

Yes, indeed. There might be situations wehere [:, i, j] makes sense, but most of the time it should be [..., i, j]. In the initial vectorised implementation, multi-dimensional batches were not intended. I only changed this afterwards. So it's possible I missed some places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants