Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Use optimized kernels for MQA/GQA #1880

Open
WoosukKwon opened this issue Dec 1, 2023 · 33 comments
Open

[Performance] Use optimized kernels for MQA/GQA #1880

WoosukKwon opened this issue Dec 1, 2023 · 33 comments
Labels
help wanted Extra attention is needed performance Performance-related issues stale

Comments

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Dec 1, 2023

In theory, MQA/GQA can reduce memory bandwidth for reading KV cache and enable using TensorCore for the dot products in attention mechanism. However, this benefit can be only realized when using optimized kernels that vLLM does not have at the moment.

  1. For prefill, vLLM explicitly expands the incoming keys and values before running the attention op:
    key = key[:, :,
    None, :].expand(key.shape[0], self.num_kv_heads,
    self.num_queries_per_kv,
    key.shape[-1])
    value = value[:, :, None, :].expand(value.shape[0],
    self.num_kv_heads,
    self.num_queries_per_kv,
    value.shape[-1])
    because xformers (nor PyTorch SDPA) does not support MQA/GQA at the moment. This is bad for performance since 1) it causes extra overhead of expanding the tensor, and 2) the attention kernel cannot leverage the advantage described above. While FlashAttention efficiently supports MQA/GQA, we need to use it carefully since it does not cover all GPUs/data types/head sizes that xformers supports.
  2. For decode, vLLM's current paged attention kernel also does not leverage the benefits of MQA/GQA. To enjoy the benefit, we need to either significantly rewrite the paged attention kernel, or modify the FlashAttention kernel to support paged KV cache.
@WoosukKwon WoosukKwon added help wanted Extra attention is needed performance Performance-related issues labels Dec 1, 2023
@casper-hansen
Copy link
Contributor

casper-hansen commented Dec 2, 2023

This is a feature that is high on my list for performance reasons. I have scavenged other people's benchmarks and found an interesting one from TensorRT-LLM that also uses PagedAttention:

Llama 2 7B with 1x A100 80GB:

  • num requests: 128
  • sequence length: 1024
  • input length: 32
  • output length: 992

Latency benchmark:

  • MHA: 133.40s
  • GQA 44.74s
  • MQA 29.17s

Conclusion: Due to MHA taking up KV cache, latency increases as you must use more memory bandwidth. This is especially seen in scenarios where we aim for high throughput, i.e. when handling many requests at once. This means we can see a significant improvement in throughput if we spend less memory on the KV cache because of the grouping in GQA.

Reference: NVIDIA/TensorRT-LLM#404

@whitelok
Copy link

whitelok commented Dec 5, 2023

Seems it is prosible to migrate https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h to fullfill vLLM's requirement.

@WoosukKwon
Copy link
Collaborator Author

@whitelok Thanks for sharing. Unfortunately, it seems the core part of the kernel is provided as cubin files, which are not human-readable.

@whitelok
Copy link

whitelok commented Dec 5, 2023

@WoosukKwon
Copy link
Collaborator Author

@whitelok Oh, not really. While I'm not sure at the moment, it seems there is some code that we can leverage. I will look into it soon. Thanks again for sharing!

@whitelok
Copy link

whitelok commented Dec 5, 2023

seems entrance point is here: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h#L1015 all vllm need is just modify kvcache buffer.

@beginlner
Copy link
Contributor

FYI, I have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel.

@casper-hansen
Copy link
Contributor

@zhaoyang-star
Copy link
Contributor

have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel.

Great work! As it is very important for performance, could you plan to submit the feature? I am glad to test it.

@casper-hansen
Copy link
Contributor

@zhaoyang-star See the following:

Dao-AILab/flash-attention#678
#1679

@beginlner
Copy link
Contributor

have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel.

Great work! As it is very important for performance, could you plan to submit the feature? I am glad to test it.

I've tidied up the code a bit, you could test using the following two branches:
https://github.com/beginlner/flash-attention/tree/blocked_kvcache
and
https://github.com/beginlner/vllm/tree/blocked_flash_attn

