-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ Docs ] Conceptual Guides #18
Open
robertgshaw2-neuralmagic
wants to merge
22
commits into
main
Choose a base branch
from
rs/concepts
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
53750c2
added examples
robertgshaw2-neuralmagic 1b90311
added guides
robertgshaw2-neuralmagic 6e9f5bf
Update benchmark_offline.py
robertgshaw2-neuralmagic 2d8ce6f
added quantization schemes
robertgshaw2-neuralmagic 97d7ae1
Merge branch 'rs/concepts' of https://github.com/vllm-project/llm-com…
robertgshaw2-neuralmagic 6a551b2
finished online serving benchmark on A10
robertgshaw2-neuralmagic bf17a75
cleanup
robertgshaw2-neuralmagic eafb1ed
added offline batch
robertgshaw2-neuralmagic 44794bb
save
robertgshaw2-neuralmagic 3868193
nit
robertgshaw2-neuralmagic ccbe6db
nit
robertgshaw2-neuralmagic 5789d9e
nit
robertgshaw2-neuralmagic 386c455
nits
robertgshaw2-neuralmagic 3478161
format
robertgshaw2-neuralmagic 2d36def
format
robertgshaw2-neuralmagic bfa605c
cleanup nits
robertgshaw2-neuralmagic ad4905e
more cleanup
robertgshaw2-neuralmagic 60708a8
Update docs/conceptual_guides/quantization_schemes.md
robertgshaw2-neuralmagic 68becd8
Update docs/conceptual_guides/quantization_schemes.md
robertgshaw2-neuralmagic 155a37b
Update docs/conceptual_guides/quantization_schemes.md
robertgshaw2-neuralmagic 832be25
Update docs/conceptual_guides/quantization_schemes.md
robertgshaw2-neuralmagic 70add53
Update docs/conceptual_guides/quantization_schemes.md
robertgshaw2-neuralmagic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Inference Acceleration with Quantization | ||
|
||
There are two "types" of quantization, each of which can accelerate inference: | ||
* Weight-Only Quantization | ||
* Weight and Activation Quantization | ||
|
||
## Weight-Only Quantization | ||
|
||
With weight-only quantization, weights are quantized to low precision (typically `int8`, `fp8`, or `int4`) while activations remain at higher precision `fp16` or `bf16`. To perform matrix multiplication, we upconvert each weight to`fp16` before computing `A*B`. | ||
|
||
### How Can We Speed Up Weight-Only Quantization? | ||
|
||
Roughly speaking, the time required to execute a matrix multiplication on a GPU equals the sum of: | ||
* Latency of moving the weights from main memory (DRAM) to the compute (SRAM) | ||
* Latency of the tensor-core compute operations | ||
|
||
Weight-only quanitzation does nnot impact latency of the tensor-core operations, but it can reduce the amount data moving from DRAM to SRAM with "fused" inference kernels that upconvert the weights to `fp16` after moving them into SRAM. | ||
|
||
### Accelerating Inference Serving | ||
|
||
Since LLM serving is usually dominated by "decode" operations, which are "memory bandwidth bound", weight-only quantization is quite useful for accelerating online-servig. | ||
|
||
[`Marlin`](https://neuralmagic.com/blog/pushing-the-boundaries-of-mixed-precision-llm-inference-with-marlin/) is an optimized fused inference kernel for weight-only quantization, supporting `int4`, `int8`, and `fp8` weights with `fp16` and `bf16` activations. vLLM uses `Marlin` when executing inference for weight-only quantized models created via `llm-compressor`. | ||
|
||
End-to-end speedups on for `Meta-Llama-3-8B-Instruct` on A10G with 1 QPS: | ||
| Weight Precision | Activation Precision | Time Per Output Token (ms) | Speedup vs `fp16` | | ||
|- |- |- | - | | ||
|`fp16` | `fp16` | 42.52 | 1.0x | | ||
|`fp8` | `fp16` | 22.95 | 1.9x | | ||
|`int8` | `fp16` | 26.34 | 1.6x | | ||
|`int4-g128` | `fp16` | 15.46 | 2.8x | | ||
|
||
> Performance results computed as of `vllm==v0.5.1` via [online serving performance benchmark](../../examples/benchmarking/online_serving) | ||
|
||
### Examples | ||
- [`int4` weight-only quantization with `Meta-Llama-3-8B-Instruct`](../../examples/quantization_w4a16) | ||
|
||
## Weight and Activation Quantization | ||
|
||
With weight and activation quantization, both the weights and activations are converted to to `int8` or `fp8`. At inference time, we can use low precision tensor cores, which have more FLOPS available: | ||
|
||
| GPU | `fp16` | `int8` | `fp8` | | ||
| - | - | - | - | | ||
| `A10G` | 125 TFLOPS | 250 TOPS | Not supported | | ||
| `A100` | 312 TFLOPS | 624 TOPS | Not supported | | ||
| `H100` | 990 TFLOPS | 1979 TOPS | 1979 TFLOPS | | ||
|
||
> [`A10G` datasheet](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) // [`A100` datasheet](https://www.nvidia.com/en-us/data-center/a100/) // [`H100` datasheet](https://www.nvidia.com/en-us/data-center/h100/) | ||
|
||
As a result, activation quantization is able to accelerate both "memory bandwidth bound" and "compute bound" operations. | ||
|
||
### Accelerating Offline Batch Processing | ||
|
||
With offline batch processing, we can crank-up the batch size as high as possible to maximize throughput, making offline batch processing "compute-bound". This means that activation quantization is very useful for accelerating performance. | ||
|
||
vLLM supports activation quantization acceleration using custom Cutlass-based inference kernels for models created via `llm-compressor`. | ||
|
||
End-to-end speedups on for `Meta-Llama-3-8B-Instruct` on A10G for offline batch processing: | ||
| Weight Precision | Activation Precision | Generation Throughtput | Speedup vs `fp16` | | ||
|- |- |- | - | | ||
|`fp16` | `fp16` | 488 tok/sec | 1.0x | | ||
|`int8` | `int8` | 977 tok/sec | 2.2x | | ||
|
||
> Performance results computed as of `vllm==v0.5.1` via [offline performance benchmark](../../examples/benchmarking/offline_batch/) | ||
|
||
### Examples | ||
- [`w8a8 int8` quantization with `Meta-Llama-3-8B-Instruct`](../../examples/quantization_w8a8_int8) | ||
- [`w8a8 fp8` quantization with `Meta-Llama-3-8B-Instruct`](../../examples/quantization_w8a8_fp8) | ||
|
||
|
||
## Other Resources | ||
|
||
- Horace He's blog [Making Deep Learning Go Brrrr From First Principles](https://horace.io/brrr_intro.html) for more conceptual background on compute vs bandwidth-bound operations | ||
- Neural Magic's blog [Pushing the Boundaries of Mixed-Precision LLM Inference With Marlin](https://neuralmagic.com/blog/pushing-the-boundaries-of-mixed-precision-llm-inference-with-marlin/) for more details on how Marlin works |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Quantization Schemes | ||
|
||
Quantization is a technique to reduce the computational and memory costs of running inference by representing the weights and activations with low-precision data types like 8-bit integer (`int8`) instead of the usual 16-bit floating point (`float16`). | ||
|
||
## Theory | ||
|
||
Performing quantization to go from `float16` to `int8` (or lower) is tricky. Only 256 values can be represented in `int8`, while `float16` can represent a very wide range of values. The idea is to find the best way to project our range [a, b] of `float32` values to the `int8` space. | ||
|
||
Let’s consider a float x in [a, b], then we can write the following quantization scheme: | ||
|
||
```bash | ||
x = S * (x_q - Z) | ||
``` | ||
|
||
where: | ||
|
||
- `x_q` is the quantized `int8` value associated to `x` | ||
- `S` is the scale, and is a positive `float16`. It is used to "rescale" a distribution from the base range in `float16` to the desired width (ie 256 for `int8`). | ||
- `Z` is called the zero-point, it is the `int8` value corresponding to the value 0 in the `float16` realm. If zero-point is ommited, we call this "symmetric" quantization because the default zero point of 0 is in the true middle of the distribution. | ||
|
||
|
||
The quantized value x_q of x in [a, b] can be computed as follows: | ||
|
||
```bash | ||
x_q = round(x/S + Z) | ||
``` | ||
|
||
And `float16` values outside of the [a, b] range are clipped to the closest representable value, so for any floating-point number x: | ||
|
||
```bash | ||
x_q = clip(round(x/S + Z), round(a/S + Z), round(b/S + Z)) | ||
``` | ||
|
||
## Quantization Flavors | ||
|
||
There are several flavors of quantization. | ||
|
||
### Static vs Dynamic | ||
|
||
The section above described how quantization from `float16` to `int8` works, but did not explain how to compute the scales and zero points. | ||
|
||
With weights, since the full range is known ahead of time, we can just compute the scales and zero points statically (sometimes using a more sophisticated algorithm like `GPTQ`). | ||
|
||
With activations, however, there are two approaches: | ||
* **Dynamic quantization**: the range for each activation is computed on the fly at runtime so that the quantization range matches exactly the current runtime range. This gives us the best possible values, but it can be a bit slower than static quantization because of the overhead introduced by computing the range each time. It is also not an option on certain hardware. | ||
|
||
* **Static quantization**: the range for each activation is computed in advance at quantization-time. This is typically done by passing representative "calibration" data through the model and recording the range of activation values. In practice, we run a number of forward passes on a calibration dataset is done and compute the ranges according to the observed calibration data. | ||
|
||
In general, it is best practice to start your experiments with: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it best practice? |
||
- For `fp8`, use static activation quantization | ||
- For `int8`, use dynamic activation quantization | ||
|
||
### Granularity | ||
|
||
Weight and activation quantization can be performed at different levels of granularity depending on accuracy / latency tradeoffs being targeted. | ||
|
||
#### Weight Quantization Granularity | ||
|
||
For weight quantization, there are three "levels" (in order of increasing of granularity): | ||
* **Per-Tensor**: one pair of quantization parameters (`S, Z`) is used per tensor | ||
* **Per-Channel**: one pair of quantization parameters (`S, Z`) is used per element of one of the dimensions of the tensor. For instance, with a weight matrix of shape `[N,M]`, the scales are a vector of shape [`M`] scales. | ||
* **Per-Group**: one pait of quantization parameters is (`S, Z`) is used per group of items in a tensor. For instance, with a weight matrix of shape `[N,M]` with `M=4096`, the scales are a matrix of shape `[32, M]` (note: `4096 / 128 = 32`). | ||
|
||
Incresing quantization granularity typically helps with accuracy at the expense of less memory reduction and slower inference performance. This is because we compute quantization ranges over smaller distributions with the trade off of needing more memory to represent them. In general, it is best practice to start your experiments with: | ||
- For `int4` weights, use `per-group (size=128)` | ||
- For `int8` weights, use `per-channel` | ||
- For `fp8` weights, use `per-tensor` | ||
|
||
#### Activation Quantization Granularity | ||
|
||
For activation quantization, there are two "levels" (in order of increasing granularity): | ||
* **Per-Tensor**: one pair of quantization parameters (`S, Z`) is used per activation tensor | ||
* **Per-Token**: one pair of quantization parameters (`S, Z`) is used per token of the activation tensor. For LLMs, the activation tensor is of shape `[batch_size, seq_len, hidden_dim]`, so the scales will be a matrix of shape `[batch_size, seq_len]`. | ||
|
||
Incresing quantization granularity typically helps with accuracy at the expense of less memory reduction and slower inference performance. | ||
|
||
In general, it is best practice to start your experiments with: | ||
- For static activation quantization, always use `per-tensor` | ||
- For `fp8` dynamic quantization, use `per-tensor` | ||
- For `int8` dynamic quantization, use `per-token` | ||
|
||
### Activation Reodering | ||
|
||
Activations of LLMs are known to be problematic to work with because for some inputs they exhibit very large activation values in a few channels relative to all other channels. Those very large activations are called "outliers", and preserving their propagation through the model is of high importance for good accuracy. Activation reordering triggers quantizing weight columns in order of decreasing activation size, meaning that we first focus on quantizing those that correspond to outliers (to preserve them as good as possible), and then we move on to the others (which correspond to smaller activations by magnitude). This can help preserve accuracy at the expense of some inference speed. | ||
|
||
In general it is best practice to start your experients with: | ||
- For `int4` weights, use activation reordering with GPTQ | ||
- For anything else, do not sue activation reordering |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Offline Batch Benchmarking | ||
|
||
When evaluating LLM performance for online serving, we should focus on throughput metrics such as tokens/second or requests/second. | ||
|
||
## Speedups from Activation Quantization | ||
|
||
On an Nvidia A10G GPU, we measure the following for offline batch processing: | ||
|
||
| Model Stub | Precision | Generation Throughput | Speedup vs Fp16 | | ||
|- |- |- |- | | ||
|`meta-llama/Meta-Llama-3-8B-Instruct` |`fp16` | 488 tok/sec |1.0x | | ||
|`nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test` |`int8` | 977 tok/sec |2.2x | | ||
|
||
## Generate Raw Benchmark Data | ||
|
||
We can measure online serving latency by running vLLM on a sample dataset. We will use `meta-llama/Meta-Llama-3-8B-Instruct` as the sample model: | ||
|
||
```bash | ||
export MODEL=meta-llama/Meta-Llama-3-8B-Instruct | ||
``` | ||
|
||
Install: | ||
|
||
```bash | ||
python -m venv vllm-venv | ||
source vllm-venv/bin/activate | ||
pip install vllm | ||
``` | ||
|
||
Run sample workload: | ||
|
||
```bash | ||
python benchmark_offline.py --help | ||
python benchmark_offline.py --model $MODEL | ||
``` | ||
|
||
Results on A10G: | ||
|
||
```bash | ||
* ========================================================== | ||
* Total Time: 461.90 | ||
* Total Generations: 1000 | ||
|
||
|
||
* Generations / Sec: 2.16 | ||
* Generation Tok / Sec: 488.13 | ||
* Prompt Tok / Sec: 1180.01 | ||
|
||
|
||
* Avg Generation Tokens: 225.47 | ||
* Avg Prompt Tokens: 545.05 | ||
* ========================================================== | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import argparse | ||
import time | ||
from datasets import load_dataset | ||
from transformers import AutoTokenizer | ||
from vllm import LLM, SamplingParams | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") | ||
parser.add_argument("--max-generated-tokens", type=int, default=250) | ||
parser.add_argument("--num-samples", type=int, default=1000) | ||
parser.add_argument("--max-num-seqs", type=int, default=256) | ||
parser.add_argument("--kv-cache-dtype", type=str, default="auto") | ||
parser.add_argument("--gpu-memory-utilization", type=float, default=.9) | ||
|
||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
NUM_TURNS_PROMPT = 3 | ||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
MODEL_ID = args.model | ||
MAX_GENERATED_TOKENS = args.max_generated_tokens | ||
MAX_NUM_SEQS = args.max_num_seqs | ||
NUM_SAMPLES = args.num_samples | ||
KV_CACHE_DTYPE = args.kv_cache_dtype | ||
GPU_MEMORY_UTILIZATION = args.gpu_memory_utilization | ||
|
||
# Pre-process your dataset. | ||
# Its a good idea to use the chat template. | ||
def preprocess(example): | ||
return {"text": tokenizer.apply_chat_template( | ||
example["messages"][:NUM_TURNS_PROMPT], tokenize=False, add_generation_prompt=True | ||
)} | ||
|
||
dataset = load_dataset(DATASET_ID, split="train_sft") | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
ds = dataset.shuffle().select(range(NUM_SAMPLES)) | ||
ds = ds.map(preprocess) | ||
|
||
# BE CAREFUL WITH THE TOKENIZER | ||
# apply_chat_template already adds the bos_token | ||
# so we set add_special_token to false | ||
examples = [ | ||
tokenizer(example["text"], add_special_tokens=False).input_ids | ||
for example in ds | ||
] | ||
|
||
# Initialize vLLM | ||
model = LLM( | ||
MODEL_ID, | ||
max_num_seqs=MAX_NUM_SEQS, | ||
kv_cache_dtype=KV_CACHE_DTYPE, | ||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION, | ||
) | ||
|
||
# Generate. | ||
start = time.perf_counter() | ||
generations = model.generate( | ||
prompt_token_ids=examples, | ||
use_tqdm=True, | ||
sampling_params=SamplingParams( | ||
max_tokens=MAX_GENERATED_TOKENS), | ||
) | ||
end = time.perf_counter() | ||
|
||
total_generations = len(generations) | ||
total_prompt_tokens = 0 | ||
total_generation_tokens = 0 | ||
total_time = end - start | ||
|
||
for generation in generations: | ||
total_prompt_tokens += len(generation.prompt_token_ids) | ||
total_generation_tokens += len(generation.outputs[0].token_ids) | ||
|
||
print("* ==========================================================") | ||
print(f"* Total Time: \t\t\t{total_time: 0.2f}") | ||
print(f"* Total Generations: \t\t{total_generations}") | ||
print("\n") | ||
print(f"* Generations / Sec: \t\t{total_generations / total_time :0.2f}") | ||
print(f"* Generation Tok / Sec: \t{total_generation_tokens / total_time :0.2f}") | ||
print(f"* Prompt Tok / Sec: \t\t{total_prompt_tokens / total_time :0.2f}") | ||
print("\n") | ||
print(f"* Avg Generation Tokens: \t{total_generation_tokens / total_generations :0.2f}") | ||
print(f"* Avg Prompt Tokens: \t\t{total_prompt_tokens / total_generations :0.2f}") | ||
print("* ==========================================================") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Online Serving Benchmarking | ||
|
||
When evaluating LLM performance for online serving, there are two latency metrics to consider: | ||
- `TTFT` (Time to first token) measures how long it takes to generate the first token. | ||
- `TPOT` (Time per output token) measures how long it takes to generate each incremental token. | ||
|
||
## Speedups from Weight-Only Quantization | ||
|
||
On an Nvidia A10G GPU, we measure the following for online serving with the `sharegpt` dataset at 1 QPS: | ||
|
||
| Model Stub | Precision | TTFT (ms) | TPOT (ms) | Speedup vs Fp16 | | ||
|- |- |- |- |- | | ||
|`meta-llama/Meta-Llama-3-8B-Instruct` |`fp16` |106 | 43 | 1.0x | | ||
|`neuralmagic/Meta-Llama-3-8B-Instruct-FP8` |`fp8` |98 | 23 | 1.9x | | ||
|`astronomer/Llama-3-8B-Instruct-GPTQ-8-Bit`|`int8` |106 | 25 | 1.6x | | ||
|`nm-testing/Meta-Llama-3-8B-Instruct-GPTQ` |`int4` |73 | 15 | 2.8x | | ||
|
||
|
||
## Generate Raw Benchmark Data | ||
|
||
We can measure online serving latency by spinning up a vLLM server and creating sample clients. We will use `meta-llama/Meta-Llama-3-8B-Instruct` as the sample model: | ||
|
||
```bash | ||
export MODEL=meta-llama/Meta-Llama-3-8B-Instruct | ||
``` | ||
|
||
### Spin Up vLLM Server | ||
|
||
Install: | ||
|
||
```bash | ||
python -m venv vllm-venv | ||
source vllm-venv/bin/activate | ||
pip install vllm | ||
``` | ||
|
||
Launch: | ||
|
||
```bash | ||
python -m vllm.entrypoints.openai.api_server --model $MODEL | ||
``` | ||
|
||
### Spin Up Clients | ||
|
||
Install: | ||
|
||
```bash | ||
python3 -m venv benchmark-venv | ||
source benchmark-venv/bin/activate | ||
pip install -U aiohttp transformers | ||
``` | ||
|
||
Download sample data: | ||
|
||
```bash | ||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json | ||
``` | ||
|
||
Launch the clients (we launch 1 client per second here): | ||
|
||
```bash | ||
python3 benchmark_serving.py \ | ||
--model $MODEL \ | ||
--request-rate 1.0 \ | ||
--num-prompts 100 \ | ||
--dataset-path ShareGPT_V3_unfiltered_cleaned_split.json | ||
``` | ||
|
||
Results: | ||
|
||
We achieve `43ms` of TPOT on an A10. | ||
|
||
```bash | ||
============ Serving Benchmark Result ============ | ||
Successful requests: 100 | ||
Benchmark duration (s): 110.95 | ||
Total input tokens: 22805 | ||
Total generated tokens: 17981 | ||
Request throughput (req/s): 0.90 | ||
Input token throughput (tok/s): 205.53 | ||
Output token throughput (tok/s): 162.06 | ||
---------------Time to First Token---------------- | ||
Mean TTFT (ms): 106.36 | ||
Median TTFT (ms): 78.82 | ||
P99 TTFT (ms): 286.84 | ||
-----Time per Output Token (excl. 1st token)------ | ||
Mean TPOT (ms): 42.52 | ||
Median TPOT (ms): 43.07 | ||
P99 TPOT (ms): 58.80 | ||
================================================== | ||
``` |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good space for a diagram (bucketing the weight distribution to 256 buckets)