Skip to content

Commit

Permalink
✏️ Add support for passing dict to image_processor_config in `ImageCa…
Browse files Browse the repository at this point in the history
…ptioningDataset`
  • Loading branch information
arxyzan committed Jun 9, 2024
1 parent 94ef704 commit 1de957c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions hezar/data/datasets/image_captioning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ImageCaptioningDatasetConfig(DatasetConfig):
text_column (str): Column name for text in the dataset.
images_paths_column (str): Column name for image paths in the dataset.
max_length (int): Maximum length of text.
test_split_size (float): Size of the test split.
image_processor_config (ImageProcessorConfig): Configuration for image processing.
"""
Expand All @@ -40,9 +39,13 @@ class ImageCaptioningDatasetConfig(DatasetConfig):
text_column: str = "label"
images_paths_column = "image_path"
max_length: int = None
test_split_size: float = 0.2
image_processor_config: ImageProcessorConfig = None

def __post_init__(self):
super().__post_init__()
if isinstance(self.image_processor_config, dict):
self.image_processor_config = ImageProcessorConfig(**self.image_processor_config)


@register_dataset("image_captioning", config_class=ImageCaptioningDatasetConfig)
class ImageCaptioningDataset(Dataset):
Expand Down

0 comments on commit 1de957c

Please sign in to comment.