Skip to content

Commit 3aeb19a

Browse files
[Model] Add support for LightOnOCR (#26916)
Signed-off-by: Said Taghadouini <taghadouinisaid@gmail.com> Signed-off-by: Said Taghadouini <84044788+staghado@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 8c017b3 commit 3aeb19a

File tree

5 files changed

+225
-0
lines changed

5 files changed

+225
-0
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
663663
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
664664
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
665665
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
666+
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
666667
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ |
667668
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
668669
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ |

examples/offline_inference/vision_language.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,26 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
734734
)
735735

736736

737+
# LightOnOCR
738+
def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData:
739+
assert modality == "image"
740+
741+
prompts = [
742+
"<|im_start|>system<|im_end|>\n<|im_start|>user\n<|image_pad|><|im_end|>\n<|im_start|>assistant\n"
743+
for _ in questions
744+
]
745+
746+
engine_args = EngineArgs(
747+
model="lightonai/LightOnOCR-1B",
748+
limit_mm_per_prompt={modality: 1},
749+
)
750+
751+
return ModelRequestData(
752+
engine_args=engine_args,
753+
prompts=prompts,
754+
)
755+
756+
737757
def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
738758
assert modality == "image"
739759

@@ -1709,6 +1729,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
17091729
"keye_vl": run_keye_vl,
17101730
"keye_vl1_5": run_keye_vl1_5,
17111731
"kimi_vl": run_kimi_vl,
1732+
"lightonocr": run_lightonocr,
17121733
"llama4": run_llama4,
17131734
"llava": run_llava,
17141735
"llava-next": run_llava_next,

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,10 @@ def check_available_online(
652652
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
653653
trust_remote_code=True,
654654
),
655+
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
656+
"lightonai/LightOnOCR-1B",
657+
is_available_online=False,
658+
),
655659
"Llama4ForConditionalGeneration": _HfExamplesInfo(
656660
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
657661
max_model_len=10240,
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Iterable, Mapping, Sequence
4+
from typing import TypeVar
5+
6+
import torch
7+
import torch.nn as nn
8+
from transformers import (
9+
BatchFeature,
10+
PixtralVisionConfig,
11+
)
12+
13+
from vllm.config import VllmConfig
14+
from vllm.model_executor.models.mistral3 import (
15+
Mistral3DummyInputsBuilder,
16+
Mistral3ForConditionalGeneration,
17+
Mistral3MultiModalProjector,
18+
Mistral3ProcessingInfo,
19+
_build_mistral3_info,
20+
init_vision_tower_for_llava,
21+
)
22+
from vllm.model_executor.models.pixtral import PixtralHFEncoderInfo
23+
from vllm.model_executor.models.utils import (
24+
AutoWeightsLoader,
25+
WeightsMapper,
26+
init_vllm_registered_model,
27+
maybe_prefix,
28+
)
29+
from vllm.multimodal import MULTIMODAL_REGISTRY
30+
from vllm.multimodal.cache import BaseMultiModalProcessorCache
31+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
32+
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
33+
from vllm.multimodal.processing import (
34+
BaseMultiModalProcessor,
35+
PromptReplacement,
36+
PromptUpdate,
37+
PromptUpdateDetails,
38+
)
39+
from vllm.multimodal.profiling import BaseDummyInputsBuilder
40+
41+
_I = TypeVar("_I", bound=Mistral3ProcessingInfo)
42+
43+
44+
class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
45+
def _call_hf_processor(
46+
self,
47+
prompt: str,
48+
mm_data: Mapping[str, object],
49+
mm_kwargs: Mapping[str, object],
50+
tok_kwargs: Mapping[str, object],
51+
) -> BatchFeature:
52+
processed_outputs = super()._call_hf_processor(
53+
prompt=prompt,
54+
mm_data=mm_data,
55+
mm_kwargs=mm_kwargs,
56+
tok_kwargs=tok_kwargs,
57+
)
58+
59+
# NOTE: LightOnOCR does not use break/end tokens, so we remove them here.
60+
input_ids = processed_outputs.get("input_ids")
61+
if input_ids is not None:
62+
processor = self.info.get_hf_processor()
63+
tokenizer = self.info.get_tokenizer()
64+
vocab = tokenizer.get_vocab()
65+
66+
break_id = vocab.get(processor.image_break_token)
67+
end_id = vocab.get(processor.image_end_token)
68+
69+
# create mask to remove break/end tokens
70+
keep_mask = ~torch.isin(
71+
input_ids,
72+
torch.tensor([break_id, end_id]),
73+
)
74+
75+
processed_outputs["input_ids"] = input_ids[keep_mask].unsqueeze(0)
76+
if "attention_mask" in processed_outputs:
77+
processed_outputs["attention_mask"] = processed_outputs[
78+
"attention_mask"
79+
][keep_mask].unsqueeze(0)
80+
81+
# un-pad pixel_values per-image so caches remain independent.
82+
pixel_values = processed_outputs.get("pixel_values")
83+
if pixel_values is not None:
84+
image_sizes = processed_outputs["image_sizes"]
85+
assert len(pixel_values) == len(image_sizes)
86+
processed_outputs["pixel_values"] = [
87+
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
88+
]
89+
90+
return processed_outputs
91+
92+
def _get_mm_fields_config(
93+
self,
94+
hf_inputs: BatchFeature,
95+
hf_processor_mm_kwargs: Mapping[str, object],
96+
) -> Mapping[str, MultiModalFieldConfig]:
97+
return dict(
98+
pixel_values=MultiModalFieldConfig.batched("image"),
99+
image_embeds=MultiModalFieldConfig.batched("image"),
100+
)
101+
102+
def _get_prompt_updates(
103+
self,
104+
mm_items: MultiModalDataItems,
105+
hf_processor_mm_kwargs: Mapping[str, object],
106+
out_mm_kwargs: MultiModalKwargs,
107+
) -> Sequence[PromptUpdate]:
108+
hf_config = self.info.get_hf_config()
109+
image_token_id = hf_config.image_token_index
110+
111+
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
112+
encoder_info = PixtralHFEncoderInfo(hf_config)
113+
114+
def replace(item_idx: int):
115+
images = mm_items.get_items("image", ImageProcessorItems)
116+
size = images.get_image_size(item_idx)
117+
ncols, nrows = encoder_info.get_patch_grid_size(
118+
image_width=size.width, image_height=size.height
119+
)
120+
# break/end tokens are not used in LightOnOCR
121+
tokens = [image_token_id] * (ncols * nrows)
122+
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
123+
124+
return [
125+
PromptReplacement(
126+
modality="image", target=[image_token_id], replacement=replace
127+
)
128+
]
129+
130+
131+
def _build_LightOnOCR_processor(
132+
info: _I,
133+
dummy_inputs: BaseDummyInputsBuilder[_I],
134+
*,
135+
cache: BaseMultiModalProcessorCache | None = None,
136+
):
137+
assert isinstance(info, Mistral3ProcessingInfo)
138+
return LightOnOCRMultiModalProcessor(info, dummy_inputs, cache=cache)
139+
140+
141+
@MULTIMODAL_REGISTRY.register_processor(
142+
_build_LightOnOCR_processor,
143+
info=_build_mistral3_info,
144+
dummy_inputs=Mistral3DummyInputsBuilder,
145+
)
146+
class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
147+
hf_to_vllm_mapper = WeightsMapper(
148+
orig_to_new_prefix={
149+
"model.vision_encoder.": "vision_tower.",
150+
"model.vision_projection.": "multi_modal_projector.",
151+
"lm_head.": "language_model.lm_head.",
152+
"model.language_model.": "language_model.model.",
153+
}
154+
)
155+
156+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
157+
nn.Module.__init__(self)
158+
config = vllm_config.model_config.hf_config
159+
quant_config = vllm_config.quant_config
160+
multimodal_config = vllm_config.model_config.multimodal_config
161+
162+
self.config = config
163+
self.multimodal_config = multimodal_config
164+
165+
self.vision_tower = init_vision_tower_for_llava(
166+
config,
167+
quant_config,
168+
require_post_norm=False,
169+
prefix=maybe_prefix(prefix, "vision_tower"),
170+
)
171+
172+
self.multi_modal_projector = Mistral3MultiModalProjector(
173+
vision_hidden_size=config.vision_config.hidden_size,
174+
text_hidden_size=config.text_config.hidden_size,
175+
projector_hidden_act=config.projector_hidden_act,
176+
spatial_merge_size=config.spatial_merge_size,
177+
patch_size=config.vision_config.patch_size,
178+
multimodal_projector_bias=config.multimodal_projector_bias,
179+
quant_config=quant_config,
180+
prefix=maybe_prefix(prefix, "multi_modal_projector"),
181+
)
182+
183+
self.language_model = init_vllm_registered_model(
184+
vllm_config=vllm_config,
185+
hf_config=config.text_config,
186+
prefix=maybe_prefix(prefix, "language_model"),
187+
)
188+
189+
self.make_empty_intermediate_tensors = (
190+
self.language_model.make_empty_intermediate_tensors
191+
)
192+
193+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
194+
loader = AutoWeightsLoader(self)
195+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@
297297
),
298298
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
299299
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
300+
"LightOnOCRForConditionalGeneration": (
301+
"lightonocr",
302+
"LightOnOCRForConditionalGeneration",
303+
),
300304
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
301305
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
302306
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),

0 commit comments

Comments
 (0)