Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for Aya-23 8B Model by Cohere #2603

Merged
merged 14 commits into from
Aug 5, 2024

Conversation

GunjanDhanuka
Copy link
Contributor

This PR adds support for the Aya-23 8B model, whose weights and config can be found here: https://huggingface.co/CohereForAI/aya-23-8B/tree/main

Also fixed a typo where LlamaForCausalLM was written as LlamaForCasualLM

There was an issue with CUDA graph while compiling the model, so use --opt "flashinfer=1;cublas_gemm=1;cudagraph=0" to the mlc_llm compile command as suggested by @MasterJH5574 .

Solves issue: mlc-ai/web-llm#483

@@ -129,7 +129,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
prefill_chunk_size=model_config.prefill_chunk_size,
attention_sink_size=getattr(model_config, "attention_sink_size", -1),
tensor_parallel_shards=model_config.tensor_parallel_shards,
conv_template=conversation,
conv_template=conversation, # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why we disable mypy for this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLance was raising a reportArgumentType issue for this line, attached the screenshot below:

image

return mapping


# def awq(model_config: CohereConfig, quantization: Quantization) -> ExternMapping:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove if we don't support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked that AWQ weights for Aya-23 are available: https://huggingface.co/alijawad07/aya-23-8B-AWQ-GEMM

In that case, we can support it, I will uncomment the awq part in cohere_loader.py

@alphaarea
Copy link

Does this PR support CohereForAI/c4ai-command-r-plus ? They're both CohereForCausalLM

@MasterJH5574 MasterJH5574 marked this pull request as draft July 9, 2024 04:24
@MasterJH5574
Copy link
Member

Will look into the tokenizer issue during mlc_llm chat.

@Ubospica
Copy link
Collaborator

@GunjanDhanuka The tokenizer issue is solved in this PR: #2649. Please tell me if there are any other related problems!

@tqchen
Copy link
Contributor

tqchen commented Jul 25, 2024

@GunjanDhanuka please rebase and cross check if things can run

@GunjanDhanuka
Copy link
Contributor Author

GunjanDhanuka commented Aug 2, 2024

@GunjanDhanuka please rebase and cross check if things can run

Yes the tokenizer issue is now resolved, but there was a delay because of the discrepancy in outputs from mlc_llm chat and python -m mlc_llm.testing.debug_chat . The current commit (599f148) works fine on the debug_chat but mlc_llm chat is a hit or miss at times.

Edit: The prompt seems to be aligned in both cases now, it was a misunderstanding of the blank template that mlc_llm chat passes in the first instance
cc: @MasterJH5574 @tqchen

@GunjanDhanuka GunjanDhanuka marked this pull request as ready for review August 2, 2024 17:45
Copy link
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @GunjanDhanuka!

@MasterJH5574 MasterJH5574 merged commit c357be9 into mlc-ai:main Aug 5, 2024
1 check passed
@DenisSergeevitch
Copy link

Hi, I tried to convert a new version of aya-expanse-8b using this command:
mlc_llm compile "/content/dist/aya-expanse-8b-q4f16_1-MLC-webgpu/mlc-chat-config.json" --device webgpu -o /content/dist/aya-expanse-8b-q4f16_1-MLC-webgpu.wasm and it looks like the converted model does not perform well due to some convertion issues, however no errors have been encountered:

[2024-11-08 12:38:33] INFO auto_config.py:70: Found model configuration: /content/dist/aya-expanse-8b-q4f16_1-MLC-webgpu/mlc-chat-config.json
[2024-11-08 12:38:33] INFO auto_config.py:154: Found model type: cohere. Use `--model-type` to override.
Compiling with arguments:
  --config          CohereConfig(model_type='cohere', hidden_size=4096, vocab_size=256000, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, intermediate_size=14336, layer_norm_eps=1e-05, position_embedding_base=10000, context_window_size=8192, prefill_chunk_size=8192, head_dim=128, tensor_parallel_shards=1, max_batch_size=128, kwargs={})
  --quantization    GroupQuantize(name='q4f16_1', kind='group-quant', group_size=32, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float16', linear_weight_layout='NK', quantize_embedding=True, quantize_final_fc=True, num_elem_per_storage=8, num_storage_per_group=4, max_int_value=7, tensor_parallel_shards=0)
  --model-type      cohere
  --target          {"host": {"kind": "llvm", "tag": "", "keys": ["cpu"], "mtriple": "wasm32-unknown-unknown-wasm"}, "max_num_threads": runtime.BoxInt(256), "kind": "webgpu", "tag": "", "keys": ["webgpu", "gpu"]}
  --opt             flashinfer=0;cublas_gemm=0;faster_transformer=0;cudagraph=0;cutlass=0;ipc_allreduce_strategy=NONE
  --system-lib-prefix ""
  --output          /content/dist/aya-expanse-8b-q4f16_1-MLC-webgpu.wasm
  --overrides       context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=None;pipeline_parallel_stages=None
[2024-11-08 12:38:33] INFO compile.py:140: Creating model from: CohereConfig(model_type='cohere', hidden_size=4096, vocab_size=256000, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, intermediate_size=14336, layer_norm_eps=1e-05, position_embedding_base=10000, context_window_size=8192, prefill_chunk_size=8192, head_dim=128, tensor_parallel_shards=1, max_batch_size=128, kwargs={})
[2024-11-08 12:38:33] INFO compile.py:158: Exporting the model to TVM Unity compiler
[2024-11-08 12:38:36] INFO compile.py:164: Running optimizations using TVM Unity
[2024-11-08 12:38:36] INFO compile.py:185: Registering metadata: {'model_type': 'cohere', 'quantization': 'q4f16_1', 'context_window_size': 8192, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 8192, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'kv_state_kind': 'kv_cache', 'max_batch_size': 128}
[2024-11-08 12:38:36] WARNING auto_target.py:130: --system-lib-prefix is not specified when building a static library
[2024-11-08 12:38:38] INFO pipeline.py:54: Running TVM Relax graph-level optimizations
[2024-11-08 12:38:42] INFO pipeline.py:54: Lowering to TVM TIR kernels
[2024-11-08 12:38:47] INFO pipeline.py:54: Running TVM TIR-level optimizations
[2024-11-08 12:39:06] INFO pipeline.py:54: Running TVM Dlight low-level optimizations
[2024-11-08 12:39:08] INFO pipeline.py:54: Lowering to VM bytecode
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `alloc_embedding_tensor`: 64.00 MB
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_decode`: 14.00 MB
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_prefill`: 897.00 MB
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_verify`: 896.00 MB
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `create_tir_paged_kv_cache`: 0.00 MB
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `decode`: 0.11 MB
[2024-11-08 12:39:11] INFO estimate_memory_usage.py:58: [Memory usage] Function `embed`: 64.00 MB
[2024-11-08 12:39:12] INFO estimate_memory_usage.py:58: [Memory usage] Function `prefill`: 896.01 MB
[2024-11-08 12:39:12] INFO estimate_memory_usage.py:58: [Memory usage] Function `softmax_with_temperature`: 0.00 MB
[2024-11-08 12:39:13] INFO pipeline.py:54: Compiling external modules
[2024-11-08 12:39:13] INFO pipeline.py:54: Compilation complete! Exporting to disk
[12:39:18] /workspace/tvm/src/target/llvm/codegen_llvm.cc:185: Warning: Set native vector bits to be 128 for wasm32
[2024-11-08 12:39:57] INFO model_metadata.py:95: Total memory usage without KV cache:: 5203.76 MB (Parameters: 4306.76 MB. Temporary buffer: 897.00 MB)
[2024-11-08 12:39:57] INFO model_metadata.py:103: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size`
[2024-11-08 12:39:57] INFO compile.py:207: Generated: /content/dist/aya-expanse-8b-q4f16_1-MLC-webgpu.wasm

Is this bug specific to the new version of Aya?

Model:
https://huggingface.co/shirman/aya-expanse-8b-MLC-WEBGPU

WebLLM demo:
https://shir-man.com/totally-not-a-deepl-clone/

@time2bot
Copy link

@DenisSergeevitch The good news are that we were able to load the model using your online demo page - even though we only used a windows pc with 16GB of ram and an 8GB integrated (onboard) GPU which is the Intel UHD Graphics 630 - closing any other open window to allow the model to load. That is great news because we want to get it to work with minimal resources.

The bad news is that the output was almost gibberish - meaning the model did respond in several languages, but the responses didn't make much sense. There were a lot of new lines and much repeats of the same few words.

We aim to let people use this model from anywhere on earth with even low resources... Perhaps you can try converting the older model aya-23? link and help humanity?

@DenisSergeevitch
Copy link

@MasterJH5574 hello, can you please help to guide how to debug the model re-compile in this case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants