| 
5 | 5 | from dataclasses import dataclass  | 
6 | 6 | from functools import cached_property  | 
7 | 7 | from pathlib import Path  | 
8 |  | -from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,  | 
9 |  | -                    TypeVar)  | 
 | 8 | +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,  | 
 | 9 | +                    TypedDict, TypeVar)  | 
10 | 10 | 
 
  | 
11 | 11 | import pytest  | 
12 | 12 | import torch  | 
13 | 13 | import torch.nn as nn  | 
14 | 14 | import torch.nn.functional as F  | 
15 | 15 | from PIL import Image  | 
16 | 16 | from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,  | 
17 |  | -                          AutoProcessor, AutoTokenizer, BatchEncoding)  | 
 | 17 | +                          AutoTokenizer, BatchEncoding)  | 
18 | 18 | 
 
  | 
19 | 19 | from vllm import LLM, SamplingParams  | 
20 | 20 | from vllm.config import TokenizerPoolConfig, VisionLanguageConfig  | 
21 | 21 | from vllm.distributed import (destroy_distributed_environment,  | 
22 | 22 |                               destroy_model_parallel)  | 
23 | 23 | from vllm.inputs import TextPrompt  | 
24 | 24 | from vllm.logger import init_logger  | 
25 |  | -from vllm.multimodal import MultiModalData  | 
26 |  | -from vllm.multimodal.image import ImageFeatureData, ImagePixelData  | 
 | 25 | + | 
 | 26 | +if TYPE_CHECKING:  | 
 | 27 | +    from vllm.multimodal import MultiModalData  | 
 | 28 | +else:  | 
 | 29 | +    # it will call torch.cuda.device_count()  | 
 | 30 | +    MultiModalData = None  | 
27 | 31 | from vllm.sequence import SampleLogprobs  | 
28 | 32 | from vllm.utils import cuda_device_count_stateless, is_cpu  | 
29 | 33 | 
 
  | 
@@ -63,6 +67,10 @@ def for_hf(self) -> Image.Image:  | 
63 | 67 |         return self.pil_image  | 
64 | 68 | 
 
  | 
65 | 69 |     def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:  | 
 | 70 | +        # don't put this import at the top level  | 
 | 71 | +        # it will call torch.cuda.device_count()  | 
 | 72 | +        from vllm.multimodal.image import ImageFeatureData  # noqa: F401  | 
 | 73 | +        from vllm.multimodal.image import ImagePixelData  | 
66 | 74 |         image_input_type = vision_config.image_input_type  | 
67 | 75 |         ImageInputType = VisionLanguageConfig.ImageInputType  | 
68 | 76 | 
 
  | 
@@ -216,6 +224,9 @@ def __init__(  | 
216 | 224 |         )  | 
217 | 225 | 
 
  | 
218 | 226 |         try:  | 
 | 227 | +            # don't put this import at the top level  | 
 | 228 | +            # it will call torch.cuda.device_count()  | 
 | 229 | +            from transformers import AutoProcessor  # noqa: F401  | 
219 | 230 |             self.processor = AutoProcessor.from_pretrained(  | 
220 | 231 |                 model_name,  | 
221 | 232 |                 torch_dtype=torch_dtype,  | 
 | 
0 commit comments