Appendix G: Creating a CNN to Predict Bitmojis

Appendix G: Creating a CNN to Predict Bitmojis#

import numpy as np
import torch
from PIL import Image
from torch import nn, optim
from torchvision import datasets, transforms, utils
from torchsummary import summary
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams.update({'font.size': 16, 'axes.labelweight': 'bold', 'axes.grid': False})
import plotly.io as pio
pio.renderers.default = 'notebook'

1. Introduction#


This code-based appendix contains the code needed to develop and save the Bitmoji CNN’s I use in Advanced CNNs Lecture.

2. CNN from Scratch#


TRAIN_DIR = "data/bitmoji_rgb/train/"
VALID_DIR = "data/bitmoji_rgb/valid/"
IMAGE_SIZE = 64
BATCH_SIZE = 64

# Transforms
data_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE), transforms.ToTensor()])
# Load data and create dataloaders
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=data_transforms)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataset = datasets.ImageFolder(root=VALID_DIR, transform=data_transforms)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Plot samples
sample_batch = next(iter(trainloader))
plt.figure(figsize=(10, 8)); plt.axis("off"); plt.title("Sample Training Images")
plt.imshow(np.transpose(utils.make_grid(sample_batch[0], padding=1, normalize=True),(1,2,0)));
_images/f1515f061f54973fe51b06519b993c9355420d73c521babac292549981514789.png
class bitmoji_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 8, (5, 5)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(8, 4, (3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((3, 3)),
            nn.Flatten(),
            nn.Linear(324, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        out = self.main(x)
        return out
def trainer(model, criterion, optimizer, trainloader, validloader, epochs=5, verbose=True):
    """Training wrapper for PyTorch network."""
    
    train_loss, valid_loss, valid_accuracy = [], [], []
    for epoch in range(epochs):  # for each epoch
        train_batch_loss = 0
        valid_batch_loss = 0
        valid_batch_acc = 0
        
        # Training
        for X, y in trainloader:
            optimizer.zero_grad()
            y_hat = model(X).flatten()
            loss = criterion(y_hat, y.type(torch.float32))
            loss.backward()
            optimizer.step()
            train_batch_loss += loss.item()
        train_loss.append(train_batch_loss / len(trainloader))
        
        # Validation
        with torch.no_grad():  # this stops pytorch doing computational graph stuff under-the-hood and saves memory and time
            for X, y in validloader:
                y_hat = model(X).flatten()
                y_hat_labels = torch.sigmoid(y_hat) > 0.5
                loss = criterion(y_hat, y.type(torch.float32))
                valid_batch_loss += loss.item()
                valid_batch_acc += (y_hat_labels == y).type(torch.float32).mean().item()
        valid_loss.append(valid_batch_loss / len(validloader))
        valid_accuracy.append(valid_batch_acc / len(validloader))  # accuracy
        
        # Print progress
        if verbose:
            print(f"Epoch {epoch + 1}:",
                  f"Train Loss: {train_loss[-1]:.3f}.",
                  f"Valid Loss: {valid_loss[-1]:.3f}.",
                  f"Valid Accuracy: {valid_accuracy[-1]:.2f}.")
    
    results = {"train_loss": train_loss,
               "valid_loss": valid_loss,
               "valid_accuracy": valid_accuracy}
    return results    
# Define and train model
model = bitmoji_CNN()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
results = trainer(model, criterion, optimizer, trainloader, validloader, epochs=20)
Epoch 1: Train Loss: 0.692. Valid Loss: 0.685. Valid Accuracy: 0.64.
Epoch 2: Train Loss: 0.670. Valid Loss: 0.666. Valid Accuracy: 0.56.
Epoch 3: Train Loss: 0.646. Valid Loss: 0.639. Valid Accuracy: 0.63.
Epoch 4: Train Loss: 0.623. Valid Loss: 0.620. Valid Accuracy: 0.67.
Epoch 5: Train Loss: 0.614. Valid Loss: 0.612. Valid Accuracy: 0.63.
Epoch 6: Train Loss: 0.589. Valid Loss: 0.550. Valid Accuracy: 0.75.
Epoch 7: Train Loss: 0.570. Valid Loss: 0.520. Valid Accuracy: 0.77.
Epoch 8: Train Loss: 0.555. Valid Loss: 0.510. Valid Accuracy: 0.76.
Epoch 9: Train Loss: 0.537. Valid Loss: 0.545. Valid Accuracy: 0.72.
Epoch 10: Train Loss: 0.527. Valid Loss: 0.530. Valid Accuracy: 0.74.
Epoch 11: Train Loss: 0.503. Valid Loss: 0.483. Valid Accuracy: 0.78.
Epoch 12: Train Loss: 0.479. Valid Loss: 0.434. Valid Accuracy: 0.81.
Epoch 13: Train Loss: 0.458. Valid Loss: 0.447. Valid Accuracy: 0.80.
Epoch 14: Train Loss: 0.438. Valid Loss: 0.435. Valid Accuracy: 0.80.
Epoch 15: Train Loss: 0.421. Valid Loss: 0.385. Valid Accuracy: 0.84.
Epoch 16: Train Loss: 0.408. Valid Loss: 0.389. Valid Accuracy: 0.83.
Epoch 17: Train Loss: 0.391. Valid Loss: 0.381. Valid Accuracy: 0.83.
Epoch 18: Train Loss: 0.368. Valid Loss: 0.346. Valid Accuracy: 0.87.
Epoch 19: Train Loss: 0.377. Valid Loss: 0.385. Valid Accuracy: 0.84.
Epoch 20: Train Loss: 0.350. Valid Loss: 0.315. Valid Accuracy: 0.88.
# Save model
PATH = "models/bitmoji_cnn.pt"
torch.save(model, PATH)

2. CNN from Scratch with Data Augmentation#


Consider the “Tom” Bitmoji below:

image = Image.open('img/tom-bitmoji.png')
image
_images/96f050236571f5237eb17952ed470c3f72079ca9c63674c921cc53d73b9b64b4.png

You can tell it’s a me (class “Tom”), can my CNN?

image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: tom

Great! But what happens if I flip my image like this:

image = image.rotate(180)
image
_images/b7409da1144b343bf2d8b26e081a93aef1e5baab2d9ec08fe38f14fdf6eb4ed4.png

You can still tell that it’s me, but can my CNN?

image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: tom

Looks like our CNN is not very robust to rotational changes in our input image. We could try and fix that using some data augmentation, let’s do that now:

# Transforms
data_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),
                                      transforms.RandomVerticalFlip(p=0.5),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.ToTensor()])

