From d121b874985ae24830320c4027eb8bc1ddb9c83d Mon Sep 17 00:00:00 2001 From: cccclai Date: Tue, 15 Jul 2025 11:27:07 -0700 Subject: [PATCH] Revert "Qualcomm AI Engine Direct - GA Static QWEN2.5 0.5B (#12054)" This reverts commit fe3062a9906ef5329877ea0b0e389286e0c0ae03. --- backends/qualcomm/_passes/layout_transform.py | 2 +- backends/qualcomm/quantizer/annotators.py | 6 +- .../qualcomm/quantizer/custom_annotation.py | 11 +- backends/qualcomm/scripts/build.sh | 8 - backends/qualcomm/tests/test_qnn_delegate.py | 63 +----- examples/qualcomm/CMakeLists.txt | 4 +- .../qualcomm/oss_scripts/llama/CMakeLists.txt | 9 - examples/qualcomm/oss_scripts/llama/README.md | 3 +- .../llama/hf_converter/convert_config.py | 45 ---- examples/qualcomm/oss_scripts/llama/llama.py | 202 ++++++------------ .../oss_scripts/llama/model/static_llama.py | 52 +---- .../oss_scripts/llama/qnn_llama_runner.cpp | 19 +- .../oss_scripts/llama/runner/runner.cpp | 80 ++++--- .../oss_scripts/llama/runner/runner.h | 8 +- 14 files changed, 139 insertions(+), 373 deletions(-) delete mode 100644 examples/qualcomm/oss_scripts/llama/hf_converter/convert_config.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 7cf7df33a14..0c6b3152561 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -101,8 +101,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.relu.default, exir_ops.edge.aten.round.default, + exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.split_with_sizes.default, exir_ops.edge.aten.split_with_sizes_copy.default, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 0f064b627d6..cc7e0054ebe 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -275,9 +275,7 @@ def annotate_masked_fill(node: Node, quantization_config: QuantizationConfig) -> ) -@register_annotator( - [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor] -) +@register_annotator([torch.ops.aten.mul, torch.ops.aten.mul.Tensor]) def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) @@ -1300,7 +1298,7 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: ) -@register_annotator([torch.ops.aten.zeros.default, torch.ops.aten.zeros_like.default]) +@register_annotator([torch.ops.aten.zeros.default]) def annotate_zeros(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]) or not _is_float_tensor(node): return diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index f1a0cd95dff..057d3ea93d2 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -153,9 +153,7 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): ) -def annotate_matmul_16a8w( # noqa: C901 - gm: torch.fx.GraphModule, annotate_conv=True -) -> None: +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 """ This function is specific for matmul op 16a8w. For k, we will tag such as the below, and @@ -319,10 +317,9 @@ def annotate_matmul_input1(node: Node): # The arguments of cat op: (the past kv cache, the new kv cache) node = node.args[0][1] elif node.target == torch.ops.aten.conv2d.default: - if annotate_conv: - annotate_conv2d( - node, quantization_config=quantization_config_8a4w_per_channel - ) + annotate_conv2d( + node, quantization_config=quantization_config_8a4w_per_channel + ) break elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: break diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 8c0f01feac6..8099ecb3de8 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -85,7 +85,6 @@ if [ "$BUILD_AARCH64" = true ]; then -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ - -DEXECUTORCH_ENABLE_LOGGING=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \ -DANDROID_ABI='arm64-v8a' \ @@ -105,9 +104,6 @@ if [ "$BUILD_AARCH64" = true ]; then -DANDROID_ABI='arm64-v8a' \ -DANDROID_PLATFORM=android-30 \ -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ - -DSUPPORT_REGEX_LOOKAHEAD=ON \ - -DBUILD_TESTING=OFF \ - -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ @@ -138,7 +134,6 @@ if [ "$BUILD_X86_64" = true ]; then -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ - -DEXECUTORCH_ENABLE_LOGGING=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ -S $PRJ_ROOT \ -B $BUILD_ROOT \ @@ -162,9 +157,6 @@ if [ "$BUILD_X86_64" = true ]; then -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ - -DSUPPORT_REGEX_LOOKAHEAD=ON \ - -DBUILD_TESTING=OFF \ - -DEXECUTORCH_ENABLE_LOGGING=ON \ -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index c3d2745d7b2..c61173ad852 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3999,7 +3999,7 @@ def test_llama3_2_1b(self): "16a4w", "--temperature", "0", - "--decoder_model", + "--llama_model", "llama3_2", "--model_mode", "hybrid", @@ -4079,7 +4079,7 @@ def test_llama_stories_110m(self): "16a4w", "--temperature", "0", - "--decoder_model", + "--llama_model", "stories110m", "--model_mode", "hybrid", @@ -4121,65 +4121,6 @@ def test_llama_stories_110m(self): if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai - def test_qwen2_5(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--ptq", - "16a8w", - "--decoder_model", - "qwen2_5", - "--model_mode", - "hybrid", - "--prefill_ar_len", - "32", - "--max_seq_len", - "128", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - # Accuracy is bad for now. Just check user's prompt is returned. - golden_start_with = "My favourite condiment is " - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - model_out = msg["result"][0] - self.assertTrue( - model_out.startswith(golden_start_with), - f"Expected Output: {golden_start_with}. Actual Output: {model_out}", - ) - self.assertGreaterEqual(msg["inference_speed"], 95) # Lanai - class TestExampleOssScript(TestQNN): def test_albert(self): diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index c10342bc247..69fa9a0b0d4 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -77,8 +77,8 @@ target_include_directories( # add tokenizers add_subdirectory( - ${EXECUTORCH_ROOT}/extension/llm/runner - ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/runner + ${EXECUTORCH_ROOT}/extension/llm/tokenizers + ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/tokenizers ) # build qnn_executor_runner diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index 95e82ba271e..dadf51bf298 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - # model sharding with custom op set(CUSTOM_OP_SRCS_FILE "${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp" @@ -64,22 +63,14 @@ target_link_libraries( executorch_core extension_data_loader extension_flat_tensor - extension_llm_runner extension_module extension_tensor - tokenizers gflags custom_ops quantized_ops_lib quantized_kernels tokenizers ) - -target_include_directories( - qnn_llama_runner - PUBLIC ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include -) - target_compile_options(qnn_llama_runner PUBLIC ${_common_compile_options}) set_target_properties( qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 6d10a935863..309de56cd89 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -1,11 +1,10 @@ # Summary ## Overview -This file provides you the instructions to run LLM Decoder model with different parameters via Qualcomm HTP backend. We currently support the following models: +This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models: 1. LLAMA2 Stories 110M 2. LLAMA3.2 1B 3. LLAMA3.2 3B - 4. QWEN2.5 0.5B We offer the following modes to execute the model: diff --git a/examples/qualcomm/oss_scripts/llama/hf_converter/convert_config.py b/examples/qualcomm/oss_scripts/llama/hf_converter/convert_config.py deleted file mode 100644 index 250ad1ce0d0..00000000000 --- a/examples/qualcomm/oss_scripts/llama/hf_converter/convert_config.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -def convert_configs(config): - # HF config keys are different from Llama configs. - # Convert the config keys to align with Llama. - if hasattr(config, "hidden_size"): - config.dim = config.hidden_size - delattr(config, "hidden_size") - - if hasattr(config, "num_attention_heads"): - config.n_heads = config.num_attention_heads - delattr(config, "num_attention_heads") - - if hasattr(config, "num_key_value_heads"): - config.n_kv_heads = config.num_key_value_heads - delattr(config, "num_key_value_heads") - - if hasattr(config, "rms_norm_eps"): - config.norm_eps = config.rms_norm_eps - delattr(config, "rms_norm_eps") - - if hasattr(config, "rope_theta"): - config.rope_freq_base = config.rope_theta - delattr(config, "rope_theta") - - if hasattr(config, "num_hidden_layers"): - config.n_layers = config.num_hidden_layers - delattr(config, "num_hidden_layers") - - if hasattr(config, "intermediate_size"): - config.hidden_dim = config.intermediate_size - delattr(config, "intermediate_size") - - if hasattr(config, "rope_scaling"): - config.use_scaled_rope = config.rope_scaling - # Use default value of precompute_freq_cis - if not hasattr(config, "rope_scale_factor"): - config.rope_scale_factor = 4 - - return config diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 1ee7b5ba9fe..db533986119 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -30,6 +30,7 @@ ) from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_linear_16a8w_in_affine_layer, annotate_matmul_16a8w, @@ -55,16 +56,9 @@ ) from executorch.devtools.backend_debug import print_delegation_info - -from executorch.examples.models.llama.hf_download import ( - download_and_convert_hf_checkpoint, -) from executorch.examples.models.llama.source_transformation.quantize import ( get_quant_embedding_transform, ) -from executorch.examples.qualcomm.oss_scripts.llama.hf_converter.convert_config import ( - convert_configs, -) from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( LlamaModel, ModelArgs, @@ -75,26 +69,27 @@ setup_common_args_and_variables, SimpleADB, ) +from executorch.exir import EdgeProgramManager +from executorch.exir.backend.backend_api import ( + MethodProgramsPartitionerSpec, + to_backend, +) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.extension.llm.custom_ops import model_sharding from executorch.extension.llm.export.builder import DType from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer -from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from transformers import AutoConfig, AutoTokenizer sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logging.getLogger().setLevel(logging.INFO) -HUGGING_FACE_REPO_IDS = {"qwen2_5": "Qwen/Qwen2.5-0.5B"} - def next_power_of_two(n): if n == 0: @@ -154,7 +149,7 @@ def _kv_calibrate( token_list = [] # Llama2 tokenizer has no special tokens - if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)): + if isinstance(tokenizer, SentencePieceTokenizer): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( @@ -162,6 +157,7 @@ def _kv_calibrate( ) else: raise RuntimeError("Unkown tokenizer") + pos = len(token_list) if len(token_list) < ar_len else ar_len dtype = torch.int64 if use_i64_token else torch.int32 @@ -225,7 +221,7 @@ def _prefill_calibrate( token_list = [] # Llama2 tokenizer has no special tokens - if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)): + if isinstance(tokenizer, SentencePieceTokenizer): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( @@ -297,14 +293,14 @@ def calibrate( class SingleLlama: - def __init__(self, decoder_model, pte_filename) -> None: + def __init__(self, llama_model, pte_filename) -> None: super().__init__() - self.decoder_model = decoder_model + self.llama_model = llama_model self.passes_job = get_capture_program_passes() self.dep_table = get_passes_dependency_for_capture_program() self.quant_attrs = None self.quant_dtype = None - self.llama_meta = self.decoder_model.get_metadata() + self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False self.pte_filename = pte_filename if self.llama_meta["get_use_kv_cache"]: @@ -315,7 +311,7 @@ def __init__(self, decoder_model, pte_filename) -> None: else: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) - self.llama_graph_module = decoder_model + self.llama_graph_module = llama_model self.io_shape = { # logit output ( @@ -412,7 +408,6 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") - calibrate( self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), args.prompt[0], @@ -484,7 +479,7 @@ def lowering_modules( exec_prog_mgr.write_to_file(file) def get_example_inputs(self, use_kv_cache=True): - return self.decoder_model.get_example_inputs(use_kv_cache) + return self.llama_model.get_example_inputs(use_kv_cache) def get_quant_attrs(self): return self.quant_attrs @@ -494,29 +489,21 @@ def compile(args, pte_filename, tokenizer): os.makedirs(args.artifact, exist_ok=True) start_ts = time.time() - kv_config, prefill_config = None, None - if args.params: - with open(args.params) as f: - kv_config = ModelArgs(**json.load(f)) - else: - # For huggingface decoder model, we need to convert config to match the keys - model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] - kv_config = AutoConfig.from_pretrained(model_id) - kv_config = convert_configs(kv_config) - - if args.decoder_model == "qwen2_5": - kv_config.attention_qkv_bias = True - - if not hasattr(kv_config, "head_dim"): - kv_config.head_dim = kv_config.dim // kv_config.n_heads - # TODO: support batch inputs if necessary - kv_config.max_batch_size = 1 - kv_config.max_seq_len = args.max_seq_len - kv_config.use_kv_cache = True - - prefill_config = copy.copy(kv_config) - prefill_config.use_kv_cache = ( - False if args.max_seq_len == args.prefill_ar_len else True + with open(args.params) as f: + kv_config = ModelArgs(**json.load(f)) + # TODO: support batch inputs if necessary + kv_config.max_batch_size = 1 + kv_config.max_seq_len = args.max_seq_len + kv_config.use_kv_cache = True + + prefill_config = copy.copy(kv_config) + prefill_config.max_seq_len = args.max_seq_len + prefill_config.use_kv_cache = ( + False if args.max_seq_len == args.prefill_ar_len else True + ) + + state_dict = torch.load( + args.checkpoint, weights_only=True, map_location="cpu", mmap=True ) llama_instance_list = [] @@ -576,46 +563,30 @@ def compile(args, pte_filename, tokenizer): else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") - if args.checkpoint is None: # HF models - model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] - if args.decoder_model == "qwen2_5": - from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21] - convert_weights, - ) - - checkpoint = download_and_convert_hf_checkpoint(model_id, convert_weights) - state_dict = torch.load( - checkpoint, weights_only=True, map_location="cpu", mmap=True - ) - else: - state_dict = torch.load( - args.checkpoint, weights_only=True, map_location="cpu", mmap=True + if "model" in state_dict: + state_dict = state_dict["model"] + + # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. + def permute(w, heads): + dim_0 = w.size(0) + dim_1 = w.size(1) + return ( + w.view(heads, dim_0 // heads // 2, 2, dim_1) + .transpose(1, 2) + .reshape(dim_0, dim_1) ) - if "model" in state_dict: - state_dict = state_dict["model"] - - # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. - def permute(w, heads): - dim_0 = w.size(0) - dim_1 = w.size(1) - return ( - w.view(heads, dim_0 // heads // 2, 2, dim_1) - .transpose(1, 2) - .reshape(dim_0, dim_1) - ) - - n_heads = llama_instance_list[0].n_heads - n_kv_heads = llama_instance_list[0].n_kv_heads - n_layers = llama_instance_list[0].n_layers + n_heads = llama_instance_list[0].n_heads + n_kv_heads = llama_instance_list[0].n_kv_heads + n_layers = llama_instance_list[0].n_layers - for layer_i in range(n_layers): - state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads - ) - state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads - ) + for layer_i in range(n_layers): + state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + ) + state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + ) for llama_instance in llama_instance_list: llama_instance.load_state_dict( @@ -640,17 +611,18 @@ def permute(w, heads): fixed_point_type["kv_type"] = torch.uint8 if args.ptq == "8a8w": fixed_point_type["io_type"] = torch.uint8 - elif args.ptq in ("16a4w", "16a4w_block", "16a8w"): + elif args.ptq in ("16a4w", "16a4w_block"): fixed_point_type["io_type"] = torch.uint16 else: assert args.ptq in [ "8a8w", "16a4w", "16a4w_block", - "16a8w", ], f"No support for quant type {args.ptq}. Support 8a8w, 16a4w and 16a4w_block." quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + if args.dtype_override is not None: dtype_override = DType[args.dtype_override] for i in range(len(llama_instance_list)): @@ -674,14 +646,8 @@ def permute(w, heads): if args.ptq: start_quantize_ts = time.time() - custom_annotations = ( - # For qwen2.5, skip annotate_conv can improve result. - partial( - annotate_matmul_16a8w, - annotate_conv=args.ptq != "16a8w", - ), - ) - if args.decoder_model == "stories110m": + custom_annotations = (annotate_matmul_16a8w,) + if args.llama_model == "stories110m": custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) @@ -815,7 +781,7 @@ def permute(w, heads): return quant_attrs -def inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version): +def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" if args.model_mode == "kv": @@ -828,8 +794,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version) raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") pte_path = ( - f"{args.pre_gen_pte}/{pte_filename}.pte" - if args.pre_gen_pte + f"{pre_gen_pte}/{pte_filename}.pte" + if pre_gen_pte else f"{args.artifact}/{pte_filename}.pte" ) @@ -870,7 +836,6 @@ def post_process(): [ f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/qnn_llama_runner", - f"--decoder_model_version {decoder_model_version}", f"--tokenizer_path {runtime_tokenizer_path}", f"--model_path {pte_path}", f"--seq_len {seq_len}", @@ -892,7 +857,6 @@ def post_process(): [ f"cd {workspace} &&", f"./qnn_llama_runner", - f"--decoder_model_version {decoder_model_version}", f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}", f"--model_path {pte_filename}.pte", f"--seq_len {seq_len}", @@ -956,28 +920,28 @@ def _build_parser(): parser.add_argument( "-P", "--ptq", - help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block, 16a8w.", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.", type=str, ) parser.add_argument( - "--decoder_model", - choices=["stories110m", "llama3_2", "qwen2_5"], - help="The Llama model to export. Current available options are: [stories110m, llama3_2, qwen2_5]", + "--llama_model", + choices=["stories110m", "llama3_2"], + help="The Llama model to export. Current available options are: [stories110m, llama3_2]", required=True, ) parser.add_argument( "--checkpoint", help="Pass llama checkpoint.", - required=False, + required=True, type=str, ) parser.add_argument( "--params", help="Pass llama params json file.", - required=False, + required=True, type=str, ) @@ -1125,10 +1089,9 @@ def export_llama(args) -> None: else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") - tokenizer = None - runtime_tokenizer_path, decoder_model_version = "", "" - if args.decoder_model == "stories110m": - tokenizer = get_tokenizer(args.tokenizer_model) + tokenizer = get_tokenizer(args.tokenizer_model) + runtime_tokenizer_path = "" + if args.llama_model == "stories110m": assert isinstance( tokenizer, SentencePieceTokenizer ), f"Wrong tokenizer provided for stories110m." @@ -1136,36 +1099,13 @@ def export_llama(args) -> None: args.tokenizer_bin is not None ), "Please provide tokenizer_bin for stories110m." runtime_tokenizer_path = args.tokenizer_bin - decoder_model_version = "llama2" - elif args.decoder_model == "llama3_2": - tokenizer = get_tokenizer(args.tokenizer_model) + elif args.llama_model == "llama3_2": assert isinstance( tokenizer, TiktokenTokenizer ), f"Wrong tokenizer provided for llama3_2." runtime_tokenizer_path = args.tokenizer_model - decoder_model_version = "llama3" - elif args.decoder_model == "qwen2_5": - model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] - tokenizer = AutoTokenizer.from_pretrained(model_id) - runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1] - tokenizer = get_tokenizer(runtime_tokenizer_path) - decoder_model_version = args.decoder_model - - with open(runtime_tokenizer_path, "r+") as file: - data = json.load(file) - # TODO: Encountered the following error during runtime, so switched behavior for now. - # Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: - # Unsupported behavior 'Isolated' for Split PreTokenizer. Only 'MergedWithPrevious' is supported. - behavior = data["pre_tokenizer"]["pretokenizers"][0]["behavior"] - if behavior == "Isolated": - data["pre_tokenizer"]["pretokenizers"][0][ - "behavior" - ] = "MergedWithPrevious" - file.seek(0) - json.dump(data, file, indent=4) - file.truncate() else: - raise RuntimeError(f"Unknown decoder_model: {args.llama_model}.") + raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") if args.kv_updater == "smart_mask": args.shared_buffer = True @@ -1176,7 +1116,7 @@ def export_llama(args) -> None: raise RuntimeError(f"Using an unknown kv update {args.kv_updater}") if args.pre_gen_pte: - inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version) + inference(args, pte_filename, runtime_tokenizer_path, args.pre_gen_pte) print(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") return @@ -1198,7 +1138,7 @@ def export_llama(args) -> None: return compile(args, pte_filename, tokenizer) - inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version) + inference(args, pte_filename, runtime_tokenizer_path) def main(): diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 6bffaa7772c..f7893792e00 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -37,7 +37,6 @@ def apply_rotary_emb_single( class LlamaAttention(nn.Module): def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() - self.config = config self.dim = config.dim self.n_heads = config.n_heads self.head_dim = config.head_dim @@ -46,21 +45,9 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): self.max_seq_len = config.max_seq_len self.output_new_cache_only = output_new_cache_only - self.wq = nn.Linear( - self.dim, - self.n_heads * self.head_dim, - bias=getattr(config, "attention_qkv_bias", False), - ) - self.wk = nn.Linear( - self.dim, - self.n_kv_heads * self.head_dim, - bias=getattr(config, "attention_qkv_bias", False), - ) - self.wv = nn.Linear( - self.dim, - self.n_kv_heads * self.head_dim, - bias=getattr(config, "attention_qkv_bias", False), - ) + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.attn_softmax = torch.nn.Softmax(dim=-1) @@ -70,34 +57,19 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): def prepare_sha(self): self.wq_sha = nn.ModuleList( [ - nn.Conv2d( - self.dim, - self.head_dim, - 1, - bias=getattr(self.config, "attention_qkv_bias", False), - ) + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) for _ in range(self.n_heads) ] ) self.wk_sha = nn.ModuleList( [ - nn.Conv2d( - self.dim, - self.head_dim, - 1, - bias=getattr(self.config, "attention_qkv_bias", False), - ) + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) for _ in range(self.n_kv_heads) ] ) self.wv_sha = nn.ModuleList( [ - nn.Conv2d( - self.dim, - self.head_dim, - 1, - bias=getattr(self.config, "attention_qkv_bias", False), - ) + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) for _ in range(self.n_kv_heads) ] ) @@ -111,29 +83,17 @@ def prepare_sha(self): i * self.head_dim : (i + 1) * self.head_dim, :, None, None ] ) - if self.wq_sha[i].bias is not None: - self.wq_sha[i].bias.data.copy_( - self.wq.bias[i * self.head_dim : (i + 1) * self.head_dim] - ) for i in range(self.n_kv_heads): self.wk_sha[i].weight.data.copy_( self.wk.weight[ i * self.head_dim : (i + 1) * self.head_dim, :, None, None ] ) - if self.wk_sha[i].bias is not None: - self.wk_sha[i].bias.data.copy_( - self.wk.bias[i * self.head_dim : (i + 1) * self.head_dim] - ) self.wv_sha[i].weight.data.copy_( self.wv.weight[ i * self.head_dim : (i + 1) * self.head_dim, :, None, None ] ) - if self.wv_sha[i].bias is not None: - self.wv_sha[i].bias.data.copy_( - self.wv.bias[i * self.head_dim : (i + 1) * self.head_dim] - ) self.wo_sha.weight.data.copy_(self.wo.weight[:, :, None, None]) def forward_sha( diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 42873417488..5c10d3eade8 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -9,8 +9,8 @@ /** * @file * - * This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B with Qualcomm - * AI Engine Direct. + * This tool can run Llama2 110M, Llama3.2 1B / 3B(WIP) with Qualcomm AI Engine + * Direct. * */ @@ -21,7 +21,6 @@ #include #include -DEFINE_string(decoder_model_version, "llama2", "The decoder model to execute."); DEFINE_string( model_path, "kv_llama_qnn.pte", @@ -89,14 +88,13 @@ std::vector CollectPrompts(int argc, char** argv) { std::string get_formatted_prompt( const std::string& prompt, const std::string& system_prompt, - example::DecoderModelVersion decoder_model_version) { + example::LlamaVersion llama_version) { std::string formatted_prompt; - switch (decoder_model_version) { - case example::DecoderModelVersion::kLlama2: - case example::DecoderModelVersion::kQwen2_5: + switch (llama_version) { + case example::LlamaVersion::kLlama2: formatted_prompt.append(prompt); break; - case example::DecoderModelVersion::kLlama3: + case example::LlamaVersion::kLlama3: if (!system_prompt.empty()) { formatted_prompt.append( "<|start_header_id|>system<|end_header_id|>\n\n"); @@ -120,7 +118,6 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // create llama runner example::Runner runner( - FLAGS_decoder_model_version.c_str(), FLAGS_model_path.c_str(), FLAGS_tokenizer_path.c_str(), FLAGS_performance_output_path.c_str(), @@ -130,7 +127,7 @@ int main(int argc, char** argv) { FLAGS_ngram, FLAGS_window, FLAGS_gcap); - auto decoder_model_version = runner.get_decoder_model_version(); + auto llama_version = runner.get_llama_version(); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); @@ -144,7 +141,7 @@ int main(int argc, char** argv) { for (const auto& prompt : prompts) { std::string formatted_prompt; formatted_prompt = get_formatted_prompt( - prompt, FLAGS_system_prompt, decoder_model_version.get()); + prompt, FLAGS_system_prompt, llama_version.get()); runner.generate(formatted_prompt.c_str(), FLAGS_seq_len, callback); } } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index f5c364e259e..30235332ebd 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -14,12 +14,10 @@ #include #include #include -#include #include #include #include #include -#include #include #include @@ -33,7 +31,6 @@ using executorch::extension::llm::time_in_ms; using executorch::runtime::Error; using executorch::runtime::MethodMeta; using executorch::runtime::Result; -namespace llm = ::executorch::extension::llm; namespace example { namespace { @@ -55,15 +52,7 @@ void print_performance_report( } } // namespace -std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer( - const std::string& tokenizer_path, - Version version) { - auto special_tokens = get_special_tokens(version); - return llm::load_tokenizer(tokenizer_path, std::move(special_tokens)); -} - Runner::Runner( - const std::string& decoder_model_version, const std::string& model_path, const std::string& tokenizer_path, const std::string& performance_output_path, @@ -92,17 +81,6 @@ Runner::Runner( } else { ET_CHECK_MSG(false, "kv updater (%s) not found", kv_updater.c_str()); } - - if (decoder_model_version == "llama2") { - decoder_model_version_ = DecoderModelVersion::kLlama2; - } else if (decoder_model_version == "llama3") { - decoder_model_version_ = DecoderModelVersion::kLlama3; - } else if (decoder_model_version == "qwen2_5") { - decoder_model_version_ = DecoderModelVersion::kQwen2_5; - } else { - ET_CHECK_MSG(false, "Unsupported Decoder Model"); - } - ET_LOG(Info, "creating module: model_path=%s", model_path.c_str()); ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); ET_LOG(Info, "eval mode=%d", eval_mode_); @@ -139,32 +117,50 @@ Error Runner::load() { break; } - tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default); - if (tokenizer_ == nullptr) { - ET_LOG(Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str()); - return Error::Internal; - } - - auto eos_ids = std::make_unique>( - std::unordered_set{tokenizer_->eos_tok()}); - - if (decoder_model_version_ == DecoderModelVersion::kLlama3) { + auto eos_ids = std::make_unique>(); + // TODO: remove this once we could release the new tokens used for the + // tokenizer + if (tokenizer_ != nullptr) { eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); + eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]); + eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]); + } else { + // load tokenizer. Assuming tiktoken is the default tokenizer + tokenizer_ = get_tiktoken_for_llama(); + auto err = tokenizer_->load(tokenizer_path_); + auto eos_ids = std::make_unique>(); + // Rely on tiktoken to throw error if the artifact is incompatible. Then we + // fallback to BPE tokenizer. + if (err != tokenizers::Error::Ok) { + ET_LOG( + Info, + "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", + tokenizer_path_.c_str()); + tokenizer_.reset(); + tokenizer_ = std::make_unique(); + err = tokenizer_->load(tokenizer_path_); + llama_version_ = LlamaVersion::kLlama2; + ET_CHECK_MSG( + err == tokenizers::Error::Ok, + "failed to load tokenizer %s", + tokenizer_path_.c_str()); + } else { + eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); + llama_version_ = LlamaVersion::kLlama3; + } + eos_ids->insert(tokenizer_->eos_tok()); } - // Try avoid getMetadataHelper as it is time consuming. - Result method_meta = - module_->method_meta(token_generator_method_name); - // For some tokenizer.json, runtime vocab_size might be different, use output - // shape to get vocab size. - int32_t vocab_size = method_meta->output_tensor_meta(0)->sizes()[2]; + int32_t vocab_size = tokenizer_->vocab_size(); decoder_runner_ = std::make_unique(module_.get(), vocab_size, temperature_); ET_CHECK_OK_OR_RETURN_ERROR(decoder_runner_->load(method_names)); ET_LOG(Info, "Reading metadata from model"); - + // Try avoid getMetadataHelper as it is time consuming. + Result method_meta = + module_->method_meta(token_generator_method_name); // retrieve any method meta, can be either prefill or kv int64_t num_layers = ET_UNWRAP(module_->get("get_n_layers")).toScalar().to(); @@ -274,6 +270,7 @@ Error Runner::load() { module_->method_meta(prompt_processor_method_name)); token_generator_->init_io( buffer_manager_.get(), module_->method_meta(token_generator_method_name)); + return Error::Ok; } @@ -311,6 +308,7 @@ Error Runner::generate( if (token_callback) { token_callback(prompt); } + auto prefill_res = prompt_processor_->prefill(prompt_tokens, cur_pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); @@ -352,13 +350,13 @@ Error Runner::generate( return Error::Ok; } -Result Runner::get_decoder_model_version() { +Result Runner::get_llama_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_.model_load_end_ms = time_in_ms(); } - return decoder_model_version_; + return llama_version_; } } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index e616812988d..ec53e7463f6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -27,15 +27,13 @@ namespace example { -enum DecoderModelVersion { +enum LlamaVersion { kLlama2 = 0, kLlama3, - kQwen2_5, }; class Runner { public: explicit Runner( - const std::string& decoder_model, const std::string& model_path, const std::string& tokenizer_path, const std::string& performance_output_path, @@ -58,7 +56,7 @@ class Runner { bool echo = true, bool warming = false); void stop() {}; - executorch::runtime::Result get_decoder_model_version(); + executorch::runtime::Result get_llama_version(); private: enum EvalMode { @@ -80,7 +78,7 @@ class Runner { std::string performance_output_path_; float temperature_; EvalMode eval_mode_; - DecoderModelVersion decoder_model_version_; + LlamaVersion llama_version_; KVManagerMode kv_updater_; std::unique_ptr buffer_manager_; std::unique_ptr kv_manager_;