How to use the SimpleImputer Class in Machine Learning with Python

Original article was published by Stephen Fordham on Artificial Intelligence on Medium

How to use the SimpleImputer Class in Machine Learning with Python

Simply use SimpleImputer

Image Courtesy of Unsplash via Ross Sneddon

Missing Value Imputation

Datasets often have missing values and this can cause problems for machine learning algorithms. It is considered good practise to identify and replace missing values in each column of your dateset prior to performing predictive modelling. This method of missing data replacement is referred to as data imputation.

Missing values in a dataset can arise due to a multitude of reasons. These commonly include, but are not limited to; malfunctioning measuring equipment, collation of non-identical datasets and changes in data collection during an experiment.

A convenient strategy for missing data imputation is to replace all missing values with a statistic calculated from the other values in a column. This strategy can often lead to impressive results, and avoids discarding meaningful data when constructing your machine learning algorithms. Commonly used statistics include calculating either the mean, median and mode for the column and imputing this value for the missing values. In addition, a constant value can be imputed to complement these methods above.

This tutorial is aimed at demonstrating the usage of the SimpleImputer class for statistical imputation.

Dataset and Missing data Assessment

The dataset used in this tutorial is the “Logistic regression To predict heart disease” dataset, available via kaggle here or my github page.

The dataset is from an ongoing cardiovascular study on residents of the town of Framingham, Massachusetts. The classification goal is to predict whether the patient has 10-year risk of future coronary heart disease (CHD).The dataset provides the patients’ information. It includes over 4,000 records and 15 attributes. The attributes can be found via kaggle. My main purpose here is not to go into this dataset in depth, but rather demonstrate a use case for the utilisation of the SimpleImputer class for predictive modelling.

First import the libraries required and perform some exploratory data analysis (EDA).

The head of the dataframe and columns attribute called on the dataframe reveal all 15 data attributes.

To calculate the number of missing values and their respective percentage in relation to their column, I first iterate over the columns. For each column I then assess the number of NaN values using the isna() method call, and sum these values through method chaining.

I then calculate the number of missing values over the total number of rows to get a percentage of missing values for each column and print out this data.

While this is informative, it requires some reading to fully interpret the results. A better way to determine missing results may be to produce a useful graphic using a seaborn heatmap.

We first pass in the dataframe.isna(), which return True if a missing value is present. We can then set both cbar and ytickslabels to False, choose a colour mapping (here, the string viridis), and visualise the results.

Now we have a quick reference graphic which usefully informs us, which columns are missing data. Here, we can see that the column glucose has the most number of missing data values.

Statistical Imputation with the SimpleImputer Class

The sci-kit learn machine learning library provides the SimpleImputer class which implements statistical imputation.

To use SimpleImputer, first import the class, and then instantiate the class with a string argument passed to the strategy parameter. For clarity, I have included ‘mean’ here, which is the default and therefore not necessary to explicitly include.

I convert the dataframe into a numpy array by calling values on my dataframe (not necessary), but a habit I prefer. I then select my features and assign them to the variable X, and select my target variable and assign it to the variable y.

The imputer is fit on the dataset to calculate the statistic for each column. The fit imputer is then applied to the dataset to create a copy of the dataset with all the missing values for each column replaced with the calculated mean statistic.

To confirm that data imputation has worked, we can assess the number of missing values on the dataset with and without the data transform applied. When the data transform is applied 645 column values have been imputed.

Model Evaluation

Model evaluation is best performed using repeated k-fold cross-validation. It is required that the calculated statistic for each column are calculated on the training dataset first, then applied to the train and test sets for each fold in the dataset.

To achieve this, we can create a modelling pipeline, where the first step is the statistical imputation, and the second step is the model itself.

The following pipeline uses the mean strategy for statistical imputation, and uses a RandomForestClassifier for making the model prediction.

We achieved 84.7% accuracy with a standard deviation of 0.007. This is a reasonable result, but it would be useful to know which imputation strategy is best for predictive modelling performance.

Test Imputation Strategies

To test different imputation strategies, we can iterate through the strategies. For each pipeline we can construct a new strategy, calculate the cross val scores, and append them to the results list.

We can use the arrays in the results list and the strategy type to produce a box and whisker plot to decipher which data imputation strategy is optimal. I have decided to check both the mean scores and the max scores for each strategy type.

The box and whisker plots indicate the best accuracy scores are obtained using the constant (0) imputation strategy.

Testing on a sample

A prediction can now be made using the best imputation strategy. The pipeline is defined and fitted on all available data.

We can take the first row in our array, reshape it, assign it to the variable name sample and pass it as an argument to the predict method call, called on the pipeline.

To confirm this is the correct result, we can check the first y target variable label. Our model predicts a 0 class and the class is in fact a 0 class too.


The SimpleImputer class can be an effective way to impute missing values using a calculated statistic. By using k-fold cross validation, we can quickly determine which strategy passed to the SimpleImputer class gives the best predictive modelling performance.

Link to Complete Jupyter Notebook

The link to the complete Jupyter notebook for this tutorial can be found here.