diff --git a/README.md b/README.md index 6f12620..427d959 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/ - [X] [Mistral AI](https://huggingface.co/mistralai) - - [X] [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409) - [X] [mx262/MiniMonkey](https://huggingface.co/mx262/MiniMonkey) +- [X] [nvidia/NVLM-D-72B](https://huggingface.co/nvidia/NVLM-D-72B) - [X] [omlab/omchat-v2.0-13B-single-beta_hf](https://huggingface.co/omlab/omchat-v2.0-13B-single-beta_hf) (alt docker) - [X] [openbmb](https://huggingface.co/openbmb) - - [X] [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) (video not supported yet) @@ -157,6 +158,10 @@ If you can't find your favorite model, you can [open a new issue](https://github ## Recent updates +Version 0.37.0 + +- new model support: nvidia/NVLM-D-72B + Version 0.36.0 - new model support: BAAI/Emu3-Chat diff --git a/backend/nvlm.py b/backend/nvlm.py new file mode 100644 index 0000000..b07e91c --- /dev/null +++ b/backend/nvlm.py @@ -0,0 +1,155 @@ +from transformers import AutoTokenizer, AutoModel +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode + +from vision_qna import * + +# nvidia/NVLM-D-72B + +MAX_TILES = 6 + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + +def dynamic_preprocess(image, min_num=1, max_num=MAX_TILES, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image, input_size=448, max_num=MAX_TILES): + #image = Image.open(image_file).convert('RGB') + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +class VisionQnA(VisionQnABase): + model_name: str = "nvlm" + format: str = "chatml" + vision_layers: List[str] = ["vision_model"] + + def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None): + super().__init__(model_id, device, device_map, extra_params, format) + + self.max_tiles = extra_params.get('max_tiles', MAX_TILES) + + self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=self.params.get('trust_remote_code', False)) + self.model = AutoModel.from_pretrained(**self.params).eval() + + self.eos_token = '<|im_end|>' + self.IMG_CONTEXT_TOKEN='<|vision_pad|>' + self.IMG_START_TOKEN = '' # <|vision_start|> ? + self.IMG_END_TOKEN = '' # <|vision_end|> ? + self.model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN) + + # bitsandbytes already moves the model to the device, so we don't need to do it again. + if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): + self.model = self.model.to(self.device) + + self.loaded_banner() + + async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]: + images, prompt = await prompt_from_messages(request.messages, self.format) + + if len(images) < 1: + pixel_values = None + else: + pixel_values = load_image(images[-1], max_num=self.max_tiles).to(self.model.dtype).cuda() + + for num_patches in [pixel_values.shape[0]]: + tile_pos_identifiers = [f"" for i in range(1, num_patches)] + [""] + image_tokens = '' + for tile_pos_identifier in tile_pos_identifiers: + image_tokens += tile_pos_identifier + self.IMG_CONTEXT_TOKEN * self.model.num_image_token + image_tokens = self.IMG_START_TOKEN + image_tokens + self.IMG_END_TOKEN + prompt = prompt.replace('', image_tokens, 1) + + model_inputs = self.tokenizer(prompt, return_tensors='pt') + input_ids = model_inputs['input_ids'].cuda() + attention_mask = model_inputs['attention_mask'].cuda() + + default_params = dict( + max_new_tokens=1024, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + ) + + params = self.get_generation_params(request, default_params) + + del params['use_cache'] + + generation_kwargs = dict( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **params, + ) + + for new_text in threaded_streaming_generator(generate=self.model.generate, tokenizer=self.tokenizer, generation_kwargs=generation_kwargs): + end = new_text.find(self.eos_token) + if end == -1: + yield new_text + else: + yield new_text[:end] + break diff --git a/model_conf_tests.json b/model_conf_tests.json index a3b59cf..c9f499c 100644 --- a/model_conf_tests.json +++ b/model_conf_tests.json @@ -101,6 +101,7 @@ ["mistralai/Pixtral-12B-2409"], ["mx262/MiniMonkey", "-A", "flash_attention_2", "--load-in-4bit"], ["mx262/MiniMonkey", "-A", "flash_attention_2"], + ["nvidia/NVLM-D-72B", "-A", "flash_attention_2", "--load-in-4bit"], ["openbmb/MiniCPM-V-2_6-int4", "-A", "flash_attention_2", "--device-map", "cuda:0"], ["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], ["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0"], diff --git a/test_api_model.py b/test_api_model.py index 26407ca..44812d6 100755 --- a/test_api_model.py +++ b/test_api_model.py @@ -98,11 +98,13 @@ def generate_response(image_url, prompt): ]}]) response = client.chat.completions.create(model=args.openai_model, messages=messages, **params) + completion_tokens = 0 answer = response.choices[0].message.content - return answer + if response.usage: + completion_tokens = response.usage.completion_tokens + return answer, completion_tokens def generate_stream_response(image_url, prompt): - messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else [] messages.extend([ { "role": "user", "content": [ @@ -112,51 +114,45 @@ def generate_stream_response(image_url, prompt): response = client.chat.completions.create(model=args.openai_model, messages=messages, **params, stream=True) answer = '' + completion_tokens = 0 for chunk in response: if chunk.choices[0].delta.content: answer += chunk.choices[0].delta.content - - return answer + if chunk.usage: + completion_tokens = chunk.usage.completion_tokens + + return answer, completion_tokens if True: # XXX TODO: timeout results = [] ### Single round + timing = [] - test_time = time.time() - - # url tests - for name, url in urls.items(): - answer = generate_response(url, "What is the subject of the image?") + def single_test(url, question, label, generator=generate_response): + tps_time = time.time() + answer, tok = generator(url, question) + tps_time = time.time() - tps_time correct = name in answer.lower() results.extend([correct]) if not correct: - print(f"{name}[url]: fail, got: {answer}") - if args.abort_on_fail: - break + print(f"{name}[{label}]: fail, got: {answer}") + #if args.abort_on_fail: + # break else: - print(f"{name}[url]: pass{', got: ' + answer if args.verbose else ''}") + print(f"{name}[{label}]: pass{', got: ' + answer if args.verbose else ''}") + if tok > 1: + timing.extend([(tok, tps_time)]) - data_url = data_url_from_url(url) - answer = generate_response(data_url, "What is the subject of the image?") - correct = name in answer.lower() - results.extend([correct]) - if not correct: - print(f"{name}[data]: fail, got: {answer}") - if args.abort_on_fail: - break - else: - print(f"{name}[data]: pass{', got: ' + answer if args.verbose else ''}") + test_time = time.time() - answer = generate_stream_response(data_url, "What is the subject of the image?") - correct = name in answer.lower() - results.extend([correct]) - if not correct: - print(f"{name}[data_stream]: fail, got: {answer}") - if args.abort_on_fail: - break - else: - print(f"{name}[data_stream]: pass{', got: ' + answer if args.verbose else ''}") + # url tests + for name, url in urls.items(): + single_test(url, "What is the subject of the image?", "url", generate_response) + + data_url = data_url_from_url(url) + single_test(data_url, "What is the subject of the image?", "data", generate_response) + single_test(data_url, "What is the subject of the image?", "data_stream", generate_stream_response) ## OCR tests @@ -166,15 +162,7 @@ def generate_stream_response(image_url, prompt): } for name, question in quality_urls.items(): prompt, data_url = question - answer = generate_stream_response(data_url, prompt) - correct = name in answer.lower() or 'wal-mart' in answer.lower() - results.extend([correct]) - if not correct: - print(f"{name}[quality]: fail, got: {answer}") - if args.abort_on_fail: - break - else: - print(f"{name}[quality]: pass{', got: ' + answer if args.verbose else ''}") + single_test(data_url, prompt, "quality", generate_stream_response) # No image tests no_image = { @@ -204,5 +192,13 @@ def no_image_response(prompt): result = all(results) note = f'{results.count(True)}/{len(results)} tests passed.' + if timing: + tok_total, tim_total = 0, 0.0 + for tok, tim in timing: + if tok > 1 and tim > 0: + tok_total += tok + tim_total += tim + if tim_total > 0.0: + note += f', ({tok_total}/{tim_total:0.1f}s) {tok_total/tim_total:0.1f} T/s' print(f"test {green_pass if results else red_fail}, time: {test_time:.1f}s, {note}") diff --git a/vision.sample.env b/vision.sample.env index 4669e6e..b086afe 100644 --- a/vision.sample.env +++ b/vision.sample.env @@ -108,6 +108,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 #CLI_COMMAND="python vision.py -m mistralai/Pixtral-12B-2409" # test pass✅, time: 16.0s, mem: 35.5GB, 13/13 tests passed (manual calc) 12.7 T/s #CLI_COMMAND="python vision.py -m mx262/MiniMonkey -A flash_attention_2 --load-in-4bit" # test pass✅, time: 11.1s, mem: 13.9GB, 13/13 tests passed, (37/3.1s) 11.7 T/s #CLI_COMMAND="python vision.py -m mx262/MiniMonkey -A flash_attention_2" # test pass✅, time: 10.0s, mem: 16.3GB, 13/13 tests passed, (37/2.8s) 13.0 T/s +#CLI_COMMAND="python vision.py -m nvidia/NVLM-D-72B -A flash_attention_2 --load-in-4bit" # test pass✅, time: 62.0s, mem: 56.7GB, 13/13 tests passed, (66/19.7s) 3.3 T/s #CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6-int4 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 19.0s, mem: 9.2GB, 13/13 tests passed, (93/5.2s) 18.0 T/s #CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6 -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 15.8s, mem: 9.5GB, 13/13 tests passed, (99/4.4s) 22.5 T/s #CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 13.3s, mem: 18.8GB, 13/13 tests passed, (101/3.4s) 30.1 T/s diff --git a/vision_qna.py b/vision_qna.py index 1bbbce4..9b64891 100644 --- a/vision_qna.py +++ b/vision_qna.py @@ -945,6 +945,9 @@ def guess_backend(model_name: str) -> str: if 'florence' in model_id: return 'florence' + if 'nvlm' in model_id: + return 'nvlm' + if 'internvl-chat' in model_id and '-v1-5' in model_id: return 'internvl-chat-v1-5'