Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualcomm AI Engine Direct - Support Hybrid Mode for Llama3.2 #7175

Merged

Conversation

winskuo-quic
Copy link
Collaborator

@winskuo-quic winskuo-quic commented Dec 4, 2024

Summary

  • Enable to export Llama3.2 1B hybrid mode.
  • Enabled runner to support hybrid mode.
  • Handle multi IO ordering for weight sharing scenario.
  • Use 64 bit flatbuffer to enable exporting larger models
  • Align timer for runner

Script to run prefill 32 context_length and kv 512 context_length

python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ${ARCHIVE} -b build-android -H ${HOST} -m ${SOC} --checkpoint Llama3.2-1B-Instruct/consolidated.00.pth --params Llama3.2-1B-Instruct/params.json --tokenizer_model Llama3.2-1B-Instruct/tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 512 --ptq 16a4w

Copy link

pytorch-bot bot commented Dec 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7175

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit db86fbd with merge base f370e78 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 4, 2024
@winskuo-quic winskuo-quic marked this pull request as draft December 4, 2024 10:29
@dbort dbort added the partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm label Dec 4, 2024
@haowhsu-quic
Copy link
Collaborator

haowhsu-quic commented Dec 5, 2024

Hi @cccclai, I wonder if you have plan to support reuse of processed bytes? Looks like current approach only works for processed bytes staying inside a method rather than across all methods

self.emitter_state.delegate_cache[processed_bytes] = delegate_index

This might be critical to the .pte size since we have different methods all point to the same processed binary. Any suggestion for us to work this out?

@cccclai
Copy link
Contributor

cccclai commented Dec 5, 2024

Hmm I'm trying to follow - in #6657, I remember we observe size reduction, how is it different than the test example added in the PR?

@haowhsu-quic
Copy link
Collaborator

Hmm I'm trying to follow - in #6657, I remember we observe size reduction, how is it different than the test example added in the PR?

Yes, the context binary size will be smaller than the version without weight_sharing but not applied to the generated the pte file.

  • Current approach, take llama as an example, prefill & decode are two methods exist in final pte:
    flowchart TB
        prefill --> id1(weight_sharing_context_prefill) --> output_prefill
        decode --> id2(weight_sharing_context_decode) --> output_decode
    
    Loading
    Since weight_sharing_context_prefill & weight_sharing_context_decode are actually the same, and we only use one of them in runtime. We're hoping these two processed bytes could be merged together instead of having identical copies in the final pte.
  • Expected version:
     flowchart TB
         prefill --> id1(weight_sharing_context) --> output_prefill
         decode --> id1(weight_sharing_context) --> output_decode
    
    Loading

Hopefully this makes sense to you.

@cccclai
Copy link
Contributor

cccclai commented Dec 5, 2024

oh hmm, I thought this line

self.emitter_state.delegate_cache[processed_bytes] = delegate_index

will just deduplicate directly, unless they're not bit-exact match. Did I misunderstand anything?

@cccclai
Copy link
Contributor

cccclai commented Dec 5, 2024

