-
Notifications
You must be signed in to change notification settings - Fork 9.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented a Siamese Network Example #1003
Changes from all commits
933d976
aec9e70
11cb38f
57903ba
7cb4788
8c17275
332e138
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,6 @@ docs/venv | |
|
||
# vi backups | ||
*~ | ||
|
||
# development | ||
.vscode |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Siamese Network Example | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
python main.py | ||
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also reference the paper you're using a baseline for your implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have referenced |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
from __future__ import print_function | ||
import argparse, random, copy | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torchvision | ||
from torch.utils.data import Dataset | ||
from torchvision import datasets | ||
from torchvision import transforms as T | ||
from torch.optim.lr_scheduler import StepLR | ||
|
||
|
||
class SiameseNetwork(nn.Module): | ||
""" | ||
Siamese network for image similarity estimation. | ||
The network is composed of two identical networks, one for each input. | ||
The output of each network is concatenated and passed to a linear layer. | ||
The output of the linear layer passed through a sigmoid function. | ||
`"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network. | ||
This implementation varies from FaceNet as we use the `ResNet-18` model from | ||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ as our feature extractor. | ||
In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. | ||
""" | ||
def __init__(self): | ||
super(SiameseNetwork, self).__init__() | ||
# get resnet model | ||
self.resnet = torchvision.models.resnet18(pretrained=False) | ||
|
||
# over-write the first conv layer to be able to read MNIST images | ||
# as resnet18 reads (3,x,x) where 3 is RGB channels | ||
# whereas MNIST has (1,x,x) where 1 is a gray-scale channel | ||
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | ||
self.fc_in_features = self.resnet.fc.in_features | ||
|
||
# remove the last layer of resnet18 (linear layer which is before avgpool layer) | ||
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])) | ||
|
||
# add linear layers to compare between the features of the two images | ||
self.fc = nn.Sequential( | ||
nn.Linear(self.fc_in_features * 2, 256), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(256, 1), | ||
) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
|
||
# initialize the weights | ||
self.resnet.apply(self.init_weights) | ||
self.fc.apply(self.init_weights) | ||
|
||
def init_weights(self, m): | ||
if isinstance(m, nn.Linear): | ||
torch.nn.init.xavier_uniform(m.weight) | ||
m.bias.data.fill_(0.01) | ||
|
||
def forward_once(self, x): | ||
output = self.resnet(x) | ||
output = output.view(output.size()[0], -1) | ||
return output | ||
|
||
def forward(self, input1, input2): | ||
# get two images' features | ||
output1 = self.forward_once(input1) | ||
output2 = self.forward_once(input2) | ||
|
||
# concatenate both images' features | ||
output = torch.cat((output1, output2), 1) | ||
|
||
# pass the concatenation to the linear layers | ||
output = self.fc(output) | ||
|
||
# pass the out of the linear layers to sigmoid layer | ||
output = self.sigmoid(output) | ||
|
||
return output | ||
|
||
class APP_MATCHER(Dataset): | ||
def __init__(self, root, train, download=False): | ||
super(APP_MATCHER, self).__init__() | ||
|
||
# get MNIST dataset | ||
self.dataset = datasets.MNIST(root, train=train, download=download) | ||
|
||
# as `self.dataset.data`'s shape is (Nx28x28), where N is the number of | ||
# examples in MNIST dataset, a single example has the dimensions of | ||
# (28x28) for (WxH), where W and H are the width and the height of the image. | ||
# However, every example should have (CxWxH) dimensions where C is the number | ||
# of channels to be passed to the network. As MNIST contains gray-scale images, | ||
# we add an additional dimension to corresponds to the number of channels. | ||
self.data = self.dataset.data.unsqueeze(1).clone() | ||
|
||
self.group_examples() | ||
|
||
def group_examples(self): | ||
""" | ||
To ease the accessibility of data based on the class, we will use `group_examples` to group | ||
examples based on class. | ||
|
||
Every key in `grouped_examples` corresponds to a class in MNIST dataset. For every key in | ||
`grouped_examples`, every value will conform to all of the indices for the MNIST | ||
dataset examples that correspond to that key. | ||
""" | ||
|
||
# get the targets from MNIST dataset | ||
np_arr = np.array(self.dataset.targets.clone()) | ||
|
||
# group examples based on class | ||
self.grouped_examples = {} | ||
for i in range(0,10): | ||
self.grouped_examples[i] = np.where((np_arr==i))[0] | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def __getitem__(self, index): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add more comments in this section |
||
""" | ||
For every example, we will select two images. There are two cases, | ||
positive and negative examples. For positive examples, we will have two | ||
images from the same class. For negative examples, we will have two images | ||
from different classes. | ||
|
||
Given an index, if the index is even, we will pick the second image from the same class, | ||
but it won't be the same image we chose for the first class. This is used to ensure the positive | ||
example isn't trivial as the network would easily distinguish the similarity between same images. However, | ||
if the network were given two different images from the same class, the network will need to learn | ||
the similarity between two different images representing the same class. If the index is odd, we will | ||
pick the second image from a different class than the first image. | ||
""" | ||
|
||
# pick some random class for the first image | ||
selected_class = random.randint(0, 9) | ||
|
||
# pick a random index for the first image in the grouped indices based of the label | ||
# of the class | ||
random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) | ||
|
||
# pick the index to get the first image | ||
index_1 = self.grouped_examples[selected_class][random_index_1] | ||
|
||
# get the first image | ||
image_1 = self.data[index_1].clone().float() | ||
|
||
# same class | ||
if index % 2 == 0: | ||
# pick a random index for the second image | ||
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) | ||
|
||
# ensure that the index of the second image isn't the same as the first image | ||
while random_index_2 == random_index_1: | ||
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) | ||
|
||
# pick the index to get the second image | ||
index_2 = self.grouped_examples[selected_class][random_index_2] | ||
|
||
# get the second image | ||
image_2 = self.data[index_2].clone().float() | ||
|
||
# set the label for this example to be positive (1) | ||
target = torch.tensor(1, dtype=torch.float) | ||
|
||
# different class | ||
else: | ||
# pick a random class | ||
other_selected_class = random.randint(0, 9) | ||
|
||
# ensure that the class of the second image isn't the same as the first image | ||
while other_selected_class == selected_class: | ||
other_selected_class = random.randint(0, 9) | ||
|
||
|
||
# pick a random index for the second image in the grouped indices based of the label | ||
# of the class | ||
random_index_2 = random.randint(0, self.grouped_examples[other_selected_class].shape[0]-1) | ||
|
||
# pick the index to get the second image | ||
index_2 = self.grouped_examples[other_selected_class][random_index_2] | ||
|
||
# get the second image | ||
image_2 = self.data[index_2].clone().float() | ||
|
||
# set the label for this example to be negative (0) | ||
target = torch.tensor(0, dtype=torch.float) | ||
|
||
return image_1, image_2, target | ||
|
||
|
||
def train(args, model, device, train_loader, optimizer, epoch): | ||
model.train() | ||
|
||
# we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. | ||
criterion = nn.BCELoss() | ||
|
||
for batch_idx, (images_1, images_2, targets) in enumerate(train_loader): | ||
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) | ||
optimizer.zero_grad() | ||
outputs = model(images_1, images_2).squeeze() | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % args.log_interval == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(images_1), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item())) | ||
if args.dry_run: | ||
break | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
|
||
# we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. | ||
criterion = nn.BCELoss() | ||
|
||
with torch.no_grad(): | ||
for (images_1, images_2, targets) in test_loader: | ||
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) | ||
outputs = model(images_1, images_2).squeeze() | ||
test_loss += criterion(outputs, targets).sum().item() # sum up batch loss | ||
pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability | ||
correct += pred.eq(targets.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
# for the 1st epoch, the average loss is 0.0001 and the accuracy 97-98% | ||
# using default settings. After completing the 10th epoch, the average | ||
# loss is 0.0000 and the accuracy 99.5-100% using default settings. | ||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment around how much loss is expected to be if default settings are used |
||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
|
||
|
||
def main(): | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch Siamese network Example') | ||
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | ||
help='input batch size for testing (default: 1000)') | ||
parser.add_argument('--epochs', type=int, default=14, metavar='N', | ||
help='number of epochs to train (default: 14)') | ||
parser.add_argument('--lr', type=float, default=1.0, metavar='LR', | ||
help='learning rate (default: 1.0)') | ||
parser.add_argument('--gamma', type=float, default=0.7, metavar='M', | ||
help='Learning rate step gamma (default: 0.7)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--dry-run', action='store_true', default=False, | ||
help='quickly check a single pass') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | ||
help='how many batches to wait before logging training status') | ||
parser.add_argument('--save-model', action='store_true', default=False, | ||
help='For Saving the current Model') | ||
args = parser.parse_args() | ||
|
||
use_cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
torch.manual_seed(args.seed) | ||
|
||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
||
train_kwargs = {'batch_size': args.batch_size} | ||
test_kwargs = {'batch_size': args.test_batch_size} | ||
if use_cuda: | ||
cuda_kwargs = {'num_workers': 1, | ||
'pin_memory': True, | ||
'shuffle': True} | ||
train_kwargs.update(cuda_kwargs) | ||
test_kwargs.update(cuda_kwargs) | ||
|
||
train_dataset = APP_MATCHER('../data', train=True, download=True) | ||
test_dataset = APP_MATCHER('../data', train=False) | ||
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) | ||
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) | ||
|
||
model = SiameseNetwork().to(device) | ||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | ||
|
||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | ||
for epoch in range(1, args.epochs + 1): | ||
train(args, model, device, train_loader, optimizer, epoch) | ||
test(model, device, test_loader) | ||
scheduler.step() | ||
|
||
if args.save_model: | ||
torch.save(model.state_dict(), "siamese_network.pt") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
torchvision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK last thing can you please quickly explain what a Siamese network is here, the high-level architecture and what users should expect to input and get as an output for this example. Not a tutorial.
Also, post some proof in a screenshot here on the PR that this thing works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have included an explanation in the class definition. I will update the read me to include the same definition.
Here is a screenshot to prove that the model works using the default setup.