Skip to content

Commit

Permalink
[WIP] Added support for converting LayoutLMv3 to CoreML
Browse files Browse the repository at this point in the history
  • Loading branch information
salmanmaq committed Jun 10, 2024
1 parent 7a54597 commit f1135dc
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 0 deletions.
60 changes: 60 additions & 0 deletions src/exporters/coreml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,40 @@ def _input_descriptions(self) -> "OrderedDict[str, InputDescription]":
),
]
)

if self.modality == "multimodal" and self.task in [
"token-classification",
]:
return OrderedDict(
[
(
"input_ids",
InputDescription(
"input_ids",
"Indices of input sequence tokens in the vocabulary",
sequence_length=self.input_ids_sequence_length,
)
),
(
"bbox",
InputDescription(
"bbox",
"Bounding Boxes"
)
),
(
"attention_mask",
InputDescription(
"attention_mask",
"Mask to avoid performing attention on padding token indices (1 = not masked, 0 = masked)",
)
),
(
"pixel_values",
InputDescription("image", "Input image", color_layout="RGB")
),
]
)

if self.task == "image-classification":
return OrderedDict(
Expand Down Expand Up @@ -931,6 +965,32 @@ def generate_dummy_inputs(
bool_masked_pos = np.random.randint(low=0, high=2, size=(1, num_patches)).astype(bool)
dummy_inputs["bool_masked_pos"] = (bool_masked_pos, bool_masked_pos.astype(np.int32))

elif self.modality == "multimodal":
input_ids_name = "input_ids"
attention_mask_name = "attention_mask"

input_desc = input_descs[input_ids_name]

# the dummy input will always use the maximum sequence length
sequence_length = self._get_max_sequence_length(input_desc, 64)

shape = (batch_size, sequence_length)

input_ids = np.random.randint(0, preprocessor.tokenizer.vocab_size, shape)
dummy_inputs[input_ids_name] = (input_ids, input_ids.astype(np.int32))

if attention_mask_name in input_descs:
attention_mask = np.ones(shape, dtype=np.int64)
dummy_inputs[attention_mask_name] = (attention_mask, attention_mask.astype(np.int32))

bbox_shape = (batch_size, sequence_length, 4)
bboxes = np.random.randint(low=0, high=1000, size=bbox_shape).astype(np.int64)
dummy_inputs["bbox"] = (bboxes, bboxes.astype(np.int32))

dummy_inputs["pixel_values"] = self._generate_dummy_image(preprocessor.image_processor, framework)

print(dummy_inputs)

elif self.modality == "audio" and isinstance(preprocessor, ProcessorMixin):
if self.seq2seq != "decoder":
if "input_features" in input_descs:
Expand Down
57 changes: 57 additions & 0 deletions src/exporters/coreml/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,63 @@ def get_input_types(
)
)

elif config.modality == "multimodal":
input_ids_name = "input_ids"
attention_mask_name = "attention_mask"

input_desc = input_descs[input_ids_name]
dummy_input = dummy_inputs[input_ids_name]
shape = get_shape(config, input_desc, dummy_input)
input_types.append(
ct.TensorType(name=input_desc.name, shape=shape, dtype=np.int32)
)

if attention_mask_name in input_descs:
input_desc = input_descs[attention_mask_name]
input_types.append(
ct.TensorType(name=input_desc.name, shape=shape, dtype=np.int32)
)
else:
logger.info(f"Skipping {attention_mask_name} input")

bbox_desc = input_descs["bbox"]
dummy_input = dummy_inputs["bbox"]
shape = get_shape(config, bbox_desc, dummy_input)
input_types.append(
ct.TensorType(name=bbox_desc.name, shape=shape, dtype=np.int32)
)

if hasattr(preprocessor.image_processor, "image_mean"):
bias = [
-preprocessor.image_processor.image_mean[0],
-preprocessor.image_processor.image_mean[1],
-preprocessor.image_processor.image_mean[2],
]
else:
bias = [0.0, 0.0, 0.0]

# If the stddev values are all equal, they can be folded into `bias` and
# `scale`. If not, Wrapper will insert an additional division operation.
if hasattr(preprocessor.image_processor, "image_std") and is_image_std_same(preprocessor.image_processor):
bias[0] /= preprocessor.image_processor.image_std[0]
bias[1] /= preprocessor.image_processor.image_std[1]
bias[2] /= preprocessor.image_processor.image_std[2]
scale = 1.0 / (preprocessor.image_processor.image_std[0] * 255.0)
else:
scale = 1.0 / 255

input_desc = input_descs["pixel_values"]
input_types.append(
ct.ImageType(
name=input_desc.name,
shape=dummy_inputs["pixel_values"][0].shape,
scale=scale,
bias=bias,
color_layout=input_desc.color_layout or "RGB",
channel_first=True,
)
)

elif config.modality == "audio":
if "input_features" in input_descs:
input_desc = input_descs["input_features"]
Expand Down
4 changes: 4 additions & 0 deletions src/exporters/coreml/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ class FeaturesManager:
"text-classification",
coreml_config_cls="models.gpt_neox.GPTNeoXCoreMLConfig",
),
"layoutlmv3": supported_features_mapping(
"token-classification",
coreml_config_cls="models.layoutlmv3.LayoutLMv3CoreMLConfig",
),
"levit": supported_features_mapping(
"feature-extraction", "image-classification", coreml_config_cls="models.levit.LevitCoreMLConfig"
),
Expand Down
57 changes: 57 additions & 0 deletions src/exporters/coreml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,63 @@ def inputs(self) -> OrderedDict[str, InputDescription]:
return input_descs


