Fine Tuning HuggingFace Models without Overwhelming Your Memory
A journey to scaling the training of HuggingFace models for large data through tokenizers and Trainer API.
There are a lot of example notebooks available for different NLP tasks that can be accomplished through the mighty HuggingFace library. When I personally tried to apply one of these tasks on a custom problem and a dataset I have, I faced a major issue with memory usage.
The examples presented by HuggingFace follow the pipeline of using first a
tokenizer and then a
model. However, applying tokenization for your whole dataset can be cumbersome on your memory and you might not even get to the model training part due to
MemoryError. Therefore, I have modified the example custom NER task given here and presented a workaround of batch tokenization and processing of data points through PyTorch’s
First, get the data on your current directory from the link here. We can then read and prepare the data for our use as follows
We can then prepare our training and testing datasets along with tag to id mapping functions using the code block below.
Now, we need to put the data in a format that can be processed by a HuggingFace model via Trainer API. For that purpose, we create a sub-class of
torch.util.data.Dataset, which can be used in Trainer API with a HuggingFace PyTorch model. The custom dataset subclass we use is as follows
Note that the subclass we created needs to override two functions
__len__ (which is used when sampling for different batches) and
__getitem__ (which is used when a single item from a batch is called. That is why it also accepts a parameter
__getitem__ function is called on each sample when it is being processed. Therefore, we can define our memory-heavy computations within this function to avoid a memory overhead. As you can see in this example, we have defined our tokenization process in this function, which freed us from holding all the tokenized data for all the examples that we have in our dataset in our memory. This way, we only tokenize and hold the corresponding batch in memory.
We should also note that in order to use such pytorch
Dataset object through HuggingFace’s Trainer API,
__getitem__ function can only return the keys that are named parameters available in forward function of the Pytorch based model of HuggingFace. Therefore, in this example we are only returning the keys
attention_mask , and
We can also override the function
__init__, which is called whenever a dataset created. Note that memory heavy operations should not be used within the
__init__ function. We suggest only initializations and basic data readings from a file should be done in this function.
We have also defined another function called
align_labels, which is used for aligning data tags with the text after tokenization since after passing the text through the tokenizer, sub-tokens will emerge and its length will be changed. We tag the special tokens that come from a tokenizer like BertTokenizer with an index of -100 on par with documentation provided by the transformers library.
As a tokenizer, we use
DistilBertTokenizerFast from the transformers library. As you can see in the custom subclass we prepared, we use parameters
truncation=True since WNUT dataset is already in a list format and the max token length found in the document was 40.
We can now create our training and validation datasets using our subclass and start training our model.
DistillBertForTokenClassification model, which uses NER head of BERT model. We specify the number of output labels we have when calling the model and we specify the necessary training arguments via
TrainingArguments object. Then, we create a
Trainer object using the model, arguments, and the datasets that we have defined. Note that we also add an
EarlyStoppingCallback to the trainer object so that the training stops before overfitting. We also used another argument in our
Trainer object called
compute_metrics . This argument allows us to pass a metric computation function that can track the performance of the model during training. The
compute_metrics function that we have used depends on
seqeval package, which can be found here, and defined as follows
We can now fine-tune a HuggingFace BERT model for NER task without overwhelming our memory :)