Only Numpy: Recommending Optimal Treatment for Depression using Dilated Update Gate RNN (Google…

Only Numpy: Recommending Optimal Treatment for Depression using Dilated Update Gate RNN (Google Brain + NIPS2017) with Interactive Code

Depressed Man from Pixel Bay

University of Florida’s Bio-statistic is a great place to find health related data. One of them being a study about depression conducted by National Institutes of Health. To quote from the website directly “109 clinically depressed patients were separated into three groups, and each group was given one of two active drugs (Imipramine or Lithium) or no drug at all. For each patient, the data set contains the treatment used, the outcome of the treatment, and several other interesting characteristics…..”

There are five column for this data set, and for description for each column please see below.

Screen shot from this web page

Using this data, we can build a classifier that predicts whether certain treatment will prevent Depression from reoccurring or not. And using this model we can recommend the optimal treatment that has the highest probability of preventing Depression from reoccurring.

Finally, for fun let’s use different types of back propagation to compare what gives us the best results. The different types of back propagation that we are going to use are….

a. Google Brain’s Gradient Noise
b. Dilated Back Propagation
c. Dilated Back Propagation + Google Brain’s Gradient Noise

If you are not aware of the difference between each of them please read the blog post that I have linked to.

NOTE: All of the CSV data are from University of Florida Bio-statistics, if you are planning to use the data set please check with their Data Use-age policy. Specifically I will use CSV data from ‘Learn By doing Depression’ Data set.

Training Data

As seen above, I removed two columns from original data set, which are…

a. Hospt: The patient’s hospital, represented by a code for each of the 5 hospitals (1, 2, 3, 5, or 6)

b. Time: Either the time (days) till recurrence, or if no recurrence, the length (days) of the patient’s participation in the study.

This was because, I wanted to simplify our objective. We just want the model to tell us whether certain treatment will prevent Depression from reoccurring or not.

Network Architecture (Mathematical Equation form)

Image from Paper

Red Box → General Equation for Update Gate RNN

As seen above, Update Gate RNN is very simple to implement since, it is not that different from vanilla recurrent neural networks.

Network Architecture (Graphic Form / OOP Form) / Feed Forward Direction

Front View of the Network
Side view of the Network

Green Box → Where the Dilated Skipped Connection of our network exist.
Pinkish? Arrow → Direction of Feed Forward Process.

So the feed forward process is very simple to understand, there is nothing really different from vanilla RNN. Now lets take look at the OOP implementation.

Green Box → Initializing all of the Weights
Blue Box → Standard Update Gate RNN Feed Forward Operation

For the C gate I used the arctan() activation function and for the G gate I used the tanh() activation function.

Case 1: Standard Back Propagation

Front View of the Gradient Flow
Side view of Gradient Flow

Pinkish? Arrow → Standard Direction of Gradient Flow

Case 1, is a model that uses standard back propagation. (Though it is called Back propagation through time in RNN, I’ll just call it back propagation for simplicity.)

Case 2: Google Brain’s added Gradient Noise

Front View of Gradient Flow
Side view of Gradient Flow

Pinkish? Arrow → Standard Direction of Gradient Flow
Blue Arrow → Added Gradient Noise at each Weight Update

Again, there is not much different from standard back propagation since we are just adding gradient noise before each weight update.

Case 3: Dilated Back Propagation

Dilated Back Propagation Front View
Side view of Dilated Back propagation

Pinkish? Arrow → Standard Direction of Gradient Flow
Black Curved Arrow → Dilated Back Propagation where we pass on some portion of the gradient to the previous layers, which are not directly connected.

If the above diagram is bit confusing, please read the blog post I made here. Simply put we are just adding some portion of the gradient to layers that are not directly connected.

Case 4: Dilated Back Propagation + Google Brain’s added Gradient Noise

Front View of Case 4
Side View of Case 4

Pinkish? Arrow → Standard Direction of Gradient Flow
Blue Arrow → Added Gradient Noise at each Weight Update
Black Curved Arrow → Dilated Back Propagation where we pass on some portion of the gradient to the previous layers, which are not directly connected.

Case 4 is just combining every method of back propagation together.

Training and Results ( All Cases )

Left Image → Case 1: Cost Over Time
Right Image → Case 1: Performance on Test Set

The learning rate seems bit high since, the cost over time wobbles a lot. However the network seem to generalize well, since it did well on the test set.

Left Image → Case 2: Cost Over Time
Right Image → Case 2: Performance on Test Set

Case 2’s cost over time graph doesn’t look that much different when compared to case 1. As well as the performance on the test set.

Left Image → Case 3: Cost Over Time
Right Image → Case 3: Performance on Test Set

Case 3 was bit interesting, as seen in the cost over time graph, we can see that the cost value is bit higher when compared to other cases, and I think that is the reason why it didn’t do so well on the test set.

Left Image → Case 4: Cost Over Time
Right Image → Case 4: Performance on Test Set

The final case was best of all, not only the model had the smallest cost value at the end of training, but also did well on the test set.

Interactive Use case

After all of the model finishes their training, the program will ask for how many days have you been depressed prior to today, your age, and your gender. After, the model will recommend you the treatment with highest probability of preventing depression from reoccurring. Above is an example use case of 48 year old female, who have been depressed for 90 days.

Interactive Code

I moved to Google Colab for Interactive codes! So you would need a google account to view the codes, also you can’t run read only scripts in Google Colab so make a copy on your play ground. Finally, I will never ask for permission to access your files on Google Drive, just FYI. Happy Coding!

Please click here to access the interactive code.

Final Words

Thankfully I’m not a type of person who gets depressed a lot. And this model is NEVER intended to be taken seriously. I just wanted to implement Dilated Update Gate RNN. However, if anyone is suffering from Depression, I really hope that you get help from a professional. There is nothing embarrassing about having depression.

If any errors are found, please email me at jae.duk.seo@gmail.com, if you wish to see the list of all of my writing please view my website here.

Meanwhile follow me on my twitter here, and visit my website, or my Youtube channel for more content. I also did comparison of Decoupled Neural Network here if you are interested.

Reference

  1. Collins, J., Sohl-Dickstein, J., & Sussillo, D. (2016). Capacity and trainability in recurrent neural networks. arXiv preprint arXiv:1611.09913.
  2. Learn By Doing — Exploring a Dataset (Depression Data). (n.d.). Retrieved February 14, 2018, from http://bolt.mph.ufl.edu/2012/08/02/learn-by-doing-exploring-a-dataset/
  3. Chang, S., Zhang, Y., Han, W., Yu, M., Guo, X., Tan, W., … & Huang, T. S. (2017). Dilated recurrent neural networks. In Advances in Neural Information Processing Systems (pp. 76–86).
  4. Limiting floats to two decimal points. (n.d.). Retrieved February 19, 2018, from https://stackoverflow.com/questions/455612/limiting-floats-to-two-decimal-points
  5. F. (2018, January 26). Google Colab Free GPU Tutorial — Deep Learning Turkey — Medium. Retrieved February 19, 2018, from https://medium.com/deep-learning-turkey/google-colab-free-gpu-tutorial-e113627b9f5d
  6. How to get csv files using wget. (n.d.). Retrieved February 19, 2018, from https://stackoverflow.com/questions/30710102/how-to-get-csv-files-using-wget
  7. Convert pandas dataframe to numpy array, preserving index. (n.d.). Retrieved February 19, 2018, from https://stackoverflow.com/questions/13187778/convert-pandas-dataframe-to-numpy-array-preserving-index
  8. Replace string/value in entire dataframe. (n.d.). Retrieved February 19, 2018, from https://stackoverflow.com/questions/17142304/replace-string-value-in-entire-dataframe
  9. Seo, J. D. (2018, February 15). Only Numpy: Dilated Back Propagation and Google Brain’s Gradient Noise with Interactive Code. Retrieved February 19, 2018, from https://hackernoon.com/only-numpy-dilated-back-propagation-and-google-brains-gradient-noise-with-interactive-code-3a527fc8003c
  10. Seo, J. D. (2018, January 18). Only Numpy: Implementing “ADDING GRADIENT NOISE IMPROVES LEARNING FOR VERY DEEP NETWORKS” from… Retrieved February 19, 2018, from https://becominghuman.ai/only-numpy-implementing-adding-gradient-noise-improves-learning-for-very-deep-networks-with-adf23067f9f1
  11. National Institutes of Health. (n.d.). Retrieved February 19, 2018, from https://www.nih.gov/

Source: Deep Learning on Medium