Skip to content

Commit 8639a80

Browse files
authored
beginner_source/flava_finetuning_tutorial.py ๋ฒˆ์—ญ (#778)
* beginner_source/nn_tutorial.py ๋ฒˆ์—ญ
1 parent dfc7f1e commit 8639a80

File tree

1 file changed

+48
-60
lines changed

1 file changed

+48
-60
lines changed
Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
TorchMultimodal Tutorial: Finetuning FLAVA
3+
TorchMultimodal ํŠœํ† ๋ฆฌ์–ผ: FLAVA ๋ฏธ์„ธ์กฐ์ •
44
============================================
5+
6+
**๋ฒˆ์—ญ:** `๊น€์ฐฌ <https://github.com/chanmuzi>`__
7+
58
"""
69

10+
711
######################################################################
8-
# Multimodal AI has recently become very popular owing to its ubiquitous
9-
# nature, from use cases like image captioning and visual search to more
10-
# recent applications like image generation from text. **TorchMultimodal
11-
# is a library powered by Pytorch consisting of building blocks and end to
12-
# end examples, aiming to enable and accelerate research in
13-
# multimodality**.
14-
#
15-
# In this tutorial, we will demonstrate how to use a **pretrained SoTA
16-
# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
17-
# TorchMultimodal library to finetune on a multimodal task i.e. visual
18-
# question answering** (VQA). The model consists of two unimodal transformer
19-
# based encoders for text and image and a multimodal encoder to combine
20-
# the two embeddings. It is pretrained using contrastive, image text matching and
21-
# text, image and multimodal masking losses.
12+
# ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ AI๋Š” ์ตœ๊ทผ์— ์ด๋ฏธ์ง€ ์ž๋ง‰์ถ”๊ฐ€, ์‹œ๊ฐ์  ๊ฒ€์ƒ‰๋ถ€ํ„ฐ ํ…์ŠคํŠธ๋กœ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ๊ฐ™์€
13+
# ์ตœ๊ทผ์˜ ์‘์šฉ๊นŒ์ง€ ๊ทธ ์‚ฌ์šฉ์ด ๋น ๋ฅด๊ฒŒ ํ™•์‚ฐ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. **TorchMultimodal์€ PyTorch๋ฅผ
14+
# ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ, ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์—ฐ๊ตฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๊ณ  ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•œ ๋นŒ๋”ฉ ๋ธ”๋ก๊ณผ
15+
# end-to-end ์˜ˆ์ œ๋“ค์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค**.
16+
#
17+
# ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” **์‚ฌ์ „ ํ›ˆ๋ จ๋œ SoTA ๋ชจ๋ธ์ธ** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **๋ฅผ**
18+
# **TorchMultimodal ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ์‚ฌ์šฉํ•˜์—ฌ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์ž‘์—…์ธ ์‹œ๊ฐ์  ์งˆ์˜ ์‘๋‹ต(VQA)์— ๋ฏธ์„ธ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ ๋“œ๋ฆฌ๊ฒ ์Šต๋‹ˆ๋‹ค.**
19+
# ์ด ๋ชจ๋ธ์€ ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€๋ฅผ ์œ„ํ•œ ๋‘ ๊ฐœ์˜ ๋‹จ์ผ ๋ชจ๋‹ฌ ํŠธ๋žœ์Šคํฌ๋จธ ๊ธฐ๋ฐ˜ ์ธ์ฝ”๋”์™€
20+
# ๋‘ ์ž„๋ฒ ๋”ฉ์„ ๊ฒฐํ•ฉํ•˜๋Š” ๋‹ค์ค‘ ๋ชจ๋‹ฌ ์ธ์ฝ”๋”๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
21+
# ์ด ๋ชจ๋ธ์€ ๋Œ€์กฐ์ , ์ด๋ฏธ์ง€-ํ…์ŠคํŠธ ๋งค์นญ, ๊ทธ๋ฆฌ๊ณ  ํ…์ŠคํŠธ, ์ด๋ฏธ์ง€ ๋ฐ ๋‹ค์ค‘ ๋ชจ๋‹ฌ ๋งˆ์Šคํ‚น ์†์‹ค์„ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์ „ ํ›ˆ๋ จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
22+
2223

2324

2425
######################################################################
25-
# Installation
26+
# ์„ค์น˜
2627
# -----------------
27-
# We will use TextVQA dataset and ``bert tokenizer`` from Hugging Face for this
28-
# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.
28+
# ์ด ํŠœํ† ๋ฆฌ์–ผ์„ ์œ„ํ•ด์„œ๋Š” TextVQA ๋ฐ์ดํ„ฐ์…‹๊ณผ Hugging Face์˜ ``bert ํ† ํฌ๋‚˜์ด์ €`` ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.
29+
# ๋”ฐ๋ผ์„œ TorchMultimodal ์™ธ์—๋„ datasets๊ณผ transformers๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
2930
#
3031
# .. note::
31-
#
32-
# When running this tutorial in Google Colab, install the required packages by
33-
# creating a new cell and running the following commands:
32+
#
33+
# ์ด ํŠœํ† ๋ฆฌ์–ผ์„ Google Colab์—์„œ ์‹คํ–‰ํ•  ๊ฒฝ์šฐ, ์ƒˆ๋กœ์šด ์…€์„ ๋งŒ๋“ค๊ณ  ๋‹ค์Œ์˜ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ
34+
# ํ•„์š”ํ•œ ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•˜์„ธ์š”:
3435
#
3536
# .. code-block::
3637
#
@@ -40,32 +41,27 @@
4041
#
4142

4243
######################################################################
43-
# Steps
44+
# ๋‹จ๊ณ„
4445
# -----
4546
#
46-
# 1. Download the Hugging Face dataset to a directory on your computer by running the following command:
47+
# 1. ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ Hugging Face ๋ฐ์ดํ„ฐ์…‹์„ ์ปดํ“จํ„ฐ์˜ ๋””๋ ‰ํ† ๋ฆฌ์— ๋‹ค์šด๋กœ๋“œํ•˜์„ธ์š”:
4748
#
4849
# .. code-block::
4950
#
5051
# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
5152
# tar xf vocab.tar.gz
5253
#
5354
# .. note::
54-
# If you are running this tutorial in Google Colab, run these commands
55-
# in a new cell and prepend these commands with an exclamation mark (!)
55+
# ์ด ํŠœํ† ๋ฆฌ์–ผ์„ Google Colab์—์„œ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ, ์ƒˆ ์…€์—์„œ ์ด ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜๊ณ  ๋ช…๋ น์–ด ์•ž์— ๋А๋‚Œํ‘œ (!)๋ฅผ ๋ถ™์ด์„ธ์š”.
5656
#
5757
#
58-
# 2. For this tutorial, we treat VQA as a classification task where
59-
# the inputs are images and question (text) and the output is an answer class.
60-
# So we need to download the vocab file with answer classes and create the answer to
61-
# label mapping.
58+
# 2. ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” VQA๋ฅผ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ(ํ…์ŠคํŠธ)์ด ์ž…๋ ฅ๋˜๊ณ  ์ถœ๋ ฅ์ด ๋‹ต๋ณ€ ํด๋ž˜์Šค์ธ ๋ถ„๋ฅ˜ ์ž‘์—…์œผ๋กœ ์ทจ๊ธ‰ํ•ฉ๋‹ˆ๋‹ค.
59+
# ๋”ฐ๋ผ์„œ ๋‹ต๋ณ€ ํด๋ž˜์Šค์™€ ๋ ˆ์ด๋ธ” ๋งคํ•‘์„ ์ƒ์„ฑํ•  ๋‹จ์–ด์žฅ ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
6260
#
63-
# We also load the `textvqa
64-
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
65-
# (images,questions and answers) from Hugging Face
61+
# ๋˜ํ•œ Hugging Face์—์„œ `textvqa ๋ฐ์ดํ„ฐ์…‹ <https://arxiv.org/pdf/1904.08920.pdf>`__ ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š”๋ฐ,
62+
# ์ด ๋ฐ์ดํ„ฐ์…‹์€ 34602๊ฐœ์˜ ํ›ˆ๋ จ ์ƒ˜ํ”Œ(์ด๋ฏธ์ง€, ์งˆ๋ฌธ, ๋‹ต๋ณ€)์„ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
6663
#
67-
# We see there are 3997 answer classes including a class representing
68-
# unknown answers.
64+
# 3997๊ฐœ์˜ ๋‹ต๋ณ€ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด์—๋Š” ์•Œ ์ˆ˜ ์—†๋Š” ๋‹ต๋ณ€์„ ๋‚˜ํƒ€๋‚ด๋Š” ํด๋ž˜์Šค๋„ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
6965
#
7066

7167
with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
@@ -81,7 +77,7 @@
8177
dataset = load_dataset("textvqa")
8278

8379
######################################################################
84-
# Lets display a sample entry from the dataset:
80+
# ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ƒ˜ํ”Œ ์—”ํŠธ๋ฆฌ๋ฅผ ํ‘œ์‹œํ•ด ๋ด…์‹œ๋‹ค:
8581
#
8682

8783
import matplotlib.pyplot as plt
@@ -95,12 +91,10 @@
9591

9692

9793
######################################################################
98-
# 3. Next, we write the transform function to convert the image and text into
99-
# Tensors consumable by our model - For images, we use the transforms from
100-
# torchvision to convert to Tensor and resize to uniform sizes - For text,
101-
# we tokenize (and pad) them using the ``BertTokenizer`` from Hugging Face -
102-
# For answers (i.e. labels), we take the most frequently occurring answer
103-
# as the label to train with:
94+
# 3. ๋‹ค์Œ์œผ๋กœ, ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ๋ฅผ ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€ํ™˜ ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.
95+
# - ์ด๋ฏธ์ง€์˜ ๊ฒฝ์šฐ, torchvision์˜ ๋ณ€ํ™˜์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์ผ์ •ํ•œ ํฌ๊ธฐ๋กœ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
96+
# - ํ…์ŠคํŠธ์˜ ๊ฒฝ์šฐ, Hugging Face์˜ ``BertTokenizer`` ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ† ํฐํ™”(๋ฐ ํŒจ๋”ฉ)ํ•ฉ๋‹ˆ๋‹ค.
97+
# - ๋‹ต๋ณ€(์ฆ‰, ๋ ˆ์ด๋ธ”)์˜ ๊ฒฝ์šฐ, ๊ฐ€์žฅ ๋นˆ๋ฒˆํ•˜๊ฒŒ ๋‚˜ํƒ€๋‚˜๋Š” ๋‹ต๋ณ€์„ ํ›ˆ๋ จ ๋ ˆ์ด๋ธ”๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค:
10498
#
10599

106100
import torch
@@ -133,25 +127,21 @@ def transform(tokenizer, input):
133127

134128

135129
######################################################################
136-
# 4. Finally, we import the ``flava_model_for_classification`` from
137-
# ``torchmultimodal``. It loads the pretrained FLAVA checkpoint by default and
138-
# includes a classification head.
130+
# 4. ๋งˆ์ง€๋ง‰์œผ๋กœ, ``torchmultimodal`` ์—์„œ ``flava_model_for_classification`` ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
131+
# ์ด๊ฒƒ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ FLAVA ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋กœ๋“œํ•˜๊ณ  ๋ถ„๋ฅ˜ ํ—ค๋“œ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
139132
#
140-
# The model forward function passes the image through the visual encoder
141-
# and the question through the text encoder. The image and question
142-
# embeddings are then passed through the multimodal encoder. The final
143-
# embedding corresponding to the CLS token is passed through a MLP head
144-
# which finally gives the probability distribution over each possible
145-
# answers.
133+
# ๋ชจ๋ธ์˜ ์ˆœ๋ฐฉํ–ฅ ํ•จ์ˆ˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ์‹œ๊ฐ ์ธ์ฝ”๋”์— ํ†ต๊ณผ์‹œํ‚ค๊ณ  ์งˆ๋ฌธ์„ ํ…์ŠคํŠธ ์ธ์ฝ”๋”์— ํ†ต๊ณผ์‹œํ‚ต๋‹ˆ๋‹ค.
134+
# ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์˜ ์ž„๋ฒ ๋”ฉ์€ ๊ทธ ํ›„ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์ธ์ฝ”๋”๋ฅผ ํ†ต๊ณผํ•ฉ๋‹ˆ๋‹ค.
135+
# ์ตœ์ข… ์ž„๋ฒ ๋”ฉ์€ CLS ํ† ํฐ์— ํ•ด๋‹นํ•˜๋ฉฐ, ์ด๋Š” MLP ํ—ค๋“œ๋ฅผ ํ†ต๊ณผํ•˜์—ฌ ๊ฐ ๊ฐ€๋Šฅํ•œ ๋‹ต๋ณ€์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
146136
#
147137

148138
from torchmultimodal.models.flava.model import flava_model_for_classification
149139
model = flava_model_for_classification(num_classes=len(vocab))
150140

151141

152142
######################################################################
153-
# 5. We put together the dataset and model in a toy training loop to
154-
# demonstrate how to train the model for 3 iterations:
143+
# 5. ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋ชจ๋ธ์„ ํ•จ๊ป˜ ๋ชจ์•„ 3ํšŒ ๋ฐ˜๋ณต์„ ์œ„ํ•œ ๊ฐ„๋‹จํ•œ ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•˜์—ฌ
144+
# ๋ชจ๋ธ ํ›ˆ๋ จ ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:
155145
#
156146

157147
from torch import nn
@@ -177,14 +167,12 @@ def transform(tokenizer, input):
177167

178168

179169
######################################################################
180-
# Conclusion
170+
# ๊ฒฐ๋ก 
181171
# -------------------
182172
#
183-
# This tutorial introduced the basics around how to finetune on a
184-
# multimodal task using FLAVA from TorchMultimodal. Please also check out
185-
# other examples from the library like
186-
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
187-
# which is a multimodal model for object detection and
188-
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
189-
# which is multitask model spanning image, video and 3d classification.
173+
# ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” TorchMultimodal์˜ FLAVA๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์ž‘์—…์— ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š”
174+
# ๊ธฐ๋ณธ์ ์ธ ๋ฐฉ์‹์„ ์†Œ๊ฐœํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ฐ์ฒด ํƒ์ง€๋ฅผ ์œ„ํ•œ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ๋ชจ๋ธ์ธ `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__ ๊ณผ
175+
# ์ด๋ฏธ์ง€, ๋น„๋””์˜ค, 3D ๋ถ„๋ฅ˜๋ฅผ ํฌ๊ด„ํ•˜๋Š” ๋‹ค์ž‘์—… ๋ชจ๋ธ `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
176+
# ๊ฐ™์€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ๋‹ค๋ฅธ ์˜ˆ์ œ๋“ค๋„ ํ™•์ธํ•ด ๋ณด์„ธ์š”.
177+
#
190178
#

0 commit comments

Comments
ย (0)