Siamese Network Example In PyTorch

by Alex Braham 35 views

Let's dive into building a Siamese network using PyTorch! This article will guide you through understanding Siamese networks, their architecture, and a practical example using PyTorch. Whether you're a seasoned deep learning practitioner or just starting, this comprehensive guide will equip you with the knowledge to implement and train your own Siamese networks.

Understanding Siamese Networks

Siamese networks are a class of neural networks containing two or more identical subnetworks. The term "identical" here means that they have the same architecture, parameters, and weights. Parameter sharing is a crucial aspect of Siamese networks because it enables the network to learn more generalized features and compare inputs effectively. These networks are particularly useful in scenarios where similarity or dissimilarity between inputs needs to be determined. Think of tasks like facial recognition, signature verification, or even identifying duplicate questions on platforms like Quora.

Key Concepts

Before we jump into the code, let's establish a clear understanding of the core concepts:

  • Shared Weights: The subnetworks in a Siamese network share the same weights. This ensures that similar inputs are mapped to similar points in the feature space.
  • Feature Embedding: Each subnetwork transforms its input into a feature vector, also known as an embedding. This embedding represents the input in a lower-dimensional space.
  • Distance Metric: A distance metric, like Euclidean distance or cosine similarity, is used to measure the similarity between the embeddings generated by the subnetworks.
  • Loss Function: The loss function guides the training process. Common loss functions for Siamese networks include contrastive loss and triplet loss, which encourage similar pairs to have small distances and dissimilar pairs to have large distances.

Applications of Siamese Networks

Siamese networks shine in various applications, including:

  • Facial Recognition: Verifying if two images belong to the same person.
  • Signature Verification: Authenticating signatures by comparing them to known samples.
  • Duplicate Question Detection: Identifying duplicate questions on online forums.
  • Image Retrieval: Finding similar images in a large database.
  • One-Shot Learning: Learning to classify new objects from just one or a few examples.

Building a Siamese Network with PyTorch

Now, let's get our hands dirty and build a Siamese network using PyTorch. We'll start by defining the network architecture, then move on to implementing the loss function and training loop.

Prerequisites

Make sure you have the following libraries installed:

  • PyTorch
  • NumPy
  • Matplotlib (optional, for visualization)

You can install these libraries using pip:

pip install torch numpy matplotlib

Defining the Network Architecture

First, we'll define the architecture of our subnetwork. For simplicity, we'll use a basic convolutional neural network (CNN). This CNN will serve as the feature extractor, transforming the input images into feature embeddings.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=4, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 128, kernel_size=4, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(128 * 3 * 3, 512),
            nn.ReLU(inplace=True),

            nn.Linear(512, 512),
            nn.ReLU(inplace=True),

            nn.Linear(512, 5)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In this code:

  • SiameseNetwork is a PyTorch nn.Module that defines the Siamese network.
  • cnn1 is a sequential block of convolutional layers, ReLU activation functions, and max-pooling layers. This part extracts features from the input images.
  • fc1 is a sequential block of fully connected layers that maps the convolutional features to the embedding space.
  • forward_once performs the forward pass for a single input.
  • forward takes two inputs and passes them through forward_once to generate their respective embeddings.

Implementing the Contrastive Loss Function

The contrastive loss function is commonly used for training Siamese networks. It encourages similar pairs to have small distances and dissimilar pairs to have large distances. Here's the implementation:

class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

In this code:

  • ContrastiveLoss is a PyTorch nn.Module that defines the contrastive loss function.
  • margin is a hyperparameter that determines the minimum distance between dissimilar pairs.
  • forward calculates the contrastive loss based on the embeddings output1 and output2 and the label indicating whether the pair is similar or dissimilar.

Preparing the Dataset

To train our Siamese network, we need a dataset of paired images with labels indicating whether the pairs are similar or dissimilar. For this example, let's assume we have a dataset of handwritten digits. We can use the MNIST dataset and create pairs from it. If the digits are the same, the label is 0 (similar), and if they are different, the label is 1 (dissimilar).

import torchvision
import torchvision.transforms as transforms
import random
import numpy as np

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor()
])

mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create pairs of images
def create_pairs(dataset, num_pairs=10000):
    pairs = []
    labels = []
    num_classes = 10
    digit_indices = [[] for i in range(num_classes)]
    
    # Collect indices for each digit
    for i, (img, label) in enumerate(dataset):
        digit_indices[label].append(i)

    for i in range(num_pairs):
        # Randomly choose two digits
        digit1 = random.randint(0, num_classes - 1)
        digit2 = random.randint(0, num_classes - 1)
        
        # Randomly choose two images for the first digit
        idx1_1, idx1_2 = random.sample(digit_indices[digit1], 2)
        img1_1, _ = dataset[idx1_1]
        img1_2, _ = dataset[idx1_2]
        
        # Create a similar pair with label 0
        pairs.append([img1_1, img1_2])
        labels.append(torch.tensor(0, dtype=torch.float32))
        
        # Randomly choose an image for the second digit
        idx2 = random.choice(digit_indices[digit2])
        img2, _ = dataset[idx2]
        
        # Create a dissimilar pair with label 1
        pairs.append([img1_1, img2])
        labels.append(torch.tensor(1, dtype=torch.float32))
        
    return pairs, labels

pairs, labels = create_pairs(mnist_dataset)

# Create a custom dataset
class SiameseDataset(torch.utils.data.Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1, img2 = self.pairs[idx]
        label = self.labels[idx]
        return img1, img2, label


siamese_dataset = SiameseDataset(pairs, labels)
dataloader = torch.utils.data.DataLoader(siamese_dataset, batch_size=64, shuffle=True)

This code does the following:

  • Loads the MNIST dataset using torchvision.datasets.MNIST.
  • create_pairs function creates pairs of images. It ensures that half the pairs are similar (same digit) and half are dissimilar (different digits).
  • SiameseDataset is a custom dataset class that returns pairs of images and their corresponding labels.
  • Creates a DataLoader to efficiently load the data in batches during training.

Training the Siamese Network

Now, we can train our Siamese network using the contrastive loss function and the prepared dataset.

# Initialize the network, loss function, and optimizer
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

# Training loop
epochs = 10
for epoch in range(epochs):
    for i, (img1, img2, label) in enumerate(dataloader):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output1, output2 = model(img1, img2)

        # Calculate the loss
        loss = criterion(output1, output2, label)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Print the loss
        if i % 50 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

print('Training finished')

In this code:

  • We initialize the SiameseNetwork, ContrastiveLoss, and Adam optimizer.
  • The training loop iterates over the dataset for a specified number of epochs.
  • In each iteration, we perform a forward pass, calculate the loss, and update the network's parameters using backpropagation.
  • The loss is printed every 50 steps to monitor the training progress.

Evaluating the Siamese Network

After training, it's important to evaluate the performance of our Siamese network. We can do this by feeding pairs of images to the network and checking if the predicted similarity matches the ground truth.

# Evaluation
model.eval()
with torch.no_grad():
    # Choose a few pairs to test
    test_pairs, test_labels = create_pairs(mnist_dataset, num_pairs=10)
    
    for i in range(len(test_pairs)):
        img1, img2 = test_pairs[i]
        label = test_labels[i]
        
        # Forward pass
        output1, output2 = model(img1.unsqueeze(0), img2.unsqueeze(0))
        
        # Calculate the Euclidean distance
        euclidean_distance = F.pairwise_distance(output1, output2)
        
        # Print the results
        print(f'Pair {i+1}: Label = {int(label.item())}, Distance = {euclidean_distance.item():.4f}')

This code:

  • Sets the model to evaluation mode using model.eval().
  • Creates a few test pairs using the create_pairs function.
  • Iterates over the test pairs and calculates the Euclidean distance between the embeddings generated by the network.
  • Prints the ground truth label and the calculated distance for each pair. Lower distances should correspond to similar pairs, and higher distances to dissimilar pairs.

Conclusion

In this article, we've walked through the process of building a Siamese network using PyTorch. We covered the key concepts behind Siamese networks, implemented the contrastive loss function, prepared a dataset of paired images, and trained the network. By following this guide, you should now have a solid understanding of how to implement and train your own Siamese networks for various applications. Remember to experiment with different architectures, loss functions, and datasets to further enhance your understanding and improve performance. Good luck!