|
59 | 59 | from vllm.multimodal.processing import (BaseMultiModalProcessor, |
60 | 60 | BaseProcessingInfo, PromptReplacement) |
61 | 61 | from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs |
| 62 | +from vllm.platforms import current_platform |
62 | 63 | from vllm.sequence import IntermediateTensors |
63 | 64 |
|
64 | 65 | from .idefics2_vision_model import Idefics2VisionTransformer |
@@ -1184,7 +1185,7 @@ def init_resampler(self, |
1184 | 1185 | quant_config=quant_config, |
1185 | 1186 | prefix=prefix) |
1186 | 1187 |
|
1187 | | - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) |
| 1188 | + return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype()) |
1188 | 1189 |
|
1189 | 1190 | def get_vision_embedding( |
1190 | 1191 | self, |
@@ -1266,7 +1267,7 @@ def init_resampler(self, |
1266 | 1267 | quant_config=quant_config, |
1267 | 1268 | prefix=prefix) |
1268 | 1269 |
|
1269 | | - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) |
| 1270 | + return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype()) |
1270 | 1271 |
|
1271 | 1272 | def get_vision_embedding( |
1272 | 1273 | self, |
@@ -1360,7 +1361,7 @@ def init_resampler(self, |
1360 | 1361 | quant_config=quant_config, |
1361 | 1362 | prefix=prefix) |
1362 | 1363 |
|
1363 | | - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) |
| 1364 | + return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype()) |
1364 | 1365 |
|
1365 | 1366 | def get_vision_embedding( |
1366 | 1367 | self, |
|
0 commit comments