From bcc89fa3133ce9a4d4d2c771b07a6657b882865b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 17 Jun 2023 22:05:12 +0200 Subject: [PATCH] [wip] converting to Core ML --- coreml/to-coreml.ipynb | 538 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 538 insertions(+) create mode 100644 coreml/to-coreml.ipynb diff --git a/coreml/to-coreml.ipynb b/coreml/to-coreml.ipynb new file mode 100644 index 00000000..9a49f8db --- /dev/null +++ b/coreml/to-coreml.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "35dec937", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "from muse import PipelineMuse" + ] + }, + { + "cell_type": "markdown", + "id": "a1edbc3d", + "metadata": {}, + "source": [ + "We load a proof-of-concept version trained on conceptual captions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "073adacf", + "metadata": {}, + "outputs": [], + "source": [ + "pipe = PipelineMuse.from_pretrained(\"openMUSE/muse-laiona6-uvit-clip-220k\")" + ] + }, + { + "cell_type": "markdown", + "id": "00d5976a", + "metadata": {}, + "source": [ + "The pipeline contains the following components:\n", + "- `text_encoder`\n", + "- `transformer`\n", + "- `vae`" + ] + }, + { + "cell_type": "markdown", + "id": "7bac4ddc", + "metadata": {}, + "source": [ + "We'll first attempt conversion of the `transformer` component, which is the image generation module." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5309e0b0", + "metadata": {}, + "outputs": [], + "source": [ + "model = pipe.transformer.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4a9592e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.\n", + "Torch version 2.0.1+cu117 has not been tested with coremltools. You may run into unexpected errors. Torch 2.0.0 is the most recent version that has been tested.\n" + ] + }, + { + "data": { + "text/plain": [ + "'7.0b1'" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import coremltools as ct\n", + "ct.__version__" + ] + }, + { + "cell_type": "markdown", + "id": "3d819634", + "metadata": {}, + "source": [ + "We can do 6-bit palettization with this version of `coremltools`. We'll convert without it first, and then we'll measure any differences in quality we observe." + ] + }, + { + "cell_type": "markdown", + "id": "b357adf5", + "metadata": {}, + "source": [ + "## Inputs" + ] + }, + { + "cell_type": "markdown", + "id": "df33c57c", + "metadata": {}, + "source": [ + "### Text conditioning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b766a605", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 77])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text_input_ids = pipe.tokenizer(\n", + " \"Labrador in the style of Vermeer\",\n", + " return_tensors=\"pt\",\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " max_length=pipe.tokenizer.model_max_length,\n", + ").input_ids\n", + "text_input_ids.shape" + ] + }, + { + "cell_type": "markdown", + "id": "f686a36a", + "metadata": {}, + "source": [ + "Like in Stable Diffusion, note that we are _not_ using attention masks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a4c480c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 77, 768])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder_hidden_states = pipe.text_encoder(text_input_ids).last_hidden_state\n", + "encoder_hidden_states.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "076047f9", + "metadata": {}, + "outputs": [], + "source": [ + "negative_input_ids = pipe.tokenizer(\n", + " \"ugly, bad anatomy\",\n", + " return_tensors=\"pt\",\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " max_length=pipe.tokenizer.model_max_length,\n", + ").input_ids\n", + "negative_encoder_hidden_states = pipe.text_encoder(negative_input_ids).last_hidden_state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2fdc1bb", + "metadata": {}, + "outputs": [], + "source": [ + "bs = 2 # cfg" + ] + }, + { + "cell_type": "markdown", + "id": "77a60e2e", + "metadata": {}, + "source": [ + "Conditioning (encoder_hidden_states and negative_encoder_hidden_states). We could use `np.random.normal` but I'm not sure what the distribution is, so let's just use the previous examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91a8664f", + "metadata": {}, + "outputs": [], + "source": [ + "sequence_length = pipe.tokenizer.model_max_length # 77\n", + "embed_size = pipe.text_encoder.config.hidden_size # 768\n", + "\n", + "conditioning_shape = (bs, sequence_length, embed_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "167ef635", + "metadata": {}, + "outputs": [], + "source": [ + "conditioning = np.concatenate((encoder_hidden_states.detach().numpy(), negative_encoder_hidden_states.detach().numpy()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0d5b9d6", + "metadata": {}, + "outputs": [], + "source": [ + "assert conditioning.shape == conditioning_shape" + ] + }, + { + "cell_type": "markdown", + "id": "b4bc7470", + "metadata": {}, + "source": [ + "### Image input" + ] + }, + { + "cell_type": "markdown", + "id": "1fbaf005", + "metadata": {}, + "source": [ + "Image input token ids. Each image is made of `model.config.num_vq_tokens` (256) tokens taken from a codebook of size `codebook_size` (8192)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c5cd771", + "metadata": {}, + "outputs": [], + "source": [ + "input_ids_shape = (bs, model.config.num_vq_tokens)\n", + "input_ids = np.random.randint(0, model.config.codebook_size, input_ids_shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca03c4e9", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = {\n", + " \"input_ids\": input_ids,\n", + " \"encoder_hidden_states\": conditioning,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "3b28ef8e", + "metadata": {}, + "source": [ + "### Model output (single step)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58cc9c92", + "metadata": {}, + "outputs": [], + "source": [ + "t_inputs = {\n", + " \"input_ids\": torch.tensor(input_ids, dtype=torch.int32),\n", + " \"encoder_hidden_states\": torch.tensor(conditioning),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e633fc7c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 256, 8192])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outputs = model(**t_inputs)\n", + "outputs.shape" + ] + }, + { + "cell_type": "markdown", + "id": "a56b4465", + "metadata": {}, + "source": [ + "Outputs are: `cond_logits`, `uncond_logits`.\n", + "\n", + "**TODO** We could chunk them here for convenience. We could also apply some more post-processing inside a model wrapper." + ] + }, + { + "cell_type": "markdown", + "id": "a52aa476", + "metadata": {}, + "source": [ + "## JIT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "079c8405", + "metadata": {}, + "outputs": [], + "source": [ + "jit_inputs = list(t_inputs.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3874753d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pedro/code/hf/muse/open-muse/muse/modeling_transformer.py:598: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " height, width = int(seq_length**0.5), int(seq_length**0.5)\n", + "/home/pedro/code/hf/muse/open-muse/muse/modeling_transformer.py:636: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " height, width = int(seq_length**0.5), int(seq_length**0.5)\n" + ] + } + ], + "source": [ + "jitted_model = torch.jit.trace(model, jit_inputs)\n", + "jitted_model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "055b690a", + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " output_jit = jitted_model(*jit_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "661fc6c0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(8.3923e-05, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(output_jit - outputs).abs().max()" + ] + }, + { + "cell_type": "markdown", + "id": "1cfc9927", + "metadata": {}, + "source": [ + "Close enough." + ] + }, + { + "cell_type": "markdown", + "id": "abc31bad", + "metadata": {}, + "source": [ + "## Core ML" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcfabce6", + "metadata": {}, + "outputs": [], + "source": [ + "def _get_coreml_inputs(sample_inputs):\n", + " return [\n", + " ct.TensorType(\n", + " name=k,\n", + " shape=v.shape,\n", + " dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype,\n", + " ) for k, v in sample_inputs.items()\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c1a65f7", + "metadata": {}, + "outputs": [], + "source": [ + "coreml_input_types = _get_coreml_inputs(t_inputs)\n", + "coreml_output_types = [ct.TensorType(name=\"logits\")] # Update when chunking/post-processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e274db50", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████████████████████████████████████████▉| 3810/3811 [00:01<00:00, 2142.19 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.02 passes/s]\n", + "Running MIL default pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:58<00:00, 1.09 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 100.53 passes/s]\n" + ] + } + ], + "source": [ + "coreml_model = ct.convert(\n", + " jitted_model,\n", + " convert_to = \"mlprogram\",\n", + " minimum_deployment_target = ct.target.macOS13,\n", + " inputs = coreml_input_types,\n", + " outputs = coreml_output_types,\n", + " compute_precision = ct.precision.FLOAT16,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78d4862f", + "metadata": {}, + "outputs": [], + "source": [ + "coreml_model.save(\"muse_transformer.mlpackage\")" + ] + }, + { + "cell_type": "markdown", + "id": "2e93d45e", + "metadata": {}, + "source": [ + "### Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe8f2a6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "1cfc866b", + "metadata": {}, + "source": [ + "## TODO\n", + "\n", + "- Verify inference on Mac\n", + "- Convert text encoder, VAE\n", + "- Python pipeline\n", + "- Swift pipeline\n", + "- Palettization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3455f01c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}