Skip to content

Commit

Permalink
Merge pull request #12 from anyangml/fix/activation-func
Browse files Browse the repository at this point in the history
Fix: fix activation and loss
  • Loading branch information
anyangml authored Jun 25, 2024
2 parents 7c06a93 + 09d7887 commit 39f93a2
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 21 deletions.
14 changes: 9 additions & 5 deletions clip/clip/image/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ def __init__(self, config: ViTConfig):
self.rearrange = Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=self.psize, p2=self.psize
)
self.pos_embd = nn.Embedding(self.n_patch + 1, self.config.n_embd)
self.pos_embd = nn.Embedding(self.n_patch + 1, config.n_embd)
self.flatten = nn.Linear(
self.c * self.psize**2, self.config.n_embd, bias=False
self.c * self.psize**2, config.n_embd, bias=False
)
self.cls_token = nn.Parameter(torch.zeros(1, self.config.n_embd))
self.cls_token = nn.Parameter(torch.randn(1, config.n_embd))

self.ln = nn.LayerNorm(self.config.n_embd)
self.ln = nn.LayerNorm(config.n_embd)
self.transformer = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layer)]
)

self.mlp_head = nn.Linear(config.n_embd, config.out_dim)
self.mlp_head = nn.Sequential(
nn.LayerNorm(config.n_embd),
nn.Linear(config.n_embd, config.mlp_size),
nn.Linear(config.mlp_size, config.out_dim)
)

self.apply(self._init_weights)

Expand Down
15 changes: 10 additions & 5 deletions clip/clip/languange/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ class GPTConfig:
) # 2**16
seq_len: int = field(default=MAX_SEQ_LENGTH, metadata={"help": "sequence length"})
n_layer: int = field(default=12, metadata={"help": "number of layers"})
mlp_size: int = field(default=2048, metadata={"help": "size of mlp"})
n_head: int = field(default=8, metadata={"help": "number of heads"})
n_embd: int = field(default=768, metadata={"help": "embedding dimension"})
out_dim: int = field(default=768, metadata={"help": "output dimension"})


class GPT(nn.Module):
Expand All @@ -35,12 +37,11 @@ def __init__(self, config: GPTConfig):
self.token_embd = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_embd = nn.Embedding(config.seq_len, config.n_embd)
self.ln = nn.LayerNorm(config.n_embd)
self.ff = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.ff = nn.Linear(config.n_embd, config.out_dim)

self.transformer = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layer)]
)
self.token_embd.weight = self.ff.weight
self.apply(self._init_weights)

def _init_weights(self, module):
Expand Down Expand Up @@ -69,10 +70,10 @@ def forward(self, x):
x = block(x)
x = self.ln(x)

# (B, L, D) --> (B, L, V)
# (B, L, D) --> (B, L, out_dim)
x = self.ff(x)

# getting the EOS
# getting the EOS (B, out_dim)
x = x[eos_mask]

return x
Expand All @@ -84,7 +85,11 @@ def __init__(self, config: GPTConfig):
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = AttentionBlock(config)
self.ff = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.ff = nn.Sequential(
nn.Linear(config.n_embd, config.mlp_size),
nn.GELU(),
nn.Linear(config.mlp_size, config.n_embd),
)

def forward(self, x):
x = x + self.attn(self.ln1(x))
Expand Down
11 changes: 6 additions & 5 deletions clip/clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@


class CLIPLoss(nn.Module):
def __init__(self, batch_size: int, device = DEVICE):
def __init__(self, device = DEVICE):
super().__init__()
self.batch_size = batch_size
self.label = torch.arange(0, self.batch_size, dtype=torch.long, device=device)
self.device = device
self.img_loss = nn.CrossEntropyLoss()
self.txt_loss = nn.CrossEntropyLoss()

def forward(self, txt_log, img_log):
# Loss function
loss_images = self.img_loss(img_log, self.label)
loss_text = self.txt_loss(txt_log, self.label)
batch_size = txt_log.size(0)
label = torch.arange(0, batch_size, dtype=torch.long, device=self.device)
loss_images = self.img_loss(img_log, label)
loss_text = self.txt_loss(txt_log, label)
loss = (loss_images + loss_text) / 2
return loss
11 changes: 6 additions & 5 deletions clip/clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,29 @@ class CLIP(nn.Module):
def __init__(self, txt_encoder, img_encoder, embd_dim, temperature):
super().__init__()

assert 0 <= temperature <= 1, "temperature must be in range [0,1]"
assert 0 < temperature, "temperature must be greater than zero."
self.temperature = temperature
self.embd_dim = embd_dim

self.txt_encoder = txt_encoder
self.img_encoder = img_encoder

self.txt_proj = nn.Linear(
self.txt_encoder.config.vocab_size, self.embd_dim, bias=False
self.txt_encoder.config.out_dim, self.embd_dim, bias=False
)
self.img_proj = nn.Linear(
self.img_encoder.config.out_dim, self.embd_dim, bias=False
)
self.temperature = nn.Parameter(torch.log(torch.tensor(1/temperature)))

def forward(self, text, image):
encoded_text = self.txt_encoder(text)
encoded_image = self.img_encoder(image)

embd_text = F.normalize(self.txt_proj(encoded_text), dim=1) # (B, D)
embd_image = F.normalize(self.img_proj(encoded_image), dim=1) # (B, D)
embd_text = F.normalize(self.txt_proj(encoded_text), p=2, dim=1) # L2 norm (B, D)
embd_image = F.normalize(self.img_proj(encoded_image), p=2, dim=1) # L2 norm (B, D)

# scaled pairwise cosine similarities (B, B)
logits = torch.mm(embd_text, embd_image.T) * np.exp(self.temperature)
logits = torch.mm(embd_text, embd_image.T) * torch.clamp(torch.exp(self.temperature), min=0.01, max=100.0)

return logits, logits.T # text, image
2 changes: 1 addition & 1 deletion clip/tests/language/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def test_GPT_forward_shape():
dummy_txts[1, 928] = 50256

ecoded = gpt(dummy_txts)
assert ecoded.shape == (2, config.vocab_size)
assert ecoded.shape == (2, config.out_dim)

0 comments on commit 39f93a2

Please sign in to comment.