Skip to content

Commit

Permalink
0.34.0 +llama3.2,+molmo,+got-ocr2
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Sep 26, 2024
1 parent 96cc4e3 commit 6c45b53
Show file tree
Hide file tree
Showing 13 changed files with 485 additions and 151 deletions.
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ RUN git clone https://github.com/togethercomputer/Dragonfly --single-branch /app

COPY requirements.txt .
ARG VERSION=latest
RUN if [ "$VERSION" = "alt" ]; then echo "transformers==4.41.2" >> requirements.txt; else echo "git+https://github.com/huggingface/transformers" >> requirements.txt ; fi
# TODO: nvidia apex wheel
RUN if [ "$VERSION" = "alt" ]; then echo "transformers==4.41.2" >> requirements.txt; else echo "transformers>=4.45.0" >> requirements.txt ; fi
RUN --mount=type=cache,target=/root/.cache/pip pip install -U -r requirements.txt

WORKDIR /app/Mantis
Expand Down
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
<details>
<summary>Full list of supported models</summary>

- [X] [AIDC-AI]()
- [X] [AIDC-AI](https://huggingface.co/AIDC-AI)
- - [X] [Ovis1.5-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.5-Gemma2-9B)
- - [X] [Ovis1.5-Llama3-8B](https://huggingface.co/AIDC-AI/Ovis1.5-Llama3-8B)
- [X] [Ai2](https://huggingface.co/allenai)
- - [X] [Molmo-72B-0924](https://huggingface.co/allenai/Molmo-72B-0924)
- - [X] [Molmo-7B-O-0924](https://huggingface.co/allenai/Molmo-7B-O-0924)
- - [X] [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924)
- - [X] [MolmoE-1B-0924](https://huggingface.co/allenai/MolmoE-1B-0924)
- [X] [BAAI](https://huggingface.co/BAAI/)
- - [X] [BAAI/Bunny-v1_0-2B-zh](https://huggingface.co/BAAI/Bunny-v1_0-2B-zh)
- - [X] [BAAI/Bunny-v1_0-3B-zh](https://huggingface.co/BAAI/Bunny-v1_0-3B-zh)
Expand Down Expand Up @@ -63,6 +68,9 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [llava-v1.5-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.5-vicuna-7b-hf)
- - [X] [llava-v1.5-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.5-vicuna-13b-hf)
- - [ ] [llava-v1.5-bakLlava-7b-hf](https://huggingface.co/llava-hf/llava-v1.5-bakLlava-7b-hf) (currently errors)
- [X] [Meta Llama](https://huggingface.co/meta-llama)
- - [X] [Llama-3.2-90B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision-Instruct)
- - [X] [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- [X] [Microsoft](https://huggingface.co/microsoft/)
- - [X] [Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)
- - [X] [Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
Expand Down Expand Up @@ -125,6 +133,7 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
- - [X] [Qwen2-VL-2B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-AWQ)
- - [X] [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
- [X] [stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0) (ocr only model)
- [X] [vikhyatk](https://huggingface.co/vikhyatk)
- - [X] [moondream2](https://huggingface.co/vikhyatk/moondream2)
- - [X] [moondream1](https://huggingface.co/vikhyatk/moondream1) (0.28.1-alt only)
Expand All @@ -145,6 +154,13 @@ If you can't find your favorite model, you can [open a new issue](https://github

## Recent updates

Version 0.34.0

- new model support: Meta-llama: Llama-3.2-11B-Vision-Instruct, Llama-3.2-90B-Vision-Instruct
- new model support: Ai2/allenai Molmo family of models (requires additional `pip install tensorflow-cpu` for now, [see note](https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/image_preprocessing_molmo.py#L88-L90))
- new model support: stepfun-ai/GOT-OCR2_0, this is an OCR only model, all chat is ignored.
- Support moved to alt image: Bunny-Llama-3-8B-V, Bunny-v1_1-Llama-3-8B-V, Mantis-8B-clip-llama3, Mantis-8B-siglip-llama3, omchat-v2.0-13B-single-beta_hf, qihoo360/360VL-8B

Version 0.33.0

- new model support: mx262/MiniMonkey, thanks [@white2018](https://github.com/white2018)
Expand Down Expand Up @@ -361,7 +377,7 @@ docker compose -f docker-compose.alt.yml pull
python -m venv .venv
source .venv/bin/activate
# install the python dependencies
pip install -U -r requirements.txt "git+https://github.com/huggingface/transformers"
pip install -U -r requirements.txt "transformers>=4.45.0"
# OR install the python dependencies for the alt version
pip install -U -r requirements.txt "transformers==4.41.2"
# run the server with your chosen model
Expand Down
43 changes: 43 additions & 0 deletions backend/got_ocr2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from transformers import AutoTokenizer, AutoModel

from vision_qna import *

# ucaslcl/GOT-OCR2_0

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'

class VisionQnA(VisionQnABase):
model_name: str = "got_ocr2"
format: str = "custom"
visual_layers: List[str] = ['vision_tower_high', 'mm_projector_vary']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModel.from_pretrained(**self.params).eval()

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.loaded_banner()

async def chat_with_images(self, request: ImageChatRequest) -> str:
try:
image = None
for m in reversed(request.messages):
for c in m.content:
if c.type == 'image_url':
image = await url_to_file(c.image_url.url)
break

response = self.model.chat(self.tokenizer, image, ocr_type='ocr') # TODO: support format and maybe convert to markdown?

return response
finally:
if image:
os.remove(image)
14 changes: 9 additions & 5 deletions backend/llava.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from transformers import LlavaProcessor, LlavaForConditionalGeneration
from transformers import AutoProcessor, LlavaForConditionalGeneration # was LlavaProcessor
from vision_qna import *

#
# llava-hf/bakLlava-v1-hf # llama2
# llava-hf/llava-1.5-7b-hf # vicuna
# llava-hf/llava-1.5-13b-hf # vicuna
# Doesn't support execution without images
# "hf-internal-testing/pixtral-12b" soon?

class VisionQnA(VisionQnABase):
model_name: str = "llava"
format: str = 'vicuna'
vision_layers: List[str] = ["vision_model", "vision_tower", "multi_modal_projector"]
vision_layers: List[str] = ["vision_model", "vision_tower", "multi_modal_projector", "vision_encoder", "vision_language_adapter"]

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)
Expand All @@ -20,7 +21,7 @@ def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_p

del self.params['trust_remote_code']

self.processor = LlavaProcessor.from_pretrained(model_id)
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = LlavaForConditionalGeneration.from_pretrained(**self.params)

self.loaded_banner()
Expand All @@ -31,9 +32,12 @@ async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGener

if len(images) < 1:
images = [ await url_to_image(black_pixel_url) ]
prompt = "<image>\n" + prompt
if self.format == 'pixtral':
prompt = "[IMG]\n" + prompt
else:
prompt = "<image>\n" + prompt

inputs = self.processor(prompt, images, return_tensors="pt").to(self.device)
inputs = self.processor(images=images, text=prompt, return_tensors="pt").to(self.device)

params = self.get_generation_params(request)

Expand Down
46 changes: 46 additions & 0 deletions backend/mllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from transformers import MllamaForConditionalGeneration, AutoProcessor

from vision_qna import *

# meta-llama/Llama-3.2-11B-Vision-Instruct
# meta-llama/Llama-3.2-90B-Vision-Instruct

class VisionQnA(VisionQnABase):
model_name: str = "mllama"
format: str = "llama3"
visual_layers: List[str] = ['vision_model', 'multi_modal_projector']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

del self.params['trust_remote_code']

self.processor = AutoProcessor.from_pretrained(model_id)
self.model = MllamaForConditionalGeneration.from_pretrained(**self.params).eval()

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.loaded_banner()

async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]:
images, prompt = await llama3_prompt_from_messages(request.messages, img_tok = "<|image|>")

if len(images) < 1:
images = [ await url_to_image(black_pixel_url) ]
prompt = "<|image|>" + prompt

inputs = self.processor(images, prompt, return_tensors="pt").to(self.model.device)

default_params = dict(do_sample=True)

params = self.get_generation_params(request, default_params=default_params)

generation_kwargs = dict(
**inputs,
**params,
)

for new_text in threaded_streaming_generator(generate=self.model.generate, tokenizer=self.processor, generation_kwargs=generation_kwargs):
yield new_text
77 changes: 77 additions & 0 deletions backend/molmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig

from vision_qna import *

# allenai/MolmoE-1B-0924
# allenai/Molmo-7B-D-0924
# allenai/Molmo-7B-O-0924
# allenai/Molmo-72B-0924

# XXX To use, pip install tensorflow-cpu
# https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/image_preprocessing_molmo.py#L88-L90

"""
["allenai/MolmoE-1B-0924", "-A", "flash_attention_2", "--load-in-4bit"],
["allenai/MolmoE-1B-0924", "-A", "flash_attention_2"],
["allenai/Molmo-7B-D-0924", "-A", "flash_attention_2", "--load-in-4bit"],
["allenai/Molmo-7B-D-0924", "-A", "flash_attention_2"],
["allenai/Molmo-7B-O-0924", "-A", "flash_attention_2", "--load-in-4bit"],
["allenai/Molmo-7B-O-0924", "-A", "flash_attention_2"],
["allenai/Molmo-72B-0924", "--load-in-4bit"],
"""

class VisionQnA(VisionQnABase):
model_name: str = "molmo"
format: str = "chatml"
visual_layers: List[str] = ['vision_backbone']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

self.dtype = self.params['torch_dtype'] = 'auto' # torch.float32

self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval()

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.eos_token_id = self.processor.tokenizer.encode(self.processor.tokenizer.eos_token)[0]

self.loaded_banner()

async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]:
images, prompt = await chatml_prompt_from_messages(request.messages, img_tok = "<|image|>")

if len(images) < 1:
images = [ await url_to_image(black_pixel_url) ]
prompt = "<|image|>" + prompt

# process the image and text
inputs = self.processor.process(
images=images,
text=prompt,
)

inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}

default_params = dict(
eos_token_id=self.eos_token_id,
pad_token_id=self.eos_token_id
)

params = self.get_generation_params(request, default_params)

generation_kwargs = dict(
batch=inputs,
generation_config=GenerationConfig(**params)
)

for new_text in threaded_streaming_generator(generate=self.model.generate_from_batch, tokenizer=self.processor.tokenizer, generation_kwargs=generation_kwargs):
end = new_text.find(self.processor.tokenizer.eos_token)
if end == -1:
yield new_text
else:
yield new_text[:end]
break
43 changes: 43 additions & 0 deletions debug_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python

import torch
import re
import sys
import pkg_resources

def get_cuda_info():
print("---- cuda")
print(f"CUDA version: {torch.version.cuda}")
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
for i in range(0, torch.cuda.device_count()):
print(f"CUDA device[{i}]{torch.cuda.get_device_capability(i)}: {torch.cuda.get_device_name(i)}")

def get_python_version():
print("---- python")
print(sys.version)

def get_pip_packages():
print("---- pip")
try:
packages = set(["transformers"])
with open("requirements.txt", "r") as f:
for line in f.readlines():
line = line.strip()
if not line or line.startswith("#") or line.startswith("http:") or line.startswith("https:") or line.startswith("git+"):
continue
package = re.split(r"[=<>;#\[ ]", line)[0]
packages.add(package)

for package in sorted(list(packages)):
try:
version = pkg_resources.get_distribution(package).version
print(f"{package}=={version}")
except pkg_resources.DistributionNotFound:
print(f"{package}: Not found")

except FileNotFoundError:
print("requirements.txt not found")

get_cuda_info()
get_python_version()
get_pip_packages()
14 changes: 14 additions & 0 deletions model_conf_tests.alt.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
[
["omlab/omchat-v2.0-13B-single-beta_hf", "-A", "flash_attention_2"],
["BAAI/Bunny-Llama-3-8B-V", "--load-in-4bit"],
["BAAI/Bunny-Llama-3-8B-V"],
["BAAI/Bunny-v1_1-Llama-3-8B-V", "--load-in-4bit"],
["BAAI/Bunny-v1_1-Llama-3-8B-V"],
["HuggingFaceM4/idefics2-8b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["HuggingFaceM4/idefics2-8b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-AWQ", "-A", "flash_attention_2", "--device-map", "cuda:0"],
Expand All @@ -18,15 +23,24 @@
["THUDM/cogvlm2-llama3-chat-19B"],
["THUDM/cogvlm2-llama3-chinese-chat-19B", "--load-in-4bit"],
["THUDM/cogvlm2-llama3-chinese-chat-19B"],
["THUDM/glm-4v-9b", "--device-map", "cuda:0", "--load-in-4bit"],
["THUDM/glm-4v-9b", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-clip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-clip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-siglip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-siglip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-72b", "-A", "flash_attention_2", "--load-in-4bit", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-7b", "-A", "flash_attention_2", "--load-in-4bit", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["llava-hf/llava-v1.6-mistral-7b-hf", "-A", "flash_attention_2", "--load-in-4bit"],
["llava-hf/llava-v1.6-mistral-7b-hf", "-A", "flash_attention_2"],
["omlab/omchat-v2.0-13B-single-beta_hf", "-A", "flash_attention_2"],
["openbmb/MiniCPM-V", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["openbmb/MiniCPM-V-2", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V-2", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["qihoo360/360VL-8B", "-A", "flash_attention_2", "--load-in-4bit"],
["qihoo360/360VL-8B", "-A", "flash_attention_2"],
["tiiuae/falcon-11B-vlm", "-A", "flash_attention_2", "--load-in-4bit"],
["tiiuae/falcon-11B-vlm", "-A", "flash_attention_2"]
]
3 changes: 3 additions & 0 deletions model_conf_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
["llava-hf/llava-v1.6-vicuna-7b-hf", "-A", "flash_attention_2"],
["lmms-lab/llava-onevision-qwen2-0.5b-ov", "-A", "flash_attention_2"],
["lmms-lab/llava-onevision-qwen2-7b-ov", "-A", "flash_attention_2"],
["meta-llama/Llama-3.2-11B-Vision-Instruct", "-A", "flash_attention_2", "--load-in-4bit"],
["meta-llama/Llama-3.2-11B-Vision-Instruct", "-A", "flash_attention_2"],
["meta-llama/Llama-3.2-90B-Vision-Instruct", "-A", "flash_attention_2", "--load-in-4bit"],
["microsoft/Florence-2-base-ft", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-base-ft", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["microsoft/Florence-2-large-ft", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
Expand Down
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,10 @@ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# mistral
mistral_inference>=1.4.0
mistral_common[opencv]>=1.4.3

# got-ocr2
verovio

# molmo... 1GB of dependencies? wait till it's removed.
# https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/image_preprocessing_molmo.py#L88-L90
#tensorflow-cpu
Loading

0 comments on commit 6c45b53

Please sign in to comment.