|
42 | 42 | SequenceData) |
43 | 43 |
|
44 | 44 | from .interfaces import SupportsMultiModal |
45 | | -from .utils import merge_multimodal_embeddings |
| 45 | +from .utils import flatten_bn, merge_multimodal_embeddings |
46 | 46 |
|
47 | 47 | # Cannot find the following 2 numbers from hf config. |
48 | 48 | _IMAGE_TOKEN_ID = 71011 |
@@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): |
165 | 165 | model_config.model) |
166 | 166 |
|
167 | 167 | model_image_input = _fuyu_image_preprocess(image_processor, image_data) |
168 | | - image_patches = torch.stack([ |
| 168 | + image_patches = torch.cat([ |
169 | 169 | image_patch[0] |
170 | 170 | for image_patch in model_image_input["image_patches"] |
171 | 171 | ]) |
@@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): |
210 | 210 | ]) |
211 | 211 |
|
212 | 212 | # image has been processed with prompt in input processor |
213 | | - return MultiModalInputs({"image_patches": data}) |
| 213 | + return MultiModalInputs({"pixel_values": data}) |
214 | 214 |
|
215 | 215 |
|
216 | 216 | @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) |
@@ -242,23 +242,42 @@ def __init__(self, |
242 | 242 | cache_config=cache_config, |
243 | 243 | quant_config=quant_config) |
244 | 244 |
|
| 245 | + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: |
| 246 | + |
| 247 | + h = w = self.config.patch_size |
| 248 | + num_channels = self.config.num_channels |
| 249 | + expected_dims = num_channels * h * w |
| 250 | + |
| 251 | + def _validate_shape(d: torch.Tensor): |
| 252 | + actual_dims = d.size(-1) |
| 253 | + |
| 254 | + if actual_dims != expected_dims: |
| 255 | + expected_expr = str(expected_dims) |
| 256 | + raise ValueError( |
| 257 | + "The expected shape of pixel values per image per batch " |
| 258 | + f" per patch is {expected_expr}. " |
| 259 | + f"You supplied {tuple(d.shape)}.") |
| 260 | + |
| 261 | + for d in data: |
| 262 | + _validate_shape(d) |
| 263 | + |
| 264 | + return data.to(self.vision_embed_tokens.weight.dtype) |
| 265 | + |
245 | 266 | def _parse_and_validate_image_input( |
246 | 267 | self, **kwargs: object) -> Optional[FuyuImagePixelInputs]: |
247 | | - image_patches = kwargs.pop("image_patches", None) |
| 268 | + pixel_values = kwargs.pop("pixel_values", None) |
248 | 269 |
|
249 | | - if isinstance(image_patches, torch.Tensor): |
250 | | - # Remove the N dimension until multiple images are supported. |
251 | | - image_patches = image_patches.squeeze(1) |
| 270 | + if pixel_values is not None: |
| 271 | + if not isinstance(pixel_values, (torch.Tensor, list)): |
| 272 | + raise ValueError("Incorrect type of image patches. " |
| 273 | + f"Got type: {type(pixel_values)}") |
| 274 | + |
| 275 | + return FuyuImagePixelInputs( |
| 276 | + type="pixel_values", |
| 277 | + data=self._validate_pixel_values( |
| 278 | + flatten_bn(pixel_values, concat=True)), |
| 279 | + ) |
252 | 280 |
|
253 | | - expected_feature_size = self.image_feature_size |
254 | | - if image_patches.size(-1) != expected_feature_size: |
255 | | - raise ValueError( |
256 | | - f"Expected image patches to have the last dimension of " |
257 | | - f"{expected_feature_size}, got {image_patches.size(-1)}") |
258 | | - image_patches = image_patches.to( |
259 | | - self.vision_embed_tokens.weight.dtype) |
260 | | - return FuyuImagePixelInputs(type="pixel_values", |
261 | | - data=image_patches) |
262 | 281 | return None |
263 | 282 |
|
264 | 283 | def _process_image_input( |
|
0 commit comments