Fixing Small Photos with Deep Learning

Original article was published by Subhaditya Mukherjee on Deep Learning on Medium


Create the network

Welcome to the deep end. Our network is actually simple since we are using PyTorch. We have 4 conv layers, 4 ReLUs and a special layer called PixelShuffle.

What is PixelShuffle. Well it is the defining moment for the paper we are considering. So defining in fact that PyTorch actually has it inbuilt. In simple terms, this layer is a shuffler. It takes the the tensor of shape H(height) x W(width) x C(channel) . r²(no of activation) and gives us back a tensor of shape rH x rW x rC. “Shuffle” aka a sub pixel convolutional layer. This is useful because now everything is parallelizable.

We also need to initalize our weights. This is done to make sure that our network starts off on a good foot. We use orthogonal initialization here.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

# Main network
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5,5), (1,1), (2,2))
self.conv2 = nn.Conv2d(64, 64, (3,3), (1,1), (1,1))
self.conv3 = nn.Conv2d(64, 32, (3,3), (1,1), (1,1))
self.conv4 = nn.Conv2d(32, upscale_factor**2, (3,3), (1,1), (1,1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

self._initialize_weights()

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x

def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight,init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight,init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)