How to read tfrecords files in PyTorch !

Original article was published by Soumo Chatterjee on Deep Learning on Medium

How to read tfrecords files in PyTorch !

Step 1 → First of all you need to know what are the contents of your data . For understanding , I am going to use the kaggle data for classifying 104 flower classes . So, there are 4 folders , having 3 sub-folders for training , validation and testing images each . So , I will be using only training and valdation sub-folders having .tfrec files for our presentation here . NOTE: It has been given that each .tfrec fie in the sub-folder contains the id, label (the class of the sample, for training & validation data) and img (the actual pixels in byte string format).

Sample picture of our dataset

Step 2 → We will use the glob library to grab the file with their path names in training and validation sub-folders

import globtrain_files = glob.glob(‘/kaggle/input/tpu-getting-started/*/train/*.tfrec’)val_files = glob.glob(‘/kaggle/input/tpu-getting-started/*/val/*.tfrec’)

Let’s see what is loaded in train_files variable


Step 3 → Now, we will be collecting the ids , filenames and images in bytes in three different list variables for training & validation files

# importing tensorfow to read .tfrec files
import tensorflow as tf

importing tensorflow first

for i in train_files:
train_image_dataset =
train_feature_description = {
‘class’:[], tf.int64),
‘id’:[], tf.string),
‘image’:[], tf.string),

Then for each complete file path we are creating a dictionary describing the features that are class , id and images in byte string .

def _parse_image_function(example_proto):
return, train_feature_description)
train_image_dataset =

then creating a function to parse the input tf.Example proto using the dictionary (train_feature_description)

train_ids = [str(image_features[‘id’].numpy())[2:-1] for image_features in train_image_dataset] train_class = [int(image_features[‘class’].numpy()) for image_features in train_image_dataset]train_images = [image_features[‘image’].numpy() for image_features in train_image_dataset]

Finally , storing the features in 3 different list . You can also create a dataframe using these list for your ease . NOTE: [2:-1] is done to remove b‘ from 1st and from last in train-id names. We can also do the same for our validation .tfrec files

For testing we can do this.

import IPython.display as displaydisplay.display(display.Image(data=train_images[211]))
output for above code

Step 4 → Finally we will be creating our pytorch dataset class with our features extracted i.e. train_ids, train_class , train_images.

Making some imports first

from PIL import Image
import cv2
import albumentations
import torch
import numpy as np
import io
from import Dataset

Here we are using albumentations library for transformations and then defining our dataset class

For dry run testing,

Creating object for our FlowerDataset Class →

output for the above code

Now we can load this train_dataset & val_dataset objects of FlowerDataset directly into pytorch data loaders.

This is the way we can read .tfrec files in pytorch . Let me know if you have any question , comment or , concerns in comments . Thanks for reading and until then enjoy learning.