Pokemon Classifier with PyTorch

Original article was published on Deep Learning on Medium


Data looks fairly balanced out barring few classes. Now let’s go ahead to process of data transformation, augmentation.

Transformations

  1. Resize

Since the image size of all pokemons varies, we will resize all images to be of 400 x 400 size.

transforms.Resize((400, 400)

2. Horizontal Flip

Horizontal flipping is most of the times a good idea while training the model, acts as an augmenter to our training model.

torchvision.transforms.RandomHorizontalFlip()

3. Rotation

Pokemons in images are likely not to be always in ‘straight’ posture, some are jumping (Machop) , flying (Farfetchd), sleeping (Snorlax). As a result of which images can come to our evaluator in different positions. Therefore, it is a good idea to feed images with random rotation to our training model.

torchvision.transforms.RandomRotation(15)

4. Normalization

torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

Normalize does the following for each channel:

image = (image — mean) / std

The parameters mean, std are passed as 0.5, 0.5 in our case. This will normalize the image in the range [-1,1]. For example, the minimum value 0 will be converted to (0–0.5)/0.5=-1, the maximum value of 1 will be converted to (1–0.5)/0.5=1.

Now that our dataset has mean=0 and std=1 and there are fewer chances of vanishing or exploding gradients.

5. Convert to Tensor

Required as PyTorch deals data in form of tensors.

transforms.ToTensor()

We can compose all 5 above listed transformation and pass it as parameter to ImageFolder which loads the data.

train_tfms = transforms.Compose([transforms.Resize((400, 400)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Once we have applied transformations, let’s load one image to observe see them in place.

Now that we have transforms ready let’s divide the data into train and test data sets in ratio of 85 and 15%.

train_size = int(len(dataset)*0.85)
val_size = int(len(dataset)*0.15)
train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)
Out: (5797, 1023)

Now using Dataloader we will load the data from dataset in batch size of 32 images.

trainloader = DataLoader(train_ds, batch_size = 32, shuffle=True, num_workers = 2)
testloader = DataLoader(val_ds, batch_size = 32, shuffle=False, num_workers = 2)
Images loaded in one iteration of train data loader

Training

Our training function looks something like this:-

train_model(model, criterion, optimizer, scheduler, n_epochs = 5)
  1. Model

We will use pre-trained ResNet-34 model with 150 input features.

model_ft = models.resnet34(pretrained=True)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 150)

2. Loss Function

We will cross entropy function as our loss function as it works quite well for multi-class classification

criterion = nn.CrossEntropyLoss()

3. Optimizer

We will use Stochastic Gradient Descent with Momentum. Momentum loosely explained is a moving average of our gradients, which is used to update weights, and ultimately helps in faster convergence.

We pass 0.9 as beta (hyperparameter) value, which often works well with SGD.

optimizer = optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9)

4. Scheduler

Scheduler helps in adjusting the learning rate. We are tracking the training accuracy and measuring whether it increases by atleast 0.9 per epoch and if it hasn’t increased by 0.9 reduce the learning rate (lr) by 0.1x.

lrscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=’max’, patience=3, threshold = 0.9)

5. Epochs

Default number of epochs are set to 5.

Using this model and hyperparameter tuned, we are able to achieve 95% accuracy in just 5 epochs in less than 10 mins, which is really cool, isn’t it?

Training results

Testing and Evaluating the trained model

Now comes the main part, where we evaluate our model on never-seen-before Pokemon images randomly picked from Google.

  1. Gengar

2. Farfetchd

3. Persian/Meowth

4. Machop

5. MrMime

Our model performs outstanding well of on these 5 never seen images with accuracy of 100%

(Note : In test case no. 3 where we have two pokemons (Persian/Meowth) in the same image, the model picks up one of them, as it is trained to pick only one pokemon in an image.)

I hope you would have had great time reading the article.

Full Implementation Notebook: https://jovian.ml/gitrohitjain/pokemon-classification