Serialization and Easier Cross-validation
Simple CSV Files to PyTorch Tensors Pipelinetowardsdatascience.com
In Part I we’ve discussed how to load text dataset from csv files, tokenize the texts, and put them into tensors via torchtext. Now we’re going to address two issues in that solution (still using the Toxic Comment dataset):
- Loading time is too long: it’ll waste you a lot of time reloading the dataset every time you want to run a new experiment.
- Limited Choices of Cross-validation Methods: actually there was only one available choice — a random split that can be controlled by seed and VAL_RATIO parameters.
First of all,
TabularDatasetis unfortunately not directly serializable. We start the search of alternatives from the observation that in the
TabularDataset read the files into a list of Examples:
with io.open(os.path.expanduser(path), encoding="utf8") as f:
examples = [make_example(line, fields) for line in f]
The next observation is that
Dataset, the superclass of
TabularDataset, accepts a parameter
examples (a list of Examples). So it becomes clear now what we need is to serialize examples from
TabularDataset instances and create
Dataset instances upon request. The bonus point is to serialize the
comment Field instance as well.
To be more specific, the following is a general workflow:
comment = data.Field(...)
train = data.TabularDataset(...)
test = data.TabularDataset(...)
return train.examples, test.examples, comment
def restore_dataset(train_examples, test_examples, comment):
train = data.Dataset(...)
test = data.Dataset(...)
return train, test
The first two returned variables are the essential components for rebuilding the datasets. You can refit a
comment Field instance if you want, but it’ll be faster if you don’t. Simply insert
comment as one of the fields when initializing a dataset.
Since now we create dataset instances from a list of Examples instead of CSV files, life is much easier. We can split the list of Examples in whatever ways we want and create dataset instances for each split. For classification tasks, I’d usually prefer stratified K-Fold validation. But because the Toxic Comment dataset is multi-label, it’s harder to do stratification. We’ll use simple K-Fold validation in following sections.
Putting It Together
Please refer to the end of the post for the complete code. Here are some comparisons between the new solution and the previous one in Part I:
- We use exactly the same tokenizer as before.
- Despite not having to create train/validation split, we still need a simplified
prepare_csvfunction to remove \n characters from the raw CSV files.
- The need of explicit serialization via
picklemodule can be eliminated by this neat library
joblib. One of its feature is transparent and fast disk-caching of output value, which can be achieved by just creating a caching setting as
MEMORYand use the
MEMORY.cachedecorator on any functions you want to cache (in this case, the function
read_fileswhich reads in the CSV files and returns two list of examples and a Field instance).
- The main function
get_datasetnow returns a generator and a test dataset. The generator gives a train dataset and a validation dataset for each iteration, and achieves a K-Fold validation after you run through all available K iterations.
The following is an example of a script training 5 models under a 5-Fold validation scheme:
train_val_generator, test_dataset = get_dataset(
fix_length=100, lower=True, vectors="fasttext.en.300d",
for fold, (train_dataset, val_dataset) in \
# Initialize a model here...
for batch in get_iterator(
train_dataset, batch_size=32, train=True,
# Train the model here...
# Create prediction for val_dataset and get a score...
# Create inference for test_dataset...
In case you forget, here’s an example of extracting features and targets:
x = batch.comment_text.data
y = torch.stack([
batch.toxic, batch.severe_toxic, batch.obscene,
batch.threat, batch.insult, batch.identity_hate
It depends on the power of your CPU and the read speed of your disk. In my computer
get_dataset takes 6+ minutes the first time it was called, and around 1 minute after that.