Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Llava chat template examples #30130

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
- For better results, we recommend users to prompt the model with the correct prompt format:

```bash
"USER: <image>\n<prompt>ASSISTANT:"
"USER: <image>\n<prompt> ASSISTANT:"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we should prepend the system prompt here, but I think it can be excluded for the purposes of keeping the demo simple.

```

For multiple turns conversation:

```bash
"USER: <image>\n<prompt1>ASSISTANT: <answer1>USER: <prompt2>ASSISTANT: <answer2>USER: <prompt3>ASSISTANT:"
"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"
```

### Using Flash Attention 2
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava model."""
"""PyTorch Llava model."""

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -388,16 +389,16 @@ def forward(
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT this prompt put the <image> token in the wrong place altogether.

>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(text=prompt, images=image, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we expecting to have this double spacing here "USER: \nWhat's ? We're skipping <image> but from the prompt I would have expected: USER: \nWhat's

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I did find this a bit weird, but it seems to be the default behaviour of the decoding

```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down
12 changes: 6 additions & 6 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Llava model. """
"""Testing suite for the PyTorch Llava model."""

import copy
import gc
Expand Down Expand Up @@ -398,13 +398,13 @@ def test_small_model_integration_test_llama(self):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)

prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:"
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip

self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
Expand All @@ -421,8 +421,8 @@ def test_small_model_integration_test_llama_batched(self):
processor = AutoProcessor.from_pretrained(model_id)

prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
"USER: <image>\nWhat is this? ASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
Expand All @@ -431,7 +431,7 @@ def test_small_model_integration_test_llama_batched(self):

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

Expand Down
Loading