|
59 | 59 | from vllm.multimodal.processing import (BaseMultiModalProcessor, |
60 | 60 | BaseProcessingInfo, PromptReplacement) |
61 | 61 | from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs |
| 62 | +from vllm.platforms import current_platform |
62 | 63 | from vllm.sequence import IntermediateTensors |
63 | 64 |
|
64 | 65 | from .idefics2_vision_model import Idefics2VisionTransformer |
@@ -1184,7 +1185,8 @@ def init_resampler(self, |
1184 | 1185 | quant_config=quant_config, |
1185 | 1186 | prefix=prefix) |
1186 | 1187 |
|
1187 | | - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) |
| 1188 | + return resampler.to(device=current_platform.device_type, |
| 1189 | + dtype=torch.get_default_dtype()) |
1188 | 1190 |
|
1189 | 1191 | def get_vision_embedding( |
1190 | 1192 | self, |
@@ -1266,7 +1268,8 @@ def init_resampler(self, |
1266 | 1268 | quant_config=quant_config, |
1267 | 1269 | prefix=prefix) |
1268 | 1270 |
|
1269 | | - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) |
| 1271 | + return resampler.to(device=current_platform.device_type, |
| 1272 | + dtype=torch.get_default_dtype()) |
1270 | 1273 |
|
1271 | 1274 | def get_vision_embedding( |
1272 | 1275 | self, |
@@ -1360,7 +1363,8 @@ def init_resampler(self, |
1360 | 1363 | quant_config=quant_config, |
1361 | 1364 | prefix=prefix) |
1362 | 1365 |
|
1363 | | - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) |
| 1366 | + return resampler.to(device=current_platform.device_type, |
| 1367 | + dtype=torch.get_default_dtype()) |
1364 | 1368 |
|
1365 | 1369 | def get_vision_embedding( |
1366 | 1370 | self, |
|
0 commit comments