Fine Tuning HuggingFace Models without Overwhelming Your Memory

A journey to scaling the training of HuggingFace models for large data through tokenizers and Trainer API.

Photo by Bernard Hermant on Unsplash

There are a lot of example notebooks available for different NLP tasks that can be accomplished through the mighty HuggingFace library. When I personally tried to apply one of these tasks on a custom problem and a dataset I have, I faced a major issue with memory usage.

The examples presented by HuggingFace follow the pipeline of using first a tokenizer and then a model. However, applying tokenization for your whole dataset can be cumbersome on your memory and you might not even get to the model training part due to MemoryError. Therefore, I have modified the example custom NER task given here and presented a workaround of batch tokenization and processing of data points through PyTorch’s

First, get the data on your current directory from the link here. We can then read and prepare the data for our use as follows

We can then prepare our training and testing datasets along with tag to id mapping functions using the code block below.

Now, we need to put the data in a format that can be processed by a HuggingFace model via Trainer API. For that purpose, we create a sub-class of, which can be used in Trainer API with a HuggingFace PyTorch model. The custom dataset subclass we use is as follows

Note that the subclass we created needs to override two functions __len__ (which is used when sampling for different batches) and __getitem__ (which is used when a single item from a batch is called. That is why it also accepts a parameter idx).

__getitem__ function is called on each sample when it is being processed. Therefore, we can define our memory-heavy computations within this function to avoid a memory overhead. As you can see in this example, we have defined our tokenization process in this function, which freed us from holding all the tokenized data for all the examples that we have in our dataset in our memory. This way, we only tokenize and hold the corresponding batch in memory.

We should also note that in order to use such pytorchDataset object through HuggingFace’s Trainer API, __getitem__ function can only return the keys that are named parameters available in forward function of the Pytorch based model of HuggingFace. Therefore, in this example we are only returning the keys input_ids , attention_mask , and labels .

We can also override the function __init__, which is called whenever a dataset created. Note that memory heavy operations should not be used within the __init__ function. We suggest only initializations and basic data readings from a file should be done in this function.

We have also defined another function called align_labels, which is used for aligning data tags with the text after tokenization since after passing the text through the tokenizer, sub-tokens will emerge and its length will be changed. We tag the special tokens that come from a tokenizer like BertTokenizer with an index of -100 on par with documentation provided by the transformers library.

As a tokenizer, we use DistilBertTokenizerFast from the transformers library. As you can see in the custom subclass we prepared, we use parameters is_split_into_words=True , max_length=64 , padding='max_length' and truncation=True since WNUT dataset is already in a list format and the max token length found in the document was 40.

We can now create our training and validation datasets using our subclass and start training our model.

We use DistillBertForTokenClassification model, which uses NER head of BERT model. We specify the number of output labels we have when calling the model and we specify the necessary training arguments via TrainingArguments object. Then, we create a Trainer object using the model, arguments, and the datasets that we have defined. Note that we also add an EarlyStoppingCallback to the trainer object so that the training stops before overfitting. We also used another argument in our Trainer object called compute_metrics . This argument allows us to pass a metric computation function that can track the performance of the model during training. The compute_metrics function that we have used depends on seqeval package, which can be found here, and defined as follows

We can now fine-tune a HuggingFace BERT model for NER task without overwhelming our memory :)

Please follow me on Medium and Twitter for more to come :)

Machine Learning Engineer & Enthusiast

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store