Chapter 07 Chapter 7: Image-to-Image Translation with GANs
Image-to-Image translation with Generative Adversarial Networks (GANs) is a technique used to convert images from one domain to another. It involves training a GAN to learn the mapping between images in two different domains, such as turning sketches into colorful images or translating day-time scenes into night-time scenes.
Vanilla Solution with Random Noise:
In a vanilla solution that uses random noise, we don't have a specific model or rule for translating images. We would randomly generate noise (random pixel values) and treat that noise as an image in the target domain. This approach might be acceptable for some tasks, like adding noise to images or creating abstract art. However, it won't be effective for structured translation tasks, as there's no guidance or understanding of the underlying patterns in the data.GAN-based Image-to-Image Translation:
In Image-to-Image translation with GANs, we have a more sophisticated approach. It involves training two neural networks simultaneously:
a. Generator (G): The generator takes input from one domain (e.g., sketches) and tries to generate images in the target domain (e.g., colorful images). Instead of random noise, it learns the mapping between the domains using training data.
b. Discriminator (D): The discriminator acts as a critic that examines images and tries to distinguish between real images from the target domain and fake images generated by the generator. It is trained to become better at identifying real images.
The training process occurs in a competitive manner:
- The generator creates fake images from the input domain.
- The discriminator evaluates the generated fake images and real images from the target domain.
- Both networks continuously improve based on their performance. The generator aims to produce more realistic images to fool the discriminator, while the discriminator improves its ability to distinguish between real and fake images.
Through this process, the generator becomes more skilled at creating realistic target-domain images from the given input. When the training converges, the generator should be capable of accurately translating images from one domain to the other.
Image-to-Image translation with GANs uses two networks (generator and discriminator) that compete against each other during training. The generator learns to translate images from one domain to another, and the discriminator provides feedback to help the generator improve its translations, resulting in high-quality and meaningful image conversions.
Image-to-Image Translation with Paired Data
Simplified pseudocode for a basic Generative Adversarial Network (GAN) for image-to-image translation:
# Generator
function generator(input_image):
# Implementation of the generator network
# Takes an input image and produces a generated image
return generated_image
# Discriminator
function discriminator(input_image):
# Implementation of the discriminator network
# Takes an image (real or generated) and predicts if it's real or fake
return probability_of_real_image
# GAN Training
function train_gan(training_data, num_epochs, batch_size):
initialize generator and discriminator networks
for epoch in num_epochs:
for batch in training_data:
# Update the discriminator
real_images = random_batch_from_training_data(batch_size)
generated_images = generator(random_batch_from_training_data(batch_size))
real_labels = ones(batch_size)
generated_labels = zeros(batch_size)
real_loss = discriminator_loss(discriminator(real_images), real_labels)
generated_loss = discriminator_loss(discriminator(generated_images), generated_labels)
total_discriminator_loss = real_loss + generated_loss
update_discriminator(total_discriminator_loss)
# Update the generator
generated_images = generator(random_batch_from_training_data(batch_size))
generator_loss = generator_loss(discriminator(generated_images), ones(batch_size))
update_generator(generator_loss)
# Print or log the losses after each epoch
print("Epoch: ", epoch, " Discriminator Loss: ", total_discriminator_loss, " Generator Loss: ", generator_loss)
# GAN Loss functions
function discriminator_loss(predictions, labels):
# Binary cross-entropy loss for the discriminator
return mean(negative_log_likelihood(predictions, labels))
function generator_loss(predictions, labels):
# Binary cross-entropy loss for the generator
return mean(negative_log_likelihood(predictions, labels))
# Negative log likelihood
function negative_log_likelihood(predictions, labels):
return -sum(labels * log(predictions) + (1 - labels) * log(1 - predictions))
# Helper function to randomly select a batch from the training data
function random_batch_from_training_data(batch_size):
return random_samples_from_training_data(batch_size)
Pix2Pix (Paired Images)
Let's imagine you are an artist who wants to turn sketches into colorful paintings.
Pix2Pix is like your art helper! You draw a simple black and white sketch, and Pix2Pix works its magic to add vibrant colors to your sketch, creating a colorful painting. It's like having a super cool painting partner!
But here's the interesting part: Pix2Pix needs to be trained first. It looks at many pairs of sketches and their colorful paintings to learn how to add the right colors. Once it's trained, it becomes an amazing artist's assistant, helping you create beautiful paintings from your sketches.
So, Pix2Pix is like a fantastic art buddy that takes your sketches and turns them into colorful masterpieces, making your art journey even more fun and exciting!
Importantly, Pix2Pix does a paired image-to-image translation, the input and the ground-truth image domains are aligned. While paired training samples might be difficult to obtain, this type of translation often leads to great results.
Example of paired images:
import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
import torch.nn as nn
import torch.nn.functional as F
import torch
opt = parser.parse_args()
print(opt)
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)
cuda = True if torch.cuda.is_available() else False
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100
# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)
# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()
if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion_GAN.cuda()
criterion_pixelwise.cuda()
if opt.epoch != 0:
# Load pretrained models
generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Configure dataloaders
transforms_ = [
transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
val_dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode="val"),
batch_size=10,
shuffle=True,
num_workers=1,
)
# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
def sample_images(batches_done):
"""Saves a generated sample from the validation set"""
imgs = next(iter(val_dataloader))
real_A = Variable(imgs["B"].type(Tensor))
real_B = Variable(imgs["A"].type(Tensor))
fake_B = generator(real_A)
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)
# ----------
# Training
# ----------
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
# Model inputs
real_A = Variable(batch["B"].type(Tensor))
real_B = Variable(batch["A"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# GAN loss
fake_B = generator(real_A)
pred_fake = discriminator(fake_B, real_A)
loss_GAN = criterion_GAN(pred_fake, valid)
# Pixel-wise loss
loss_pixel = criterion_pixelwise(fake_B, real_B)
# Total loss
loss_G = loss_GAN + lambda_pixel * loss_pixel
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Real loss
pred_real = discriminator(real_B, real_A)
loss_real = criterion_GAN(pred_real, valid)
# Fake loss
pred_fake = discriminator(fake_B.detach(), real_A)
loss_fake = criterion_GAN(pred_fake, fake)
# Total loss
loss_D = 0.5 * (loss_real + loss_fake)
loss_D.backward()
optimizer_D.step()
# --------------
# Log Progress
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
# Print log
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_pixel.item(),
loss_GAN.item(),
time_left,
)
)
# If at sample interval save image
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))
torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))
CycleGan (Unpaired Images)
Imagine you have two boxes of crayons, one with colors of the ocean, and the other with colors of the forest. Now, you want to turn a picture of the ocean into a picture that looks like it was drawn in the forest crayon style, and vice versa.
CycleGAN helps you do that! It's like having a magical crayon changer. You give it a picture of the ocean, and it uses its magic to change the colors to look like the forest crayon style. Then, it also works the other way around, turning a forest picture into an ocean crayon-style picture.
But wait, there's something special! The CycleGAN also checks if the changed picture of the ocean can be turned back into the original ocean picture using another magical crayon. If it works, then the CycleGAN knows it did a good job!
So, in the end, CycleGAN can turn ocean pictures into forest-style pictures and then turn them back into ocean pictures, making a cool transformation between the two styles with its magical crayon changer. It's like having a fun art adventure with pictures!
import argparse
import os
import numpy as np
import math
import itertools
import datetime
import time
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch
# Create sample and checkpoint directories
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
cuda = torch.cuda.is_available()
input_shape = (opt.channels, opt.img_height, opt.img_width)
# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
if cuda:
G_AB = G_AB.cuda()
G_BA = G_BA.cuda()
D_A = D_A.cuda()
D_B = D_B.cuda()
criterion_GAN.cuda()
criterion_cycle.cuda()
criterion_identity.cuda()
if opt.epoch != 0:
# Load pretrained models
G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
# Initialize weights
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
# Image transformations
transforms_ = [
transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((opt.img_height, opt.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=1,
)
def sample_images(batches_done):
"""Saves a generated sample from the test set"""
imgs = next(iter(val_dataloader))
G_AB.eval()
G_BA.eval()
real_A = Variable(imgs["A"].type(Tensor))
fake_B = G_AB(real_A)
real_B = Variable(imgs["B"].type(Tensor))
fake_A = G_BA(real_B)
# Arange images along x-axis
real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
# Arange images along y-axis
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
# ----------
# Training
# ----------
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
loss_D = (loss_D_A + loss_D_B) / 2
# --------------
# Log Progress
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
# Print log
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_GAN.item(),
loss_cycle.item(),
loss_identity.item(),
time_left,
)
)
# If at sample interval save image
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
# Update learning rates
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))