Image clustering using Transfer learning

Source: Deep Learning on Medium

Resnet50 + Kmeans based image clustering model for dogs and Cats !!!!

Clustering is an interesting field of Unsupervised Machine learning where we classify datasets into set of similar groups. It is part of ‘Unsupervised learning’ meaning, where there is no prior training happening and the dataset will be unlabeled. Clustering can be done using different techniques like K-means clustering, Mean Shift clustering, DB Scan clustering, Hierarchical clustering etc. The key assumption behind all the clustering algorithms is that nearby points in the feature space, possess similar qualities and they can be clustered together.

In this article, we will be doing a clustering on images. Images are also same as datapoints in regular ML and can considered as similar issue. But the Big question is,

Define Similarity of Images !!!!!

Similarity may mean to be similar looking images or may be similar size or may be similar pixel distribution, similar background etc. For different use cases, we have to derive specific image vector. ie, The image vector containing the entity of an image(contains cat or dog) will be different to an image vector having pixel distributions.

In this article we will be having a set of images of cats and dogs. We will try to cluster them into cat photos and dog photos. For this purpose, we can derive the image vector from a pretrained CNN model like Resnet50. We can remove the final layer of the resnet50 and pull the 2048 sized vector. Once we have the vectors, we apply KMeans clustering over the datapoints.

So, here are some the pictures in my dataset, having around 60 images of dogs and cats randomly pulled from net.

Code Walk Through

First step is to load the required libraries and load the pretrained Resnet50 model. Keep in mind to remove the last softmax layer from the model.

resnet_weights_path = '../input/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'

my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))

# Say not to train first layer (ResNet) model. It is already trained
my_new_model.layers[0].trainable = False

Once we loaded the model, we can have a function to load all the images , resize images into the fixed pixel size (224,224) , pass it through the model and extract the featureset.

def extract_vector(path):
resnet_feature_list = []

for im in glob.glob(path):

im = cv2.imread(im)
im = cv2.resize(im,(224,224))
img = preprocess_input(np.expand_dims(im.copy(), axis=0))
resnet_feature = my_new_model.predict(img)
resnet_feature_np = np.array(resnet_feature)

return np.array(resnet_feature_list)

Once we have the extracted feature set, we can do KMeans clustering over the datset. K have to be decided prior, Or we can plot the loss function vs K and derive it. As we know the value of K as 2, we can directly substitute it.

kmeans = KMeans(n_clusters=2, random_state=0).fit(array)

Thats all !!!! we are done with our image clustering model. Lets see, how good our model can cluster the images.

Below are some of the images corresponding to first cluster :

And here are the other cluster :

Overall the cluster performance seems very good. Out of 60 images that i clustered, only two images were wrongly clustered. Here are those images :

The above two dogs were wrongly clustered as cats. May be the ML model felt them to be very similar to cats. :)

We can further investigate on the distribution of the images using t-SNE algorithm. It is a type of dimensionality reduction algorithm, where the 2048 image vector will be reduced to smaller dimensions for better plotting purposes, memory and time constraints. Below are the result that i got for the 60 image dataset.

Blue dots represent cluster-1 (cats) and green dots represent cluster-2 (dogs). Please note that the mini photos are not part of t-SNE and it is just extra added. The intersection area can be considered as where the model found its difficult to fit the clustering properly.


Hope you have a good understanding of building a basic image clustering method using transfer learning. As i already said, in some situations, the CNN output may not be the best choice for image features. We can also consider HSV(Hue-Saturation-Value) with bagging technique also, to create vectors, where similar pixel distribution is our means of clustering.

Happy Learning :)