Skip to content

Commit

Permalink
Resolve test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan committed Jul 9, 2024
1 parent 88e4bb5 commit 5987d94
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
43 changes: 27 additions & 16 deletions tests/lora/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import patch

import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from huggingface_hub.utils import HfHubHTTPError
from torch import nn

from vllm.lora.utils import (get_lora_absolute_path,
Expand Down Expand Up @@ -188,41 +188,52 @@ def test_lru_cache():


# Unit tests for get_lora_absolute_path
@patch('os.path.isabs', True)
@patch('os.path.isabs')
def test_get_lora_absolute_path_absolute(mock_isabs):
path = '/absolute/path/to/lora'
mock_isabs.return_value = True
assert get_lora_absolute_path(path) == path


@patch('os.path.expanduser', '/home/user/relative/path/to/lora')
@patch('os.path.expanduser')
def test_get_lora_absolute_path_expanduser(mock_expanduser):
# Path with ~ that needs to be expanded
path = '~/relative/path/to/lora'
assert get_lora_absolute_path(path) == '/home/user/relative/path/to/lora'
absolute_path = '/home/user/relative/path/to/lora'
mock_expanduser.return_value = absolute_path
assert get_lora_absolute_path(path) == absolute_path


@patch('os.path.exists', True)
@patch('os.path.abspath', '/absolute/path/to/lora')
def test_get_lora_absolute_path_local_existing(mock_exists, mock_abspath):
@patch('os.path.exists')
@patch('os.path.abspath')
def test_get_lora_absolute_path_local_existing(mock_abspath, mock_exist):
# Relative path that exists locally
path = 'relative/path/to/lora'
assert get_lora_absolute_path(path) == '/absolute/path/to/lora'
absolute_path = '/absolute/path/to/lora'
mock_exist.return_value = True
mock_abspath.return_value = absolute_path
assert get_lora_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download', '/mock/snapshot/path')
@patch('os.path.exists', False)
def test_get_lora_absolute_path_huggingface(mock_exists,
@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_lora_absolute_path_huggingface(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier
path = 'org/repo'
assert get_lora_absolute_path(path) == '/mock/snapshot/path'
absolute_path = '/mock/snapshot/path'
mock_exist.return_value = False
mock_snapshot_download.return_value = absolute_path
assert get_lora_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download',
side_effect=RepositoryNotFoundError)
@patch('os.path.exists', False)
def test_get_lora_absolute_path_huggingface_error(mock_exists,
@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_lora_absolute_path_huggingface_error(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier with download error
path = 'org/repo'
mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info")
assert get_lora_absolute_path(path) == path
11 changes: 6 additions & 5 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
from typing import List, Optional, Set, Tuple, Type

from huggingface_hub import snapshot_download
from huggingface_hub.utils import (EntryNotFoundError, HFValidationError,
RepositoryNotFoundError)
import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, RepositoryNotFoundError)
from torch import nn
from transformers import PretrainedConfig

Expand Down Expand Up @@ -142,8 +142,9 @@ def get_lora_absolute_path(lora_path: str) -> str:

# If the path does not exist locally, assume it's a Hugging Face repo.
try:
local_snapshot_path = snapshot_download(repo_id=lora_path)
except (RepositoryNotFoundError, EntryNotFoundError,
local_snapshot_path = huggingface_hub.snapshot_download(
repo_id=lora_path)
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
HFValidationError) as e:
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
Expand Down

0 comments on commit 5987d94

Please sign in to comment.