@zhaoyang-star
Copy link
Contributor

have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel.

Great work! As it is very important for performance, could you plan to submit the feature? I am glad to test it.

I've tidied up the code a bit, you could test using the following two branches: https://github.com/beginlner/flash-attention/tree/blocked_kvcache and https://github.com/beginlner/vllm/tree/blocked_flash_attn

Thanks for the info. I will share the latency comparison after benchmark.

@zhaoyang-star
Copy link
Contributor

@beginlner I found you write a new function flash_attn_with_blocked_kvcache based on FlashAttention. I just want to know how much work will be done if we directly rewrite paged attention kernrel in vllm?

@beginlner
Copy link
Contributor

@beginlner I found you write a new function flash_attn_with_blocked_kvcache based on FlashAttention. I just want to know how much work will be done if we directly rewrite paged attention kernrel in vllm?

I think it will be a little complicated.

@casper-hansen
Copy link
Contributor

casper-hansen commented Jan 9, 2024

  1. For prefill, I replaced xformers with FlashAttention and found attention caculation (softmax(Q @ K^T * softmax_scale) @ V) latency is reduced 2+ times. I could submit a PR if necessary.
  • Using CodeLLaMA-34B config (num_query_heads=64, num_key_value_heads=8, head_size=128)
  • Tested on A100-40GB

Test id Batchsize Prompt length Original xformers (us) FA (us) Speedup (Original / FA)
1 1 1024 6.875 1.625 4.2
2 10 1024 36.965 11.518 3.2
3 100 1024 364.861 126.109 2.9
2. For decoding, Is there lots of codes changed if we support MQA/GQA in paged attention kernel?

These are great results. I hope that decoding can get a speed up as well as this is likely to yield a substantial improvement. @zhaoyang-star make sure to fall back to xformers attention since it supports older GPUs as well. Flash attention supports Ampere and newer.

@beginlner Do you know what it takes for Tri Dao to accept flash_attn_with_blocked_kvcache into flash attention?

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Jan 11, 2024

have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel.

Great work! As it is very important for performance, could you plan to submit the feature? I am glad to test it.

I've tidied up the code a bit, you could test using the following two branches: https://github.com/beginlner/flash-attention/tree/blocked_kvcache and https://github.com/beginlner/vllm/tree/blocked_flash_attn

@beginlner It seems the blocked flash attn unittest failed.
image

@beginlner
Copy link
Contributor

have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel.

Great work! As it is very important for performance, could you plan to submit the feature? I am glad to test it.

I've tidied up the code a bit, you could test using the following two branches: https://github.com/beginlner/flash-attention/tree/blocked_kvcache and https://github.com/beginlner/vllm/tree/blocked_flash_attn

@beginlner It seems the blocked flash attn unittest failed. image

@zhaoyang-star Thanks for the reminder, the error tolerance should be relaxed.

@zhaoyang-star
Copy link
Contributor

@zhaoyang-star Thanks for the reminder, the error tolerance should be relaxed.

@beginlner The greastest releative difference is 9.48. Is this a little larger than expected? FYI, the original test_flash_attn.py also failed. Please have a look.

@beginlner
Copy link
Contributor

@zhaoyang-star Thanks for the reminder, the error tolerance should be relaxed.

@beginlner The greastest releative difference is 9.48. Is this a little larger than expected? FYI, the original test_flash_attn.py also failed. Please have a look.

Base on my experience, the greatest absolute difference of 0.008 and the greatest relative difference of 9.48 are acceptable. I have updated a more reliable test to the branch.

Additionally, how the original test_flash_attn.py failed?

@zhaoyang-star
Copy link
Contributor

@beginlner Sorry I didnot store the log. I suggested you take a test.

@beginlner
Copy link
Contributor

beginlner commented Jan 12, 2024

@beginlner Sorry I didnot store the log. I suggested you take a test.

It failed on some tests only by OOM on A100 40G because someone else is also using the GPU.

