From f1135dc6371eded115f164c1892e05ce1b41eeb9 Mon Sep 17 00:00:00 2001 From: Salman Maqbool Date: Thu, 6 Jun 2024 11:53:54 +0500 Subject: [PATCH] [WIP] Added support for converting LayoutLMv3 to CoreML --- src/exporters/coreml/config.py | 60 ++++++++++++++++++++++++++++++++ src/exporters/coreml/convert.py | 57 ++++++++++++++++++++++++++++++ src/exporters/coreml/features.py | 4 +++ src/exporters/coreml/models.py | 57 ++++++++++++++++++++++++++++++ 4 files changed, 178 insertions(+) diff --git a/src/exporters/coreml/config.py b/src/exporters/coreml/config.py index 33f529d..7907728 100644 --- a/src/exporters/coreml/config.py +++ b/src/exporters/coreml/config.py @@ -327,6 +327,40 @@ def _input_descriptions(self) -> "OrderedDict[str, InputDescription]": ), ] ) + + if self.modality == "multimodal" and self.task in [ + "token-classification", + ]: + return OrderedDict( + [ + ( + "input_ids", + InputDescription( + "input_ids", + "Indices of input sequence tokens in the vocabulary", + sequence_length=self.input_ids_sequence_length, + ) + ), + ( + "bbox", + InputDescription( + "bbox", + "Bounding Boxes" + ) + ), + ( + "attention_mask", + InputDescription( + "attention_mask", + "Mask to avoid performing attention on padding token indices (1 = not masked, 0 = masked)", + ) + ), + ( + "pixel_values", + InputDescription("image", "Input image", color_layout="RGB") + ), + ] + ) if self.task == "image-classification": return OrderedDict( @@ -931,6 +965,32 @@ def generate_dummy_inputs( bool_masked_pos = np.random.randint(low=0, high=2, size=(1, num_patches)).astype(bool) dummy_inputs["bool_masked_pos"] = (bool_masked_pos, bool_masked_pos.astype(np.int32)) + elif self.modality == "multimodal": + input_ids_name = "input_ids" + attention_mask_name = "attention_mask" + + input_desc = input_descs[input_ids_name] + + # the dummy input will always use the maximum sequence length + sequence_length = self._get_max_sequence_length(input_desc, 64) + + shape = (batch_size, sequence_length) + + input_ids = np.random.randint(0, preprocessor.tokenizer.vocab_size, shape) + dummy_inputs[input_ids_name] = (input_ids, input_ids.astype(np.int32)) + + if attention_mask_name in input_descs: + attention_mask = np.ones(shape, dtype=np.int64) + dummy_inputs[attention_mask_name] = (attention_mask, attention_mask.astype(np.int32)) + + bbox_shape = (batch_size, sequence_length, 4) + bboxes = np.random.randint(low=0, high=1000, size=bbox_shape).astype(np.int64) + dummy_inputs["bbox"] = (bboxes, bboxes.astype(np.int32)) + + dummy_inputs["pixel_values"] = self._generate_dummy_image(preprocessor.image_processor, framework) + + print(dummy_inputs) + elif self.modality == "audio" and isinstance(preprocessor, ProcessorMixin): if self.seq2seq != "decoder": if "input_features" in input_descs: diff --git a/src/exporters/coreml/convert.py b/src/exporters/coreml/convert.py index d8ab75a..8ce9ff6 100644 --- a/src/exporters/coreml/convert.py +++ b/src/exporters/coreml/convert.py @@ -254,6 +254,63 @@ def get_input_types( ) ) + elif config.modality == "multimodal": + input_ids_name = "input_ids" + attention_mask_name = "attention_mask" + + input_desc = input_descs[input_ids_name] + dummy_input = dummy_inputs[input_ids_name] + shape = get_shape(config, input_desc, dummy_input) + input_types.append( + ct.TensorType(name=input_desc.name, shape=shape, dtype=np.int32) + ) + + if attention_mask_name in input_descs: + input_desc = input_descs[attention_mask_name] + input_types.append( + ct.TensorType(name=input_desc.name, shape=shape, dtype=np.int32) + ) + else: + logger.info(f"Skipping {attention_mask_name} input") + + bbox_desc = input_descs["bbox"] + dummy_input = dummy_inputs["bbox"] + shape = get_shape(config, bbox_desc, dummy_input) + input_types.append( + ct.TensorType(name=bbox_desc.name, shape=shape, dtype=np.int32) + ) + + if hasattr(preprocessor.image_processor, "image_mean"): + bias = [ + -preprocessor.image_processor.image_mean[0], + -preprocessor.image_processor.image_mean[1], + -preprocessor.image_processor.image_mean[2], + ] + else: + bias = [0.0, 0.0, 0.0] + + # If the stddev values are all equal, they can be folded into `bias` and + # `scale`. If not, Wrapper will insert an additional division operation. + if hasattr(preprocessor.image_processor, "image_std") and is_image_std_same(preprocessor.image_processor): + bias[0] /= preprocessor.image_processor.image_std[0] + bias[1] /= preprocessor.image_processor.image_std[1] + bias[2] /= preprocessor.image_processor.image_std[2] + scale = 1.0 / (preprocessor.image_processor.image_std[0] * 255.0) + else: + scale = 1.0 / 255 + + input_desc = input_descs["pixel_values"] + input_types.append( + ct.ImageType( + name=input_desc.name, + shape=dummy_inputs["pixel_values"][0].shape, + scale=scale, + bias=bias, + color_layout=input_desc.color_layout or "RGB", + channel_first=True, + ) + ) + elif config.modality == "audio": if "input_features" in input_descs: input_desc = input_descs["input_features"] diff --git a/src/exporters/coreml/features.py b/src/exporters/coreml/features.py index 0875ec6..f2d1fe6 100644 --- a/src/exporters/coreml/features.py +++ b/src/exporters/coreml/features.py @@ -278,6 +278,10 @@ class FeaturesManager: "text-classification", coreml_config_cls="models.gpt_neox.GPTNeoXCoreMLConfig", ), + "layoutlmv3": supported_features_mapping( + "token-classification", + coreml_config_cls="models.layoutlmv3.LayoutLMv3CoreMLConfig", + ), "levit": supported_features_mapping( "feature-extraction", "image-classification", coreml_config_cls="models.levit.LevitCoreMLConfig" ), diff --git a/src/exporters/coreml/models.py b/src/exporters/coreml/models.py index 15ae53d..555b4c4 100644 --- a/src/exporters/coreml/models.py +++ b/src/exporters/coreml/models.py @@ -348,6 +348,63 @@ def inputs(self) -> OrderedDict[str, InputDescription]: return input_descs +class LayoutLMv3CoreMLConfig(CoreMLConfig): + modality = "multimodal" + + @property + def outputs(self) -> OrderedDict[str, OutputDescription]: + output_descs = super().outputs + self._add_pooler_output(output_descs) + return output_descs + + def patch_pytorch_ops(self): + def clip(context, node): + from coremltools.converters.mil import Builder as mb + from coremltools.converters.mil.frontend.torch.ops import _get_inputs + from coremltools.converters.mil.mil.var import Var + import numpy as _np + from coremltools.converters.mil.mil import types + from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes + inputs = _get_inputs(context, node, expected=[1,2,3]) + x = inputs[0] + min_val = inputs[1] if (len(inputs) > 1 and inputs[1]) else mb.const(val=_np.finfo(_np.float32).min) + max_val = inputs[2] if (len(inputs) > 2 and inputs[2]) else mb.const(val=_np.finfo(_np.float32).max) + + if isinstance(min_val, Var) and isinstance(max_val, Var) and min_val.val >= max_val.val: + # When min >= max, PyTorch sets all values to max. + context.add(mb.fill(shape=mb.shape(x=x), value=max_val.val, name=node.name)) + return + + is_input_int = types.is_int(x.dtype) + if not types.is_float(x.dtype): + # The `mb.clip` op requires parameters from type domain ['fp16', 'fp32']. + x = mb.cast(x=x, dtype="fp32") + x, min_val, max_val = promote_input_dtypes([x, min_val, max_val]) + if is_input_int: + clip_res = mb.clip(x=x, alpha=min_val, beta=max_val) + context.add(mb.cast(x=clip_res, dtype="int32", name=node.name)) + else: + context.add(mb.clip(x=x, alpha=min_val, beta=max_val, name=node.name)) + + def one_hot(context, node): + from coremltools.converters.mil import Builder as mb + from coremltools.converters.mil.frontend.torch.ops import _get_inputs + from coremltools.converters.mil.mil import types + from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes + inputs = _get_inputs(context, node, expected=[2,3]) + indices = inputs[0] + num_classes = inputs[1] + + if not types.is_int(indices.dtype): + indices = mb.cast(x=indices, dtype="int32") + if not types.is_int(num_classes.dtype): + num_classes = mb.cast(x=num_classes, dtype="int32") + indices, num_classes = promote_input_dtypes([indices, num_classes]) + one_hot_res = mb.one_hot(indices=indices, one_hot_vector_size=num_classes) + context.add(one_hot_res, node.name) + + return {"clip": clip, "one_hot": one_hot} + class LevitCoreMLConfig(CoreMLConfig): modality = "vision"