Deep Learning con Fastai V2 : Clasificador de cáncer de piel

Original article can be found here (source): Deep Learning on Medium

Deep Learning con Fastai V2 : Clasificador de cáncer de piel

En este blog aprenderás a crear algoritmos de clasificación utilizando técnicas de Deep Learning(Redes Neuronales Convolucionales) con la libreria fastai V2 para clasificar imágenes con cáncer de piel(melanoma y otras lesiones en la piel).

Contenido

  • Introducción
  • Conjunto de Datos
  • Exploración de Datos
  • Procesado de Datos
  • Modelo
  • Aprendizaje del modelo
  • Evaluación

Introducción

El Deep Learning(DL) es una de las ramas de la Inteligencia Artificial que ha tenido un gran impacto en la industria y en el mundo, debido a su capacidad para resolver problemas complejos. Los algoritmos de DL se basan en las redes neuronales artificiales(RNA), las cuales son modelos matemáticos que se encargan de aprender patrones en los datos. El DL actualmente es mas como un “arte” que una ciencia debido a que no siempre sabes si tienes el modelo correcto, los suficientes datos, y que hacer al respecto si no esta funcionando correctamente; en su mayoría el DL es experimental y requiere de mucha prueba y error para resolver sus problemas.

Las RNA reciben como entrada, datos(e.g. imágenes, texto, audio, etc.) →estos se procesan por una capa de “neuronas”→ estas producen una salida → que sirve de entrada para la siguiente capa y así progresivamente hasta llegar a una función que calcula el error entre la salida la RNA y los datos reales(clases). Una vez hecho esto el algoritmo toma la función de error y va modificando los parámetros iterativa-mente hasta encontrar los parámetros que contengan el mínimo error de la función y así aumentando su exactitud, a este proceso se le llama entrenamiento o aprendizaje.

Su gran ventaja contraria a los algoritmos comunes de aprendizaje de máquina es que estos pueden aprender funciones más complejas y eliminan parte del proceso de la extracción de características de los datos. También debido a su gran modularidad estas arquitecturas pueden ser fácilmente integradas a otros algoritmos. La desventaja es que estas requieren de bastante cómputo, recursos y suelen ser muy pesadas para resolver algunos problemas.

Las aplicaciones del Deep Learning varían bastante en industrias y áreas de aplicación como:

Visión por computador: clasificación de imágenes en satélites, agricultura, detección de objetos cotidianos para ayudar a personas invidentes, segmentación de imágenes para carros autónomos.

Procesamiento del Lenguaje Natural: chatbots, reconocimiento de voz, clasificación de texto, contestar preguntas, resumir documentos, encontrar significados, generación de voz, etc.

Medicina: detección de tumores en imágenes de Patologías, pneumonia en Rayos X, detección de lesiones en la piel o incluso predecir el tiempo de hospitalizacion de un paciente.

Biología: clasificación de proteínas, generación de moléculas, muchas tareas genéticas como la secuenciacion y clasificación de variantes genéticas, etc.

Sistemas de Recomendación: sistemas de recomendación de películas (Netflix), de compras (Amazon), cursos educativos y mucho mas.

Robotica: manipulación de compleja de objetos, robots o drones autónomos, sistemas de navegación, prótesis que predicen el siguiente movimiento.

Otras aplicaciones: generación de imágenes, super resolución, detección de fraudes, clasificación de audio, etc.

Una de las areas que atacaremos esta vez sera salud, desarrollando un algorimo de clasificacion de imagenes utilizando redes neuronales convolucionales(CNNs) que sea capaz de distinguir distintas clases de cáncer de piel como el melanoma.

Para el desarrollo de este programa, programamos en Python, utilizando la libreria fastai V2 la cual hace que sea mas sencillo el desarrollo debido a su gran flexibilidad y rapidez.

Aquí se encuentra nuestro codigo del programa, este se desarrollo en colaboratory. Una plataforma de google que ofrece computación con CPUs/GPUs/TPUs de forma gratuita y que se ha convertido en una de las herramientas de desarrollo principales en Deep Learning.

En el notebook del github, aqui abres el colab para correr el programa

Este tutorial es para personas que quieran aprender Deep learning, que tengan bases de programación y un hambre enorme por aprender.

Nota: El algoritmo creado es de propósito meramente educativo y no puede utilizarse bajo ninguna circunstancia para propósito medico.

Conjunto de Datos

Los datos que utilizaremos son extraídos del Archivo ISIC, el cual contiene una gran colección de imágenes de lesiones en la piel capturadas en la dermatoscopia. Este contiene cerca de 13,000 imágenes dermatoscopicas, las cuales fueron recolectadas por clínicas alrededor del mundo y adquiridadas por distintos dispositivos dentro de cada clínica. Esto para asegurarse de tener una muestra lo suficientemente representativa, variada y clinicamente relevante.

Estas imágenes fueron capturadas de forma privada y asegurando su calidad como su diagnostico. La gran mayoria de las imagenes fueron anotadas por expertos en cancer de piel, el resto fue extraida de los metadatos de los expedientes electrónicos en la clinica.

Acerca del Melanoma

El cancer de piel es un gran problema, acerca de 5,000,000 casos de diagnostico al año tan solo en Estados Unidos. El Melanoma es uno de los canceres de piel mas mortales, responsable por la mayoria de las muertes por cancer de piel. Se estima que en el 2015, globalmente hubieron 350,000 casos onde 60,000 resultaron en muertes. A pesar de que la mortalidad sea alta, cuando es detectado a tiempo el 95%.

Figura 1. Imagenes del conjunto de datos

Clases de lesiones del conjunto de datos

Esas son las categorías que nuestro programa clasificara.

  • MEL: “Melanoma”
  • NV: “Melanocytic nevus”
  • BCC: “Basal cell carcinoma”
  • AK: “Actinic keratosis”
  • BKL: “Benign keratosis (solar lentigo / seborrheic keratosis / lichen planus-like keratosis)”
  • DF: “Dermatofibroma”
  • VASC: “Vascular lesion”
  • SCC: “Squamous cell carcinoma”
  • UNK: None of the others / “out of distribution

Exploracion de Datos

Una parte super importante en el Deep Learning es conocer tus datos, porque no importa que tantos datos tengas sino la calidad de ellos. Mejores datos significa: un mejor modelo, que tenga una muestra representativa de la población, variada, que tenga el menor sesgo posible y que los datos sean lo mas cercano a la vida real.

  1. Instalar las librerías o dependencias.
instalar fastai2

2. Descargar los datos:

3. Después comenzaremos por importar las librerías que utilizaremos.

4. Importar las “Etiquetas” o valores reales de los datos:

En este caso cada imagen ISIC_0000.jpg tiene asignado una etiqueta(e.g MEL ) que corresponde a una de las 8 clases del conjunto de datos.

En esta parte observamos las clases que pertenecen a cada imagen con un 1.0 si pertenece. A esto se le llama etiquetado o ground truth

5. Visualizar imágenes

Dentro del codigo del programa podras observar como visualizamos los metadatos de las imagenes junto con ellas, como esta:

Procesado de datos

El procesado de datos es una de las tareas más rigurosas debido a que el algoritmo depende enormemente de los datos para poder realizar un buen trabajo.

  1. Extraer archivos y directorios
  • Fastai utiliza un metodo get_image_files(path) para extraer una lista de los archivos de imágenes del directorio o path. El método get_image_files(train_path) obtienes una lista de los archivos:

ISIC_0000.jpg, ISIC_0001.jpg,ISIC_0002.jpg...

  • Una vez hecho esto creamos una funcion get_labels para extraer las etiquetas de los datos, esta captura el nombre del archivo (ISIC_000.jpg) del dataframe ground_truthy retorna su etiqueta (MEL). partial es una función que retorna una función pero menos argumentos.

2. Adquisición de datos y procesado

items = get_image_files(train_path)
splitter = RandomSplitter()(items)
  • items lista con los archivos de imagen *.jpg
  • RandomSplitter() es una función que divide la lista de archivos items en un conjunto de entrenamiento con 80% de los datos y otro de validación con 20%, aleatoriamente.
item_img_tfms= [ImageResizer(224), ToTensor]
gpu_tfms = [IntToFloatTensor, *aug_transforms(flip_vert=True)]
tfms= [[PILImage.create], [get_labels, Categorize]]
  • tfms es una lista de listas que contiene las transformaciones de los archivos items donde la primera lista contiene las transformaciones para las imagenes como PILImage.create que se encarga de importar la imagen de un archivo, la segunda lista se encarga de procesar las etiquetas empezando por la funcion get_labels y seguido por Categorize que se encarga de asignar un numero a cada clase (e.g. {clase1: 0, clase2:1,clase3:2, clase4:3} y retorna una lista(e.g yb=[1,0,3,2]).
  • item_img_tfms : es una lista de las transformaciones que se le harán a las imágenes estén en el formato correcto. Empezando por ImageResizer() que se encarga de cambiar el tamaño de la imagen original a uno de 224×224, seguido por ToTensor() que convierte la imagen tipo PILImage a un tensor.
  • gpu_tfms es una lista de las transformaciones que se haran en el GPU el cual procesa varias imagenes en paralelo lo que hace que esas operaciones se realizen ahi. IntToFloatTensor() simplemente convierte el batch de datos de int a float. La funcion aug_transforms se encarga de realizar una serie de transformacionesa las imagenes, a estas se les llama aumentacion de datos. Estas son trasformaciones de los datos como Rotacion de la Imagen, Zoom, Cambio de color , contraste, brillo, agregar ruido, cortar la imagen, etc. Estas variaciones en cada imagen nos permiten agregar mas datos, a nuestro conjunto y el hecho de tener mas datos y mas variados hace que nuestro modelo sea mas robusto, mas efectivo y aprenda mejor de nuevos datos nunca antes vistos.
dsrc = Datasets(items, tfms, splits= splitter)
data = dsrc.dataloaders(bs=64, num_workers=4, after_item=item_img_tfms, after_batch=gpu_tfms, device=default_device())
  • Datasets es una clase que toma los archivos items y aplica las transformaciones de datos tfms a los conjuntos de entrenamiento y validación.
  • Una vez hecho esto el atributo .dataloaders se encarga aplicar las trasformaciones necesarias (after_item y after_batch) a las imágenes para que estén en el formato correcto para entrenar la red neuronal. Esta toma porciones de los datos(bs=64), a eso se le llama batch size(e.g. Toma 64 imágenes de 2000, cada ciclo de aprendizaje) este sirve para que el algoritmo sea mas rápido de entrenar, entre mayor bs mayor sera el computo y memoria consumida, pero no puede ser muy poco porque este puede hacer el entrenamiento inestable.

Modelo

Los modelos de Deep learning usan redes neuronales para aprender patrones en los datos. Las Redes Neuronales Convolucionales(CNNs) son las que utilizaremos para clasificar las imagenes. En especifico, para nuestro modelo utilizaremos la arquitectura(que tipo de modelo CNN utilizara) Resnet34, esta tomara como entrada las imagenes y les asignara un “peso” o parametros; estos parametros son los que cambiaran o actualizaran iterativamente por el algoritmo de optimizacion para aumentar el rendimiento del modelo. El algoritmo de optimizacion que utilizaremos es ADAM este se encargara de encontrar los parametros que contengan el minimo error de la funcion de error “Cross Entropy”. Para nuestro modelo utilizaremos la CNN(ResNet34) pre-entrenada, que quiere decir que ha sido entrenada antes con otro conjunto de datos y ahora utilizaremos esa base de conocimiento para que modificar(fine-tuning) nuestra CNN y que esta aprenda mas rapido con nuevos datos, que bien pueden ser similares o distintos al conjunto de datos; a esta tecnica se le llama Transfer Learning.

La funcion cnn_learner se encarga de tomar los datos(data), con la arquitectura de CNN(Resnet34), una funcion de error(loss_func: Cross Entropy) y un algoritmo de optimizacion(ADAM) para poder entrenar la red neuronal. Las metricas son funciones que sirven para medir la calidad de las predicciones de nuestro modelo en el conjunto de validacion; la metrica que utilizaremos sera la exactitud (accuracy).

learner = cnn_learner(data, resnet34, metrics=[accuracy])

Hay parametros que gobiernan a nuestros parametros y tienen mucha influencia en el entrenamiento, estos son los hyper-parametros. El learning rate es un valor(hyper-parametro) mas importante, el cual determina que “tanto” avanzara el algoritmo de optimizacion

El metodo .lr_find() se encarga de encontrar el learning rate mas optimo para nuestro modelo. Lo que hace es que entrena sobre un batch del conjunto entrenamiento mientras incrementa 10x el learning rate con respecto al cambio de la funcion de error. De esta forma elejimos al LR mas optimo como aquel que tiene la mayor pendiente(1e-2) lo cual indica un mayor cambio.

learner.lr_find()
Curva del Learning Rate

Entrenamiento del modelo

El entrenamiento es el proceso de aprendizaje del modelo. Este se encarga de encontrar los parametros de forma iterativa, que maximicen el rendimiento y reduzcan el error. El proceso basico de entrenamiento tomas tu batch de datos(xb,yb) , propagas el batch imagenes xb por el modelo, el cual realizara una prediccion preds.Estas predicciones seran evaluadas por una funcion de error y comparadas con las etiquetas reales yb. Despues calculamos que tanto cambia la funcion de error con respecto a los parametros y actualizamos los parametros para que nuestras predicciones se acerquen mas y mas a nuestras etiquetas reales.

A continuacion describire el codigo del proceso de entrenamiento basico.

  • epoch: paso completo por todo el conjunto de datos
  • xb, yb: batch de imagenes [bs, 224,224,3] y etiquetas [bs]
  • model() arquitectura ResNet34
  • loss: funcion de error categorica Cross Entropy
  • loss.backward() esta calcula que tanto cambia la funcion de error con respecto a los parametros, a esto se le llama backpropagation.
  • optim.step(): actualizacion de los parametros
for epoch in epochs:
for xb, yb in data:
preds = model(xb)
loss = loss_function(preds, yb)
loss.backward()
optim.step()
optim.zero_grad()

En fastai el metodo fit_one_cycle(epochs, learning_rate) se encarga de entrenar al algoritmo y aplica fine-tuning donde solo entrenan los parametros de la cabeza que se le agrego a la base de conocimineto de la CNN y se congelan(freeze ) los parametros de la base.

learner.fit_one_cycle(10, 1e-2)

Resultado de 10 epochs con un LR de 1e-2:

Ahora haremos uso de una tecnica de Transfer Learning con el metodo unfreeze() el cual hace que todos los parametros de la CNN se actualizen. Esto permite que el modelo aprenda datos mas detallados.

Volvemos a entrenar, con diferentes valores del LR

Al terminar esta session de entrenamiento obtuvimos una exactitud del 81%, pero es necesario evaluar mas rigurosamente para determinar la validez de nuestro modelo.

OJO Este resultado puede ser mucho mejor, todo depende de los experimentos, pruebas con diferentes arquitecturas, hyper-parametros, etc. Esta en ti, hacerlo mejor!

Evaluación

La evaluación es una de las fases mas importantes ya que determina si estamos cerca de cumplir nuestro objetivo. Aquí observamos los resultados de nuestro modelo en el conjunto de validación. La etiqueta de arriba es la correcta y la de abajo es la predicta por nuestro modelo.

Ahora evaluaremos nuestro clasificador con una matriz de confusion. Esta se encarga de evaluar los valores de las predicciones y los de las categorias reales. La matriz muestra los Falsos Positivos y Falsos Negativos, estos son las clases donde se “confunde” el modelo. La diagonal demuestra los valores correctos.

interp = ClassificationInterpretation.from_learner(learner)
interp.plot_confusion_matrix(figsize=(10,8))

Aqui observamos que la clase NV es la que mas valores correctos predijo el modelo, en cambio las demas demuestran muchos falsos positivos y negativos como VASC o DF. En gran medida se debe a que hay mas casos de una clase que de otra o un desbalance de clases.

Para evaluar mas efectivamente a un conjunto de datos con clases desbalanceadas se utiliza la curva de Precision-Recall. Esta nos dara una idea mas clara de como se comporta el modelo con cada clase.

Conclusión

En conclusión desarrollamos un algoritmo de clasificación de imágenes básico, el cual se dio la tarea de clasificar cáncer de piel en imágenes con una exactitud del 83%, la cual puede ser fácilmente mejorada si se sigue experimentando. Por ultimo, aprendimos las bases del Deep Learning con la librería fastai V2, manejo de datos como Imágenes, el proceso de entrenamiento y evaluación de modelos para la clasificación.