Simple Scaleable Preprocessing With Pytorch and Ray

Background

I have been using PyTorch for a few months now and I really like the Dataset and DataLoader workflow (see torch.utils.data ). I realized I might be able to use this workflow for every step in my Machine Learning pipeline, i.e. preprocessing, training, and inference. I further realized I could use Ray to coordinate multi-node parallelism with little changes to my original code.

Escape Hatch: if you would rather explore the code with no explanation there is a Jupyter Notebook on Github

I believe most folks are using Dataset/DataLoader to handle training and inference pipelines but let's consider a more general preprocessing workflow. A data scientist needs to write a function which processes their entire data set, the function has the approximate signature:

InputFile -> (OutputFiles, Metadata)

Here, InputFile is an input file in your dataset. The function may produce one, or more, OutputFiles and some Metadata related to the operation performed. As a practical example, I often have to split large audio files into multiple audio files of a fixed size and retain some metadata (source audio, destination audio, labels).

In this blog post, I'll discuss how to get PyTorch's DataSet and DataLoader workflow running in parallel for this general use case. I will also go over some of the mistakes I made while first exploring this workflow. I will assume the reader knows basic Python.

Why should you care?

I believe this workflow is really easy to teach to beginners. A user only needs to know how to write a function to process an input file and the relationship between batches and parallelism. With the exception of the collate_fn (explained later) the code is essentially boilerplate. If you can implement a Dataset the parallelism comes almost for free which is a massive win for beginners.

Up and Running

I am going to build an example data set which mimics the audio splitting example I introduced. I will have a dataset.csv file which contains the following:

input
a.txt
b.txt
c.txt
d.txt

Each TXT file will contain a word (simple, scaleable, preprocessing, and pytorch respectively). The files will be located in an inputs/ directory. The goal is to split each word into parts of a certain number of characters and overlap, e.g.

a = "hello"
b = split_word(a, num_chars=2, overlap=1)
assert b == ["he", "el", "ll", "lo"]
c = split_word(a, num_chars=3, overlap=2)
assert c == ["hel", "ell", "llo"]

We can build a Dataset which performs this action on all of the input files. First, let's generate a list of input files. I'll use the built-in CSV library:

import csv

with open("dataset.csv", "r") as csv_file:
    reader = csv.DictReader(csv_file)
    input_files = [f"inputs/{row['input']}" for row in reader]

assert input_files == ["inputs/a.txt", "inputs/b.txt", "inputs/c.txt", "inputs/d.txt"]

To use Dataset, you'll need PyTorch (e.g. pip3 install torch==1.5.0)

from torch.utils.data import Dataset

class WordSplitter(Dataset):
    def __init__(self, inputs, num_chars=2, overlap=1):
        self.inputs = inputs
        self.num_chars = num_chars
        self.overlap = overlap

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        filename = self.inputs[idx]

        with open(filename, "r") as f:
            word = f.read().strip()

        return split_word(
            word,
            num_chars=self.num_chars,
            overlap=self.overlap
        )

For the Dataset to work, we need to define 3 "dunder" methods __init__, __len__, and __getitem. The __init__ function stores the input files and parameters needed to run split_word. The __len__ function returns the length of input_files. The __getitem__ function is where the computation happens. First, we extract the file at the given index. Second, we read the word from the file and remove any whitespace sorrounding the word. Finally, we feed our word to split_word with the appropriate parameters. Let's see if it works:

word_splitter = WordSplitter(input_files, num_chars=3, overlap=2)
assert word_splitter[0] == ['sim', 'imp', 'mpl', 'ple']

Awesome. It is really important to make sure your Dataset works before moving on to the next steps. Remember our signature from before:

InputFile -> (OutputFiles, Metadata)

Think of the __getitem__ method in WordSplitter as inputting an InputFile, not writing any OutputFiles, and producing Metadata related to the operation. In the realistic audio splitting example the OutputFiles could be written to an outputs/ directory. We can now wrap this into a DataLoader and run our analysis in parallel!

from torch.utils.data import DataLoader

loader = DataLoader(
    word_splitter,
    batch_size=1,
    shuffle=False,
    num_workers=len(word_splitter),
)

The DataLoader bundles our work into batches to be operated on. The DataLoader takes in the word_splitter Dataset object we initialized previously. When we set batch_size=1, the loader will split our work into 4 total batches where each batch contains 1 file (batch_size=2 means 2 batches each with 2 files). With 4 batches it is possible to split the work over 4 cores on our machine by setting num_workers=len(word_splitter). Important: with batch_size=4 there is only 1 batch to process and therefore no parallelism can be extracted (i.e. setting num_workers will have no effect). The shuffle=False argument asks the loader to process inputs in order (the default). The loader object behaves like other iterators, i.e. we can print the results in a for loop:

for metadata in loader:
    print(metadata)

Let's look at the output:

[('sim',), ('imp',), ('mpl',), ('ple',)]
[('sca',), ('cal',), ('ale',), ('lea',), ('eab',), ('abl',), ('ble',)]
[('pre',), ('rep',), ('epr',), ('pro',), ('roc',), ('oce',), ('ces',), ('ess',), ('ssi',), ('sin',), ('ing',)]
[('pyt',), ('yto',), ('tor',), ('orc',), ('rch',)]

Hmm... Something looks weird, each string is embedded in a tuple. The issue is PyTorch uses a collation function which is designed for their Tensor type. It doesn't work great in this case. Luckily, we can define our own to fix this! In the following code I will use ... to represent code shown above. First, we need to figure out what the input to collate_fn even looks like. Add the collate_fn to WordSplitter

 class WordSplitter(Dataset):
    ...

    @classmethod
    def collate_fn(*batch):
        print(f"BATCH: {batch}")
        return []

The @classmethod decorator allows us to call WordSplitter.collate_fn (you'll see it in a moment). I use *batch to tuple up all of the inputs if the arity is greater than one. The collate_fn isn't complete but this allows us to inspect our inputs to the function. Second, we add our new function to the DataLoader:

loader = DataLoader(
    ...,
    collate_fn=WordSplitter.collate_fn,
)

Note, you don't want to run this test over your entire data set. I would suggest doing this on a small subset of inputs. If we loop over the loader again,

BATCH: (<class '__main__.WordSplitter'>, [['sim', 'imp', 'mpl', 'ple']])
BATCH: (<class '__main__.WordSplitter'>, [['sca', 'cal', 'ale', 'lea', 'eab', 'abl', 'ble']])
BATCH: (<class '__main__.WordSplitter'>, [['pre', 'rep', 'epr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ssi', 'sin', 'ing']])
BATCH: (<class '__main__.WordSplitter'>, [['pyt', 'yto', 'tor', 'orc', 'rch']])
[]
[]
[]
[]

Let's modify batch_size=2 in the loader and see what happens when there is actual batching,

BATCH: (<class '__main__.WordSplitter'>, [['sim', 'imp', 'mpl', 'ple'], ['sca', 'cal', 'ale', 'lea', 'eab', 'abl', 'ble']])
BATCH: (<class '__main__.WordSplitter'>, [['pre', 'rep', 'epr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ssi', 'sin', 'ing'], ['pyt', 'yto', 'tor', 'orc', 'rch']])
[]
[]

Okay, so PyTorch returns something like (DatasetObject, [metadata0, metadata1, ...]). All we need to do is extract the list of metadata from the tuple and return it, i.e.

@classmethod
def collate_fn(*batch):
    return batch[1]

In the for loop we need to additionally loop over the returned list of metadata, i.e.

for metadatas in loader:
    for metadata in metadatas:
        print(metadata)

Result with batch_size=1,

['sim', 'imp', 'mpl', 'ple']
['sca', 'cal', 'ale', 'lea', 'eab', 'abl', 'ble']
['pre', 'rep', 'epr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ssi', 'sin', 'ing']
['pyt', 'yto', 'tor', 'orc', 'rch']

With batch_size=2,

['sim', 'imp', 'mpl', 'ple']
['sca', 'cal', 'ale', 'lea', 'eab', 'abl', 'ble']
['pre', 'rep', 'epr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ssi', 'sin', 'ing']
['pyt', 'yto', 'tor', 'orc', 'rch']

With batch_size=4,

['sim', 'imp', 'mpl', 'ple']
['sca', 'cal', 'ale', 'lea', 'eab', 'abl', 'ble']
['pre', 'rep', 'epr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ssi', 'sin', 'ing']
['pyt', 'yto', 'tor', 'orc', 'rch']

Heck yes, this is exactly what we want! You could easily write this metadata somewhere for further use. The key thing to remember here is that the parallelism happens over batches, in this case the maximum possible cores used with varying batch sizes:

batch_size cores
1 4
2 2
4 1

The full code is available in a Jupyter Notebook on Github. This concludes part 0. Next time we'll look into Ray and let it coordinate the Dataset/DataLoader workflow over multiple nodes!

If you have any suggestions or improvements please message me on Twitter @chiroptical or submit an issue on Github.

Edits