Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: --enable-lora raises error while trying to start api_server #405

Closed
1 task done
JHLEE17 opened this issue Oct 18, 2024 · 5 comments · Fixed by #415
Closed
1 task done

[Bug]: --enable-lora raises error while trying to start api_server #405

JHLEE17 opened this issue Oct 18, 2024 · 5 comments · Fixed by #415
Assignees
Labels
bug Something isn't working

Comments

@JHLEE17
Copy link

JHLEE17 commented Oct 18, 2024

Your current environment

The output of `python collect_env.py`
Collecting environment information...
/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn(
PyTorch version: 2.3.1a0+git4989238
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (conda-forge gcc 12.1.0-17) 12.1.0
Clang version: Could not collect
CMake version: version 3.30.2
Libc version: glibc-2.35

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             160
On-line CPU(s) list:                0-159
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 2
Core(s) per socket:                 40
Socket(s):                          2
Stepping:                           6
CPU max MHz:                        3400.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities
L1d cache:                          3.8 MiB (80 instances)
L1i cache:                          2.5 MiB (80 instances)
L2 cache:                           100 MiB (80 instances)
L3 cache:                           120 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-39,80-119
NUMA node1 CPU(s):                  40-79,120-159
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Syscall hardening, KVM SW loop
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] habana-torch-dataloader==1.17.0.495
[pip3] habana-torch-plugin==1.17.0.495
[pip3] numpy==1.26.4
[pip3] pynvml==8.0.4
[pip3] pytorch-lightning==2.4.0
[pip3] pyzmq==26.2.0
[pip3] torch==2.3.1a0+git4989238
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.3.0+952ea74
[pip3] torchdata==0.7.1+5e6f7b7
[pip3] torchmetrics==1.4.1
[pip3] torchtext==0.18.0a0+9bed85d
[pip3] torchvision==0.18.1a0+fe70bc8
[pip3] transformers==4.45.2
[pip3] triton==3.1.0
[conda] habana-torch-dataloader   1.17.0.495               pypi_0    pypi
[conda] habana-torch-plugin       1.17.0.495               pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pynvml                    8.0.4                    pypi_0    pypi
[conda] pytorch-lightning         2.4.0                    pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.3.1a0+git4989238          pypi_0    pypi
[conda] torch-tb-profiler         0.4.0                    pypi_0    pypi
[conda] torchaudio                2.3.0+952ea74            pypi_0    pypi
[conda] torchdata                 0.7.1+5e6f7b7            pypi_0    pypi
[conda] torchmetrics              1.4.1                    pypi_0    pypi
[conda] torchtext                 0.18.0a0+9bed85d          pypi_0    pypi
[conda] torchvision               0.18.1a0+fe70bc8          pypi_0    pypi
[conda] transformers              4.45.2                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3.dev553+g9276ccca
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

Model Input Dumps

No response

🐛 Describe the bug

I encountered an error while trying to start the api_server with Multi-LoRA.

The command I used is as follows:
(you can reproduce the error w/o last 3 lines in the command)

python -m vllm.entrypoints.openai.api_server \
    --model /home/irteamsu/models/Meta-Llama-3.1-8B-Instruct \
    --block-size 128 \
    --max-model-len 2048 \
    --disable-log-requests \
    --enable-lora \
    --max-loras 2 \
    --max-lora-rank 8 \
    --lora-modules lora-1=/home/irteamsu/models/Gaudi_LoRA_Llama-3-8B-Instruct lora-2=/home/irteamsu/models/Gaudi_LoRA_Llama-3-8B-Instruct

However, when I run this command, the following error occurs:

Error logs
~/habanaAI/vllm-fork$ python -m vllm.entrypoints.openai.api_server     --model /home/irteamsu/models/Meta-Llama-3-8B-Instruct     --block-size 128     --max-model-len 2048     --disable-log-requests     --enable-lora     --max-loras 2     --max-lora-rank 8     --lora-modules lora-1=/home/irteamsu/models/Gaudi_LoRA_Llama-3-8B-Instruct lora-2=/home/irteamsu/models/Gaudi_LoRA_Llama-3-8B-Instruct --port 8001
/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn(
INFO 10-18 16:07:06 api_server.py:527] vLLM API server version 0.6.3.dev553+g9276ccca
INFO 10-18 16:07:06 api_server.py:528] args: Namespace(host=None, port=8001, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=[LoRAModulePath(name='lora-1', path='/home/irteamsu/models/Gaudi_LoRA_Llama-3-8B-Instruct', base_model_name=None), LoRAModulePath(name='lora-2', path='/home/irteamsu/models/Gaudi_LoRA_Llama-3-8B-Instruct', base_model_name=None)], prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/home/irteamsu/models/Meta-Llama-3-8B-Instruct', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', weights_load_device=None, config_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=2048, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=128, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, enable_lora=True, max_loras=2, max_lora_rank=8, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, scheduling_policy='fcfs', disable_log_requests=True, max_log_len=None, disable_fastapi_docs=False)
INFO 10-18 16:07:06 api_server.py:165] Multiprocessing frontend to use ipc:///tmp/73468b3b-0de1-4f35-a284-2c7025932957 for IPC Path.
INFO 10-18 16:07:06 api_server.py:178] Started engine process with PID 954103
/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn(
INFO 10-18 16:07:12 llm_engine.py:238] Initializing an LLM engine (v0.6.3.dev553+g9276ccca) with config: model='/home/irteamsu/models/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='/home/irteamsu/models/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, weights_load_device=hpu, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=hpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/home/irteamsu/models/Meta-Llama-3-8B-Instruct, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=True, mm_processor_kwargs=None)
WARNING 10-18 16:07:12 utils.py:794] Pin memory is not supported on HPU.
INFO 10-18 16:07:12 selector.py:147] Using HPUAttention backend.
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_PROMPT_BS_BUCKET_MIN=1 (default:1)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_PROMPT_BS_BUCKET_STEP=32 (default:32)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_PROMPT_BS_BUCKET_MAX=64 (default:64)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_DECODE_BS_BUCKET_MIN=1 (default:1)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_DECODE_BS_BUCKET_STEP=32 (default:32)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_DECODE_BS_BUCKET_MAX=256 (default:256)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_PROMPT_SEQ_BUCKET_MIN=128 (default:128)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_PROMPT_SEQ_BUCKET_STEP=128 (default:128)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_PROMPT_SEQ_BUCKET_MAX=1024 (default:1024)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_DECODE_BLOCK_BUCKET_MIN=128 (default:128)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_DECODE_BLOCK_BUCKET_STEP=128 (default:128)
INFO 10-18 16:07:12 hpu_model_runner.py:98] VLLM_DECODE_BLOCK_BUCKET_MAX=4096 (default:4096)
INFO 10-18 16:07:12 hpu_model_runner.py:706] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024]
INFO 10-18 16:07:12 hpu_model_runner.py:711] Decode bucket config (min, step, max_warmup) bs:[1, 32, 256], block:[128, 128, 4096]
============================= HABANA PT BRIDGE CONFIGURATION =========================== 
 PT_HPU_LAZY_MODE = 1
 PT_RECIPE_CACHE_PATH = 
 PT_CACHE_FOLDER_DELETE = 0
 PT_HPU_RECIPE_CACHE_CONFIG = 
 PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
 PT_HPU_LAZY_ACC_PAR_MODE = 1
 PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
---------------------------: System Configuration :---------------------------
Num CPU Cores : 160
CPU RAM       : 2113407780 KB
------------------------------------------------------------------------------
INFO 10-18 16:07:16 selector.py:147] Using HPUAttention backend.
INFO 10-18 16:07:16 loader.py:405] Loading weights on hpu...
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:03,  1.08s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.05it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.11it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.58it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.33it/s]

