diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 919c445fafa..67073afc042 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -103,8 +103,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.round.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.round.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.split_with_sizes.default, exir_ops.edge.aten.split_with_sizes_copy.default, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 58b1a036955..d1ff0bc0e56 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -278,7 +278,9 @@ def annotate_masked_fill(node: Node, quantization_config: QuantizationConfig) -> ) -@register_annotator([torch.ops.aten.mul, torch.ops.aten.mul.Tensor]) +@register_annotator( + [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor] +) def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) @@ -1311,7 +1313,7 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: ) -@register_annotator([torch.ops.aten.zeros.default]) +@register_annotator([torch.ops.aten.zeros.default, torch.ops.aten.zeros_like.default]) def annotate_zeros(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]) or not _is_float_tensor(node): return diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 057d3ea93d2..f1a0cd95dff 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -153,7 +153,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): ) -def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 +def annotate_matmul_16a8w( # noqa: C901 + gm: torch.fx.GraphModule, annotate_conv=True +) -> None: """ This function is specific for matmul op 16a8w. For k, we will tag such as the below, and @@ -317,9 +319,10 @@ def annotate_matmul_input1(node: Node): # The arguments of cat op: (the past kv cache, the new kv cache) node = node.args[0][1] elif node.target == torch.ops.aten.conv2d.default: - annotate_conv2d( - node, quantization_config=quantization_config_8a4w_per_channel - ) + if annotate_conv: + annotate_conv2d( + node, quantization_config=quantization_config_8a4w_per_channel + ) break elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: break diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 8099ecb3de8..8c0f01feac6 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -85,6 +85,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \ -DANDROID_ABI='arm64-v8a' \ @@ -104,6 +105,9 @@ if [ "$BUILD_AARCH64" = true ]; then -DANDROID_ABI='arm64-v8a' \ -DANDROID_PLATFORM=android-30 \ -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ + -DBUILD_TESTING=OFF \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ @@ -134,6 +138,7 @@ if [ "$BUILD_X86_64" = true ]; then -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ -S $PRJ_ROOT \ -B $BUILD_ROOT \ @@ -157,6 +162,9 @@ if [ "$BUILD_X86_64" = true ]; then -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ + -DBUILD_TESTING=OFF \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index a7942404d18..2f580cb71b2 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4049,7 +4049,7 @@ def test_llama3_2_1b(self): "16a4w", "--temperature", "0", - "--llama_model", + "--decoder_model", "llama3_2", "--model_mode", "hybrid", @@ -4129,7 +4129,7 @@ def test_llama_stories_110m(self): "16a4w", "--temperature", "0", - "--llama_model", + "--decoder_model", "stories110m", "--model_mode", "hybrid", @@ -4171,6 +4171,65 @@ def test_llama_stories_110m(self): if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai + def test_qwen2_5(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "My favourite condiment is " + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + f"{prompt}", + "--ptq", + "16a8w", + "--decoder_model", + "qwen2_5", + "--model_mode", + "hybrid", + "--prefill_ar_len", + "32", + "--max_seq_len", + "128", + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + # Accuracy is bad for now. Just check user's prompt is returned. + golden_start_with = "My favourite condiment is " + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: {golden_start_with}. Actual Output: {model_out}", + ) + self.assertGreaterEqual(msg["inference_speed"], 95) # Lanai + class TestExampleOssScript(TestQNN): def test_albert(self): diff --git a/examples/models/qwen2_5/config/0_5b_config.json b/examples/models/qwen2_5/config/0_5b_config.json new file mode 100644 index 00000000000..0b9a2a2d4ce --- /dev/null +++ b/examples/models/qwen2_5/config/0_5b_config.json @@ -0,0 +1,14 @@ +{ + "dim": 896, + "ffn_dim_multiplier": 1, + "hidden_dim": 4864, + "n_heads": 14, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-06, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 151936, + "use_hf_rope": true, + "attention_qkv_bias": true +} diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 67aa9bb4b05..6bceec128f2 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -77,8 +77,8 @@ target_include_directories( # add tokenizers add_subdirectory( - ${EXECUTORCH_ROOT}/extension/llm/tokenizers - ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/tokenizers + ${EXECUTORCH_ROOT}/extension/llm/runner + ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/runner ) # build qnn_executor_runner diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index aee40fd3f18..ba483abe92b 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + # model sharding with custom op set(CUSTOM_OP_SRCS_FILE "${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp" @@ -63,14 +64,22 @@ target_link_libraries( executorch_core extension_data_loader extension_flat_tensor + extension_llm_runner extension_module extension_tensor + tokenizers gflags custom_ops quantized_ops_lib quantized_kernels tokenizers ) + +target_include_directories( + qnn_llama_runner + PUBLIC ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include +) + target_compile_options(qnn_llama_runner PUBLIC ${_common_compile_options}) set_target_properties( qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 309de56cd89..6d10a935863 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -1,10 +1,11 @@ # Summary ## Overview -This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models: +This file provides you the instructions to run LLM Decoder model with different parameters via Qualcomm HTP backend. We currently support the following models: 1. LLAMA2 Stories 110M 2. LLAMA3.2 1B 3. LLAMA3.2 3B + 4. QWEN2.5 0.5B We offer the following modes to execute the model: diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index cbd9f711bae..21a61e33992 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -31,7 +31,6 @@ ) from executorch.backends.qualcomm.builders.utils import is_graph_output -from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_linear_16a8w_in_affine_layer, annotate_matmul_16a8w, @@ -57,6 +56,10 @@ ) from executorch.devtools.backend_debug import print_delegation_info + +from executorch.examples.models.llama.hf_download import ( + download_and_convert_hf_checkpoint, +) from executorch.examples.models.llama.source_transformation.quantize import ( get_quant_embedding_transform, ) @@ -78,29 +81,28 @@ setup_common_args_and_variables, SimpleADB, ) -from executorch.exir import EdgeProgramManager -from executorch.exir.backend.backend_api import ( - MethodProgramsPartitionerSpec, - to_backend, -) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.extension.llm.custom_ops import model_sharding from executorch.extension.llm.export.builder import DType from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer +from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer from torchao.prototype.spinquant import apply_spinquant from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from transformers import AutoConfig, AutoTokenizer sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logging.getLogger().setLevel(logging.INFO) +HUGGING_FACE_REPO_IDS = {"qwen2_5": "Qwen/Qwen2.5-0.5B"} + def next_power_of_two(n): if n == 0: @@ -160,7 +162,7 @@ def _kv_calibrate( token_list = [] # Llama2 tokenizer has no special tokens - if isinstance(tokenizer, SentencePieceTokenizer): + if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( @@ -168,7 +170,6 @@ def _kv_calibrate( ) else: raise RuntimeError("Unkown tokenizer") - pos = len(token_list) if len(token_list) < ar_len else ar_len dtype = torch.int64 if use_i64_token else torch.int32 @@ -232,7 +233,7 @@ def _prefill_calibrate( token_list = [] # Llama2 tokenizer has no special tokens - if isinstance(tokenizer, SentencePieceTokenizer): + if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( @@ -304,14 +305,14 @@ def calibrate( class SingleLlama: - def __init__(self, llama_model, pte_filename) -> None: + def __init__(self, decoder_model, pte_filename) -> None: super().__init__() - self.llama_model = llama_model + self.decoder_model = decoder_model self.passes_job = get_capture_program_passes() self.dep_table = get_passes_dependency_for_capture_program() self.quant_attrs = None self.quant_dtype = None - self.llama_meta = self.llama_model.get_metadata() + self.llama_meta = self.decoder_model.get_metadata() self.has_quant_io = False self.pte_filename = pte_filename if self.llama_meta["get_use_kv_cache"]: @@ -322,7 +323,7 @@ def __init__(self, llama_model, pte_filename) -> None: else: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) - self.llama_graph_module = llama_model + self.llama_graph_module = decoder_model self.io_shape = { # logit output ( @@ -499,7 +500,7 @@ def lowering_modules( exec_prog_mgr.write_to_file(file) def get_example_inputs(self, use_kv_cache=True): - return self.llama_model.get_example_inputs(use_kv_cache) + return self.decoder_model.get_example_inputs(use_kv_cache) def get_quant_attrs(self): return self.quant_attrs @@ -509,21 +510,34 @@ def compile(args, pte_filename, tokenizer): os.makedirs(args.artifact, exist_ok=True) start_ts = time.time() - with open(args.params) as f: + kv_config, prefill_config = None, None + params_path = "" + if args.params: + params_path = args.params + else: + if args.decoder_model == "qwen2_5": + cur_dir = os.path.dirname(__file__) + params_path = os.path.join( + cur_dir, + "..", + "..", + "..", + "models", + "qwen2_5", + "config", + "0_5b_config.json", + ) + with open(params_path) as f: kv_config = ModelArgs(**json.load(f)) - # TODO: support batch inputs if necessary - kv_config.max_batch_size = 1 - kv_config.max_seq_len = args.max_seq_len - kv_config.use_kv_cache = True - - prefill_config = copy.copy(kv_config) - prefill_config.max_seq_len = args.max_seq_len - prefill_config.use_kv_cache = ( - False if args.max_seq_len == args.prefill_ar_len else True - ) - state_dict = torch.load( - args.checkpoint, weights_only=True, map_location="cpu", mmap=True + # TODO: support batch inputs if necessary + kv_config.max_batch_size = 1 + kv_config.max_seq_len = args.max_seq_len + kv_config.use_kv_cache = True + + prefill_config = copy.copy(kv_config) + prefill_config.use_kv_cache = ( + False if args.max_seq_len == args.prefill_ar_len else True ) llama_instance_list = [] @@ -583,31 +597,47 @@ def compile(args, pte_filename, tokenizer): else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") - if "model" in state_dict: - state_dict = state_dict["model"] - - # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. - def permute(w, heads): - dim_0 = w.size(0) - dim_1 = w.size(1) - return ( - w.view(heads, dim_0 // heads // 2, 2, dim_1) - .transpose(1, 2) - .reshape(dim_0, dim_1) - ) - - n_heads = llama_instance_list[0].n_heads - n_kv_heads = llama_instance_list[0].n_kv_heads - n_layers = llama_instance_list[0].n_layers + if args.checkpoint is None: # HF models + model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] + if args.decoder_model == "qwen2_5": + from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21] + convert_weights, + ) - for layer_i in range(n_layers): - state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + checkpoint = download_and_convert_hf_checkpoint(model_id, convert_weights) + state_dict = torch.load( + checkpoint, weights_only=True, map_location="cpu", mmap=True ) - state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + else: + state_dict = torch.load( + args.checkpoint, weights_only=True, map_location="cpu", mmap=True ) + if "model" in state_dict: + state_dict = state_dict["model"] + + # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. + def permute(w, heads): + dim_0 = w.size(0) + dim_1 = w.size(1) + return ( + w.view(heads, dim_0 // heads // 2, 2, dim_1) + .transpose(1, 2) + .reshape(dim_0, dim_1) + ) + + n_heads = llama_instance_list[0].n_heads + n_kv_heads = llama_instance_list[0].n_kv_heads + n_layers = llama_instance_list[0].n_layers + + for layer_i in range(n_layers): + state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + ) + state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + ) + for llama_instance in llama_instance_list: llama_instance.load_state_dict( state_dict, @@ -680,18 +710,17 @@ def permute(w, heads): fixed_point_type["kv_type"] = torch.uint8 if args.ptq == "8a8w": fixed_point_type["io_type"] = torch.uint8 - elif args.ptq in ("16a4w", "16a4w_block"): + elif args.ptq in ("16a4w", "16a4w_block", "16a8w"): fixed_point_type["io_type"] = torch.uint16 else: assert args.ptq in [ "8a8w", "16a4w", "16a4w_block", + "16a8w", ], f"No support for quant type {args.ptq}. Support 8a8w, 16a4w and 16a4w_block." quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") - assert args.tokenizer_model is not None, "Need tokenizer model for calibration" - if args.dtype_override is not None: dtype_override = DType[args.dtype_override] for i in range(len(llama_instance_list)): @@ -715,8 +744,14 @@ def permute(w, heads): if args.ptq: start_quantize_ts = time.time() - custom_annotations = (annotate_matmul_16a8w,) - if args.llama_model == "stories110m": + custom_annotations = ( + # For qwen2.5, skip annotate_conv can improve result. + partial( + annotate_matmul_16a8w, + annotate_conv=args.ptq != "16a8w", + ), + ) + if args.decoder_model == "stories110m": custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) @@ -851,7 +886,7 @@ def permute(w, heads): return quant_attrs -def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): +def inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" if args.model_mode == "kv": @@ -864,8 +899,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") pte_path = ( - f"{pre_gen_pte}/{pte_filename}.pte" - if pre_gen_pte + f"{args.pre_gen_pte}/{pte_filename}.pte" + if args.pre_gen_pte else f"{args.artifact}/{pte_filename}.pte" ) @@ -906,6 +941,7 @@ def post_process(): [ f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + f"--decoder_model_version {decoder_model_version}", f"--tokenizer_path {runtime_tokenizer_path}", f"--model_path {pte_path}", f"--seq_len {seq_len}", @@ -927,6 +963,7 @@ def post_process(): [ f"cd {workspace} &&", f"./qnn_llama_runner", + f"--decoder_model_version {decoder_model_version}", f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}", f"--model_path {pte_filename}.pte", f"--seq_len {seq_len}", @@ -990,28 +1027,28 @@ def _build_parser(): parser.add_argument( "-P", "--ptq", - help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block, 16a8w.", type=str, ) parser.add_argument( - "--llama_model", - choices=["stories110m", "llama3_2"], - help="The Llama model to export. Current available options are: [stories110m, llama3_2]", + "--decoder_model", + choices=["stories110m", "llama3_2", "qwen2_5"], + help="The Llama model to export. Current available options are: [stories110m, llama3_2, qwen2_5]", required=True, ) parser.add_argument( "--checkpoint", help="Pass llama checkpoint.", - required=True, + required=False, type=str, ) parser.add_argument( "--params", help="Pass llama params json file.", - required=True, + required=False, type=str, ) @@ -1171,9 +1208,10 @@ def export_llama(args) -> None: else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") - tokenizer = get_tokenizer(args.tokenizer_model) - runtime_tokenizer_path = "" - if args.llama_model == "stories110m": + tokenizer = None + runtime_tokenizer_path, decoder_model_version = "", "" + if args.decoder_model == "stories110m": + tokenizer = get_tokenizer(args.tokenizer_model) assert isinstance( tokenizer, SentencePieceTokenizer ), f"Wrong tokenizer provided for stories110m." @@ -1181,13 +1219,31 @@ def export_llama(args) -> None: args.tokenizer_bin is not None ), "Please provide tokenizer_bin for stories110m." runtime_tokenizer_path = args.tokenizer_bin - elif args.llama_model == "llama3_2": + decoder_model_version = "llama2" + elif args.decoder_model == "llama3_2": + tokenizer = get_tokenizer(args.tokenizer_model) assert isinstance( tokenizer, TiktokenTokenizer ), f"Wrong tokenizer provided for llama3_2." runtime_tokenizer_path = args.tokenizer_model + decoder_model_version = "llama3" + elif args.decoder_model == "qwen2_5": + model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] + tokenizer = AutoTokenizer.from_pretrained(model_id) + runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1] + tokenizer = get_tokenizer(runtime_tokenizer_path) + decoder_model_version = args.decoder_model + + with open(runtime_tokenizer_path, "r+") as file: + data = json.load(file) + # TODO: Encountered the following error during runtime, so switched behavior for now. + # Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: Unsupported Normalizer type: NFC. + data.pop("normalizer") + file.seek(0) + json.dump(data, file, indent=4) + file.truncate() else: - raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") + raise RuntimeError(f"Unknown decoder_model: {args.decoder_model}.") if args.kv_updater == "smart_mask": args.shared_buffer = True @@ -1198,7 +1254,7 @@ def export_llama(args) -> None: raise RuntimeError(f"Using an unknown kv update {args.kv_updater}") if args.pre_gen_pte: - inference(args, pte_filename, runtime_tokenizer_path, args.pre_gen_pte) + inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version) print(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") return @@ -1220,7 +1276,7 @@ def export_llama(args) -> None: return compile(args, pte_filename, tokenizer) - inference(args, pte_filename, runtime_tokenizer_path) + inference(args, pte_filename, runtime_tokenizer_path, decoder_model_version) def main(): diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index e710443f07a..5c1f17abe47 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -37,6 +37,7 @@ def apply_rotary_emb_single( class LlamaAttention(nn.Module): def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() + self.config = config self.dim = config.dim self.n_heads = config.n_heads self.head_dim = config.head_dim @@ -45,9 +46,21 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): self.max_seq_len = config.max_seq_len self.output_new_cache_only = output_new_cache_only - self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wq = nn.Linear( + self.dim, + self.n_heads * self.head_dim, + bias=getattr(config, "attention_qkv_bias", False), + ) + self.wk = nn.Linear( + self.dim, + self.n_kv_heads * self.head_dim, + bias=getattr(config, "attention_qkv_bias", False), + ) + self.wv = nn.Linear( + self.dim, + self.n_kv_heads * self.head_dim, + bias=getattr(config, "attention_qkv_bias", False), + ) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.attn_softmax = torch.nn.Softmax(dim=-1) @@ -57,19 +70,34 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): def prepare_sha(self): self.wq_sha = nn.ModuleList( [ - nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + nn.Conv2d( + self.dim, + self.head_dim, + 1, + bias=getattr(self.config, "attention_qkv_bias", False), + ) for _ in range(self.n_heads) ] ) self.wk_sha = nn.ModuleList( [ - nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + nn.Conv2d( + self.dim, + self.head_dim, + 1, + bias=getattr(self.config, "attention_qkv_bias", False), + ) for _ in range(self.n_kv_heads) ] ) self.wv_sha = nn.ModuleList( [ - nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + nn.Conv2d( + self.dim, + self.head_dim, + 1, + bias=getattr(self.config, "attention_qkv_bias", False), + ) for _ in range(self.n_kv_heads) ] ) @@ -83,17 +111,29 @@ def prepare_sha(self): i * self.head_dim : (i + 1) * self.head_dim, :, None, None ] ) + if self.wq_sha[i].bias is not None: + self.wq_sha[i].bias.data.copy_( + self.wq.bias[i * self.head_dim : (i + 1) * self.head_dim] + ) for i in range(self.n_kv_heads): self.wk_sha[i].weight.data.copy_( self.wk.weight[ i * self.head_dim : (i + 1) * self.head_dim, :, None, None ] ) + if self.wk_sha[i].bias is not None: + self.wk_sha[i].bias.data.copy_( + self.wk.bias[i * self.head_dim : (i + 1) * self.head_dim] + ) self.wv_sha[i].weight.data.copy_( self.wv.weight[ i * self.head_dim : (i + 1) * self.head_dim, :, None, None ] ) + if self.wv_sha[i].bias is not None: + self.wv_sha[i].bias.data.copy_( + self.wv.bias[i * self.head_dim : (i + 1) * self.head_dim] + ) self.wo_sha.weight.data.copy_(self.wo.weight[:, :, None, None]) def forward_sha( diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 5c10d3eade8..42873417488 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -9,8 +9,8 @@ /** * @file * - * This tool can run Llama2 110M, Llama3.2 1B / 3B(WIP) with Qualcomm AI Engine - * Direct. + * This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B with Qualcomm + * AI Engine Direct. * */ @@ -21,6 +21,7 @@ #include #include +DEFINE_string(decoder_model_version, "llama2", "The decoder model to execute."); DEFINE_string( model_path, "kv_llama_qnn.pte", @@ -88,13 +89,14 @@ std::vector CollectPrompts(int argc, char** argv) { std::string get_formatted_prompt( const std::string& prompt, const std::string& system_prompt, - example::LlamaVersion llama_version) { + example::DecoderModelVersion decoder_model_version) { std::string formatted_prompt; - switch (llama_version) { - case example::LlamaVersion::kLlama2: + switch (decoder_model_version) { + case example::DecoderModelVersion::kLlama2: + case example::DecoderModelVersion::kQwen2_5: formatted_prompt.append(prompt); break; - case example::LlamaVersion::kLlama3: + case example::DecoderModelVersion::kLlama3: if (!system_prompt.empty()) { formatted_prompt.append( "<|start_header_id|>system<|end_header_id|>\n\n"); @@ -118,6 +120,7 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // create llama runner example::Runner runner( + FLAGS_decoder_model_version.c_str(), FLAGS_model_path.c_str(), FLAGS_tokenizer_path.c_str(), FLAGS_performance_output_path.c_str(), @@ -127,7 +130,7 @@ int main(int argc, char** argv) { FLAGS_ngram, FLAGS_window, FLAGS_gcap); - auto llama_version = runner.get_llama_version(); + auto decoder_model_version = runner.get_decoder_model_version(); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); @@ -141,7 +144,7 @@ int main(int argc, char** argv) { for (const auto& prompt : prompts) { std::string formatted_prompt; formatted_prompt = get_formatted_prompt( - prompt, FLAGS_system_prompt, llama_version.get()); + prompt, FLAGS_system_prompt, decoder_model_version.get()); runner.generate(formatted_prompt.c_str(), FLAGS_seq_len, callback); } } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 30235332ebd..f5c364e259e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -14,10 +14,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -31,6 +33,7 @@ using executorch::extension::llm::time_in_ms; using executorch::runtime::Error; using executorch::runtime::MethodMeta; using executorch::runtime::Result; +namespace llm = ::executorch::extension::llm; namespace example { namespace { @@ -52,7 +55,15 @@ void print_performance_report( } } // namespace +std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer( + const std::string& tokenizer_path, + Version version) { + auto special_tokens = get_special_tokens(version); + return llm::load_tokenizer(tokenizer_path, std::move(special_tokens)); +} + Runner::Runner( + const std::string& decoder_model_version, const std::string& model_path, const std::string& tokenizer_path, const std::string& performance_output_path, @@ -81,6 +92,17 @@ Runner::Runner( } else { ET_CHECK_MSG(false, "kv updater (%s) not found", kv_updater.c_str()); } + + if (decoder_model_version == "llama2") { + decoder_model_version_ = DecoderModelVersion::kLlama2; + } else if (decoder_model_version == "llama3") { + decoder_model_version_ = DecoderModelVersion::kLlama3; + } else if (decoder_model_version == "qwen2_5") { + decoder_model_version_ = DecoderModelVersion::kQwen2_5; + } else { + ET_CHECK_MSG(false, "Unsupported Decoder Model"); + } + ET_LOG(Info, "creating module: model_path=%s", model_path.c_str()); ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); ET_LOG(Info, "eval mode=%d", eval_mode_); @@ -117,50 +139,32 @@ Error Runner::load() { break; } - auto eos_ids = std::make_unique>(); - // TODO: remove this once we could release the new tokens used for the - // tokenizer - if (tokenizer_ != nullptr) { + tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default); + if (tokenizer_ == nullptr) { + ET_LOG(Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str()); + return Error::Internal; + } + + auto eos_ids = std::make_unique>( + std::unordered_set{tokenizer_->eos_tok()}); + + if (decoder_model_version_ == DecoderModelVersion::kLlama3) { eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); - eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]); - eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]); - } else { - // load tokenizer. Assuming tiktoken is the default tokenizer - tokenizer_ = get_tiktoken_for_llama(); - auto err = tokenizer_->load(tokenizer_path_); - auto eos_ids = std::make_unique>(); - // Rely on tiktoken to throw error if the artifact is incompatible. Then we - // fallback to BPE tokenizer. - if (err != tokenizers::Error::Ok) { - ET_LOG( - Info, - "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", - tokenizer_path_.c_str()); - tokenizer_.reset(); - tokenizer_ = std::make_unique(); - err = tokenizer_->load(tokenizer_path_); - llama_version_ = LlamaVersion::kLlama2; - ET_CHECK_MSG( - err == tokenizers::Error::Ok, - "failed to load tokenizer %s", - tokenizer_path_.c_str()); - } else { - eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); - llama_version_ = LlamaVersion::kLlama3; - } - eos_ids->insert(tokenizer_->eos_tok()); } + // Try avoid getMetadataHelper as it is time consuming. + Result method_meta = + module_->method_meta(token_generator_method_name); - int32_t vocab_size = tokenizer_->vocab_size(); + // For some tokenizer.json, runtime vocab_size might be different, use output + // shape to get vocab size. + int32_t vocab_size = method_meta->output_tensor_meta(0)->sizes()[2]; decoder_runner_ = std::make_unique(module_.get(), vocab_size, temperature_); ET_CHECK_OK_OR_RETURN_ERROR(decoder_runner_->load(method_names)); ET_LOG(Info, "Reading metadata from model"); - // Try avoid getMetadataHelper as it is time consuming. - Result method_meta = - module_->method_meta(token_generator_method_name); + // retrieve any method meta, can be either prefill or kv int64_t num_layers = ET_UNWRAP(module_->get("get_n_layers")).toScalar().to(); @@ -270,7 +274,6 @@ Error Runner::load() { module_->method_meta(prompt_processor_method_name)); token_generator_->init_io( buffer_manager_.get(), module_->method_meta(token_generator_method_name)); - return Error::Ok; } @@ -308,7 +311,6 @@ Error Runner::generate( if (token_callback) { token_callback(prompt); } - auto prefill_res = prompt_processor_->prefill(prompt_tokens, cur_pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); @@ -350,13 +352,13 @@ Error Runner::generate( return Error::Ok; } -Result Runner::get_llama_version() { +Result Runner::get_decoder_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_.model_load_end_ms = time_in_ms(); } - return llama_version_; + return decoder_model_version_; } } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index ec53e7463f6..e616812988d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -27,13 +27,15 @@ namespace example { -enum LlamaVersion { +enum DecoderModelVersion { kLlama2 = 0, kLlama3, + kQwen2_5, }; class Runner { public: explicit Runner( + const std::string& decoder_model, const std::string& model_path, const std::string& tokenizer_path, const std::string& performance_output_path, @@ -56,7 +58,7 @@ class Runner { bool echo = true, bool warming = false); void stop() {}; - executorch::runtime::Result get_llama_version(); + executorch::runtime::Result get_decoder_model_version(); private: enum EvalMode { @@ -78,7 +80,7 @@ class Runner { std::string performance_output_path_; float temperature_; EvalMode eval_mode_; - LlamaVersion llama_version_; + DecoderModelVersion decoder_model_version_; KVManagerMode kv_updater_; std::unique_ptr buffer_manager_; std::unique_ptr kv_manager_;