AutoTrain- Train a CNN in Just One Line

Source: Deep Learning on Medium

Train a CNN in One Line — AutoTrain

A simple and effective interface to train Deep Learning Models

A couple of days ago, I started working on a project called AutoTrain which makes the process of training Deep Neural Networks easy and effective. We achieve this by taking the advantage of Command Line.

AutoTrain is a CLI (command line interface) designed for fast prototyping and training models quickly. It allows you to create different versions of your model, without changing a single line of code.

Let’s dive in.

Data

AutoTrain is built with Keras and uses the Image Generator class extensively. The data should be strictly arranged in the following way-

I have a dataset of two classes- Rot and Golden, both containing images of dogs of Rottweiler and Golden retriever breeds respectively.

Alright, time to clone the AutoTrain repo.

AutoTrain Repository

git clone the repository in your project folder. The AutoTrain github repo holds two boilerpolates so far. One is cnn_scratch_cli.py which is meant to train a CNN from scratch (i.e without using any pre-trained weights) while cnn_transfer_learning.py should be used when training a CNN using transfer learning.

We will use transfer learning file to train build an image classifier on our Dogs dataset. Let’s get familiar with the arguments that we will need to input in order to train the network.

Args

Let’s fire up terminal and input the arguments

python3 cnn_transfer_learning.py — train_dir data/train — val_dir data/val — num_classes 2 — dense_neurons 256 — batch_size 2 — learning_rate 0.0001 — epochs 20 — model_name dogsV1

Outputs

The good thing about using AutoTrain is that after the training is done, the learned .h5 weights file, model accuracy plots, model loss plots and the model architecture are saved in a separate directory which gets generated by the — model_name argument.

Having everything available in one folder saves lots of time spent in writing separate lines of code for saving weights, graphs etc. This makes the process more organized and efficient.

Try it yourself. Find the code here. Make sure to star it on Github ❤. Happy coding!