class LayoutLMv3CoreMLConfig(CoreMLConfig):
modality = "multimodal"

@property
def outputs(self) -> OrderedDict[str, OutputDescription]:
output_descs = super().outputs
self._add_pooler_output(output_descs)
return output_descs

def patch_pytorch_ops(self):
def clip(context, node):
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.mil.var import Var
import numpy as _np
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes
inputs = _get_inputs(context, node, expected=[1,2,3])
x = inputs[0]
min_val = inputs[1] if (len(inputs) > 1 and inputs[1]) else mb.const(val=_np.finfo(_np.float32).min)
max_val = inputs[2] if (len(inputs) > 2 and inputs[2]) else mb.const(val=_np.finfo(_np.float32).max)

if isinstance(min_val, Var) and isinstance(max_val, Var) and min_val.val >= max_val.val:
# When min >= max, PyTorch sets all values to max.
context.add(mb.fill(shape=mb.shape(x=x), value=max_val.val, name=node.name))
return

is_input_int = types.is_int(x.dtype)
if not types.is_float(x.dtype):
# The `mb.clip` op requires parameters from type domain ['fp16', 'fp32'].
x = mb.cast(x=x, dtype="fp32")
x, min_val, max_val = promote_input_dtypes([x, min_val, max_val])
if is_input_int:
clip_res = mb.clip(x=x, alpha=min_val, beta=max_val)
context.add(mb.cast(x=clip_res, dtype="int32", name=node.name))
else:
context.add(mb.clip(x=x, alpha=min_val, beta=max_val, name=node.name))

def one_hot(context, node):
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes
inputs = _get_inputs(context, node, expected=[2,3])
indices = inputs[0]
num_classes = inputs[1]

if not types.is_int(indices.dtype):
indices = mb.cast(x=indices, dtype="int32")
if not types.is_int(num_classes.dtype):
num_classes = mb.cast(x=num_classes, dtype="int32")
indices, num_classes = promote_input_dtypes([indices, num_classes])
one_hot_res = mb.one_hot(indices=indices, one_hot_vector_size=num_classes)
context.add(one_hot_res, node.name)

return {"clip": clip, "one_hot": one_hot}

class LevitCoreMLConfig(CoreMLConfig):
modality = "vision"

Expand Down

0 comments on commit f1135dc

Please sign in to comment.