Skip to content

Commit 1eb85ec

Browse files
ngxsontheo77186
authored andcommitted
model : add LightOnOCR-1B model (ggml-org#16764)
* model : add LightOnOCR-1B model * add test
1 parent 19b0a00 commit 1eb85ec

File tree

6 files changed

+56
-5
lines changed

6 files changed

+56
-5
lines changed

convert_hf_to_gguf.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,18 +2460,21 @@ def set_gguf_parameters(self):
24602460
)
24612461
class LlavaVisionModel(MmprojModel):
24622462
img_break_tok_id = -1
2463+
use_break_tok = True
24632464

24642465
def __init__(self, *args, **kwargs):
24652466
super().__init__(*args, **kwargs)
24662467
if self.hparams.get("model_type") == "pixtral":
24672468
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
24682469
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
2469-
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
2470+
if self.use_break_tok:
2471+
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
24702472
elif self.is_mistral_format:
24712473
# hparams is already vision config here so norm_eps is only defined in global_config.
24722474
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
24732475
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
2474-
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
2476+
if self.use_break_tok:
2477+
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
24752478
else:
24762479
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
24772480
logger.info(f"Image break token id: {self.img_break_tok_id}")
@@ -3998,6 +4001,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
39984001
return torch.stack([true_row, false_row], dim=0)
39994002

40004003
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4004+
if "model.vision_" in name:
4005+
# skip multimodal tensors
4006+
return []
4007+
40014008
if self.is_rerank:
40024009
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
40034010
is_real_head = not self.is_tied_embeddings and "lm_head" in name
@@ -9670,6 +9677,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
96709677
return super().map_tensor_name(name, try_suffixes)
96719678

96729679

9680+
@ModelBase.register("LightOnOCRForConditionalGeneration")
9681+
class LightOnOCRVisionModel(LlavaVisionModel):
9682+
is_mistral_format = False
9683+
use_break_tok = False
9684+
9685+
def set_gguf_parameters(self):
9686+
super().set_gguf_parameters()
9687+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
9688+
9689+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
9690+
name = name.replace("model.vision_encoder.", "vision_tower.")
9691+
name = name.replace("model.vision_projection.", "multi_modal_projector.")
9692+
return super().modify_tensors(data_torch, name, bid)
9693+
9694+
96739695
@ModelBase.register("KimiVLForConditionalGeneration")
96749696
class KimiVLModel(MmprojModel):
96759697
def __init__(self, *args, **kwargs):

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3113,6 +3113,7 @@ class VisionProjectorType:
31133113
VOXTRAL = "voxtral"
31143114
LFM2 = "lfm2"
31153115
KIMIVL = "kimivl"
3116+
LIGHTONOCR = "lightonocr"
31163117

31173118

31183119
# Items here are (block size, type size)

tools/mtmd/clip-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ enum projector_type {
144144
PROJECTOR_TYPE_VOXTRAL,
145145
PROJECTOR_TYPE_LFM2,
146146
PROJECTOR_TYPE_KIMIVL,
147+
PROJECTOR_TYPE_LIGHTONOCR,
147148
PROJECTOR_TYPE_UNKNOWN,
148149
};
149150

@@ -167,6 +168,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
167168
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
168169
{ PROJECTOR_TYPE_LFM2, "lfm2"},
169170
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
171+
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
170172
};
171173

172174
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ struct clip_graph {
631631
}
632632

633633
// arrangement of the [IMG_BREAK] token
634-
{
634+
if (model.token_embd_img_break) {
635635
// not efficient, but works
636636
// the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
637637
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
@@ -2289,6 +2289,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
22892289
res = graph.build_siglip();
22902290
} break;
22912291
case PROJECTOR_TYPE_PIXTRAL:
2292+
case PROJECTOR_TYPE_LIGHTONOCR:
22922293
{
22932294
res = graph.build_pixtral();
22942295
} break;
@@ -2581,6 +2582,7 @@ struct clip_model_loader {
25812582
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
25822583
} break;
25832584
case PROJECTOR_TYPE_PIXTRAL:
2585+
case PROJECTOR_TYPE_LIGHTONOCR:
25842586
{
25852587
hparams.rope_theta = 10000.0f;
25862588
hparams.warmup_image_size = hparams.patch_size * 8;
@@ -2966,6 +2968,15 @@ struct clip_model_loader {
29662968
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
29672969
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
29682970
} break;
2971+
case PROJECTOR_TYPE_LIGHTONOCR:
2972+
{
2973+
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
2974+
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
2975+
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
2976+
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
2977+
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
2978+
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
2979+
} break;
29692980
case PROJECTOR_TYPE_ULTRAVOX:
29702981
{
29712982
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
@@ -3881,7 +3892,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
38813892
res_imgs->entries.push_back(std::move(img_f32));
38823893
return true;
38833894

3884-
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
3895+
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
3896+
|| ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
3897+
) {
38853898
clip_image_u8 resized_image;
38863899
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
38873900
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
@@ -4125,12 +4138,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
41254138
n_patches = x_patch * y_patch;
41264139
} break;
41274140
case PROJECTOR_TYPE_PIXTRAL:
4141+
case PROJECTOR_TYPE_LIGHTONOCR:
41284142
{
41294143
// dynamic size
41304144
int n_merge = params.spatial_merge_size;
41314145
int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
41324146
int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
4133-
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
4147+
if (ctx->model.token_embd_img_break) {
4148+
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
4149+
} else {
4150+
n_patches = n_patches_y * n_patches_x;
4151+
}
41344152
} break;
41354153
case PROJECTOR_TYPE_VOXTRAL:
41364154
case PROJECTOR_TYPE_ULTRAVOX:
@@ -4508,6 +4526,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
45084526
} break;
45094527
case PROJECTOR_TYPE_PIXTRAL:
45104528
case PROJECTOR_TYPE_KIMIVL:
4529+
case PROJECTOR_TYPE_LIGHTONOCR:
45114530
{
45124531
// set the 2D positions
45134532
int n_patches_per_col = image_size_width / patch_size;
@@ -4638,6 +4657,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
46384657
return ctx->model.mm_model_peg_0_b->ne[0];
46394658
case PROJECTOR_TYPE_MLP:
46404659
case PROJECTOR_TYPE_PIXTRAL:
4660+
case PROJECTOR_TYPE_LIGHTONOCR:
46414661
return ctx->model.mm_2_w->ne[1];
46424662
case PROJECTOR_TYPE_MLP_NORM:
46434663
return ctx->model.mm_3_b->ne[0];

tools/mtmd/mtmd.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ struct mtmd_context {
275275
img_beg = "<img>";
276276
img_end = "</img>";
277277

278+
} else if (proj == PROJECTOR_TYPE_LIGHTONOCR) {
279+
// <|im_start|> ... (image embeddings) ... <|im_end|>
280+
img_beg = "<|im_start|>";
281+
img_end = "<|im_end|>";
282+
278283
}
279284
}
280285

tools/mtmd/tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
7070
add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
7171
add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0"
7272
add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0"
73+
add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
7374

7475
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
7576
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"

0 commit comments

Comments
 (0)