# only run this cell if you are in collab
# !pip install ohmeow-blurr
# !pip install nlp
import nlp
import pandas as pd
from fastai.text.all import *
from transformers import *

from blurr.data.all import *
from blurr.modeling.all import *

Data Preparation

We're going to use to use the new nlp library from huggingface to grab your raw data. This package gives you access to all kinds of NLP related datasets, explanations of each, and various task specific metrics to use in evaluating your model. The best part being everything comes down to you in JSON! This makes it a breeze to get up and running quickly!

We'll just use a subset of the training set to build both our training and validation DataLoaders

raw_data = nlp.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')
df = pd.DataFrame(raw_data)
df.head()
article highlights id
0 It's official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons. The proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It's a step that is set to turn an internat... Syrian official: Obama climbed to the top of the tree, "doesn't know how to get down"\nObama sends a letter to the heads of the House and Senate .\nObama to seek congressional approval on military action against Syria .\nAim is to determine whether CW were used, not by whom, says U.N. spokesman . 0001d1afc246a7964130f43ae940af6bc6c57f01
1 (CNN) -- Usain Bolt rounded off the world championships Sunday by claiming his third gold in Moscow as he anchored Jamaica to victory in the men's 4x100m relay. The fastest man in the world charged clear of United States rival Justin Gatlin as the Jamaican quartet of Nesta Carter, Kemar Bailey-Cole, Nickel Ashmeade and Bolt won in 37.36 seconds. The U.S finished second in 37.56 seconds with Canada taking the bronze after Britain were disqualified for a faulty handover. The 26-year-old Bolt has now collected eight gold medals at world championships, equaling the record held by American trio... Usain Bolt wins third gold of world championship .\nAnchors Jamaica to 4x100m relay victory .\nEighth gold at the championships for Bolt .\nJamaica double up in women's 4x100m relay . 0002095e55fcbd3a2f366d9bf92a95433dc305ef
2 Kansas City, Missouri (CNN) -- The General Services Administration, already under investigation for lavish spending, allowed an employee to telecommute from Hawaii even though he is based at the GSA's Kansas City, Missouri, office, a CNN investigation has found. It cost more than $24,000 for the business development specialist to travel to and from the mainland United States over the past year. He is among several hundred GSA "virtual" workers who also travel to various conferences and their home offices, costing the agency millions of dollars over the past three years. Under the program, ... The employee in agency's Kansas City office is among hundreds of "virtual" workers .\nThe employee's travel to and from the mainland U.S. last year cost more than $24,000 .\nThe telecommuting program, like all GSA practices, is under review . 00027e965c8264c35cc1bc55556db388da82b07f
3 Los Angeles (CNN) -- A medical doctor in Vancouver, British Columbia, said Thursday that California arson suspect Harry Burkhart suffered from severe mental illness in 2010, when she examined him as part of a team of doctors. Dr. Blaga Stancheva, a family physician and specialist in obstetrics, said both Burkhart and his mother, Dorothee, were her patients in Vancouver while both were applying for refugee status in Canada. "I was asked to diagnose and treat Harry to support a claim explaining why he was unable to show up in a small-claims court case," Stancheva told CNN in a phone intervie... NEW: A Canadian doctor says she was part of a team examining Harry Burkhart in 2010 .\nNEW: Diagnosis: "autism, severe anxiety, post-traumatic stress disorder and depression"\nBurkhart is also suspected in a German arson probe, officials say .\nProsecutors believe the German national set a string of fires in Los Angeles . 0002c17436637c4fe1837c935c04de47adb18e9a
4 (CNN) -- Police arrested another teen Thursday, the sixth suspect jailed in connection with the gang rape of a 15-year-old girl on a northern California high school campus. Jose Carlos Montano, 18, was arrested on charges of felony rape, rape in concert with force, and penetration with a foreign object, said Richmond Police Lt. Mark Gagan. Montano was arrested Thursday evening in San Pablo, California, a small town about two miles from the city of Richmond, where the crime took place. Montano, who was held in lieu of $1.3 million bail, is accused of taking part in what police said was a 2½... Another arrest made in gang rape outside California school .\nInvestigators say up to 20 people took part or stood and watched the assault .\nFour suspects appeared in court Thursday; three wore bulletproof vests . 0003ad6ef0c37534f80b55b4235108024b407f0b

We begin by getting our hugginface objects needed for this task (e.g., the architecture, tokenizer, config, and model). We'll use blurr's get_hf_objects helper method here.

pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.configuration_bart.BartConfig,
 transformers.tokenization_bart.BartTokenizer,
 transformers.modeling_bart.BartForConditionalGeneration)

Next we need to build out our DataBlock. Remember tha a DataBlock is a blueprint describing how to move your raw data into something modelable. That blueprint is executed when we pass it a data source, which in our case, will be the DataFrame we created above. We'll use a random subset to get things moving along a bit faster for the demo as well.

Notice that the blurr DataBlock as been dramatically simplified given the shift to on-the-fly batch-time tokenization. All we need is to define a single HF_SummarizationBatchTransform instance, optionally passing a list to any of the tokenization arguments to differentiate the values for the input and summary sequences. We pass noop as a type transform for our targets because everything is already handled by the batch transform now.

hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer, max_length=[256, 130])

blocks = (HF_TextBlock(hf_batch_tfm=hf_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
dls = dblock.dataloaders(df, bs=2)
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
len(dls.train.items), len(dls.valid.items)
(2297, 574)

It's always a good idea to check out a batch of data and make sure the shapes look right.

b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
(2, torch.Size([2, 256]), torch.Size([2, 71]))

Even better, we can take advantage of blurr's TypeDispatched version of show_batch to look at things a bit more intuitively. We pass in the dls via the dataloaders argument so we can access all tokenization/modeling configuration stored in our batch transform above.

dls.show_batch(dataloaders=dls, max_n=2)
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
text target
0 While Iraq's military claimed Wednesday to have driven back militants battling for control of the country, the chairman of the Joint Chiefs of Staff told Congress that the United States has received a request from the Iraqi government to use its air power in the conflict. Gen. Martin Dempsey, the senior ranking member of the U.S. armed forces, spoke before the Senate Appropriations Committee Wednesday on Capitol Hill in Washington, saying that the United States' "national security interest (is) to counter (ISIS) where we find them." ISIS is the Islamic State in Iraq and Syria. Comprising mostly Sunni Muslims, ISIS is an al Qaeda splinter group that wants to establish a caliphate, or Islamic state, that would stretch from Iraq into northern Syria. The group has had substantial success in Syria battling Syrian President Bashar al-Assad's security forces. Since launching their offensive in Iraq, ISIS claims to have killed at least 1,700 Shiites. Hundreds of thousands of Iraqis have fled, prompting fears of a brewing humanitarian crisis. Qassim Atta, a spokesman for Iraqi security forces, on Wednesday night said an investigation had been ordered into 59 high-ranking security officials accused of leaving their posts. The officials could be executed if found guilty, Atta said. Concerns over Vice President Biden stresses need for national unity in talk with Iraqi PM.\nA cleric called for attacks against U.S. embassies in the case of airstrikes.\nInvestigation will probe Iraqi security forces who left posts, general commander says.\nSaudi Arabia responds to Iraq's accusation that it's helping ISIS, calling allegation a "falsehood"
1 Los Angeles (CNN) -- A former Los Angeles cop with military training vowed war against other men in blue Thursday, leaving one officer dead days after he allegedly killed two other people to begin a wave of retribution for being fired, police said. The focus of the intensive, expansive manhunt is Christopher Jordan Dorner, a 270-pound former Navy lieutenant who has professed his venom against LAPD officers he claimed ruined his life by forcing him out of his dream job. Dorner blames one retired officer for bungling his appeal to get his job back in an 11-page manifesto, in which he also complained of mistreatment by the LAPD. In that letter -- provided to CNN by an LAPD source -- he vowed to violently target police officers and their families, whoever and wherever they are. "I will bring unconventional and asymmetrical warfare to those in LAPD uniform whether on or off duty," Dorner wrote. "I never had the opportunity to have a family of my own, I'm terminating yours." Authorities believe he followed through on his threats early Thursday by shooting a Riverside, California, police officer and two others. A day earlier, Irvine police named Dorner a suspect in the double slayings Sunday of a woman -- identified by Los Angeles police as the daughter of a NEW: With snow coming, authorities continue to hunt for the suspect near Big Bear Lake.\nPolice believe former cop Christopher Jordan Dorner shot three officers, killing one.\nThis was days after he allegedly killed two people, one a retired LAPD officer's daughter.\nIn an 11-page manifesto, Dorner promises "war" on police and their families.

Training

We'll prepare our BART model for training by wrapping it in blurr's HF_BaseModelWrapper model object and defining a new callback, HF_SummarizationModelCallback. This class will handle ensuring all our inputs get translated into the proper arguments needed by a huggingface conditional generation model. We'll also use a custom model splitter that will allow us to apply discriminative learning rates over the various layers in our huggingface model.

Once we have everything in place, we'll freeze our model so that only the last layer group's parameters of trainable. See here for our discriminitative learning rates work in fastai.

Note: This has been tested with BART only thus far (if you try any other conditional generation transformer models they may or may not work ... if you do, lmk either way)

text_gen_kwargs = { **hf_config.task_specific_params['summarization'], **{'max_length': 130, 'min_length': 30} }
text_gen_kwargs
{'early_stopping': True,
 'length_penalty': 2.0,
 'max_length': 130,
 'min_length': 30,
 'no_repeat_ngram_size': 3,
 'num_beams': 4}
model = HF_BaseModelWrapper(hf_model)
model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

learn = Learner(dls, 
                model,
                opt_func=ranger,
                loss_func=HF_MaskedLMLoss(),
                cbs=[model_cb],
                splitter=partial(summarization_splitter, arch=hf_arch))#.to_fp16()

learn.create_opt() 
learn.freeze()

Still experimenting with how to use fastai's learning rate finder for these kinds of models. If you all have any suggestions or interesting insights to share, please let me know. We're only going to train the frozen model for one epoch for this demo, but feel free to progressively unfreeze the model and train the other layers to see if you can best my results below.

learn.lr_find(suggestions=True)
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
SuggestedLRs(lr_min=0.00014454397605732084, lr_steep=0.00013182566908653826)

It's also not a bad idea to run a batch through your model and make sure the shape of what goes in, and comes out, looks right.

b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds[0], preds[1].shape
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
(3,
 tensor(4.6251, device='cuda:0', grad_fn=<NllLossBackward>),
 torch.Size([2, 76, 50264]))
learn.fit_one_cycle(1, lr_max=3e-5)
epoch train_loss valid_loss rouge1 rouge2 rougeL time
0 1.844189 1.670262 0.388298 0.173981 0.267575 17:09
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,

And now we can look at the "greedy decoded" predictions ...

learn.show_results(learner=learn, max_n=2)
/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:542: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
  FutureWarning,
text target prediction
0 The news from Pakistan is generally bad news. In the past week, which was far from atypical, suicide bombers attacked a court building in the northwestern city of Peshawar taking hostages and killing four people. In the southern city of Karachi the director of a renowned social program working in the megacity's poorest neighborhoods was shot and killed. And gunmen kidnapped two female Czech tourists in southwestern Pakistan. But this past week also saw more than a glimmer of good news from Pakistan: Saturday, March 16 marked an extraordinary moment in Pakistani history, as this is the first time a civilian government has served its entire five-year term (from 2008 to 2013). And, for the first time in its history, the Pakistani military appears unwilling to mount a coup against the civilian government. The military has successfully executed three coups and attempted a number of others since Pakistan's independence in 1947. Today the army understands that the most recent coup by General Pervez Musharraf who took power in 1999 has tarnished its brand. Musharraf hung on to power for almost a decade and his imposition of emergency rule in 2007 triggered massive street protests and eventually his ouster. On Saturday, Musharaf announced he is returning to Pakistan from self-imposed exile on March 24 to Peter Bergen: For the first time, Pakistan government served its full term.\nHe says lack of military coup attempt shows government is more stable than many think.\nElections in Pakistan, Afghanistan likely to be crucial for those two nations.\nBergen: He says Afghan economy is resilient and corruption may be receding. This is the first time a civilian government has served its entire five-year term (from 2008 to 2013) in Pakistan's history .\nThe military has successfully executed three coups and attempted a number of others since Pakistan's independence in 1947 .\nGeneral Pervez Musharraf announced he is returning to Pakistan from self-imposed exile on March 24 .\nMusharraf hung on to power for almost a decade and his imposition of emergency rule in 2007 triggered protests .
1 (CNN) -- A fledgling force of Syrian military deserters said it struck an important government security complex on the outskirts of the capital overnight, a bold strike reflecting the resolve and confidence of the regime's opposition. The assault came ahead of an Arab League meeting Wednesday to reaffirm a decision to suspend Syria's membership, a move the group made over the weekend after President Bashar al-Assad's government failed to abide by a proposal to end a brutal crackdown on protesters. Also Wednesday, France recalled its ambassador to Syria, the French Foreign Ministry said. The move followed attacks on French missions in Syria. The defector group, called the Free Syrian Army, said it attacked an air intelligence base in Harasta and planted "powerful explosions inside and around the compound that shook its foundations." Andrew Tabler, an expert on Syria at the Washington Institute for Near East Policy, said air intelligence has been deeply involved in the eight-month-long crackdown by the Syrian government against protesters, a grinding civil conflict that the United Nations says has left more than 3,500 people dead. Tabler said the strike reflects the growing sophistication of the Free Syrian Army, which has brigades across the country and has been in existence since the summer. "It opens up a new era of the Free Syrian Army says it knows of no casualties in assault on government complex.\n"It opens up a new era of the conflict," a scholar says.\nA Turkish diplomat says "helping hands" to Syria have been wasted.\nMore deaths are reported in Syria on Wednesday. NEW: France recalls its ambassador to Syria after attacks on French missions in Syria .\nThe Free Syrian Army says it attacked an air intelligence base in Harasta .\nAn Arab League meeting will reaffirm a decision to suspend Syria's membership .\nMore than 3,500 people have died in Syria's civil conflict, the United Nations says .

Even better though, blurr augments the fastai Learner with a blurr_summarize method that allows you to use huggingface's PreTrainedModel.generate method to create something more human-like.

test_article = """
The past 12 months have been the worst for aviation fatalities so far this decade - with the total of number of people killed if airline crashes reaching 1,050 even before the Air Asia plane vanished. Two incidents involving Malaysia Airlines planes - one over eastern Ukraine and the other in the Indian Ocean - led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed a further 49 people. The remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations. Despite 2014 having the highest number of fatalities so far this decade, the total number of crashes was in fact the lowest since the first commercial jet airliner took off in 1949 - totalling just 111 across the whole world over the past 12 months. The all-time deadliest year for aviation was 1972 when a staggering 2,429 people were killed in a total of 55 plane crashes - including the crash of Aeroflot Flight 217, which killed 174 people in Russia, and Convair 990 Coronado, which claimed 155 lives in Spain. However this year's total death count of 1,212, including those presumed dead on board the missing Air Asia flight, marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since the end of the Second World War. Scroll down for videos. Deadly: The past 12 months have been the worst for aviation fatalities so far this decade - with the total of number of people killed if airline crashes reaching 1,158 even before the Air Asia plane (pictured) vanished. Fatal: Two incidents involving Malaysia Airlines planes - one over eastern Ukraine (pictured) and the other in the Indian Ocean - led to the deaths of 537 people. Surprising: Despite 2014 having the highest number of fatalities so far this decade, the total number of crashes was in fact the lowest since the first commercial jet airliner took off in 1949. 2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident. In total more than half the people killed in aviation incidents this year had been flying on board Malaysia-registered planes. In January a total of 12 people lost their lives in five separate incidents, while the same number of crashes in February killed 107. 
"""
outputs = learn.blurr_summarize(test_article, early_stopping=True, num_beams=4, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 The past 12 months have been the worst for aviation fatalities so far this decade - with 1,158 deaths .
Two incidents involving Malaysia Airlines planes led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed 49 people .
The remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations .
This year's total death count of 1,212 marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since the end of the Second

=== Prediction 2 ===
 The past 12 months have been the worst for aviation fatalities so far this decade - with 1,158 deaths .
Two incidents involving Malaysia Airlines planes led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed 49 people .
The remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations .
This year's total death count of 1,212 marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since end of Second World War

=== Prediction 3 ===
 The past 12 months have been the worst for aviation fatalities so far this decade - with 1,158 deaths .
Two incidents involving Malaysia Airlines planes led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed 49 people .
The remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations .
This year's total death count of 1,212 marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since the end of Second World

What about inference? Easy!

learn.export(fname='ft_cnndm_export.pkl')
inf_learn = load_learner(fname='ft_cnndm_export.pkl')
inf_learn.blurr_summarize(test_article)
[" The past 12 months have been the worst for aviation fatalities so far this decade - with 1,158 deaths .\nTwo incidents involving Malaysia Airlines planes led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed 49 people .\nThe remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations .\nThis year's total death count of 1,212 marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since the end of the Second"]

That's it

blurr supports a number of huggingface transformer model tasks in addition to summarization (e.g., sequence classification , token classification, and question/answering). The docs include examples for each of these tasks if you're curious to learn more.

For more information about ohmeow or to get in contact with me, head over to ohmeow.com for all the details.

Thanks!