Entendiendo nuestro modelo de DeepLearning: Stiffness

Source: Deep Learning on Medium


Mucho esfuerzo se ha invertido últimamente en entender cómo funcionan las redes neuronales profundas. Y es que, mientras más responsabilidades le damos a nuestros modelos en atender situaciones de la vida diaria de manera automática, más nos sentimos en la necesidad de entender por qué estos algoritmos funcionan de la manera que funcionan.

Una de las preguntas que todavía no se ha podido responder de manera clara es cómo los modelos de Redes Neuronales logran generalizar los datos de entrenamiento de manera tan exitosa, y adicionalmente, entender en qué momento el modelo deja de generalizar y comienza a hacer Overfitting. Y es que, podemos determinar empíricamente cuando cierta red está haciendo Overfitting de los datos, revisando que tan alta es la varianza del modelo (la diferencia entre la precisión de entrenamiento y de validación), pero no se conoce a ciencia cierta cómo se puede determinar esto de manera intrínseca en el modelo.

Para dar un poco de luces en este sentido, los investigadores Fort, Nowak y Narayanan desarrollaron un indicador llamado Stiffness que se podría utilizar para establecer que tan bien está aprendiendo a generalizar nuestro modelo, y en qué momento está dejando de hacerlo. Para entender este indicador, vamos a situarnos en el conjunto de datos de MNIST.

https://m-alcu.github.io/blog/2018/01/13/nmist-dataset/

Como sabemos, el entrenamiento de nuestro modelo de redes neuronales profundas cuenta con dos fases: el forward propagation y el backpropagation. Durante el forward propagation aplicamos las funciones de activación, pesos y bias sobre cada una de las observaciones en nuestro conjunto de entrenamiento, esto nos genera la predicción del modelo sobre esa observación (que numero dice el modelo es la foto observada). Al comparar esta predicción con la clasificación real de la observación vamos a obtener un valor para la pérdida de ese paso (una predicción acertada tendrá valores bajos, mientras una predicción incorrecta tendrá valores altos). Luego, en el backpropagation utilizamos la gradiente de la función de pérdida con respecto a cada uno de los parámetros para actualizar dichos valores. Es decir, vamos a actualizar todos los parámetros de nuestra red en la dirección contraria de la gradiente de la función de pérdida, esto con la finalidad de minimizar dicha función para todas las observaciones.

Como ya vimos, individualmente cada observación cuenta con un valor de pérdida (que tanto se equivoco el modelo categorizando) y un gradiente de dicha función en relación a los parámetros (hacia donde hay que mover los parámetros para minimizar el valor de la pérdida). Teniendo esto en cuenta, cabría revisar cómo el gradiente para una observación particular se relacion con los gradientes de observaciones que son nuevas para el modelo (va en la misma dirección? van en direcciones contrarias? son totalmente perpendiculares?). Esta relación es la que los investigadores llaman Stiffness y nos permite determinar que tanto contribuye la gradiente de una observación la minimización de valores de costo de observaciones que la red todavía no ha visto.

Stiffness: A New Perspective on Generalization in Neural Networks

Formalmente podemos definir el Stiffness de la siguiente manera: dado el gradiente de la función costo de una observación del conjunto de entrenamiento X1 que llamaremos g1 y dada una observación X2 del conjunto de validación llamaremos Stiffness a la relación entre la disminución de del valor del costo de X1 y X2. Donde el costo de X2 disminuye con una cambio en los parámetros en la dirección de g1 vamos a considerar un Stiffness positivo (el modelo usa lo aprendido en X1 para mejorar su predicción de X2). Donde el costo de X2 aumenta con un cambio de los parámetros W en dirección g1 vamos a considerar un Stiffness negativo (el modelo se comporta peor en X2 con lo aprendido en X1). Y de no haber cambios en X2 consideraremos un Stiffness igual a cero (como lo muestra la figura).

La propuesta resulta interesante ya que podemos ver de manera directa como nuestra red neuronal no solo aprende para el conjunto de entrenamiento, sino que extiende ese conocimiento al conjunto de validación. Podríamos teóricamente suponer que en el momento en el que el Stiffness comience a disminuir estamos frente a una red cuyos pasos de entrenamiento no están aportando nada a predecir mejor datos del conjunto de validación. Esta idea la es puesta a prueba en el trabajo de investigación resultando en la siguiente gráfica:

Stiffness: A New Perspective on Generalization in Neural Networks

En la segunda figura podemos ver un gráfico de Pérdida vs Observaciones entrenadas, el conjunto de entrenamiento en azul y el de validación en verde. Como hablábamos al principio, esta gráfica nos permite determinar cuando el modelo empieza a hacer Overfitting (justo en la línea punteada amarilla, cuando la línea verde y azul empiezan a separarse). En la figura de arriba, podemos ver los valores de Stiffness para cuando X1 y X2 son de la misma clase (en rojo), cuando son de clases diferentes (en azul) y el promedio (en verde). Es bastante claro como luego de la línea amarilla los valores de Stiffness empiezan a caer, demostrando que el modelo dejó de generalizar el aprendizaje de sus observaciones de entrenamiento a las de validación, y por ende, haciendo Overfitting.

Aun cuando en la gráfica anterior estamos viendo el desarrollo del Stiffness solo a través del tiempo de entrenamiento, podemos hacer un poco más de detalle y ver cómo afecta a las diferentes clases en específico. Digamos, tomando a X1 como una clase fija (por ejemplo, un 1) y viendo el desarrollo del Stiffness para cada una de las otras clases. De esta forma podremos ver como el modelo generaliza de una clase a cada una de las demás. Para eso los investigadores construyeron una matriz donde se tiene la clase de X1 en la parte de abajo y la clase de X2 en la parte lateral. Los valores más altos de Stiffness se ven en rojo y los más bajos en azul. Como se puede observar en la figura para la iteración 0 la diagonal (Stiffness de la misma clase) muestra valores cercanos a 1. En los valores inter clase se pueden apreciar distintos rangos, desde los más elevados (clase 2 con 8) que son la minoría, hasta los menores (clase 8 con 6).

Stiffness: A New Perspective on Generalization in Neural Networks

Esta primera figura muestra el Stiffness en el momento de inicialización del modelo, en la iteración 0. Los investigadores exponen que de manera intuitiva se puede decir que en esta fase del modelo el Stiffness entre la misma clase es alta debido a que la red está aprendiendo características crudas que solo comparten los miembros de la misma clase (intensidad de los pieles, posición, etc). Pero es baja entre distintas clases ya que la red no es lo suficientemente robusta para generalizar características comunes a todas las clases. En la siguiente figura vamos a ver la misma matriz pero en una etapa temprana de optimización (iteración 800).

Stiffness: A New Perspective on Generalization in Neural Networks

En esta etapa el Stiffness entre observaciones de la misma clase sigue bastante alto, pero comenzamos a ver como el Stiffness entre observaciones de distintas clases comienza a subir de manera importante (clases 0 con 1 y 2 con 4). En esta etapa de la optimización los investigadores intuyen que se están aprendiendo características comunes para todas las clases, donde por cada observación que se entrena, el desempeño mejora de manera universal. En la última figura podemos ver dos etapas más en el proceso de entrenamiento (iteracion 638176 y 5627552).

Stiffness: A New Perspective on Generalization in Neural Networks

En el gráfico de la izquierda se puede observar como el modelo sigue generando características comunes para todas las clases, aunque también podríamos intuir que la red empieza a entender no solo como identificar la clase de una observación con base en sus características, sino también a identificar que una observación NO es parte de una clase (si el modelo inicialmente confunde un 3 por un 2, entendiendo mejor como se ven los 2 va a dejar de clasificar el 3 de esa manera). Por el contrario, en el gráfico de la derecha vemos como el modelo ya empieza a hacer Overfitting, tendiendo a 0 en todos los renglones de Stiffness.

Las futuras lineas de investigación que dejan los autores son las siguientes:

  1. Estudiar el Stiffness para distintas arquitecturas y poder estudiar como una arquitectura puede ser mejor que otra para generalizar.
  2. Utilizar el Stiffness para determinar como el orden en que se exponen los datos de entrenamiento al modelo hacen que este tenga un mejor resultado.

De mi parte, creo que puede ser interesante estudiar lo siguiente:

  1. Como las distintos métodos de regularización impactan el Stiffness de un modelo.
  2. Creación de un callback de Keras que nos permita monitorear el Stiffness de manera automática.