@zhaoyang-star
Copy link
Contributor

@beginlner Sorry I didnot store the log. I suggested you take a test.

It failed on some tests only by OOM on A100 40G because someone else is also using the GPU.

Good news.
BTW, have you benchmark the latency or throughput using flash_attn_with_blocked_kvcache in vllm? I am very interested in the perf data.

@beginlner
Copy link
Contributor

@beginlner Sorry I didnot store the log. I suggested you take a test.

It failed on some tests only by OOM on A100 40G because someone else is also using the GPU.

Good news. BTW, have you benchmark the latency or throughput using flash_attn_with_blocked_kvcache in vllm? I am very interested in the perf data.

The flash_attn_with_blocked_kvcache kernel is expected to provide a speedup of about n_heads // n_heads_kv times compared to the current paged attention kernel base on my tests. Would you like to conducting a comprehensive end-to-end latency benchmark?

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Jan 22, 2024

@beginlner The unittest has passed. From the kernel benchmark and e2e benchmark we can see there is no speedup compared with paged attention v2 version. Did you have similar results?
I used starcoder which is a MQA model.

[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python benchmarks/kernels/benchmark_paged_attention.py --version v1 --batch-size 1 --num-kv-heads 1
Namespace(batch_size=1, block_size=16, context_len=4096, dtype='half', head_size=128, num_kv_heads=1, num_query_heads=64, profile=False, seed=0, use_alibi=False, use_fp8_kv_cache=False, version='v1')
Warming up...
Kernel running time: 100.380 us
[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python benchmarks/kernels/benchmark_paged_attention.py --version v2 --batch-size 1 --num-kv-heads 1
Namespace(batch_size=1, block_size=16, context_len=4096, dtype='half', head_size=128, num_kv_heads=1, num_query_heads=64, profile=False, seed=0, use_alibi=False, use_fp8_kv_cache=False, version='v2')
Warming up...
Kernel running time: 52.862 us
[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python benchmarks/kernels/benchmark_paged_attention.py --version flash-attn --batch-size 1 --num-kv-heads 1
Namespace(batch_size=1, block_size=16, context_len=4096, dtype='half', head_size=128, num_kv_heads=1, num_query_heads=64, profile=False, seed=0, use_alibi=False, use_fp8_kv_cache=False, version='flash-attn')
Warming up...
Kernel running time: 70.118 us

[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python3 benchmarks/benchmark_latency.py --model /bigdata/shared/models/huggingface/starcoder/ --input-len 512 --output-len 512 --num-iters 10 --batch-size 1 --enforce-eager --use-flash-attn
Namespace(batch_size=1, dtype='auto', enforce_eager=True, input_len=512, kv_cache_dtype=None, model='/bigdata/shared/models/huggingface/starcoder/', n=1, num_iters=10, output_len=512, profile=False, profile_result_dir=None, quantization=None, tensor_parallel_size=1, tokenizer=None, trust_remote_code=False, use_beam_search=False, use_flash_attn=True)
INFO 01-22 22:20:30 llm_engine.py:71] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/starcoder/', tokenizer='/bigdata/shared/models/huggingface/starcoder/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=True, use_flash_attn=True, kv_cache_dtype=None, seed=0)
INFO 01-22 22:21:30 llm_engine.py:283] # GPU blocks: 1775, # CPU blocks: 1638
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=512, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True)
Warming up...
Profiling iterations: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:18<00:00, 13.90s/it]
Avg latency: 13.8989108757989 seconds
[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python3 benchmarks/benchmark_latency.py --model /bigdata/shared/models/huggingface/starcoder/ --input-len 512 --output-len 512 --num-iters 10 --batch-size 1 --enforce-eager 
Namespace(batch_size=1, dtype='auto', enforce_eager=True, input_len=512, kv_cache_dtype=None, model='/bigdata/shared/models/huggingface/starcoder/', n=1, num_iters=10, output_len=512, profile=False, profile_result_dir=None, quantization=None, tensor_parallel_size=1, tokenizer=None, trust_remote_code=False, use_beam_search=False, use_flash_attn=False)
INFO 01-22 22:24:44 llm_engine.py:71] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/starcoder/', tokenizer='/bigdata/shared/models/huggingface/starcoder/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=True, use_flash_attn_zte=False, kv_cache_dtype=None, seed=0)
INFO 01-22 22:25:46 llm_engine.py:283] # GPU blocks: 14201, # CPU blocks: 13107
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=512, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True)
Warming up...
Profiling iterations: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:21<00:00, 14.17s/it]
Avg latency: 14.166307100310224 seconds

