|
5 | 5 | from dataclasses import dataclass |
6 | 6 | from functools import cached_property |
7 | 7 | from pathlib import Path |
8 | | -from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict, |
9 | | - TypeVar) |
| 8 | +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, |
| 9 | + TypedDict, TypeVar) |
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 | import torch |
13 | 13 | import torch.nn as nn |
14 | 14 | import torch.nn.functional as F |
15 | 15 | from PIL import Image |
16 | 16 | from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, |
17 | | - AutoProcessor, AutoTokenizer, BatchEncoding) |
| 17 | + AutoTokenizer, BatchEncoding) |
18 | 18 |
|
19 | 19 | from vllm import LLM, SamplingParams |
20 | 20 | from vllm.config import TokenizerPoolConfig, VisionLanguageConfig |
21 | 21 | from vllm.distributed import (destroy_distributed_environment, |
22 | 22 | destroy_model_parallel) |
23 | 23 | from vllm.inputs import TextPrompt |
24 | 24 | from vllm.logger import init_logger |
25 | | -from vllm.multimodal import MultiModalData |
26 | | -from vllm.multimodal.image import ImageFeatureData, ImagePixelData |
| 25 | + |
| 26 | +if TYPE_CHECKING: |
| 27 | + from vllm.multimodal import MultiModalData |
| 28 | +else: |
| 29 | + # it will call torch.cuda.device_count() |
| 30 | + MultiModalData = None |
27 | 31 | from vllm.sequence import SampleLogprobs |
28 | 32 | from vllm.utils import cuda_device_count_stateless, is_cpu |
29 | 33 |
|
@@ -63,6 +67,10 @@ def for_hf(self) -> Image.Image: |
63 | 67 | return self.pil_image |
64 | 68 |
|
65 | 69 | def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData: |
| 70 | + # don't put this import at the top level |
| 71 | + # it will call torch.cuda.device_count() |
| 72 | + from vllm.multimodal.image import ImageFeatureData # noqa: F401 |
| 73 | + from vllm.multimodal.image import ImagePixelData |
66 | 74 | image_input_type = vision_config.image_input_type |
67 | 75 | ImageInputType = VisionLanguageConfig.ImageInputType |
68 | 76 |
|
@@ -217,6 +225,9 @@ def __init__( |
217 | 225 | ) |
218 | 226 |
|
219 | 227 | try: |
| 228 | + # don't put this import at the top level |
| 229 | + # it will call torch.cuda.device_count() |
| 230 | + from transformers import AutoProcessor # noqa: F401 |
220 | 231 | self.processor = AutoProcessor.from_pretrained( |
221 | 232 | model_name, |
222 | 233 | torch_dtype=torch_dtype, |
|
0 commit comments