Summarization with blurr
blurr is a libray I started that integrates huggingface transformers with the world of fastai v2, giving fastai devs everything they need to train, evaluate, and deploy transformer specific models. In this article, I provide a simple example of how to use blurr's new summarization capabilities to train, evaluate, and deploy a BART summarization model.
# !pip install transformers -Uqq
# !pip install datasets -Uqq
# !pip install bert-score -Uqq
# !pip install sacremoses
# !pip install ohmeow-blurr -Uqq
import datasets
import pandas as pd
from fastai.text.all import *
from transformers import *
from blurr.text.data.all import *
from blurr.text.modeling.all import *
import nltk
nltk.download('punkt', quiet=True)
We're going to use to use the datasets library from huggingface to grab your raw data. This package gives you access to all kinds of NLP related datasets, explanations of each, and various task specific metrics to use in evaluating your model. The best part being everything comes down to you in JSON! This makes it a breeze to get up and running quickly!
We'll just use a subset of the training set to build both our training and validation DataLoaders
raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')
df = pd.DataFrame(raw_data)
df.head()
We begin by getting our hugginface objects needed for this task (e.g., the architecture, tokenizer, config, and model). We'll use blurr's get_hf_objects
helper method here.
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
Next we need to build out our DataBlock. Remember tha a DataBlock is a blueprint describing how to move your raw data into something modelable. That blueprint is executed when we pass it a data source, which in our case, will be the DataFrame we created above. We'll use a random subset to get things moving along a bit faster for the demo as well.
Notice that the blurr DataBlock as been dramatically simplified given the shift to on-the-fly batch-time tokenization. All we need is to define a single Seq2SeqBatchTokenizeTransform
instance, optionally passing a list to any of the tokenization arguments to differentiate the values for the input and summary sequences. In addition to specifying a custom max length for the inputs, we can also do the same for the output sequences ... and with the latest release of blurr, we can even customize the text generation by passing in text_gen_kwargs
.
We pass noop
as a type transform for our targets because everything is already handled by the batch transform now.
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization'); text_gen_kwargs
hf_batch_tfm = Seq2SeqBatchTokenizeTransform(
hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs
)
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
dls = dblock.dataloaders(df, bs=2)
len(dls.train.items), len(dls.valid.items)
It's always a good idea to check out a batch of data and make sure the shapes look right.
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
Even better, we can take advantage of blurr's TypeDispatched version of show_batch
to look at things a bit more intuitively. We pass in the dls
via the dataloaders
argument so we can access all tokenization/modeling configuration stored in our batch transform above.
dls.show_batch(dataloaders=dls, max_n=2)
We'll prepare our BART model for training by wrapping it in blurr's BaseModelWrapper
object and using the callback, BaseModelCallback
, as usual. A new Seq2SeqMetricsCallback
object allows us to specify Seq2Seq metrics we want to use, things like rouge and bertscore for tasks like summarization as well as metrics such as meteor, bleu, and sacrebleu for translations tasks. Using huggingface's metrics library is as easy as specifying a metrics configuration such as below.
Once we have everything in place, we'll freeze our model so that only the last layer group's parameters of trainable. See here for our discriminitative learning rates work in fastai.
Note: This has been tested with ALOT of other Seq2Seq models; see the docs for more information.
seq2seq_metrics = {
'rouge': {
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
'returns': ["rouge1", "rouge2", "rougeL"]
},
'bertscore': {
'compute_kwargs': { 'lang': 'en' },
'returns': ["precision", "recall", "f1"]
}
}
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
learn = Learner(dls,
model,
opt_func=ranger,
loss_func=CrossEntropyLossFlat(),
cbs=learn_cbs,
splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16()
learn.create_opt()
learn.freeze()
Still experimenting with how to use fastai's learning rate finder for these kinds of models. If you all have any suggestions or interesting insights to share, please let me know. We're only going to train the frozen model for one epoch for this demo, but feel free to progressively unfreeze the model and train the other layers to see if you can best my results below.
learn.lr_find()
It's also not a bad idea to run a batch through your model and make sure the shape of what goes in, and comes out, looks right.
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds[0], preds[1].shape
learn.fit_one_cycle(1, lr_max=3e-5, cbs=fit_cbs)
And now we can look at the generated predictions using our text_gen_kwargs
above
learn.show_results(learner=learn, max_n=2)
Even better though, blurr augments the fastai Learner with a blurr_summarize
method that allows you to use huggingface's PreTrainedModel.generate
method to create something more human-like.
test_article = """
The past 12 months have been the worst for aviation fatalities so far this decade - with the total of number of people killed if airline
crashes reaching 1,050 even before the Air Asia plane vanished. Two incidents involving Malaysia Airlines planes - one over eastern Ukraine and the other in the Indian Ocean - led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed a further 49 people. The remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations. Despite 2014 having the highest number of fatalities so far this decade, the total number of crashes was in fact the lowest since the first commercial jet airliner took off in 1949 - totalling just 111 across the whole world over the past 12 months. The all-time deadliest year for aviation was 1972 when a staggering 2,429 people were killed in a total of 55 plane crashes - including the crash of Aeroflot Flight 217, which killed 174 people in Russia, and Convair 990 Coronado, which claimed 155 lives in Spain. However this year's total death count of 1,212, including those presumed dead on board the missing Air Asia flight, marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since the end of the Second World War. Scroll down for videos. Deadly: The past 12 months have been the worst for aviation fatalities so far this decade - with the total of number of people killed if airline crashes reaching 1,158 even before the Air Asia plane (pictured) vanished. Fatal: Two incidents involving Malaysia Airlines planes - one over eastern Ukraine (pictured) and the other in the Indian Ocean - led to the deaths of 537 people. Surprising: Despite 2014 having the highest number of fatalities so far this decade, the total number of crashes was in fact the lowest since the first commercial jet airliner took off in 1949. 2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident. In total more than half the people killed in aviation incidents this year had been flying on board Malaysia-registered planes. In January a total of 12 people lost their lives in five separate incidents, while the same number of crashes in February killed 107.
"""
We can override the text_gen_kwargs
we specified for our DataLoaders
when we generate text using blurr's Learner.blurr_generate
method
outputs = learn.blurr_summarize(test_article, early_stopping=True, num_beams=4, num_return_sequences=3)
for idx, o in enumerate(outputs):
print(f'=== Prediction {idx+1} ===\n{o}\n')
What about inference? Easy!
learn.metrics = None
learn.export(fname='ft_cnndm_export.pkl')
inf_learn = load_learner(fname='ft_cnndm_export.pkl')
inf_learn.blurr_summarize(test_article)
That's it
blurr supports a number of huggingface transformer model tasks in addition to summarization (e.g., sequence classification , token classification, and question/answering, causal language modeling, and transation). The docs include examples for each of these tasks if you're curious to learn more.
For more information about ohmeow or to get in contact with me, head over to ohmeow.com for all the details.
Thanks!