@casper-hansen
Copy link
Contributor

I would not expect a speedup for implementing GQA on prefilling, only on decoding (which is much harder because of PagedAttention).

@beginlner
Copy link
Contributor

beginlner commented Jan 22, 2024

@beginlner The unittest has passed. From the kernel benchmark and e2e benchmark we can see there is no speedup compared with paged attention v2 version. Did you have similar results? I used starcoder which is a MQA model.

[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python benchmarks/kernels/benchmark_paged_attention.py --version v1 --batch-size 1 --num-kv-heads 1
Namespace(batch_size=1, block_size=16, context_len=4096, dtype='half', head_size=128, num_kv_heads=1, num_query_heads=64, profile=False, seed=0, use_alibi=False, use_fp8_kv_cache=False, version='v1')
Warming up...
Kernel running time: 100.380 us
[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python benchmarks/kernels/benchmark_paged_attention.py --version v2 --batch-size 1 --num-kv-heads 1
Namespace(batch_size=1, block_size=16, context_len=4096, dtype='half', head_size=128, num_kv_heads=1, num_query_heads=64, profile=False, seed=0, use_alibi=False, use_fp8_kv_cache=False, version='v2')
Warming up...
Kernel running time: 52.862 us
[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python benchmarks/kernels/benchmark_paged_attention.py --version flash-attn --batch-size 1 --num-kv-heads 1
Namespace(batch_size=1, block_size=16, context_len=4096, dtype='half', head_size=128, num_kv_heads=1, num_query_heads=64, profile=False, seed=0, use_alibi=False, use_fp8_kv_cache=False, version='flash-attn')
Warming up...
Kernel running time: 70.118 us

[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python3 benchmarks/benchmark_latency.py --model /bigdata/shared/models/huggingface/starcoder/ --input-len 512 --output-len 512 --num-iters 10 --batch-size 1 --enforce-eager --use-flash-attn
Namespace(batch_size=1, dtype='auto', enforce_eager=True, input_len=512, kv_cache_dtype=None, model='/bigdata/shared/models/huggingface/starcoder/', n=1, num_iters=10, output_len=512, profile=False, profile_result_dir=None, quantization=None, tensor_parallel_size=1, tokenizer=None, trust_remote_code=False, use_beam_search=False, use_flash_attn=True)
INFO 01-22 22:20:30 llm_engine.py:71] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/starcoder/', tokenizer='/bigdata/shared/models/huggingface/starcoder/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=True, use_flash_attn=True, kv_cache_dtype=None, seed=0)
INFO 01-22 22:21:30 llm_engine.py:283] # GPU blocks: 1775, # CPU blocks: 1638
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=512, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True)
Warming up...
Profiling iterations: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:18<00:00, 13.90s/it]
Avg latency: 13.8989108757989 seconds
[use_fa]root@50c663527862:/bigdata/zhaoyang/gerrit/vllm# python3 benchmarks/benchmark_latency.py --model /bigdata/shared/models/huggingface/starcoder/ --input-len 512 --output-len 512 --num-iters 10 --batch-size 1 --enforce-eager 
Namespace(batch_size=1, dtype='auto', enforce_eager=True, input_len=512, kv_cache_dtype=None, model='/bigdata/shared/models/huggingface/starcoder/', n=1, num_iters=10, output_len=512, profile=False, profile_result_dir=None, quantization=None, tensor_parallel_size=1, tokenizer=None, trust_remote_code=False, use_beam_search=False, use_flash_attn=False)
INFO 01-22 22:24:44 llm_engine.py:71] Initializing an LLM engine with config: model='/bigdata/shared/models/huggingface/starcoder/', tokenizer='/bigdata/shared/models/huggingface/starcoder/', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=True, use_flash_attn_zte=False, kv_cache_dtype=None, seed=0)
INFO 01-22 22:25:46 llm_engine.py:283] # GPU blocks: 14201, # CPU blocks: 13107
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=512, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True)
Warming up...
Profiling iterations: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:21<00:00, 14.17s/it]
Avg latency: 14.166307100310224 seconds

