|
5 | 5 | "colab": {
|
6 | 6 | "name": "Understanding BeitForMaskedImageModeling.ipynb",
|
7 | 7 | "provenance": [],
|
8 |
| - "collapsed_sections": [], |
9 |
| - "authorship_tag": "ABX9TyML+527/GMXCF12tSDUgQFX", |
| 8 | + "authorship_tag": "ABX9TyMNKksEST+khtV9qo1CbZT9", |
10 | 9 | "include_colab_link": true
|
11 | 10 | },
|
12 | 11 | "kernelspec": {
|
|
737 | 736 | "source": [
|
738 | 737 | "!pip install -q transformers"
|
739 | 738 | ],
|
740 |
| - "execution_count": 8, |
| 739 | + "execution_count": null, |
741 | 740 | "outputs": [
|
742 | 741 | {
|
743 | 742 | "output_type": "stream",
|
|
765 | 764 | "source": [
|
766 | 765 | "!git clone https://github.com/microsoft/unilm.git"
|
767 | 766 | ],
|
768 |
| - "execution_count": 2, |
| 767 | + "execution_count": null, |
769 | 768 | "outputs": [
|
770 | 769 | {
|
771 | 770 | "output_type": "stream",
|
|
790 | 789 | "source": [
|
791 | 790 | "!pip install -q einops"
|
792 | 791 | ],
|
793 |
| - "execution_count": 3, |
| 792 | + "execution_count": null, |
794 | 793 | "outputs": []
|
795 | 794 | },
|
796 | 795 | {
|
|
805 | 804 | "source": [
|
806 | 805 | "!pip install -q DALL-E"
|
807 | 806 | ],
|
808 |
| - "execution_count": 4, |
| 807 | + "execution_count": null, |
809 | 808 | "outputs": [
|
810 | 809 | {
|
811 | 810 | "output_type": "stream",
|
|
836 | 835 | "source": [
|
837 | 836 | "%cd unilm/beit"
|
838 | 837 | ],
|
839 |
| - "execution_count": 5, |
| 838 | + "execution_count": null, |
840 | 839 | "outputs": [
|
841 | 840 | {
|
842 | 841 | "output_type": "stream",
|
|
877 | 876 | "\n",
|
878 | 877 | "image"
|
879 | 878 | ],
|
880 |
| - "execution_count": 6, |
| 879 | + "execution_count": null, |
881 | 880 | "outputs": [
|
882 | 881 | {
|
883 | 882 | "output_type": "execute_result",
|
|
917 | 916 | "outputId": "02c13ee1-11ba-4157-f15d-2d23162b66e7"
|
918 | 917 | },
|
919 | 918 | "source": [
|
920 |
| - "from transformers import BeitFeatureExtractor\n", |
| 919 | + "from transformers import BeitImageProcessor\n", |
921 | 920 | "\n",
|
922 |
| - "feature_extractor = BeitFeatureExtractor()\n", |
| 921 | + "image_processor = BeitImageProcessor()\n", |
923 | 922 | "\n",
|
924 | 923 | "# create input 1 (pixel_values)\n",
|
925 |
| - "pixel_values = feature_extractor(image, return_tensors=\"pt\").pixel_values\n", |
| 924 | + "pixel_values = image_processor(image, return_tensors=\"pt\").pixel_values\n", |
926 | 925 | "pixel_values.shape"
|
927 | 926 | ],
|
928 |
| - "execution_count": 9, |
| 927 | + "execution_count": null, |
929 | 928 | "outputs": [
|
930 | 929 | {
|
931 | 930 | "output_type": "execute_result",
|
|
961 | 960 | "pixel_values_dall_e = visual_token_transform(image).unsqueeze(0)\n",
|
962 | 961 | "pixel_values_dall_e.shape"
|
963 | 962 | ],
|
964 |
| - "execution_count": 10, |
| 963 | + "execution_count": null, |
965 | 964 | "outputs": [
|
966 | 965 | {
|
967 | 966 | "output_type": "execute_result",
|
|
985 | 984 | "!mkdir -p dall_e_tokenizer\n",
|
986 | 985 | "!wget -o dall_e_tokenizer/encoder.pkl https://cdn.openai.com/dall-e/encoder.pkl"
|
987 | 986 | ],
|
988 |
| - "execution_count": 11, |
| 987 | + "execution_count": null, |
989 | 988 | "outputs": []
|
990 | 989 | },
|
991 | 990 | {
|
|
1027 | 1026 | "\n",
|
1028 | 1027 | "model = BeitForMaskedImageModeling.from_pretrained(\"microsoft/beit-base-patch16-224-pt22k\")"
|
1029 | 1028 | ],
|
1030 |
| - "execution_count": 12, |
| 1029 | + "execution_count": null, |
1031 | 1030 | "outputs": [
|
1032 | 1031 | {
|
1033 | 1032 | "output_type": "display_data",
|
|
1079 | 1078 | " min_num_patches=min_mask_patches_per_block,\n",
|
1080 | 1079 | " )"
|
1081 | 1080 | ],
|
1082 |
| - "execution_count": 13, |
| 1081 | + "execution_count": null, |
1083 | 1082 | "outputs": []
|
1084 | 1083 | },
|
1085 | 1084 | {
|
|
1094 | 1093 | "bool_masked_pos = mask_generator()\n",
|
1095 | 1094 | "bool_masked_pos = torch.from_numpy(bool_masked_pos).unsqueeze(0)"
|
1096 | 1095 | ],
|
1097 |
| - "execution_count": 34, |
| 1096 | + "execution_count": null, |
1098 | 1097 | "outputs": []
|
1099 | 1098 | },
|
1100 | 1099 | {
|
|
1109 | 1108 | "source": [
|
1110 | 1109 | "bool_masked_pos.shape"
|
1111 | 1110 | ],
|
1112 |
| - "execution_count": 35, |
| 1111 | + "execution_count": null, |
1113 | 1112 | "outputs": [
|
1114 | 1113 | {
|
1115 | 1114 | "output_type": "execute_result",
|
|
1129 | 1128 | "id": "fFad_m_s41Ru"
|
1130 | 1129 | },
|
1131 | 1130 | "source": [
|
1132 |
| - "from dall_e import map_pixels, load_model \n", |
| 1131 | + "from dall_e import map_pixels, load_model\n", |
1133 | 1132 | "import torch\n",
|
1134 | 1133 | "\n",
|
1135 | 1134 | "# step 2: get input_ids from OpenAI's DALL-E\n",
|
1136 | 1135 | "device = torch.device('cpu')\n",
|
1137 | 1136 | "encoder = load_model(\"https://cdn.openai.com/dall-e/encoder.pkl\", device)"
|
1138 | 1137 | ],
|
1139 |
| - "execution_count": 16, |
| 1138 | + "execution_count": null, |
1140 | 1139 | "outputs": []
|
1141 | 1140 | },
|
1142 | 1141 | {
|
|
1151 | 1150 | "source": [
|
1152 | 1151 | "pixel_values_dall_e.shape"
|
1153 | 1152 | ],
|
1154 |
| - "execution_count": 17, |
| 1153 | + "execution_count": null, |
1155 | 1154 | "outputs": [
|
1156 | 1155 | {
|
1157 | 1156 | "output_type": "execute_result",
|
|
1177 | 1176 | " bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)\n",
|
1178 | 1177 | " labels = input_ids[bool_masked_pos]"
|
1179 | 1178 | ],
|
1180 |
| - "execution_count": 37, |
| 1179 | + "execution_count": null, |
1181 | 1180 | "outputs": []
|
1182 | 1181 | },
|
1183 | 1182 | {
|
|
1192 | 1191 | "source": [
|
1193 | 1192 | "input_ids.shape"
|
1194 | 1193 | ],
|
1195 |
| - "execution_count": 38, |
| 1194 | + "execution_count": null, |
1196 | 1195 | "outputs": [
|
1197 | 1196 | {
|
1198 | 1197 | "output_type": "execute_result",
|
|
1218 | 1217 | "source": [
|
1219 | 1218 | "labels.shape"
|
1220 | 1219 | ],
|
1221 |
| - "execution_count": 39, |
| 1220 | + "execution_count": null, |
1222 | 1221 | "outputs": [
|
1223 | 1222 | {
|
1224 | 1223 | "output_type": "execute_result",
|
|
1249 | 1248 | "source": [
|
1250 | 1249 | "outputs = model(pixel_values, bool_masked_pos)"
|
1251 | 1250 | ],
|
1252 |
| - "execution_count": 40, |
| 1251 | + "execution_count": null, |
1253 | 1252 | "outputs": []
|
1254 | 1253 | },
|
1255 | 1254 | {
|
|
1264 | 1263 | "source": [
|
1265 | 1264 | "labels"
|
1266 | 1265 | ],
|
1267 |
| - "execution_count": 41, |
| 1266 | + "execution_count": null, |
1268 | 1267 | "outputs": [
|
1269 | 1268 | {
|
1270 | 1269 | "output_type": "execute_result",
|
|
1292 | 1291 | "source": [
|
1293 | 1292 | "predictions = outputs.logits[bool_masked_pos].argmax(-1)"
|
1294 | 1293 | ],
|
1295 |
| - "execution_count": 42, |
| 1294 | + "execution_count": null, |
1296 | 1295 | "outputs": []
|
1297 | 1296 | },
|
1298 | 1297 | {
|
|
1307 | 1306 | "source": [
|
1308 | 1307 | "predictions"
|
1309 | 1308 | ],
|
1310 |
| - "execution_count": 43, |
| 1309 | + "execution_count": null, |
1311 | 1310 | "outputs": [
|
1312 | 1311 | {
|
1313 | 1312 | "output_type": "execute_result",
|
|
1356 | 1355 | "# prepare for model (simply resize + normalize)\n",
|
1357 | 1356 | "mean = (0.5, 0.5, 0.5)\n",
|
1358 | 1357 | "std = (0.5, 0.5, 0.5)\n",
|
1359 |
| - "transform = transforms.Compose([transforms.Resize((224, 224)), \n", |
1360 |
| - " transforms.ToTensor(), \n", |
| 1358 | + "transform = transforms.Compose([transforms.Resize((224, 224)),\n", |
| 1359 | + " transforms.ToTensor(),\n", |
1361 | 1360 | " transforms.Normalize(mean, std)])\n",
|
1362 | 1361 | "pixel_values = transform(image).unsqueeze(0)\n",
|
1363 | 1362 | "pixel_values.shape"
|
1364 | 1363 | ],
|
1365 |
| - "execution_count": 25, |
| 1364 | + "execution_count": null, |
1366 | 1365 | "outputs": [
|
1367 | 1366 | {
|
1368 | 1367 | "output_type": "execute_result",
|
|
1388 | 1387 | "source": [
|
1389 | 1388 | "pixel_values[0,:3,:3,:3]"
|
1390 | 1389 | ],
|
1391 |
| - "execution_count": 26, |
| 1390 | + "execution_count": null, |
1392 | 1391 | "outputs": [
|
1393 | 1392 | {
|
1394 | 1393 | "output_type": "execute_result",
|
|
1426 | 1425 | "# forward pass\n",
|
1427 | 1426 | "outputs = model(pixel_values, bool_masked_pos)"
|
1428 | 1427 | ],
|
1429 |
| - "execution_count": 27, |
| 1428 | + "execution_count": null, |
1430 | 1429 | "outputs": []
|
1431 | 1430 | },
|
1432 | 1431 | {
|
|
1441 | 1440 | "source": [
|
1442 | 1441 | "outputs.logits.shape"
|
1443 | 1442 | ],
|
1444 |
| - "execution_count": 28, |
| 1443 | + "execution_count": null, |
1445 | 1444 | "outputs": [
|
1446 | 1445 | {
|
1447 | 1446 | "output_type": "execute_result",
|
|
1467 | 1466 | "source": [
|
1468 | 1467 | "outputs.logits[bool_masked_pos][:3,:3]"
|
1469 | 1468 | ],
|
1470 |
| - "execution_count": 29, |
| 1469 | + "execution_count": null, |
1471 | 1470 | "outputs": [
|
1472 | 1471 | {
|
1473 | 1472 | "output_type": "execute_result",
|
|
1488 | 1487 | "metadata": {
|
1489 | 1488 | "id": "D-8mymKCOQZK"
|
1490 | 1489 | },
|
1491 |
| - "source": [ |
1492 |
| - "" |
1493 |
| - ], |
| 1490 | + "source": [], |
1494 | 1491 | "execution_count": null,
|
1495 | 1492 | "outputs": []
|
1496 | 1493 | }
|
|
0 commit comments