# Load data and re-create training loader
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=data_transforms)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Plot samples
sample_batch = next(iter(trainloader))
plt.figure(figsize=(10, 8)); plt.axis("off"); plt.title("Sample Training Images")
plt.imshow(np.transpose(utils.make_grid(sample_batch[0], padding=1, normalize=True),(1,2,0)));
_images/7fe4892a61056a9e2a1871501785751579d7fa92e88d2cc12b4505497a1a7b35.png

Okay, let’s train again with our new augmented dataset:

# Define and train model
model = bitmoji_CNN()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
results = trainer(model, criterion, optimizer, trainloader, validloader, epochs=40)
Epoch 1: Train Loss: 0.694. Valid Loss: 0.693. Valid Accuracy: 0.50.
Epoch 2: Train Loss: 0.692. Valid Loss: 0.691. Valid Accuracy: 0.56.
Epoch 3: Train Loss: 0.689. Valid Loss: 0.677. Valid Accuracy: 0.62.
Epoch 4: Train Loss: 0.674. Valid Loss: 0.674. Valid Accuracy: 0.53.
Epoch 5: Train Loss: 0.656. Valid Loss: 0.637. Valid Accuracy: 0.65.
Epoch 6: Train Loss: 0.643. Valid Loss: 0.636. Valid Accuracy: 0.60.
Epoch 7: Train Loss: 0.632. Valid Loss: 0.599. Valid Accuracy: 0.67.
Epoch 8: Train Loss: 0.612. Valid Loss: 0.611. Valid Accuracy: 0.66.
Epoch 9: Train Loss: 0.606. Valid Loss: 0.584. Valid Accuracy: 0.67.
Epoch 10: Train Loss: 0.591. Valid Loss: 0.555. Valid Accuracy: 0.74.
Epoch 11: Train Loss: 0.580. Valid Loss: 0.614. Valid Accuracy: 0.63.
Epoch 12: Train Loss: 0.575. Valid Loss: 0.546. Valid Accuracy: 0.74.
Epoch 13: Train Loss: 0.562. Valid Loss: 0.532. Valid Accuracy: 0.77.
Epoch 14: Train Loss: 0.552. Valid Loss: 0.531. Valid Accuracy: 0.75.
Epoch 15: Train Loss: 0.550. Valid Loss: 0.511. Valid Accuracy: 0.77.
Epoch 16: Train Loss: 0.533. Valid Loss: 0.518. Valid Accuracy: 0.77.
Epoch 17: Train Loss: 0.522. Valid Loss: 0.505. Valid Accuracy: 0.77.
Epoch 18: Train Loss: 0.511. Valid Loss: 0.507. Valid Accuracy: 0.75.
Epoch 19: Train Loss: 0.520. Valid Loss: 0.467. Valid Accuracy: 0.79.
Epoch 20: Train Loss: 0.509. Valid Loss: 0.462. Valid Accuracy: 0.80.
Epoch 21: Train Loss: 0.494. Valid Loss: 0.504. Valid Accuracy: 0.75.
Epoch 22: Train Loss: 0.479. Valid Loss: 0.439. Valid Accuracy: 0.81.
Epoch 23: Train Loss: 0.463. Valid Loss: 0.455. Valid Accuracy: 0.81.
Epoch 24: Train Loss: 0.455. Valid Loss: 0.420. Valid Accuracy: 0.81.
Epoch 25: Train Loss: 0.444. Valid Loss: 0.412. Valid Accuracy: 0.81.
Epoch 26: Train Loss: 0.430. Valid Loss: 0.453. Valid Accuracy: 0.77.
Epoch 27: Train Loss: 0.432. Valid Loss: 0.402. Valid Accuracy: 0.82.
Epoch 28: Train Loss: 0.426. Valid Loss: 0.402. Valid Accuracy: 0.82.
Epoch 29: Train Loss: 0.393. Valid Loss: 0.374. Valid Accuracy: 0.81.
Epoch 30: Train Loss: 0.389. Valid Loss: 0.350. Valid Accuracy: 0.85.
Epoch 31: Train Loss: 0.375. Valid Loss: 0.354. Valid Accuracy: 0.85.
Epoch 32: Train Loss: 0.365. Valid Loss: 0.315. Valid Accuracy: 0.87.
Epoch 33: Train Loss: 0.354. Valid Loss: 0.342. Valid Accuracy: 0.85.
Epoch 34: Train Loss: 0.350. Valid Loss: 0.320. Valid Accuracy: 0.88.
Epoch 35: Train Loss: 0.333. Valid Loss: 0.326. Valid Accuracy: 0.87.
Epoch 36: Train Loss: 0.335. Valid Loss: 0.286. Valid Accuracy: 0.88.
Epoch 37: Train Loss: 0.315. Valid Loss: 0.299. Valid Accuracy: 0.88.
Epoch 38: Train Loss: 0.310. Valid Loss: 0.284. Valid Accuracy: 0.88.
Epoch 39: Train Loss: 0.288. Valid Loss: 0.285. Valid Accuracy: 0.88.
Epoch 40: Train Loss: 0.290. Valid Loss: 0.288. Valid Accuracy: 0.89.
# Save model
PATH = "models/bitmoji_cnn_augmented.pt"
torch.save(model, PATH)

Let’s try predict this one again:

image
_images/b7409da1144b343bf2d8b26e081a93aef1e5baab2d9ec08fe38f14fdf6eb4ed4.png
image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: tom

Got it now!