Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DETR #9998

Closed
wants to merge 20 commits into from
Closed

Add DETR #9998

wants to merge 20 commits into from

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Feb 4, 2021

What does this PR do?

It adds the first vision-only Transformer to the library! Namely DETR, End-to-End Object Detection with Transformers, by Facebook AI. The main contribution of DETR is its simplicity: it replaces a lot of hand-engineered features (which models like Faster-R-CNN and Mask-R-CNN include) such as non-maximum suppression and anchor generation by just an end-to-end model and a clever loss function, while matching the performance of these heavily complex models.

For a really good explanation (which helped me a lot), see Yannic Kilcher's video here. I'll provide a TLDR here:

The main thing to know is that an image of shape (batch_size, num_channels, height, width), so in case of a single image, a tensor of shape (1, 3, height, width) is first sent through a CNN backbone, outputting a lower-resolution feature map, typically of shape (1, 2048, height/32, width/32). This is then projected to match the hidden dimension of the Transformer, which is 256 by default, using nn.Conv2D. So now we have a tensor of shape (1, 256, height/32, width/32). Next, the image is flattened and transposed to obtain a tensor of shape (batch_size, seq_len, d_model) = (1, width/32*height/32, 256). So a difference with NLP models is that the sequence length is actually longer than usual, but with a smaller hidden_size (which in NLP is typically 768 or higher).

This is sent through the encoder, outputting encoder_hidden_states of the same shape. Next, so-called object queries are sent through the decoder. This is just a tensor of shape (batch_size, num_queries, d_model), with num_queries typically set to 100 and is initialized with zeros. Each object query looks for a particular object in the image. Next, the decoder updates these object queries through multiple self-attention and encoder-decoder attention layers to output decoder_hidden_states of the same shape: (batch_size, num_queries, d_model). Next, two heads are added on top for object detection: a linear layer for classifying each object query into one of the objects or "no object", and a MLP to predict bounding boxes for each query. So the number of queries actually determines the maximum number of objects the model can detect in an image.

The model is trained using a "bipartite matching loss": so what we actually do is compare the predicted classes + bounding boxes of each of the N = 100 object queries to the ground truth annotations, padded up to the same length N (so if an image only contains 4 objects, 96 annotations will just have a "no object" as class and "no bounding box" as bounding box). The Hungarian matching algorithm is used to create an optimal one-to-one mapping between each of the N queries and each of the N annotations. Next, standard cross-entropy for the classes and L1 regression loss for the bounding boxes are used to optimize the parameters of the model.

Paper: https://arxiv.org/abs/2005.12872
Original repo: https://github.com/facebookresearch/detr

Usage

Quick demo of my current implementation (with some cool attention visualizations): https://colab.research.google.com/drive/1aJ00yPxT4-PCMhSx2BipbTKqMSBQ80vJ?usp=sharing

(Old demo: https://colab.research.google.com/drive/1G4oWTOg_Jotp_2jJhdYkYVfkcT9ucX4P?usp=sharing)

Note that the authors did release 7 model variants (4 for object detection, 3 for panoptic segmentation). Currenty I've defined two models: the base DetrModel (which outputs the raw hidden states of the decoder) and DetrForObjectDetection, which adds object detection heads (classes + bounding boxes) on top. I've currently only converted and tested the base model for object detection (DETR-resnet-50). Adding the other models for object detection seems quite easy (as these only use a different backbone and I copied the code of the backbone from the original repo). Adding the models for panoptic segmentation (DetrForPanopticSegmentation) is on the to-do list as can be seen below.

Done

  • load pretrained weights into the model
  • make sure forward pass yields equal outputs on the same input data
  • successful transcription
  • add tokenizer (not sure if DETR needs one, see discussion below)
  • add model tests: currently added 2 integration tests which pass, more tests to follow
  • add tokenizer tests (not sure if DETR needs one, see discussion below)
  • add docstrings
  • fill in rst file

Discussion

Writing DETR in modeling_detr.py went quite fast thanks to the CookieCutter template (seriously, the person who added this, thank you!!). The main thing to write was the conversion script (basically translating PyTorch's default nn.MultiHeadAttention to the self-attention mechanism defined in this library). DETR is an encoder-decoder Transformer, with only some minor differences, namely:

  • it uses parallel decoding instead of autoregressive. So I assume I can delete all the past_key_values and causal_mask mechanisms? cc @patrickvonplaten
  • it adds positional embeddings to the hidden states (in both the encoder and decoder) in each self-attention and encoder-decoder attention before projecting to queries and keys
  • it uses the "relu" activation function instead of the default "gelu" one.
  • during training, it helps to train on the outputs of each decoder layer. So what the authors do is predict classes + bounding boxes based on the output of each decoder layer, and also train these. This is a hyperparameter of DetrConfig called auxiliary_loss. This is also why I defined an additional ModelOutput called BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates, which adds intermediate activations of the decoder layers as output.

I wonder whether DETR needs a tokenizer. Currently, it accepts a NestedTensor as input to the encoder, not the usual input_ids, attention_mask and token_type_ids. The authors of DETR really like this data type because of its flexibility. It basically allows to batch images of different sizes and pad them up to the biggest image in the batch, also providing a mask indicating which pixels are real and which are padding. See here for a motivation on why they chose this data type (the authors of PyTorch are also experimenting with this, see their project here). So maybe NestedTensor is something we could use as well, since it automatically batches different images and adds a mask, which Transformer models require?

Also, no special tokens are used, as the input of the encoder are just flattened images. The decoder on the other hand accepts object queries as input (which are created in DetrModel), instead of regular input_ids, attention_mask and token_type_ids. So I wonder whether these can also be removed.

Future to-do

  • Add DetrForPanopticSegmentation
  • Let DETR support any backbone, perhaps those of the timm library as well as any model in the torchvision package

Who can review?

@LysandreJik @patrickvonplaten @sgugger

Fixes #4663

Unfortunately, self-attention and MultiHeadAttention seem to be easier to understand than git.. I'm having some issues with line endings on Windows. Any help is greatly appreciated. I'm mainly opening this for discussing how to finish DETR.

@LysandreJik
Copy link
Member

I'll have a look at the git issue in the evening

@sgugger
Copy link
Collaborator

sgugger commented Feb 4, 2021

Thanks for the PR, a few quick comments:

This is also why I defined an additional ModelOutput called BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates, which adds intermediate activations of the decoder layers as output.

I will strongly object to a name that long as a matter of principle 😅 But jsut so I understand what it adds, are those intermediate activations of the decoder layers not in the hidden_states attribute already?

I wonder whether DETR needs a tokenizer.

I think the "tokenization" file (we can rename it if we want) should exist and contain the NestedTensor class and the utilities for padding. Like Wav2Vec2 Patrick added recently, the tokenizer call would only take care of the padding, resizing to a max size (if given) and normalizing. The tokenizer could also have a method that loads the images from a filename and accept in its call one or a list of decoded images (as np.array or tensor) or one or a list of filenames (and decode them with PIL for instance).
It could also have a decode method which would in this case do the rescale of bounding boxes and map label IDs to label names, so it's easier to then plot the results.

The inputs of the models should completely be renamed to reflect the types of objects expected (so probably pixel_values and pixel_mask would be better names than input_ids etc) and the tokenizer call should output a dictionary with those names as keys (so we can use the usual API of feeding directly to the model the output of the tokenizer).

I imagine something like as a final easy API:

inputs = tokenizer([filename1, filename2])
outputs = model(**inputs)
preocessed_outputs = tokenizer.decode(outputs)

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Feb 9, 2021

will strongly object to a name that long as a matter of principle 😅 But jsut so I understand what it adds, are those intermediate activations of the decoder layers not in the hidden_states attribute already?

Yes, the intermediate activations are the hidden states of the decoder layers, each of them followed by a LayerNorm. I agree that the name is too long 😅

I think the "tokenization" file (we can rename it if we want) should exist and contain the NestedTensor class and the utilities for padding. Like Wav2Vec2 Patrick added recently, the tokenizer call would only take care of the padding, resizing to a max size (if given) and normalizing. The tokenizer could also have a method that loads the images from a filename and accept in its call one or a list of decoded images (as np.array or tensor) or one or a list of filenames (and decode them with PIL for instance).

I've created a first draft of DetrTokenizer as you requested. The API looks as follows:

from PIL import Image
import requests
from transformers import DetrTokenizer

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

tokenizer = DetrTokenizer() # later, this is gonna be .from_pretrained("facebook/detr-resnet-50")
encoding = tokenizer(image)

Currently it accepts PIL images, Numpy arrays and PyTorch tensors. The encoding (which is a BatchEncoding) has 2 keys, namely pixel_values and pixel_mask. You can call the tokenizer with the following parameters:

  • resize: whether to resize images to a given size.
  • size: arbitrary integer to which you want to resize the images
  • max_size: the largest size an image dimension can have (otherwise it's capped).
  • normalize: whether to apply mean-std normalization.

An additional complexity with object detection is that if you resize images, the annotated bounding boxes must be resized accordingly. So if you want to prepare data for training, you can also pass in annotations in the __call__ method of DetrTokenizer. In that case, the encoding will also include a key named labels.

@LysandreJik
Copy link
Member

Resolution of the git issue: #10119

@sgugger
Copy link
Collaborator

sgugger commented Feb 10, 2021

Currently it accepts PIL images, Numpy arrays and PyTorch tensors.

Pretty cool! Can we strings or pathlib.Paths too?

About the general API, not sure if we should inherit from PreTrainedTokenizer since the from_pretrained/save_pretrained methods are not going to work. Wdyt @LysandreJik ? This is also not a tokenizer, more like an AnnotatedImagePreProcessor or something like that.

An additional complexity with object detection is that if you resize images, the annotated bounding boxes must be resized accordingly. So if you want to prepare data for training, you can also pass in annotations in the call method of DetrTokenizer

Yes, this is expected. Maybe we could create a new type a bit like BatchEncoding that groups together the image (on all possible formats, string, PIL, array, tensor) with its annotation, so we can then just pass that object (or a list of those objects) to the tokenizer. What do you think?

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Feb 10, 2021

Pretty cool! Can we strings or pathlib.Paths too?

About the general API, not sure if we should inherit from PreTrainedTokenizer since the from_pretrained/save_pretrained methods are not going to work. Wdyt @LysandreJik ? This is also not a tokenizer, more like an AnnotatedImagePreProcessor or something like that.

Sure, it's best to make a similar API for ViT, right? (And more Transformer-based image models that will come after that). I've heard some people are working on ViT? To be fair, I could write a conversion script for ViT if you want, I see it's available in timm.

Yes, this is expected. Maybe we could create a new type a bit like BatchEncoding that groups together the image (on all possible formats, string, PIL, array, tensor) with its annotation, so we can then just pass that object (or a list of those objects) to the tokenizer. What do you think?

You mean pass that object to the model, rather than the tokenizer? For me, BatchEncoding seems like a good name.

@sgugger
Copy link
Collaborator

sgugger commented Feb 10, 2021

Sure, it's best to make a similar API for ViT, right? (And more Transformer-based image models that will come after that). I

Since ViT is not ported yet, this is where we decide the API that will be used for other vision/multi-model models :-)

You mean pass that object to the model, rather than the tokenizer? For me, BatchEncoding seems like a good name.

No, I meant to the tokenizer (though I'm not too sure about this part, it may end up over-complicating things). BatchEncoding comes with its text-related methods (word_ids, sequence_ids etc) so I don't think it should be used here since they won't be available.

@LysandreJik
Copy link
Member

Regarding the tokenizer I think we can have a bit more freedom here than we would with NLP models as it's the first vision model, but as you've said @sgugger I think that it should still be somewhat aligned with NLP tokenizers:

  • It should take care of all the pre-processing steps
    • Creation of batches of images, with padding & truncation
    • All the functionalities you mentionned @NielsRogge resize/size/normalize, etc
  • Ideally it should have a very similar API to existing NLP tokenizers. Applying processing with the __call__ method, loading/saving with from_pretrained/save_pretrained. I didn't dive in the implementation, but if parameters like resize/size/normalize etc are checkpoint-specific, then it's a good opportunity to save these configuration values in the tokenizer_config.json, leveraging the loading/saving methods mentioned above.
  • If there needs to be some decoding done after the model has processed the image, then that object should be able to handle it as well.

@sgugger regarding what the tokenizer accepts, I'm not sure I see the advantage of handling paths directly. We don't handle paths to text files or paths to CSVs in our other tokenizers. We don't handle paths to sound files either for Wav2Vec2, for all of that we rely on external tools and I think that's fine.

Furthermore, handling images directly in the tokenizer sounds especially memory-heavy, and relying on the datasets library, which can handle memory mapping, seems like a better approach than leveraging the tokenizer to load files into memory.

@sgugger
Copy link
Collaborator

sgugger commented Feb 15, 2021

Yes at least the normalize statistics (mean and std) are checkpoint-specific so should be loaded/saved with the usual API.

@sgugger regarding what the tokenizer accepts, I'm not sure I see the advantage of handling paths directly. We don't handle paths to text files or paths to CSVs in our other tokenizers. We don't handle paths to sound files either for Wav2Vec2, for all of that we rely on external tools and I think that's fine.

The difference is that a tokenizer accepts strings which is a universal type, whereas this image processor accepts PIL images, which is the format given by one specific library (so you can't load your image with openCV and feed it to the tokenizer). Since we already have a privileged image preprocessing library I really think it makes sense to let it also accept filenames. An alternative is to accept only numpy arrays and tensors, but there is the conversion back to PIL images inside the function (we could avoid it and do everything on tensors if we wanted to btw) so I don't think it makes sense.

In any case the user can still use their own preprocessing and pass the final numpy array/torch tensor with the API so I don't see the downside in accepting filenames. Usual tokenizers would have a hard time making the difference between a string that is a text and a string that is a path but this is not the case for images (or sounds, we could have that API there too and I think we should). It's just free functionality.

In NLP we have datasets as lists of texts since text is light in memory, but in CV all the datasets will come as lists of filenames that you have to load lazily (except maybe CIFAR10 and MNIST since they are tiny). Just trying to make it as easy as possible to the user.

Furthermore, handling images directly in the tokenizer sounds especially memory-heavy

The memory will be used in any case as the images passed to the tokenizer are already loaded if you don't pass filenames. The use shouldn't change between passing n filenames and n images.

@LysandreJik
Copy link
Member

I think this goes against the API we've defined up to now for all existing modalities (text, speech, tabular), and it adds additional work on the tokenizer whereas I think data loading should be handled by PyTorch dataloaders/Datasets, or with datasets.

However, your points echo with me and I have less experience than you both in vision, so if you feel that such an API is what would be best for vision, then happy to drop it and feel free to implement it this way.

@sgugger
Copy link
Collaborator

sgugger commented Feb 15, 2021

Let's not add the file supports for now and discuss it at our next internal meeting then. I agree it is a new functionality that would be different from our other APIs.

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Mar 1, 2021

Any update on this?

The tokenizer (I know we should rename it to something else) that I currently implemented accepts images as PIL images, Numpy arrays or PyTorch tensors, and creates 2 things: pixel_values and pixel_mask. It could be used for both DETR and ViT.

We should probably define some base utils similar to what Patrick did for the speech models.

cc @LysandreJik @sgugger @patrickvonplaten

@sgugger
Copy link
Collaborator

sgugger commented Mar 1, 2021

Thanks for reaching out!

So the "tokenizer" as you wrote it is good, but it should be renamed to a DetrFeatureExtractor and subclass PreTrainedFeatureExtractor (following the example of Wav2Vec2). All the necessary info to create one should be in one json file in the model repo (basically the same API as Wav2Vec2, but just the feature extractor part since there is no tokenizer in DETR). For ViT we can copy the same (we will refactor down the road if there are many models sharing the same functionality but for now we'll just use copies with # Copied from xxx markers).

There is no need for new base utils, the base utils Patrick defined are the ones to use for this case. As for the inputs, we agreed to stay with PIL Images, NumPy arrays and torch Tensors, so all good on this side.

@NielsRogge
Copy link
Contributor Author

The PreTrainedFeatureExtractor seems to be quite specifically defined for speech recognition (it requires a sampling_rate for instance at initialization).

@sgugger
Copy link
Collaborator

sgugger commented Mar 1, 2021

cc @patrickvonplaten but I thought this one was supposed to be generic.

@sgugger
Copy link
Collaborator

sgugger commented Mar 1, 2021

Talked offline with Patrick and I misunderstood the plan. PreTrainedFeatureExtractor is for all kinds of inputs that are representable as 1d arrays of floats (like speech). For images, we should create a new base class that will implement the same methods. If you can take inspiration on PreTrainedFeatureExtractor to create an ImageProcessor, it would be great! The only thing that should be exactly the same is the name of the saved config: preprocessing_config.json.

Does that make sense?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Apr 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

End-to-end object detection with Transformers
3 participants