Skip to content

Commit

Permalink
Merge pull request huggingface#2 from RUFFY-369/imagebind_hf
Browse files Browse the repository at this point in the history
Imagebind hf changes
  • Loading branch information
EduardoPach authored Sep 1, 2024
2 parents 8d717d0 + e2f3064 commit 030027d
Show file tree
Hide file tree
Showing 5 changed files with 571 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/imagebind.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The abstract from the paper is the following:

*We present ImageBind, an approach to learn a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. We show that all combinations of paired data are not necessary to train such a joint embedding, and only image-paired data is sufficient to bind the modalities together. ImageBind can leverage recent large scale vision-language models, and extends their zero-shot capabilities to new modalities just by using their natural pairing with images. It enables novel emergent applications 'out-of-the-box' including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. The emergent capabilities improve with the strength of the image encoder and we set a new state-of-the-art on emergent zero-shot recognition tasks across modalities, outperforming specialist supervised models. Finally, we show strong few-shot recognition results outperforming prior work, and that ImageBind serves as a new way to evaluate vision models for visual and non-visual tasks.*

This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97).
This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [ruffy369](https://huggingface.co/ruffy369) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97).
The original code can be found [here](https://github.com/facebookresearch/ImageBind).

## Usage tips
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
self.fps = fps
self._valid_processor_keys = [
"images",
"videos",
"do_resize",
"size",
"resample",
Expand All @@ -379,6 +380,7 @@ def __init__(
"do_chunk",
"chunk_duration",
"num_chunks",
"num_frames_per_chunk",
"fps",
"return_tensors",
"data_format",
Expand Down
72 changes: 49 additions & 23 deletions src/transformers/models/imagebind/processing_imagebind.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,38 @@
Image/Text processor class for ImageBind
"""

from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding

from typing import List, Optional, Union

try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack

from ...image_utils import ImageInput
from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput

class ImageBindProcessorImagesKwargs(ImagesKwargs, total=False):
do_convert_rgb: bool = None
do_chunk: bool = None
chunk_duration: float = None
num_chunks: int = None
num_frames_per_chunk: int = None
fps: int = None

class ImageBindProcessorAudioKwargs(AudioKwargs, total=False):
do_normalize: Optional[bool] = None
mean: Optional[float] = None
std: Optional[float] = None
do_chunk: Optional[bool] = None
chunk_duration: Optional[float] = None
num_chunks: Optional[int] = None

class ImageBindProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": "max_length",
"max_length": 64,
},
}
images_kwargs: ImageBindProcessorImagesKwargs
audio_kwargs: ImageBindProcessorAudioKwargs
_defaults = {}


class ImageBindProcessor(ProcessorMixin):
Expand All @@ -53,23 +73,29 @@ class ImageBindProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer, feature_extractor):
super().__init__(image_processor, tokenizer, feature_extractor)

def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kwargs):
def __call__(
self,
images=None,
text=None,
audio=None,
**kwargs: Unpack[ImageBindProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to ImageBindTokenizerFast's [`~ImageBindTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
ImageBindImageProcessor's [`~ImageBindImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
images (`ImageInput`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `List[str]`, `List[List[str]]`):
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`):
audio (`AudioInput`, `List[float]`, `List[List[float]]`, `List[List[List[float]]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of numpy
arrays or a (possibly nested) list of float values. The supported input types are as follows:
Expand All @@ -78,12 +104,6 @@ def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kw
- batched with clips: `List[List[List[float]]]`, `List[List[np.ndarray]]` (`ndim=1`), `List[np.ndarray]` (`ndim=2`), np.ndarray (`ndim=3`)
The input will always be interpreted as mono channel audio, not stereo, i.e. a single float per timestep.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
Expand All @@ -97,21 +117,27 @@ def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kw
if text is None and images is None and audio is None:
raise ValueError("You have to specify either text, images or audio. Both cannot be none.")

output_kwargs = self._merge_kwargs(
ImageBindProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

data = {}

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
data.update(encoding)

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)

if audio is not None:
audio_features = self.feature_extractor(audio, return_tensors=return_tensors)
audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
data.update(audio_features)

return BatchEncoding(data=data, tensor_type=return_tensors)
return BatchEncoding(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
"""
Expand Down
Loading

0 comments on commit 030027d

Please sign in to comment.