Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slowdown of Cheetah 0.7.1 compared to 0.6.3 #367

Merged
merged 21 commits into from
Mar 12, 2025
Merged

Conversation

jank324
Copy link
Member

@jank324 jank324 commented Mar 7, 2025

Description

Introduces a couple of changes that improve the speed of Cheetah by about 2x (measured on the ARES RL example):

  • Removes not-needed memory allocation in Screen.
  • More efficient transfer map reduction logic in Segment.track that avoids the use of Segment.is_skippable, which is more expensive than expected.
  • Makes more use of register_buffer and register_parameter in __init__s of beams and elements, because these are significantly more efficient that property assignments on nn.Modules, which are actually very slow.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Running the ARES RL code again with a prototype of Cheetah 0.7.1, I found that the samples per second were reduced by a factor of about 3.5x when compared to 0.6.1 (which this code ran with originally), but also 0.6.3 (the last non-vectorised version.

This matches observations from the vectorisation PR #116, where at one point I wrote:

Looking at the optimize_speed.ipynb example Notebook, it seems like this change currently slows down Cheetah between 2x and 4x for single samples. On the other hand, if you run with 1,000 samples, we are looking at a 40x to 55x speed improvement on an M1 Pro CPU. This would likely be more with more samples and/or on GPU.

The goal of Cheetah is also to be fast, especially for the purpose of RL, so we should check if this can be fixed.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@jank324 jank324 added bug Something isn't working enhancement New feature or request labels Mar 7, 2025
@jank324 jank324 self-assigned this Mar 7, 2025
@jank324
Copy link
Member Author

jank324 commented Mar 7, 2025

Some speed benchmarks:

Env benchmark

%%timeit

observation, info = env.reset()
done = False
while not done:
    observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
    done = terminated or truncated
  • 0.6.1: 65.4 ms ± 225 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
  • 0.7.1: 166 ms ± 4.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Raw Cheetah benchmark

segment.AREABSCR1.is_active = True
%%timeit
outgoing = segment.track(incoming)
  • 0.6.1: 841 μs ± 822 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
  • 0.7.1: 2.25 ms ± 21.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@jank324
Copy link
Member Author

jank324 commented Mar 7, 2025

It appears there are two main causes: Segment.track and Screen.track. In Screen.track, the slow part is a torch.full inside of Screen.set_read_beam, which creates a dummy screen image every time a new beam is read and the old screen image becomes invalid.

Screenshot 2025-03-07 at 09 47 04

I just ran some tests how this compares, for example, against setting Screen.cached_reading to None. Clearly this is the cause.

Screenshot 2025-03-07 at 09 49 05

The last commit therefore sets Screen.cached_reading to None. The resulting speeds are:

  • Env benchmark: 106 ms ± 18.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Raw Cheetah benchmark: 1.11 ms ± 7.32 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

@jank324
Copy link
Member Author

jank324 commented Mar 7, 2025

It seems that compute_relativistic_factors and base_rmatrix are now the slowest parts.

Screenshot 2025-03-07 at 10 09 23

@jank324 jank324 changed the title Fix slowdown of Cheetah 0.7.0 compared to 0.6.3 Fix slowdown of Cheetah 0.7.1 compared to 0.6.3 Mar 7, 2025
@jank324
Copy link
Member Author

jank324 commented Mar 7, 2025

I just made a modification to the __init__ of both beam classes that from the Before profiling results I thought might improve the speed. According the the timeit benchmarks it didn't make any difference at all, but the profiling would suggest that it made quite a major difference ... I'm confused 🤔

It would also have to be checked with @Hespe why the original order of buffer registration was introduced. Also, if we choose to keep the modification, it should be made in all Element subclasses as well.

  • Env benchmark: 96.9 ms ± 3.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Raw Cheetah benchmark: 1.12 ms ± 2.46 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Before
Screenshot 2025-03-07 at 19 23 26

After
Screenshot 2025-03-07 at 19 23 54

@jank324
Copy link
Member Author

jank324 commented Mar 7, 2025

Another realisation: The main expense in Segment.track seems to be caused by base_rmatrix, wherein the main expense is broadcasting. This then is just a price we have to pay for vectorising Cheetah I guess.

Screenshot 2025-03-07 at 19 35 17

For reference, here are a few experiments I did on the cost of running broadcasting in different ways by itself.

Screenshot 2025-03-07 at 19 36 34

@Hespe
Copy link
Member

Hespe commented Mar 10, 2025

Would be interesting to see how much we pay for broadcasting if all tensors are scalars.

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

I just ran the benchmarks again (with the parameter registration change):

  • Env benchmark: 91.8 ms ± 630 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Raw Cheetah benchmark: 950 μs ± 10.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Screenshot 2025-03-10 at 10 23 46 Screenshot 2025-03-10 at 10 24 02

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

And here another run with the parameter registration change undone:

  • Env benchmark: 109 ms ± 833 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Raw Cheetah benchmark: 1.28 ms ± 22 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Screenshot 2025-03-10 at 10 22 04 Screenshot 2025-03-10 at 10 22 24

@Hespe
Copy link
Member

Hespe commented Mar 10, 2025

Ok, so it's a roughly 20% increase that we have to pay in these two cases if we want to be able to assign nn.Parameters at initialization. Not really that great. I assume modying the beam to pass it along instead of allocating a new object is not an option?

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

Ok, so it's a roughly 20% increase that we have to pay in these two cases if we want to be able to assign nn.Parameters at initialization. Not really that great. I assume modying the beam to pass it along instead of allocating a new object is not an option?

Hmm ... so from a user experience point of view I never liked that idea. It invites programming errors if your incoming beam is different after tracking. On the other hand, reserving new memory does introduce a non-neglible overhead.

I'm thinking right now if introducing an inplace argument to track, which is False by default, but elements internally might set it to True and this way they reuse the beam object, but the user never sees this (unless the user wants to but then its their own fault).

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

I also just added the profiling outputs for both cases as well.

What still confuses me there is that the profiles outputs more than 2x with the separated buffer registration.

@Hespe
Copy link
Member

Hespe commented Mar 10, 2025

I'm thinking right now if introducing an inplace argument to track, which is False by default, but elements internally might set it to True and this way they reuse the beam object, but the user never sees this (unless the user wants to but then its their own fault).

With the idea being Segment does so, except for the very first element? Such that the user does not have the incoming beam modified, but internally we do not have to allocate new tensors for each element? I guess that would not even be noticable from the outside if all elements properly implement that behaviour.

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

One more note. In the profiler output it seems that the more recent ones are about 10x - 20x slower ... this is simply because I was using 10th of the sampling interval at one point.

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

I'm thinking right now if introducing an inplace argument to track, which is False by default, but elements internally might set it to True and this way they reuse the beam object, but the user never sees this (unless the user wants to but then its their own fault).

With the idea being Segment does so, except for the very first element? Such that the user does not have the incoming beam modified, but internally we do not have to allocate new tensors for each element? I guess that would not even be noticable from the outside if all elements properly implement that behaviour.

It could probably be even be implemented in the Element base class. I guess the general template would be:

def track(self, incoming: Beam, inplace: bool = True) -> Beam:
    outgoing = incoming if inplace else incoming.clone()
    
    outgoing.particles = # Do computations ...

On the other hand, ParameterBeam.__setattr__ appears to be the truly expensive part ... and that would remain.

@Hespe
Copy link
Member

Hespe commented Mar 10, 2025

It could probably be even be implemented in the Element base class. I guess the general template would be:

def track(self, incoming: Beam, inplace: bool = True) -> Beam:
    outgoing = incoming if inplace else incoming.clone()
    
    outgoing.particles = # Do computations ...

On the other hand, ParameterBeam.__setattr__ appears to be the truly expensive part ... and that would remain.

Oh yes, you are right. In place tensor operations on the other hand would be troublesome here, right? They would be incompatible with automatically broadcasting within the elements. And could lead to problems with differentiation?

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

I didn't necessarily mean inplace tensor operations actually. What I mean was more along these lines ... the first is what Element currently does, the second would be what it could do in an inlplace case.

Screenshot 2025-03-10 at 12 01 37

Btw ... I discovered a bug, where the species is not passed in Element.track.

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

According to the benchmark I just posted, the inplace operation would be faster. We should consider this for the future, but for the specific ARES RL example that I need to speed up right now, the code actually combines all transfer maps, and then does only one multiplication. Even in the suggested case there would have to be at least one new beam copy created, and here it only creates that one. So this optimisation would not give an advantage in this case. Similarly the copy for Screen.read_beam slows things down, but that one is unavoidable as well.

The only other source of slowdown as far as I can tell is the broadcasting in base_rmatrix ... we need that too. Regarding this, I'm not sure @Hespe what you meant with

Would be interesting to see how much we pay for broadcasting if all tensors are scalars.

The example I'm running uses a vectorised version of Cheetah (i.e. >=0.7.0), but non of the inputs have a vector dimension.

@Hespe
Copy link
Member

Hespe commented Mar 10, 2025

The only other source of slowdown as far as I can tell is the broadcasting in base_rmatrix ... we need that too. Regarding this, I'm not sure @Hespe what you meant with

Would be interesting to see how much we pay for broadcasting if all tensors are scalars.

The example I'm running uses a vectorised version of Cheetah (i.e. >=0.7.0), but non of the inputs have a vector dimension.

I was referring to your micro benchmark above, that just does some broadcasting of tensors a and b of different shapes. If they have different shapes, the cost of broadcasting is somehow unavoidable (not sure actually, operations that internally broadcast without explicit call to broadcast_tensors are likely very efficient). If broadcast_tensors still takes 2 us for scalar tensors, that could add quite a bit of time to some scalar simulations.

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

Here is the same benchmark for scalar tensors. Other broadcasting operations must do the same thing,

Screenshot 2025-03-10 at 13 07 34

For the fun of it, I just ran a broadcasting math operation. This is clearly faster. I'm not sure why, but my best guess is that the broadcasting here is done in the compiled code rather than the interpreted Python.

Screenshot 2025-03-10 at 13 10 46 Screenshot 2025-03-10 at 13 11 09

Maybe this means we should look through the uses of manual broadcasting operations and see if they can be replaced with other operations. My assumption is that they cannot. Specifically the one in base_rmatrix would have a significant impact on the speed of Cheetah.

@jank324
Copy link
Member Author

jank324 commented Mar 10, 2025

This micro test is probably also worth being aware of, especially because it's the exact opposite of what I expected.

Screenshot 2025-03-10 at 15 16 46

@jank324
Copy link
Member Author

jank324 commented Mar 11, 2025

@Hespe and I just went through everything and optimised every last bit on the ARES example that we think can be optimised.

The current numbers are:

  • Env benchmark: 95.8 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Raw Cheetah benchmark: 1.06 ms ± 11.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

This is around a 2x improvement over the times we had at the beginning of this PR. The rest of the slowdown is simply needed for vectorised Cheetah, and btw also payed for as soon as you run at least two samples.

@jank324
Copy link
Member Author

jank324 commented Mar 11, 2025

The question came up if is_skippable should become a function to denote that computing it is expensive ... a lot more expensive than we thought.

To be discussed with @Hespe.

@jank324
Copy link
Member Author

jank324 commented Mar 11, 2025

I just reran everything with the most recent commit, to make sure that the changes didn't negatively affect anything.

The current numbers are:

  • Env benchmark: 85.8 ms ± 1.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Raw Cheetah benchmark: 926 μs ± 13.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

These were all run on an Apple M1 Pro. We also ran a test earlier on a Dell Windows laptop, and saw that compute times were about 1.7x longer.

Funnily enough, the results of optimize_speed.ipynb appear to be about the same as before this PR. 🤔

@jank324 jank324 marked this pull request as ready for review March 11, 2025 17:53
@jank324 jank324 requested review from Hespe and cr-xu March 11, 2025 17:53
@jank324
Copy link
Member Author

jank324 commented Mar 11, 2025

For reference: The RL training went from 1,450 fps with v0.6.1 to 1,343 fps with v0.7.1 ... so less than 8% worse, which I'm pretty happy with. It was around 400 fps when this PR was opened.

@cr-xu
Copy link
Member

cr-xu commented Mar 12, 2025

@jank324 @Hespe On a side note, it would be helpful to gather the insights & experience here into a written document (e.g. to clone or not, when to broadcast, what operations can be substituted by more efficient ones...) so that in the future new contributors can also implement the elements in a consistent & fast-executing way.

@cr-xu cr-xu merged commit ffedcb1 into master Mar 12, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants