Skip to content

Commit 6faef7d

Browse files
committed
Test that breaks lengthless element fix
1 parent 2a27cd7 commit 6faef7d

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/test_vectorized.py

+18
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,21 @@ def test_cavity_with_zero_and_non_zero_voltage():
449449
).broadcast((3,))
450450

451451
_ = cavity.track(beam)
452+
453+
454+
def test_screen_length_shape():
455+
"""
456+
Test that the shape of a screen's length matches the shape of its misalignment.
457+
"""
458+
screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2], [0.3, 0.4]]))
459+
assert screen.length.shape == screen.misalignment.shape[:-1]
460+
461+
462+
def test_screen_length_broadcast_shape():
463+
"""
464+
Test that the shape of a screen's length matches the shape of its misalignment
465+
after broadcasting.
466+
"""
467+
screen = cheetah.Screen(misalignment=torch.tensor([[0.1, 0.2]]))
468+
broadcast_screen = screen.broadcast((3, 10))
469+
assert broadcast_screen.length.shape == broadcast_screen.misalignment.shape[:-1]

0 commit comments

Comments
 (0)