Skip to content

Commit 7bba4d1

Browse files
authored
Fix video processing channel format (#41603)
fix
1 parent ab92534 commit 7bba4d1

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/transformers/video_processing_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@
5757
VideoInput,
5858
VideoMetadata,
5959
group_videos_by_shape,
60+
infer_channel_dimension_format,
6061
is_valid_video,
6162
load_video,
6263
make_batched_metadata,
6364
make_batched_videos,
6465
reorder_videos,
65-
to_channel_dimension_format,
6666
)
6767

6868

@@ -338,10 +338,16 @@ def _prepare_input_videos(
338338
for video in videos:
339339
# `make_batched_videos` always returns a 4D array per video
340340
if isinstance(video, np.ndarray):
341-
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format)
342341
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
343342
video = torch.from_numpy(video).contiguous()
344343

344+
# Infer the channel dimension format if not provided
345+
if input_data_format is None:
346+
input_data_format = infer_channel_dimension_format(video)
347+
348+
if input_data_format == ChannelDimension.LAST:
349+
video = video.permute(0, 3, 1, 2).contiguous()
350+
345351
if device is not None:
346352
video = video.to(device)
347353

0 commit comments

Comments
 (0)