What is the len(self.emitter_state.delegate_cache) in your test case? (Maybe I should just grab the PR and repro...

@haowhsu-quic
Copy link
Collaborator

oh hmm, I thought this line

self.emitter_state.delegate_cache[processed_bytes] = delegate_index

will just deduplicate directly, unless they're not bit-exact match. Did I misunderstand anything?

No, your understanding is correct. But I think the delegate_cache only exists in method-wise manner:
https://github.com/pytorch/executorch/blob/cd306d356660fb4a8cdb4639fd3300e25fd412ef/exir/emit/_emit_program.py#L149C1-L159C10

    # emit each entry point in order according to name.
    for name, exported_program in sorted(methods.items()):
        # create empty state
        emitter_state = _EmitterState(
            values=[],
            operators=[],
            delegates=[],
            operator_cache={},
            delegate_cache={},
            emit_stacktrace=emit_stacktrace,
        )

Our scenario for weight_sharing today is multi-methods connecting to the same processed bytes, but cache is not reused since the emitter_state could not be shared.

@cccclai
Copy link
Contributor

cccclai commented Dec 5, 2024

Hmm I feel like something is off. If we look at the executorch .pte schema, delegate is outside of each method https://github.com/pytorch/executorch/blob/main/exir/schema.py#L292

@cccclai
Copy link
Contributor

cccclai commented Dec 5, 2024

Hmm I feel like something is off. If we look at the executorch .pte schema, delegate is outside of each method https://github.com/pytorch/executorch/blob/main/exir/schema.py#L292

There is a chance that maybe we have a bug somewhere internal, I'd need to check.

@haowhsu-quic
Copy link
Collaborator

This is the patch works on my side:

diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py
index 33be00ed5..fcb7eb667 100644
--- a/backends/qualcomm/utils/utils.py
+++ b/backends/qualcomm/utils/utils.py
@@ -733,7 +733,7 @@ def from_context_binary(  # noqa: C901
     bundle_prog = build_graph(inputs, outputs)
     bundle_prog.update({"inputs": inputs, "outputs": outputs})
     edge_prog_mgr = to_edge(
-        programs={graph_name: bundle_prog["exported_program"]},
+        {graph_name: bundle_prog["exported_program"]},
         # do not alter name for custom op
         compile_config=EdgeCompileConfig(_use_edge_ops=False),
     )
@@ -791,7 +791,7 @@ def generate_multi_graph_program(
     ]
     # leverage ExecutorchProgramManager for generating pte with multi-methods
     edge_prog_mgr = to_edge(
-        programs={
+        {
             graph_name: bundle_prog["exported_program"]
             for graph_name, bundle_prog in zip(graph_names, bundle_progs)
         },
diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py
index 00a3d4700..26e28589c 100644
--- a/exir/_serialize/_program.py
+++ b/exir/_serialize/_program.py
@@ -254,6 +254,7 @@ def _extract_delegate_segments(
     """
     remaining_inline: List[BackendDelegateInlineData] = []
     inline_indices_seen: set[int] = set()
+    segment_index_map: dict[bytes, int] = {}
     for plan in program.execution_plan:
         for delegate in plan.delegates:
             if delegate.processed.location != DataLocation.INLINE:
@@ -279,8 +280,12 @@ def _extract_delegate_segments(
             inline_indices_seen.add(delegate.processed.index)
             if inline.data:
                 # Move the delegate data out of the program.
-                segment_index = len(segments)
-                segments.append(Cord(inline.data))
+                segment_index = segment_index_map.get(inline.data)
+                if segment_index is None:
+                    segment_index = len(segments)
+                    segments.append(Cord(inline.data))
+                    segment_index_map[inline.data] = segment_index
+
                 delegate.processed = BackendDelegateDataReference(
                     location=DataLocation.SEGMENT,
                     index=segment_index,
diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py
index 381bab618..cf19c0479 100644
--- a/exir/emit/_emitter.py
+++ b/exir/emit/_emitter.py
@@ -119,6 +119,7 @@ class _ProgramState:
     # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
     # and should be copied to Program.backend_delegate_data.
     backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
+    backend_delegate_data_cache: Dict[bytes, int] = field(default_factory=dict)
 
 
 @dataclass
@@ -1049,10 +1050,13 @@ class _Emitter(torch.fx.Interpreter):
         if delegate_index is None:
             # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
             # present.
-            data_index: int = len(self.program_state.backend_delegate_data)
-            self.program_state.backend_delegate_data.append(
-                BackendDelegateInlineData(data=processed_bytes)
-            )
+            data_index = self.program_state.backend_delegate_data_cache.get(processed_bytes)
+            if data_index is None:
+                data_index: int = len(self.program_state.backend_delegate_data)
+                self.program_state.backend_delegate_data_cache[processed_bytes] = data_index
+                self.program_state.backend_delegate_data.append(
+                    BackendDelegateInlineData(data=processed_bytes)
+                )
 
             backend_delegate = BackendDelegate(
                 id=lowered_module.backend_id,

You could inspect the generated pte size w/wo patch via test case:

python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperators.test_qnn_backend_multi_graphs -s $DEVICE_SERIAL -m $SOC_MODEL -b build-android/ -a $PATH_TO_ARTIFACTS

@chunit-quic
Copy link
Collaborator

Script to run prefill 32 context_length and kv 512 context_length

python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ${ARCHIVE} -b build-android -H ${HOST} -m ${SOC} --checkpoint Llama3.2-1B-Instruct/consolidated.00.pth --params Llama3.2-1B-Instruct/params.json --tokenizer_model Llama3.2-1B-Instruct/tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 512 --ptq 16a4w

Hi Chen, Just in case you still encounter problem during runtime. Please feel free to let us know. We will try to assist you.

@cccclai
Copy link
Contributor

cccclai commented Dec 6, 2024

Script to run prefill 32 context_length and kv 512 context_length

python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ${ARCHIVE} -b build-android -H ${HOST} -m ${SOC} --checkpoint Llama3.2-1B-Instruct/consolidated.00.pth --params Llama3.2-1B-Instruct/params.json --tokenizer_model Llama3.2-1B-Instruct/tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 512 --ptq 16a4w

Hi Chen, Just in case you still encounter problem during runtime. Please feel free to let us know. We will try to assist you.

Hi @chunit-quic, thank you for offering, I'm preparing an internal design review tomorrow for supporting multimethods, such that we don't need to work around for weight sharing. I'll get back to this repro right away.

@winskuo-quic winskuo-quic marked this pull request as ready for review December 6, 2024 06:53
@winskuo-quic
Copy link
Collaborator Author

winskuo-quic commented Dec 6, 2024

Hi @cccclai,
We have verified that Snapdragon 8 Gen2 and Gen 3 hybrid mode can be successfully lowered and executed. The performance appears to be consistent with running either prefill or KV mode.
However, we are still working on enabling Snapdragon 8 Gen1 to execute hybrid mode as Gen1 will encounter memory related error.
Please have a look. Thanks

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Dec 9, 2024

I've verified the stories model PR and batch prefill/decode standalone seems working well. Trying the 1b model

@cccclai
Copy link
Contributor

cccclai commented Dec 9, 2024

In the meanwhile, looks like we run into some flatbuffers compilation issue for the internal build,

error: /re_cwd/fbcode/executorch/backends/qualcomm/serialization/qc_binary_info.fbs:17: 33: error: user define attributes must be declared before use: vector64

Is vector64 required?

@cccclai
Copy link
Contributor

cccclai commented Dec 9, 2024

If alignement is required, can we do something similar to

data: [ubyte] (force_align: 16); // @executorch-delegate-alignment
to force alignment?

flatbuffers::FlatBufferBuilder builder_;
flatbuffers::Verifier::Options fb_opt_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one flatbuffers::Verifier::Options is also causing build error, maybe the flatbuffers version is different...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like internally we pinged to https://github.com/google/flatbuffers/commits/338393f8/ Any chance we can limit using features before this version? If not, we may need to add patches for flatbuffers::Verifier::Options and FlatBufferBuilder64

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also try to bump the internal flatbuffers version..will provide more info after trying it out

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drafted an internal diff with the flatbuffers version bump - will watch the CI signal and hopefully it can pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cccclai,
Thanks for reviewing the PR. We tried to use the original 32-bit flatbuffer, however, for larger models, we encounter errors complaining 32bit flatbuffer size is insufficient. The 64bit flatbuffer allows size to be larger than 2 GB.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, did you use wrap the context binary inside flatbuffers? That's likely the reason. In executorch, we move the context binary outside of flatbuffers so it's no longer limited by the 32-bit flatbuffers #885 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI seems fine to bump flatbuffers version..I'll try get it landed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we tried wrapping the context binary inside flatbuffer. If the version bump does not work, we could try moving it outside of flatbuffer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the CI is really bad when I bump the flatbuffers version...we tried to bump flatbuffers version before and it didn't go well. I guess we can try again, but no guarantee..

@winskuo-quic
Copy link
Collaborator Author

If alignement is required, can we do something similar to

data: [ubyte] (force_align: 16); // @executorch-delegate-alignment

to force alignment?

I believe force alignment is not required for our case.

@winskuo-quic winskuo-quic force-pushed the dev1/winskuo/llama3_2_hybrid_mode branch from 645fb63 to 43dfe13 Compare December 10, 2024 09:39
@cccclai
Copy link
Contributor

cccclai commented Dec 11, 2024

I was able to repro the following number,
for input=16

[INFO] [Qnn ExecuTorch]: Use cached delegate handle for current method: kv_forward
PyTorchObserver {"prompt_tokens":16,"generated_tokens":495,"model_load_start_ms":1733884854462,"model_load_end_ms":1733884855367,"inference_start_ms":1733884855367,"inference_end_ms":1733884864383,"prompt_eval_end_ms":1733884855423,"first_token_ms":1733884855423,"aggregate_sampling_time_ms":352,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:09.991087 executorch:runner.cpp:390] 	Prompt Tokens: 16    Generated Tokens: 495
I 00:00:09.991112 executorch:runner.cpp:396] 	Model Load Time:		0.905000 (seconds)
I 00:00:09.991183 executorch:runner.cpp:406] 	Total inference time:		9.016000 (seconds)		 Rate: 	54.902396 (tokens/second)
I 00:00:09.991199 executorch:runner.cpp:414] 		Prompt evaluation:	0.056000 (seconds)		 Rate: 	285.714286 (tokens/second)
I 00:00:09.991210 executorch:runner.cpp:425] 		Generated 495 tokens:	8.960000 (seconds)		 Rate: 	55.245536 (tokens/second)
I 00:00:09.991222 executorch:runner.cpp:433] 	Time to first generated token:	0.056000 (seconds)
I 00:00:09.991230 executorch:runner.cpp:440] 	Sampling time over 495 tokens:	0.352000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context

for input=31

PyTorchObserver {"prompt_tokens":31,"generated_tokens":480,"model_load_start_ms":1733890181248,"model_load_end_ms":1733890182105,"inference_start_ms":1733890182105,"inference_end_ms":1733890190885,"prompt_eval_end_ms":1733890182163,"first_token_ms":1733890182163,"aggregate_sampling_time_ms":349,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:09.711373 executorch:runner.cpp:390] 	Prompt Tokens: 31    Generated Tokens: 480
I 00:00:09.711406 executorch:runner.cpp:396] 	Model Load Time:		0.857000 (seconds)
I 00:00:09.711436 executorch:runner.cpp:406] 	Total inference time:		8.780000 (seconds)		 Rate: 	54.669704 (tokens/second)
I 00:00:09.711453 executorch:runner.cpp:414] 		Prompt evaluation:	0.058000 (seconds)		 Rate: 	534.482759 (tokens/second)
I 00:00:09.711466 executorch:runner.cpp:425] 		Generated 480 tokens:	8.722000 (seconds)		 Rate: 	55.033249 (tokens/second)
I 00:00:09.711480 executorch:runner.cpp:433] 	Time to first generated token:	0.058000 (seconds)
I 00:00:09.711491 executorch:runner.cpp:440] 	Sampling time over 480 tokens:	0.349000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context

but the model size is

2.1G	./hybrid_llama3_2_qnn.pte

@winskuo-quic
Copy link
Collaborator Author

I was able to repro the following number

[INFO] [Qnn ExecuTorch]: Use cached delegate handle for current method: kv_forward
PyTorchObserver {"prompt_tokens":16,"generated_tokens":495,"model_load_start_ms":1733884854462,"model_load_end_ms":1733884855367,"inference_start_ms":1733884855367,"inference_end_ms":1733884864383,"prompt_eval_end_ms":1733884855423,"first_token_ms":1733884855423,"aggregate_sampling_time_ms":352,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:09.991087 executorch:runner.cpp:390] 	Prompt Tokens: 16    Generated Tokens: 495
I 00:00:09.991112 executorch:runner.cpp:396] 	Model Load Time:		0.905000 (seconds)
I 00:00:09.991183 executorch:runner.cpp:406] 	Total inference time:		9.016000 (seconds)		 Rate: 	54.902396 (tokens/second)
I 00:00:09.991199 executorch:runner.cpp:414] 		Prompt evaluation:	0.056000 (seconds)		 Rate: 	285.714286 (tokens/second)
I 00:00:09.991210 executorch:runner.cpp:425] 		Generated 495 tokens:	8.960000 (seconds)		 Rate: 	55.245536 (tokens/second)
I 00:00:09.991222 executorch:runner.cpp:433] 	Time to first generated token:	0.056000 (seconds)
I 00:00:09.991230 executorch:runner.cpp:440] 	Sampling time over 495 tokens:	0.352000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context

but the model size is

2.1G	./hybrid_llama3_2_qnn.pte

Thank you for providing the numbers.
Could you also share the QNN version you are currently using?

Regarding the model size, I believe @haowhsu-quic mentioned this at the top of the PR. While the context binary is indeed smaller, the PTE size is not. It would be appreciated if you could review the patch above. If you think it works, we can create another PR for it.

@cccclai
Copy link
Contributor

cccclai commented Dec 11, 2024

I'm using qnn 2.26. In the meanwhile, I think apply the patch above.

@cccclai
Copy link
Contributor

cccclai commented Dec 11, 2024

I actually test the patch and it looks correct to me, though i'm still not sure why the .pte size double..

cccclai added a commit to cccclai/executorch-1 that referenced this pull request Dec 11, 2024
Summary: Reported by pytorch#7175 that the delegate is not deduplicate when they're exactly the same

Differential Revision: D67067997
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Jan 2, 2025

The CI looks clean now! A separate question, does the quantized_linear in qnn use the quantized weight from prefill or decode?

@winskuo-quic
Copy link
Collaborator Author

The CI looks clean now! A separate question, does the quantized_linear in qnn use the quantized weight from prefill or decode?

If you are referring to the weights for linear op, I believe that the prefill and decode should have the same quantized weights, which means it should be using weight sharing.

@cccclai
Copy link
Contributor

cccclai commented Jan 2, 2025

If you are referring to the weights for linear op, I believe that the prefill and decode should have the same quantized weights, which means it should be using weight sharing.

Yeah it's for linear op. However, prefill and decode are calibrated differently, meaning the quantized weights may not be the same?

@winskuo-quic
Copy link
Collaborator Author

If you are referring to the weights for linear op, I believe that the prefill and decode should have the same quantized weights, which means it should be using weight sharing.

Yeah it's for linear op. However, prefill and decode are calibrated differently, meaning the quantized weights may not be the same?

I believe the weights scale/offset should be the same, and that's why we can achieve almost half the pte size when enabling weight sharing. The fp value for the weights should always be the same for both modes, so we should always be getting the same scale/offset for both modes. On the other hand, activation's scale/offset will be different due to different input's during calibration.

@winskuo-quic
Copy link
Collaborator Author

Just another separate question. We are still unable to achieve great accuracy for hybrid mode by applying PTQ. We are looking forward to enabling QAT on static llama, and we would like to know if there are currently any barriers or difficulties that you are facing. Thanks

Current results that we are getting:
prompt: "what is 1+1"
response: "<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

what is 1+1<|eot_id|><|start_header_id|>assistant<|end_header_id|>

There is one or one!<|eot_id|>"

@cccclai
Copy link
Contributor

cccclai commented Jan 3, 2025

Just another separate question. We are still unable to achieve great accuracy for hybrid mode by applying PTQ. We are looking forward to enabling QAT on static llama, and we would like to know if there are currently any barriers or difficulties that you are facing. Thanks

Yeah we've done some work on QAT. By eyeballing the result is reasonable. Currently we QAT the weight and calibrate prefill then decode seperately, however, the model size increase and we're trying figure out if it's due to the quantized weight difference.

@winskuo-quic
Copy link
Collaborator Author

Just another separate question. We are still unable to achieve great accuracy for hybrid mode by applying PTQ. We are looking forward to enabling QAT on static llama, and we would like to know if there are currently any barriers or difficulties that you are facing. Thanks

Yeah we've done some work on QAT. By eyeballing the result is reasonable. Currently we QAT the weight and calibrate prefill then decode seperately, however, the model size increase and we're trying figure out if it's due to the quantized weight difference.

I would like to confirm if you can reproduce the model size reduction with the patch @haowhsu-quic shared in this thread. If applied correctly on this PR, we should see a size reduction in the 16a4w 1B hybrid mode, from approximately 2.2GB to 1.1GB.

Additionally, I would like to confirm that QAT is performed only once and the fp weights(before calibration) for prefill and kv mode are the same. If the fp weights are the same, the calibration process should generate the exact same scale/offset, leading to significant size reduction through weight sharing.

Thanks.

@cccclai
Copy link
Contributor

cccclai commented Jan 3, 2025

I would like to confirm if you can reproduce the model size reduction with the patch @haowhsu-quic shared in this thread. If applied correctly on this PR, we should see a size reduction in the 16a4w 1B hybrid mode, from approximately 2.2GB to 1.1GB.

With this pr, yes I observe the smaller size down to 1.1GB. But after we apply QAT + PTQ, the size seems to be larger, up to 1.6 GB

@cccclai
Copy link
Contributor

cccclai commented Jan 3, 2025

Additionally, I would like to confirm that QAT is performed only once and the fp weights(before calibration) for prefill and kv mode are the same. If the fp weights are the same, the calibration process should generate the exact same scale/offset, leading to significant size reduction through weight sharing.

The fp weights will be the same, however, they will be calibrated differently (calibrate prefill and calibrate decode), meaning the quantization parameters (the scale and zero point) might be slightly different.

@winskuo-quic
Copy link
Collaborator Author

winskuo-quic commented Jan 3, 2025

Additionally, I would like to confirm that QAT is performed only once and the fp weights(before calibration) for prefill and kv mode are the same. If the fp weights are the same, the calibration process should generate the exact same scale/offset, leading to significant size reduction through weight sharing.

The fp weights will be the same, however, they will be calibrated differently (calibrate prefill and calibrate decode), meaning the quantization parameters (the scale and zero point) might be slightly different.

I think by calibrating the data with only PTQ, we do get the same scale/offset for weights, which aligns with the size of reduction from 2.2GB to 1.1GB.
Just to double check, since the fp weights are the same, does that mean QAT is only performed once in hybrid mode?

@cccclai
Copy link
Contributor

cccclai commented Jan 3, 2025

Just to double check, since the fp weights are the same, does that mean QAT is only performed once in hybrid mode?

QAT is only performed once, and PTQ will perform twice, one with the prefill model and one with the decode model

@facebook-github-bot facebook-github-bot merged commit 3c83553 into pytorch:main Jan 3, 2025
45 of 46 checks passed
@cccclai
Copy link
Contributor

cccclai commented Jan 5, 2025

How long did it take to export a model currently? I have been waiting for 2 hours....

@winskuo-quic
Copy link
Collaborator Author

How long did it take to export a model currently? I have been waiting for 2 hours....

For my case where I do 512kv and 32prefill, I think it took about 3 hours. The time required to export model can vary depending on the prefill/kv sequence length provided.

@winskuo-quic
Copy link
Collaborator Author

Just to double check, since the fp weights are the same, does that mean QAT is only performed once in hybrid mode?

QAT is only performed once, and PTQ will perform twice, one with the prefill model and one with the decode model

For the issue that some weights have a slight difference, would you mind providing a draft PR, so we can try to reproduce the issue and see if it can be resolved? We can use dummy inputs to test the scale/offset.

@cccclai
Copy link
Contributor

cccclai commented Jan 6, 2025

Yeah let me see how to repro..

Additionally, it seems like the model is bigger with the new commit. Without applying #7281, the model bumps to 4 GB, after applying, it's 2GB

@winskuo-quic
Copy link
Collaborator Author

Yeah let me see how to repro..

Additionally, it seems like the model is bigger with the new commit. Without applying #7281, the model bumps to 4 GB, after applying, it's 2GB

Thanks!
And for the issue you mentioned where model size bumps to 2GB, may I know if you are running QAT?
I tried exporting this model in mainline(commit: 507c767) + deduplicate delegate cache and can still get 1.1GB pte size for hybrid mode.
This is the command I used to export the model, and the image shows the file size I am getting:
python examples/qualcomm/oss_scripts/llama3_2/llama.py -b build-android -m SM8650 --checkpoint ../llama/llama/Llama3.2-1B-Instruct/consolidated.00.pth --params ../llama/llama/Llama3.2-1B-Instruct/params.json --tokenizer_model ../llama/llama/Llama3.2-1B-Instruct/tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --kv_seq_len 128 --prefill_seq_len 32 --model_mode hybrid --ptq 16a4w -a check_pte_size --compile_only
image

@cccclai
Copy link
Contributor

cccclai commented Jan 7, 2025

I'm currently getting the following data with #7281

PyTorchObserver {"prompt_tokens":31,"generated_tokens":96,"model_load_start_ms":1736230916307,"model_load_end_ms":1736230917194,"inference_start_ms":1736230917194,"inference_end_ms":1736230919028,"prompt_eval_end_ms":1736230917250,"first_token_ms":1736230917250,"aggregate_sampling_time_ms":71,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:02.721382 executorch:runner.cpp:414] 	Prompt Tokens: 31    Generated Tokens: 96
I 00:00:02.721396 executorch:runner.cpp:420] 	Model Load Time:		0.887000 (seconds)
I 00:00:02.721407 executorch:runner.cpp:430] 	Total inference time:		1.834000 (seconds)		 Rate: 	52.344602 (tokens/second)
I 00:00:02.721417 executorch:runner.cpp:438] 		Prompt evaluation:	0.056000 (seconds)		 Rate: 	553.571429 (tokens/second)
I 00:00:02.721427 executorch:runner.cpp:449] 		Generated 96 tokens:	1.778000 (seconds)		 Rate: 	53.993251 (tokens/second)
I 00:00:02.721437 executorch:runner.cpp:457] 	Time to first generated token:	0.056000 (seconds)
I 00:00:02.721443 executorch:runner.cpp:464] 	Sampling time over 96 tokens:	0.071000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
2.2G    ./llama3_2_qnn/llama3_2_sm8650_seq_512_qnn_uncalibrated.pte

@cccclai
Copy link
Contributor

cccclai commented Jan 7, 2025

This is the only change we made for calibration.

--- a/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py
+++ b/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py
@@ -73,37 +73,46 @@
     max_seq_len=512,
 ):
     sp_model = get_tokenizer(tokenizer_model_path)
-    _, atten_mask, _, k_caches, v_caches = example_inputs
 
     # TODO: change criteria & support batch inputs if necessary
-    pos = torch.tensor(0, dtype=torch.int32)
     max_cache_len = max_seq_len - 1
-    token_list = sp_model.encode(user_prompts, bos=True, eos=False)
-
-    with torch.no_grad():
-        while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
-            logits, new_k_caches, new_v_caches = module(
-                torch.full((1, 1), token_list[pos], dtype=torch.int32),
-                atten_mask,
-                torch.full((1, 1), pos),
-                *k_caches,
-                *v_caches,
-            )
-            k_caches = [
-                torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
-                for i, k_cache in enumerate(k_caches)
-            ]
-            v_caches = [
-                torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
-                for i, v_cache in enumerate(v_caches)
-            ]
+    # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
+	
+    user_token_list = [
+        # what is the capital of the united states
+        [128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271],
+        # what is 1 + 1
+        [128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271],
+        # what is the meaning of life
+        [128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271],
+    ]
+    for token_list in user_token_list:
+        _, atten_mask, _, k_caches, v_caches = copy.deepcopy(example_inputs)
+        pos = torch.tensor(0, dtype=torch.int32)
+        with torch.no_grad():
+            while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
+                logits, new_k_caches, new_v_caches = module(
+                    torch.full((1, 1), token_list[pos], dtype=torch.int32),
+                    atten_mask,
+                    torch.full((1, 1), pos),
+                    *k_caches,
+                    *v_caches,
+                )
+                k_caches = [
+                    torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
+                    for i, k_cache in enumerate(k_caches)
+                ]
+                v_caches = [
+                    torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
+                    for i, v_cache in enumerate(v_caches)
+                ]
 
-            pos += 1
-            atten_mask[0][-pos - 1] = 0
-            if pos >= len(token_list):
-                token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
+                pos += 1
+                atten_mask[0][-pos - 1] = 0
+                if pos >= len(token_list):
+                    token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
 
-    print(f"calibration data:\n{sp_model.decode(token_list)}")
+            logging.info(f"calibration data:\n{sp_model.decode(token_list)}")
 
 
 def _prefill_calibrate(
@@ -114,32 +123,43 @@
     max_seq_len=512,
 ):
     sp_model = get_tokenizer(tokenizer_model_path)
-    _, atten_mask = example_inputs
     max_cache_len = max_seq_len - 1
 
     # TODO: change criteria & support batch inputs if necessary
-    token_list = sp_model.encode(user_prompts, bos=True, eos=False)
-    token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
-    last_prompt_pos = token_list.numel()
-    if last_prompt_pos < max_cache_len:
-        token_list = torch.cat(
-            [
-                token_list,
-                torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
-            ],
-            dim=1,
-        )
-    else:
-        token_list = token_list[:, :max_cache_len]
+    # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
+
+    user_token_list = [
+        # what is the capital of the united states
+        [128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271],
+        # what is 1 + 1
+        [128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271],
+        # what is the meaning of life
+        [128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271],
+    ]
+
+    for token_list in user_token_list:
+        _, atten_mask = copy.deepcopy(example_inputs)
+        token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
+        last_prompt_pos = token_list.numel()
+        if last_prompt_pos < max_cache_len:
+            token_list = torch.cat(
+                [
+                    token_list,
+                    torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
+                ],
+                dim=1,
+            )
+        else:
+            token_list = token_list[:, :max_cache_len]
 
-    with torch.no_grad():
-        logits, new_k_caches, new_v_caches = module(
-            token_list,
-            atten_mask,
-        )
-        predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]
+        with torch.no_grad():
+            logits, new_k_caches, new_v_caches = module(
+                token_list,
+                atten_mask,
+            )
+            predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]
 
-    print(f"calibration data:\n{sp_model.decode(predict)}")
+        logging.info(f"calibration data:\n{sp_model.decode(predict)}")
 
 
 def calibrate(
@@ -249,7 +269,17 @@
             max_seq_len=self.llama_meta["get_max_seq_len"],
         )
 
-        self.llama_model = convert_pt2e(fx_graph_module)
+        fx_graph_module = convert_pt2e(fx_graph_module)
+
+        logging.info("Evaluating the converted model...")
+        calibrate(
+            self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
+            args.prompt,
+            fx_graph_module,
+            tokenizer_model_path=args.tokenizer_model,
+            max_seq_len=self.llama_meta["get_max_seq_len"],
+        )
+        self.llama_model = fx_graph_module
 
     def lowering_modules(
         self,

@winskuo-quic
Copy link
Collaborator Author

I'm currently getting the following data with #7281

PyTorchObserver {"prompt_tokens":31,"generated_tokens":96,"model_load_start_ms":1736230916307,"model_load_end_ms":1736230917194,"inference_start_ms":1736230917194,"inference_end_ms":1736230919028,"prompt_eval_end_ms":1736230917250,"first_token_ms":1736230917250,"aggregate_sampling_time_ms":71,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:02.721382 executorch:runner.cpp:414] 	Prompt Tokens: 31    Generated Tokens: 96
I 00:00:02.721396 executorch:runner.cpp:420] 	Model Load Time:		0.887000 (seconds)
I 00:00:02.721407 executorch:runner.cpp:430] 	Total inference time:		1.834000 (seconds)		 Rate: 	52.344602 (tokens/second)
I 00:00:02.721417 executorch:runner.cpp:438] 		Prompt evaluation:	0.056000 (seconds)		 Rate: 	553.571429 (tokens/second)
I 00:00:02.721427 executorch:runner.cpp:449] 		Generated 96 tokens:	1.778000 (seconds)		 Rate: 	53.993251 (tokens/second)
I 00:00:02.721437 executorch:runner.cpp:457] 	Time to first generated token:	0.056000 (seconds)
I 00:00:02.721443 executorch:runner.cpp:464] 	Sampling time over 96 tokens:	0.071000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
2.2G    ./llama3_2_qnn/llama3_2_sm8650_seq_512_qnn_uncalibrated.pte

I have applied the calibration patch you provided above and #7281.
However, I am getting a better inference speed, and the model generated is only 1.1GB.
Would it help if I share the .pte file I have generated so we can see if it is the device that is causing performance drop?

image

image

@cccclai
Copy link
Contributor

cccclai commented Jan 8, 2025

@winskuo-quic yes that wil be very helpful.

In the meanwhile, looks like the decoding speed from hybrid model is a bit slower than the previous decode speed from kv mode. Do we know the reason?

@winskuo-quic
Copy link
Collaborator Author

winskuo-quic commented Jan 9, 2025

@winskuo-quic yes that wil be very helpful.

In the meanwhile, looks like the decoding speed from hybrid model is a bit slower than the previous decode speed from kv mode. Do we know the reason?

Great!
I will send it to you later.
For the slight speed drop, I believe there are some small fluctuations between inference, and also, I am using QNN2.26 when exporting the model, which QNN 2.28 should have a slightly better score.

@winskuo-quic
Copy link
Collaborator Author

winskuo-quic commented Jan 9, 2025

Hi @cccclai,
I have created a separate DRAFT PR that applies all patches so it would be easier for us to stay on the same page.
I have changed from QNN2.26 to QNN2.28 so it is easier for us to compare the inference speed. Thanks.
#7569

cccclai added a commit to cccclai/executorch-1 that referenced this pull request Jan 15, 2025
Summary:

Reported by pytorch#7175 that the delegate is not deduplicate when they're exactly the same

Differential Revision: D67067997
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Jan 15, 2025
Summary:

Reported by pytorch#7175 that the delegate is not deduplicate when they're exactly the same

Differential Revision: D67067997
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Jan 15, 2025
Summary:

Reported by pytorch#7175 that the delegate is not deduplicate when they're exactly the same

Differential Revision: D67067997
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Jan 15, 2025
Summary:

Reported by pytorch#7175 that the delegate is not deduplicate when they're exactly the same

Reviewed By: tarun292

Differential Revision: D67067997
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants