Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuyu: improve image processing #27007

Merged
merged 11 commits into from
Nov 2, 2023
Merged

Fuyu: improve image processing #27007

merged 11 commits into from
Nov 2, 2023

Conversation

molbap
Copy link
Contributor

@molbap molbap commented Oct 23, 2023

What does this PR do?

This PR aims at aligning the FuyuImageProcessor class with other vision/language models within transformers. Fuyu model expects a tensor of token ids, a tensor of patch embeddings, and an indexing tensor indicating where to put rows of patch embeddings into the token embeddings, separated by the input ids. Currently the image processor does not separate the steps necessary to achieve this output in the Processor. It also limits the inference size to batches of size 1. It also aims at improving readability and code quality of the processor to possibly enable pipelining later on.

Pending tasks:

  • Return a BatchFeature with arbitrary batch size
  • add do_rescale, do_normalize, do_pad arguments in the ImageProcessor constructor
  • align patch-ification methods to ViTMAE and possibly pix2struct
  • rework and refactor method process_images_for_model_input, currently hard to read
  • test long images, stretched images, usual processor edge cases
  • test images and no text, text and no image in Processor class leveraging tokenizer + ImageProcessor

Before submitting

Who can review?

Models:

pcuenca and others added 3 commits October 19, 2023 01:57
It could produce negative padding and hence inference errors for certain
image sizes.
@ArthurZucker ArthurZucker mentioned this pull request Oct 23, 2023
5 tasks
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 23, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link

@yaoxingcheng yaoxingcheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discovered one bug in FuyuBatchEncoding

src/transformers/models/fuyu/processing_fuyu.py Outdated Show resolved Hide resolved
amyeroberts and others added 5 commits November 1, 2023 18:53
* Add file headers

* Add file headers

* First pass - preprocess method with standard args

* First pass image processor rework

* Small tweaks

* More args and docstrings

* Tidying iterating over batch

* Tidying up

* Modify to have quick tests (for now)

* Fix up

* BatchFeature

* Passing tests

* Add tests for processor

* Sense check when patchifying

* Add some tests

* FuyuBatchFeature

* Post-process box coordinates

* Update to `size` in processor

* Remove unused and duplicate constants

* Store unpadded dims after resize

* Fix up

* Return FuyuBatchFeature

* Get unpadded sizes after resize

* Update exception

* Fix return

* Convert input `<box>` coordinates to model format.

* Post-process point coords, support multiple boxes/points in a single
sequence

* Replace constants

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Preprocess List[List[image]]

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update to Amy's latest state.

* post-processing returns a list of tensors

* Fix error when target_sizes is None

Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Review comments

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix up

* Fix up

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
fixing conflicts and updating on main
Revert "Fix conflicts in fuyu_follow_up_image_processing (#27228)"

This reverts commit acce10b.
…ace/transformers into fuyu_follow_up_image_processing
@molbap molbap marked this pull request as ready for review November 2, 2023 09:51
@molbap
Copy link
Contributor Author

molbap commented Nov 2, 2023

This version of the processor now correctly supports batching, dtype casting, and the left-padded batch generation yields the same results as single-input generation.

from PIL import Image
import requests
import io
from transformers import FuyuForCausalLM, FuyuProcessor, FuyuImageProcessor, AutoTokenizer
from PIL import Image

pretrained_path = "adept/fuyu-8b"

tokenizer = AutoTokenizer.from_pretrained(pretrained_path, pad_token_id=0)
image_processor = FuyuImageProcessor()
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)

text_prompt = "Answer the following DocVQA question based on the image. \n Which is the metro in California that has a good job Outlook?"
jobs_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/jobs.png"
jobs_image_pil = Image.open(io.BytesIO(requests.get(jobs_image_url).content))

second_text_prompt = "Answer the following DocVQA question based on the image. \n What if the maximum male life expectancy?"
chart_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/chart.png"
chart_image_pil = Image.open(io.BytesIO(requests.get(chart_image_url).content))

third_text_prompt = "Answer the following DocVQA question based on the image. \n What sport is that?"
skate_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/skateboard.png"
skate_image_pil = Image.open(io.BytesIO(requests.get(skate_image_url).content))

fourth_text_prompt = "Answer the following DocVQA question based on the image. \n What was the fair amount of paid vacation days in the United Kingdom?"
vacations_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/vacation_days_hr.png"
vacations_image_pil = Image.open(io.BytesIO(requests.get(vacations_image_url).content)).convert('RGB')

texts = [text_prompt, second_text_prompt, third_text_prompt, fourth_text_prompt]
images = [jobs_image_pil, chart_image_pil, skate_image_pil, vacations_image_pil]

model_inputs = processor(text=texts, images=images).to('cuda')


model = FuyuForCausalLM.from_pretrained(pretrained_path, device_map='auto')

generation = processor.tokenizer.batch_decode(model.generate(
    **model_inputs, max_new_tokens=10)[:, -10:], skip_special_tokens=True)

single_generations = ['Los Angeles', '80.7',
                      'skateboarding', '28']


for single_generation, batched_generation in zip(single_generations, generation):
    answer = batched_generation.split('\x04 ', 1)[1] if '\x04' in batched_generation else ''
    assert (single_generation == answer)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

There's quite a lot of changes to the image processing code which were done by me - so I might be a bit blind to any issues in the diff. @pcuenca gave a detailed review however, so I think we're good :)

@molbap molbap merged commit 8a31295 into main Nov 2, 2023
21 checks passed
@molbap molbap deleted the fuyu_follow_up_image_processing branch November 2, 2023 11:25
Narsil pushed a commit that referenced this pull request Nov 2, 2023
* Fix Fuyu image scaling bug

It could produce negative padding and hence inference errors for certain
image sizes.

* initial rework commit

* add batching capabilities, refactor image processing

* add functional batching for a list of images and texts

* make args explicit

* Fuyu processing update (#27133)

* Add file headers

* Add file headers

* First pass - preprocess method with standard args

* First pass image processor rework

* Small tweaks

* More args and docstrings

* Tidying iterating over batch

* Tidying up

* Modify to have quick tests (for now)

* Fix up

* BatchFeature

* Passing tests

* Add tests for processor

* Sense check when patchifying

* Add some tests

* FuyuBatchFeature

* Post-process box coordinates

* Update to `size` in processor

* Remove unused and duplicate constants

* Store unpadded dims after resize

* Fix up

* Return FuyuBatchFeature

* Get unpadded sizes after resize

* Update exception

* Fix return

* Convert input `<box>` coordinates to model format.

* Post-process point coords, support multiple boxes/points in a single
sequence

* Replace constants

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Preprocess List[List[image]]

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update to Amy's latest state.

* post-processing returns a list of tensors

* Fix error when target_sizes is None

Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Review comments

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix up

* Fix up

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Fix conflicts in fuyu_follow_up_image_processing (#27228)

fixing conflicts and updating on main

* Revert "Fix conflicts in fuyu_follow_up_image_processing" (#27232)

Revert "Fix conflicts in fuyu_follow_up_image_processing (#27228)"

This reverts commit acce10b.

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
@Victorwz
Copy link

Victorwz commented Nov 3, 2023

I think current version of image processing and tokenization does not support the usage sample code in the original release, right?

from transformers import FuyuProcessor, FuyuForCausalLM
from PIL import Image

# load model and processor
model_id = "adept/fuyu-8b"
processor = FuyuProcessor.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0")

# prepare inputs for the model
text_prompt = "Generate a coco-style caption.\n"
image_path = "bus.png"  # https://huggingface.co/adept-hf-collab/fuyu-8b/blob/main/bus.png
image = Image.open(image_path)

inputs = processor(text=text_prompt, images=image, return_tensors="pt")
for k, v in inputs.items():
    inputs[k] = v.to("cuda:0")

# autoregressively generate text
generation_output = model.generate(**inputs, max_new_tokens=7)
generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
assert generation_text == ['A bus parked on the side of a road.']

I am trying to run the above code, and the error occurs that the inputs['image_patches'] is now a list and cannot be put to device.

I suggested that either you can also support this type of processing or you can directly update the sample code on the huggingface release page in link

@NielsRogge
Copy link
Contributor

Hi,

I've updated the code snippet on the model card, it works for me as expected (note that you need to install Transformers from the main branch: pip install -q git+https://github.com/huggingface/transformers.git)

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Fix Fuyu image scaling bug

It could produce negative padding and hence inference errors for certain
image sizes.

* initial rework commit

* add batching capabilities, refactor image processing

* add functional batching for a list of images and texts

* make args explicit

* Fuyu processing update (huggingface#27133)

* Add file headers

* Add file headers

* First pass - preprocess method with standard args

* First pass image processor rework

* Small tweaks

* More args and docstrings

* Tidying iterating over batch

* Tidying up

* Modify to have quick tests (for now)

* Fix up

* BatchFeature

* Passing tests

* Add tests for processor

* Sense check when patchifying

* Add some tests

* FuyuBatchFeature

* Post-process box coordinates

* Update to `size` in processor

* Remove unused and duplicate constants

* Store unpadded dims after resize

* Fix up

* Return FuyuBatchFeature

* Get unpadded sizes after resize

* Update exception

* Fix return

* Convert input `<box>` coordinates to model format.

* Post-process point coords, support multiple boxes/points in a single
sequence

* Replace constants

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Preprocess List[List[image]]

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update to Amy's latest state.

* post-processing returns a list of tensors

* Fix error when target_sizes is None

Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Review comments

* Update src/transformers/models/fuyu/image_processing_fuyu.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix up

* Fix up

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>

* Fix conflicts in fuyu_follow_up_image_processing (huggingface#27228)

fixing conflicts and updating on main

* Revert "Fix conflicts in fuyu_follow_up_image_processing" (huggingface#27232)

Revert "Fix conflicts in fuyu_follow_up_image_processing (huggingface#27228)"

This reverts commit acce10b.

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
@cyrilzakka
Copy link

cyrilzakka commented Mar 2, 2024

Have there been any updates to this? Still running into the same issue. Thanks!

@pcuenca
Copy link
Member

pcuenca commented Mar 2, 2024

Hello @cyrilzakka 👋 I've read the thread and according to Niels the sample code in the model card should work, can you please give more details about the issue you are facing? Thank you! :)

@cyrilzakka
Copy link

cyrilzakka commented Mar 2, 2024

Hey @pcuenca! Sorry for the trouble but having issues running Fuyu inference on a multi-GPU (4x25GB GPUs) setup: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:3):

from transformers import FuyuProcessor, FuyuForCausalLM
from PIL import Image
import requests
import torch

# load model and processor
model_id = "adept/fuyu-8b"
processor = FuyuProcessor.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

# # prepare inputs for the model
text_prompt = "What do you see in the image?\n"
image = Image.open("/home/cyril/Downloads/image.jpg").convert("RGB")

inputs = processor(text=text_prompt, images=image, return_tensors="pt").to('cuda')
generation_output = model.generate(**inputs, max_new_tokens=7)

@oops343
Copy link

oops343 commented Mar 19, 2024

Hey @pcuenca! Sorry for the trouble but having issues running Fuyu inference on a multi-GPU (4x25GB GPUs) setup: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:3):

from transformers import FuyuProcessor, FuyuForCausalLM
from PIL import Image
import requests
import torch

# load model and processor
model_id = "adept/fuyu-8b"
processor = FuyuProcessor.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

# # prepare inputs for the model
text_prompt = "What do you see in the image?\n"
image = Image.open("/home/cyril/Downloads/image.jpg").convert("RGB")

inputs = processor(text=text_prompt, images=image, return_tensors="pt").to('cuda')
generation_output = model.generate(**inputs, max_new_tokens=7)

same here

@ArthurZucker
Copy link
Collaborator

cc @SunMarc maybe a device placement issue, anyway we need the full traceback @oops343 not the same here traceback 😈

@oops343
Copy link

oops343 commented Mar 25, 2024

The code I used :

from transformers import FuyuForCausalLM, FuyuProcessor, FuyuImageProcessor, AutoTokenizer
from PIL import Image
import torch
torch.manual_seed(1234)

# load model and processor
model_id = "adept/fuyu-8b"

tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token_id=0)
image_processor = FuyuImageProcessor()
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)

model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto",torch_dtype=torch.float16).eval()

def baseline_for_fuyu(model, processor, batch):
    img_path = batch['img_path'][0]
    text = batch['text'][0]

    query = f"Given the meme, with the text [{text}] accompanied by the image, is this meme {adj}?\n."
    image = Image.open(img_path)
    if image.mode != "RGB":
        image = image.convert("RGB")
    inputs = processor(text=[query], images=[image]).to(model.device)
    with torch.no_grad():  
        generation_output = model.generate(**inputs, max_new_tokens=20)
        generation_text = processor.batch_decode(generation_output[:, -20:], skip_special_tokens=True)
        print("Generated text: ", generation_text[0].split('\x04', 1)[-1])
        torch.cuda.empty_cache()
        return generation_text[0].split('\x04', 1)[-1]

Traceback here:

Traceback (most recent call last):
File "/public/home/baseline/fuyu_baseline.py", line 151, in
response = baseline_for_fuyu(model, processor, batch)
File "/public/home/baseline/fuyu_baseline.py", line 142, in baseline_for_fuyu
generation_output = model.generate(**inputs, max_new_tokens=20)
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/transformers/generation/utils.py", line 1527, in generate
result = self._greedy_search(
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/transformers/generation/utils.py", line 2411, in _greedy_search
outputs = self(
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/transformers/models/fuyu/modeling_fuyu.py", line 296, in forward
inputs_embeds = self.gather_continuous_embeddings(
File "/public/home/.conda/envs/baseEnv/lib/python3.9/site-packages/transformers/models/fuyu/modeling_fuyu.py", line 207, in gather_continuous_embeddings
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

Sry about the simple traceback lol
This happened with 2 x RTX3090 GPU, CUDA version 11.7, transformers 4.39.0, torch 2.0.1
@ArthurZucker please check , I think this won't happen for the other transformer models I'm using, with the same .to(model.device)

@NielsRogge
Copy link
Contributor

@oops343 would it be possible to open a new issue for this?

@oops343
Copy link

oops343 commented Mar 25, 2024

@NielsRogge sure, pls check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.