Recommendation System implementation with Deep Learning and PyTorch

Original article was published by Rudraa Labs on Deep Learning on Medium

A machine learning algorithm accepts only an array of numerical values, we just can’t send the above dataset directly. There are a lot of embedding approaches most widely used approach is one-hot encoding, next is word2vec. One hot encoding, the columns with categorical data are numbered based on the number of categories like 0/1. We have more than 1000 category data, so we created a Neural network-based embedding of data.


With the PyTorch framework, we created an embedding network, which takes in the Number of users and Number of movies as input. The network takes the output of Movies embedding and User embeddings as inputs, which concatenate into a column (array). The network has 4 layers starting with dropout layers, then 3 fully connected layers with relu activation and a dropout. Drop out as added to randomize the network and increase the learning capability of the network. Finally, output layer with a sigmoid activation function.

Below snippet shows the network implementation with the PyTorch framework. (

For example with a network with 100,200,300 as a number of hidden layers and dropouts. The output network with be as shown below.

EmbeddingNet(n, m, n_factors=150, hidden=[100, 200, 300], dropouts=[0.25, 0.5])

Training Loop

Below is the snippet of the training loop with Mean Squared Loss as quality measurement metrics and Adam function are used as an optimizer. Training parameters are chosen based on my previous experience, you can change it according to your own style.

Please follow this link to see the full source code required to prepare the dataset and to train the model.


After 10–15mins of training the network resulted in an accuracy of 0.8853. The above graph shows the learning curve of the network during training and validation over 16 epochs. PyTorch is a powerful framework that has the potential to easily scaling it up large datasets. There is another data set available at Kaggle ( that would be great to try it.