Building a Deep Learning model with Pytorch to classify fruits and vegetables

Original article was published on Deep Learning on Medium


Good ,we got 2 right and 1 wrong.The wrong prediction was predicting corn husk as corn which isn’t much of a concern and should be resolved by a bit more training data.

Testing on the entire test data

Identifying where our model performs poorly can help us improve the model, by collecting more training data, increasing/decreasing the complexity of the model, and changing the hyperparameters.

As a final step, let’s also look at the overall loss and accuracy of the model on the test set. We expect these values to be similar to those for the validation set. If not, we might need a better validation set that has similar data and distribution as the test set.

As we can see we got only 91% accuracy on the test data while we got 98% on the validation data.This might be due to the validation set not having a sufficient distribution of images.We should be able to improve the test accuracy by adding more data to training set.

Saving and loading the model

Since we’ve trained our model for a long time and achieved a resonable accuracy, it would be a good idea to save the weights of the model to disk, so that we can reuse the model later and avoid retraining from scratch. Here’s how you can save the model.

torch.save(model.state_dict(), 'fruits360-cnn.pth')

The .state_dict method returns an OrderedDict containing all the weights and bias matrices mapped to the right attributes of the model. We then save the weights and biases to a file called ‘fruits360-cnn.pth’.

To load the model weights, we can redefine the model with the same structure, and use the .load_state_dict method.

model2 = to_device(Fruits360CnnModel(), device)

Just as a sanity check, let’s verify that this model has the same loss and accuracy on the test set as before.

So we’re all set and good.The re-initialized model gives the same accuracy as our original model.

Conclusion

We have successfully created and trained a deep learning model based on CNNs to classify images of fruits and vegetables.We have also seen the Adam optimization algorithm which is different from the classical stochastic gradient descent but gave us great results.The accuracy can be improved a bit more if we use the test data as validation set and train the model.

The model can be further challenged and improved by introducing images of fruits and vegetables that are harder to differentiate into the training and test datasets.You can find the full code for this project here if you want to check it out.You can also use the same model architecture to tackle other similar classification problems.