Hi, here is my kernel benchmark result on a SXM A100 40GB. I have updated the code on https://github.com/beginlner/vllm/tree/blocked_flash_attn. Note that the shape and the block size of the KV cache are different from vllm's paged attention.
image

@beginlner
Copy link
Contributor

I would not expect a speedup for implementing GQA on prefilling, only on decoding (which is much harder because of PagedAttention).

Yes, GQA is computation-bounded on prefilling, and is memory-bounded on decoding. So there is a speedup only on decoding.

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Jan 23, 2024

@beginlner Yes. The shape and block_size are different from paged attention. I used https://github.com/zhaoyang-star/vllm/tree/blocked_flash_attn_based_on_beginIner, which add some options based on your branch blocked_flash_attn.

Results using run_benchmark.sh in branch blocked_flash_attn_based_on_beginIner are as following. Blocked FA could both boost kernel perf and e2e when batch size is larger. Note that the speedup can hardly seen in e2e latency benchmark when batch size is small.

Env:

  • vllm main branch latest commit id 8cd5a99
  • A100-40GB-PCIE
  • context_len=4096
  • num query heads=56, num kv heads=8
  • kernel data type=bfloat16
Batchsize Paged attention V1 (us) Paged attention V2 (us) FA with blocked kv cache (us)
1 128.882 61.142 67.126
4 355.433 256.677 75.528
16 552.368 887.509 264.750
64 2504.497 3050.440 818.838
256 7878.213 8814.990 2550.177

Starcoder (MQA) e2e latency with In/Out length=512:

Batchsize Baseline (sec) FA with blocked kv cache (sec)
1 13.76 13.79
4 14.41 14.33
16 17.45 16.37
64 29.79 24.37
256 94.41 73.03

@beginlner
Copy link
Contributor

beginlner commented Jan 23, 2024

@zhaoyang-star It's as expected that the speedup can hardly seen in e2e latency benchmark when batch size is small. Because when the batch size is small, loading parameters is the bottleneck; when the batch size is large, loading the KV cache is the bottleneck to performance.

@casper-hansen
Copy link
Contributor

@zhaoyang-star Blocked KV cache was added to flash attention in 2.5.0. I wonder if the newer implementation gives any performance boost? Either way, it’s now in flash attention which makes it easy to use in vLLM

@zhaoyang-star
Copy link
Contributor

@zhaoyang-star Blocked KV cache was added to flash attention in 2.5.0. I wonder if the newer implementation gives any performance boost? Either way, it’s now in flash attention which makes it easy to use in vLLM

Great! @beginlner is the core contributor of this feature.

I have not benchmark it under FA 2.5.0. Bute I think the results will be close to data we had before.

The main question is FA only supports half and bfloat16 kv cache data type. So supporting fp8-e5m2 cache data type is needed. I am not familar with FA's code. But I think there is only limited lines should be changed to support fp8 cache without scale factor, because there is no additional quantization params.

We are looking for any contributions to deliver this feature :)

@sh1ng
Copy link
Contributor

sh1ng commented Mar 1, 2024

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed performance Performance-related issues stale
Projects
None yet
Development

No branches or pull requests

6 participants