How to build your very first SimSiam model with FashionMNIST
Contrastive learning has many use cases these days. From NLP and computer vision to recommendation systems, contrastive learning can be used to learn underlying data representations without any explicit labels, which can then be used for downstream classification, detection, similarity search, etc.
There are many online resources to help the audience understand the basic ideas of contrastive learning so that I won’t add one more blog post repeating the information. Instead, I will show you how to convert your supervised learning problem into a contrastive learning problem in this article. Specifically, I will start with a basic classification model for the FashionMNIST (MIT licence). Then, I will proceed to an advanced problem with limited training labels (e.g., reducing the full training set of 60,000 labels to 1,000). I will introduce SimSiam, a state-of-the-art method for contrastive learning, and show step-by-step instructions on modifying the original linear layers in the SimSiam style. Ultimately, I’ll show the results — SimSiam could improve the F1 score by 15% with a very basic configuration.
Now, let’s start. First, we’ll load in the FashionMNIST dataset. A custom FashionMNIST class is used to obtain a subset of the training set named the finetune_dataset. The source code for the customer FashionMNIST class will be given at the end of this article.
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from FashionMNIST import FashionMNIST
train_dataset = FashionMNIST("./FashionMNIST",
train=True,
transform=transforms.ToTensor(),
download=True,
)
test_dataset = FashionMNIST("./FashionMNIST",
train=False,
transform=transforms.ToTensor(),
download=True,
)
finetune_dataset = FashionMNIST("./FashionMNIST",
train=True,
transform=transforms.ToTensor(),
download=True,
first_k=1000,
)
# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
# Loop through each subplot and plot an image
for i in range(4):
for j in range(4):
image, label = train_dataset[i * 4 + j] # Get image and label
image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array
axs[i, j].imshow(image_numpy, cmap='gray') # Plot the image
axs[i, j].axis('off') # Turn off axis
axs[i, j].set_title(f"Label: {label}") # Set title with label
plt.tight_layout() # Adjust layout
plt.show() # Show plot
The code will show a grid of images from the train_dataset
Next, we’ll define the supervised classification model. The architecture contains a backbone of convolutional layers and an MLP head of two linear layers. This will set a consistent baseline for the following experiments, as SimSiam will only replace the MLP head for contrastive learning purposes.
import torch.nn as nn
class supervised_classification(nn.Module):
def __init__(self):
super(supervised_classification, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.fc = nn.Sequential(
nn.Linear(128*4*4, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
def forward(self, x):
x = self.backbone(x).view(-1, 128 * 4 * 4)
return self.fc(x)
We’ll train the model for 10 epochs:
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
wandb_config = {
"learning_rate": 0.001,
"architecture": "fashion mnist classification full training",
"dataset": "FashionMNIST",
"epochs": 10,
"batch_size": 64,
}
wandb.init(
# set the wandb project where this run will be logged
project="supervised_classification",
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
supervised = supervised_classification()
optimizer = optim.SGD(supervised.parameters(),
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)
train_dataloader = DataLoader(train_dataset,
batch_size=wandb_config["batch_size"],
shuffle=True,
)
# Training loop
loss_fun = nn.CrossEntropyLoss()
for epoch in range(wandb_config["epochs"]):
supervised.train()
train_loss = 0
for batch_idx, (image, target) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()
prediction = supervised(image)
loss = loss_fun(prediction, target)
loss.backward()
optimizer.step()
wandb.log({"training loss": loss})
torch.save(supervised.state_dict(), "weights/fully_supervised.pt")
Using the classification_report from the scikit-learn package, we’ll get the following results:
from sklearn.metrics import classification_report
supervised = supervised_classification()
supervised.load_state_dict(torch.load("weights/fully_supervised.pt"))
supervised.eval()
supervised.to(device)
target_list = []
prediction_list = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
with torch.no_grad():
prediction = supervised(image.to(device))
prediction_list.extend(torch.argmax(prediction, dim=1).detach().cpu().numpy())
target_list.extend(target.detach().cpu().numpy())
print(classification_report(target_list, prediction_list))
# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
# Loop through each subplot and plot an image
for i in range(4):
for j in range(4):
image, label = test_dataset[i * 4 + j] # Get image and label
image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array
prediction = supervised(torch.unsqueeze(image, dim=0).to(device))
prediction = torch.argmax(prediction, dim=1).detach().cpu().numpy()
axs[i, j].imshow(image_numpy, cmap='gray') # Plot the image
axs[i, j].axis('off') # Turn off axis
axs[i, j].set_title(f"Label: {label}, Pred: {prediction}") # Set title with label
plt.tight_layout() # Adjust layout
plt.show() # Show plot
Now, let’s think about a new problem. What should we do if we’re given a limited subset of the training set labels, e.g., only 1000 images out of the total 60,000 images are annotated? The natural idea is to simply train the model on the limited annotated dataset. So without changing the backbone, we let the model train on the limited subset for 100 epochs (we increase the epochs to have a fair comparison to our SimSiam training):
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
wandb_config = {
"learning_rate": 0.001,
"architecture": "fashion mnist classification full training on finetune set",
"dataset": "FashionMNIST",
"epochs": 100,
"batch_size": 64,
}
wandb.init(
# set the wandb project where this run will be logged
project="supervised_classification",
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
supervised = supervised_classification()
optimizer = optim.SGD(supervised.parameters(),
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)
finetune_dataloader = DataLoader(finetune_dataset,
batch_size=wandb_config["batch_size"],
shuffle=True,
)
# Training loop
loss_fun = nn.CrossEntropyLoss()
for epoch in range(wandb_config["epochs"]):
supervised.train()
train_loss = 0
for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader, total=len(finetune_dataloader))):
optimizer.zero_grad()
prediction = supervised(image)
loss = loss_fun(prediction, target)
loss.backward()
optimizer.step()
wandb.log({"training loss": loss})
torch.save(supervised.state_dict(), "weights/fully_supervised_finetunedataset.pt")
Now it’s time for some contrastive learning. To mitigate the issue of insufficient annotation labels and fully utilize the large quantity of unlabelled data, contrastive learning could be used to effectively help the backbone learn the data representations without a specific task. The backbone could be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to achieve satisfactory results.
The most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.
SimCLR calculates over positive and negative pairs within the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a large batch size. SimCLR also requires the LARS optimizer to accommodate a large batch size.
SimSiam, however, uses a Siamese architecture, which avoids using negative pairs and further avoids the need for large batch sizes. The differences between SimSiam and SimCLR are given in the table below.
We can see from the figure above that the SimSiam architecture only contains two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.
So, how do we implement this architecture in reality? Continuing on the supervised classification design, we keep the backbone the same and only modify the MLP layer. In the supervised learning architecture, the MLP outputs a 10-element vector indicating the probabilities of the 10 classes. But for SimSiam, the purpose is not to perform “classification” but to learn the “representation,” so we need the output to be of the same dimension as the backbone output for loss calculation. And the negative_cosine_similarity is given below:
import torch.nn as nn
import matplotlib.pyplot as plt
class SimSiam(nn.Module):
def __init__(self):
super(SimSiam, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 128*4*4),
)
def forward(self, x):
x = self.backbone(x)
x = x.view(-1, 128 * 4 * 4)
pred_output = self.prediction_mlp(x)
return x, pred_output
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def negative_cosine_similarity_stopgradient(pred, proj):
return -cos(pred, proj.detach()).mean()
The pseudo-code for training the SimSiam is given in the original paper below:
And we convert it into real training code:
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import RandAugment
import wandb
wandb_config = {
"learning_rate": 0.0001,
"architecture": "simsiam",
"dataset": "FashionMNIST",
"epochs": 100,
"batch_size": 256,
}
wandb.init(
# set the wandb project where this run will be logged
project="simsiam",
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simsiam = SimSiam()
random_augmenter = RandAugment(num_ops=5)
optimizer = optim.SGD(simsiam.parameters(),
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)
train_dataloader = DataLoader(train_dataset, batch_size=wandb_config["batch_size"], shuffle=True)
# Training loop
for epoch in range(wandb_config["epochs"]):
simsiam.train()
print(f"Epoch {epoch}")
train_loss = 0
for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()
aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0,
random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0
proj1, pred1 = simsiam(aug1)
proj2, pred2 = simsiam(aug2)
loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2
loss.backward()
optimizer.step()
wandb.log({"training loss": loss})
if (epoch+1) % 10 == 0:
torch.save(simsiam.state_dict(), f"weights/simsiam_epoch{epoch+1}.pt")
We trained for 100 epochs as a fair comparison to the limited supervised training; the training loss is shown below. Note: Due to its Siamese design, SimSiam could be very sensitive to hyperparameters like learning rate and MLP hidden layers. The original SimSiam paper provides a detailed configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.
Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:
import tqdm
import numpy as np
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simsiam = SimSiam()
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
simsiam.load_state_dict(torch.load("weights/simsiam_epoch100.pt"))
simsiam.eval()
simsiam.to(device)
features = []
labels = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
with torch.no_grad():
proj, pred = simsiam(image.to(device))
features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())
labels.extend(target.detach().cpu().numpy().tolist())
import plotly.express as px
import umap.umap_ as umap
reducer = umap.UMAP(n_components=3, n_neighbors=10, metric="cosine")
projections = reducer.fit_transform(np.array(features))
px.scatter(projections, x=0, y=1,
color=labels, labels={'color': 'Fashion MNIST Labels'}
)
It’s interesting to see that there are two small islands in the reduced-dimension map above: class 5, 7, 8, and some 9. If we pull out the FashionMNIST class list, we know that these classes correspond to footwear such as “Sandal,” “Sneaker,” “Bag,” and “Ankle boot.” The big purple cluster corresponds to clothing classes like “T-shirt/top,” “Trousers,” “Pullover,” “Dress,” “Coat,” and “Shirt.” The SimSiam demonstrates learning a meaningful representation in the vision domain.
Now that we have the correct representations, how can they benefit our classification problem? We simply load the trained SimSiam backbone into our classification model. However, instead of fine-tuning the whole architecture in the limited training set, we fine-tuned the linear layers and froze the backbone because we didn’t want to corrupt the representation already learned.
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
wandb_config = {
"learning_rate": 0.001,
"architecture": "supervised learning with simsiam backbone",
"dataset": "FashionMNIST",
"epochs": 100,
"batch_size": 64,
}
wandb.init(
# set the wandb project where this run will be logged
project="simsiam-finetune",
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
supervised = supervised_classification()
model_dict = supervised.state_dict()
simsiam_dict = {k: v for k, v in model_dict.items() if k in torch.load("simsiam.pt")}
supervised.load_state_dict(simsiam_dict, strict=False)
finetune_dataloader = DataLoader(finetune_dataset, batch_size=32, shuffle=True)
for param in supervised.backbone.parameters():
param.requires_grad = False
parameters = [para for para in supervised.parameters() if para.requires_grad]
optimizer = optim.SGD(parameters,
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)
# Training loop
for epoch in range(wandb_config["epochs"]):
supervised.train()
train_loss = 0
for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader)):
optimizer.zero_grad()
prediction = supervised(image)
loss = nn.CrossEntropyLoss()(prediction, target)
loss.backward()
optimizer.step()
wandb.log({"training loss": loss})
torch.save(supervised.state_dict(), "weights/supervised_with_simsiam.pt")
Here is the evaluation result of the SimSiam-pre-trained classification model. The average F1 score is increased by 15% compared to the supervised-only method.
Summary. We showcase a simple but intuitive example, using FashionMNIST for contrastive learning. By using SimSiam for backbone pre-training and only fine-tuning the linear layers on the limited training set (which contains only 2% of the labels of the full training set), we increased the average F1 score by 15% over the fully supervised learning method. The trained weights, the notebook, and the customized FashionMNIST dataset class are all included in this GitHub repository.
Give it a try!
References:
- Chen et al., Exploring simple siamese representation learning. CVPR 2021.
- Chen et al., A simple framework for contrastive learning of visual representations. ICML 2020.
- Chen et al., An Empirical Study of Training Self-Supervised Vision Transformers. ICCV 2021.
- Xiao et al., Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. arXiv preprint 2017. Github: https://github.com/zalandoresearch/fashion-mnist
A Practical Guide to Contrastive Learning was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
A Practical Guide to Contrastive Learning
Go Here to Read this Fast! A Practical Guide to Contrastive Learning