INFO 10-18 16:07:19 hpu_model_runner.py:609] Pre-loading model weights on hpu:0 took 14.98 GiB of device memory (14.98 GiB/94.62 GiB used) and 882.2 MiB of host memory (115.9 GiB/1.968 TiB used)
INFO 10-18 16:07:20 hpu_model_runner.py:656] Wrapping in HPU Graph took 0 B of device memory (15.07 GiB/94.62 GiB used) and -252 KiB of host memory (115.9 GiB/1.968 TiB used)
INFO 10-18 16:07:20 hpu_model_runner.py:660] Loading model weights took in total 15.07 GiB of device memory (15.07 GiB/94.62 GiB used) and 935.3 MiB of host memory (115.9 GiB/1.968 TiB used)
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/engine/multiprocessing/engine.py", line 390, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/engine/multiprocessing/engine.py", line 138, in from_engine_args
    return cls(
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/engine/multiprocessing/engine.py", line 78, in __init__
    self.engine = LLMEngine(*args,
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/engine/llm_engine.py", line 354, in __init__
    self._initialize_kv_caches()
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/engine/llm_engine.py", line 489, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/executor/hpu_executor.py", line 77, in determine_num_available_blocks
    return self.driver_worker.determine_num_available_blocks()
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/worker/hpu_worker.py", line 180, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/worker/hpu_model_runner.py", line 1265, in profile_run
    self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/worker/hpu_model_runner.py", line 1344, in warmup_scenario
    self.execute_model(inputs, kv_caches, warmup_mode=True)
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/worker/hpu_model_runner.py", line 1868, in execute_model
    self.set_active_loras(model_input.lora_requests,
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/worker/hpu_model_runner.py", line 1362, in set_active_loras
    self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/lora/worker_manager.py", line 136, in set_active_adapters
    set_active_adapters_worker(requests, mapping, self._apply_adapters,
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/adapter_commons/utils.py", line 53, in set_active_adapters_worker
    set_adapter_mapping_func(mapping)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/lora/models.py", line 745, in set_adapter_mapping
    self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/adapter_commons/utils.py", line 28, in set_adapter_mapping
    set_mapping_func(mapping)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/lora/models.py", line 541, in _set_adapter_mapping
    self.punica_wrapper.update_metadata(
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/lora/punica.py", line 241, in update_metadata
    self._update_prefill_metada(self.token_lora_indices)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/lora/punica.py", line 288, in _update_prefill_metada
    no_lora) = compute_meta(token_lora_tensor)
  File "/home/irteamsu/works/jh/habanaAI/vllm-fork/vllm/lora/punica.py", line 41, in compute_meta
    lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/_jit_internal.py", line 497, in fn
    return if_false(*args, **kwargs)
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/_jit_internal.py", line 495, in fn
    return if_true(*args, **kwargs)
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/functional.py", line 1048, in _consecutive_return_counts
    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  File "/home/irteamsu/miniconda3/envs/jongho/lib/python3.10/site-packages/torch/functional.py", line 975, in _unique_consecutive_impl
    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
RuntimeError: synNodeCreateWithId failed for node: slice_insert with synStatus 1 [Invalid argument]. .

The issue occurs with both commits 9276ccc(habana_main) and d6bd375(remove-lora-warmup-constraints). I've verified that the paths to the models are correct and that the models are accessible.

Any guidance on resolving this issue would be greatly appreciated.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@JHLEE17 JHLEE17 added the bug Something isn't working label Oct 18, 2024
@JHLEE17 JHLEE17 changed the title [Bug]: --enable-lora raises error [Bug]: --enable-lora raises error while trying to start api_server Oct 18, 2024
@vivekgoe
Copy link

@JHLEE17 Thanks for raising this issue, we will check it immediately. Can you please share information about which SynapseAI release you are using to run this test?

@JHLEE17
Copy link
Author

JHLEE17 commented Oct 18, 2024

I executed on 1.17 version. But I got a similar error on 1.18 version.

@vivekgoe vivekgoe self-assigned this Oct 18, 2024
@vivekgoe
Copy link

We seem to have a backward compatibility issue, #382 works on latest SynapseAI code (not released yet), but throws above error with 1.18.0 SynapseAI code. We will work on fixing this and get back asap.

@michalkuligowski
Copy link

@vivekgoe @JHLEE17 This issue occurs with SynapseAI 1.18.0 and HabanaAI/vllm-fork 1.18.0, do I understand correctly?

@vivekgoe
Copy link

@michalkuligowski No issue occurs only with SynapseAI 1.18.0 + HabanaAI/vllm-fork (master_next branch). Whereas SynapseAI 1.19.0 (not released yet) + HabanaAI/vllm-fork (master_next branch) works ok.
We are using HabanaAI/vllm-fork (master_next) because this is where we added fix to improve decode perf. with LoRA enabled. This fix seems to expose some issue in SynapseAI 1.18.0 which got fixed in 1.19.0.

michalkuligowski pushed a commit that referenced this issue Oct 22, 2024
CUDA uses `capture` for warmup runs and `execute_model` for actual runs.
During each phase they call `set_active_loras` only once. HPU uses
`execute_model` for both warmup and actual runs. Since `execute_model`
already takes care of `set_active_loras` internally, the redundant call
can be removed.

This special handling is redundant and incorrect, as it causes
out-of-bound slicing in decode phase reported in
#405.

This PR removes special handling of `set_active_loras` function call
from warmup runs and resolves the issue in
#405.
xuechendi pushed a commit to xuechendi/vllm-fork that referenced this issue Oct 23, 2024
CUDA uses `capture` for warmup runs and `execute_model` for actual runs.
During each phase they call `set_active_loras` only once. HPU uses
`execute_model` for both warmup and actual runs. Since `execute_model`
already takes care of `set_active_loras` internally, the redundant call
can be removed.

This special handling is redundant and incorrect, as it causes
out-of-bound slicing in decode phase reported in
HabanaAI#405.

This PR removes special handling of `set_active_loras` function call
from warmup runs and resolves the issue in
HabanaAI#405.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants