Text-summarization

Intro

First, we need to install blurr module for Transformers integration.

reticulate::py_install('ohmeow-blurr',pip = TRUE)

Dataset

Get dataset from the link:

library(fastai)
library(magrittr)
library(zeallot)

cnndm_df = data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/cnndm_sample.csv')

Pretrained model

Get pretrained BART model from transformers:

transformers = transformers()

BartForConditionalGeneration = transformers$BartForConditionalGeneration

pretrained_model_name = "facebook/bart-large-cnn"
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-%
  get_hf_objects(pretrained_model_name,model_cls=BartForConditionalGeneration)

Datablock

Create datablock and see batch:

before_batch_tfm = HF_SummarizationBeforeBatchTransform(hf_arch, 
                                                        hf_tokenizer, max_length=c(256, 130))
blocks = list(HF_Text2TextBlock(before_batch_tfms=before_batch_tfm, 
                                input_return_type=HF_SummarizationInput), noop())

dblock = DataBlock(blocks=blocks,
                   get_x=ColReader('article'),
                   get_y=ColReader('highlights'),
                   splitter=RandomSplitter())

dls = dblock %>% dataloaders(cnndm_df, bs=2)

dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[    0,    36, 16256,    43,   480,  2193,     7,    62,     7,   158,
           135,     9,    70,   684,  4707,     6,  1625,    16,  4984,    25,
            65,     9,     5,   144, 39865, 32273,  3806,    15,     5,  5518,
             4,    20,  9544,  3455,     9,  2147,   464,     8,  1050, 29779,
         26284,    15,  1632, 11534,    32,     6,   959,     6,  5608,     5,
          8066,     9,     5,   247,    18,  4066,  7892,     4,   178,    89,
            16,    10,   372,   432,     7,  2217,     4,    96,     5,   315,
          3076,  9356,  4928,    36,  4154,  9662,    43,   623, 12978, 23853,
          2521,    18,   889,     9, 10721,   625, 32273,   749,  1625,  5546,
           365,   212,     4,    20,   889,  3372,    10,   333,     9,   601,
           749,    14, 19655,     5,  1647,     9,     5,  3875,    18,  4707,
             8,    32,  3891,  1687,  2778, 39865, 32273,     4,  1740,    63,
         23491, 28919,    11,     5,  7599,  3939,     7,    63, 10602, 41166,
          1634,    11, 12718,  1115,   281,     8,     5,   854,  3964, 21785,
         14113,     8,    63, 38905,     8, 21950, 19947,    11,     5,  1926,
             6,  1625, 10902,    41,  5784,  4066,  3143,     9, 39456,     8,
           856, 25729,     4,   993,   195,  5243,    66,     9,   262,  1360,
         38950,  1848,  4707,   303,    11,  1625,   480,     5,   144,    11,
           143,   247,   480,    64,   129,    28, 13590,   624,    63,  7562,
             4,    85,    16,   184,     7, 39179,  3505,     9, 27772,     6,
         28482,  4707,     9,  7723,     6,   112,     6,  6115, 17576,     9,
          7723,     8,   973,     6,   151,  1380, 14868,     9,  3451,     4,
           221,  2839,  8367,   102,     6,    10,   786,    12,  7699,  1651,
            14,  1364,     7,  3720,  8360,     8,  5068,   709,    11,  1625,
             6,    34,  3919,   411,  4707,    61,    24,   161,  7648,  2072,
             5,  1272,  2713,    30,     5,     2],
        [    0,    36, 16256,    43,   480,  6871,    24,  1655, 22049,    50,
         41571,  1896,    54, 24509,    13,   244,     5,   363,     5,   601,
            12,   180,    12,   279,  1896,    21,   738,  1462,   116,   280,
           115,  6723,    15,    61,   985,     5,  3940,  2046,     4,  1868,
         22049,    18,     8,  1896,    18,  8826,  2327,   117, 28946,   273,
            11,  2559,   461,  4961,    25,     7,  1060, 28604,  2236,    16,
          1317, 11347,   148,    10,  7158,   486,    31,    14,   902,   973,
             6,  1125,     6,   363,    11, 23058,     6,  1261,    35,  4028,
            26,    24,    21,    69,   979,     4,   280, 34920,   480,    19,
          5767,  3809,  1243, 21605, 13875,    24,    21,    69,   979,     6,
         41571,     6,    54, 16670,    66,     6,   150, 15151,  2459, 22049,
            26,    24,    21,    69,   979,     6,  1655,     6,    54,    21,
         16600,    71,   145,  4487,    30,     5,  6066,   480,    21,  1353,
             7,   273,    18,   461,  7069,     6,     8,  1353,     7,     5,
           200,    12,  5743,  1900,   403, 25451,    11,  1353,  1261,     4,
         22049,    34,  4407,    45,  2181,     8,  1695,    37,   738,     5,
          7044,    11,  1403,    12, 21409,     4,    20,  7158,   486,   702,
          2330,    11,   461,    15,   273,     6,    39,  3969,  2026,     6,
           124,    62,    49, 19395,    14,    24,    21,  1896,     6,     8,
            45,    49,  3653,     6,    54,    21,     5, 29593,   368,     4,
          4500,  4945,   628,   273,  1390,     6, 15151,  2459, 22049,    26,
            79,    21,   686,  1655,    21,     5,    65, 16600,     4,  2612,
           116, 41039, 10105,    37,    18,   127,   979, 39732,   264,  7173,
         41039,  1250,     9,     5,  1065, 48149,    77,   553,   549,    79,
            56,   655,   137,  1317,    69,  1655, 22049,  7923, 24120,    50,
          8930,    66,    13,   244,     4,     2]], device='cuda:0')

[[1]]$attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')

[[1]]$decoder_input_ids
tensor([[    0,  1625,  4452,     7,    62,     7,   158,   135,     9,    70,
           684,  4707,    15,  3875,   479, 50118,   243,    16,   184,     7,
         39179,  3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,
             6,   151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,
          2147,   464,    16,  9405,    10,   380,  8793,    15,    63, 28809,
           479, 50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,
             9,   145,     5,   247,    18,   632,  7648,   479,     2,     1,
             1,     1,     1,     1,     1,     1],
        [    0,  5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,
            45,     5,  7158,   486,     6,    40,  3094,     5,   403,   479,
         50118,  5341,    35,    83,  2470,    13,  1896,    18,   284,   161,
            37,  4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,
           534, 15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,
            16,    14,     9,    69,   979,   479, 50118, 30541,     6, 41571,
          1896,    18, 18631,  1843,    26,    14,    24,    69,   979,    18,
          2236,    15,     5,  7158,   486,   479]], device='cuda:0')

[[1]]$labels
tensor([[ 1625,  4452,     7,    62,     7,   158,   135,     9,    70,   684,
          4707,    15,  3875,   479, 50118,   243,    16,   184,     7, 39179,
          3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,     6,
           151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,  2147,
           464,    16,  9405,    10,   380,  8793,    15,    63, 28809,   479,
         50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,     9,
           145,     5,   247,    18,   632,  7648,   479,     2,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100],
        [ 5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,    45,
             5,  7158,   486,     6,    40,  3094,     5,   403,   479, 50118,
          5341,    35,    83,  2470,    13,  1896,    18,   284,   161,    37,
          4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,   534,
         15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,    16,
            14,     9,    69,   979,   479, 50118, 30541,     6, 41571,  1896,
            18, 18631,  1843,    26,    14,    24,    69,   979,    18,  2236,
            15,     5,  7158,   486,   479,     2]], device='cuda:0')


[[2]]
tensor([[ 1625,  4452,     7,    62,     7,   158,   135,     9,    70,   684,
          4707,    15,  3875,   479, 50118,   243,    16,   184,     7, 39179,
          3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,     6,
           151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,  2147,
           464,    16,  9405,    10,   380,  8793,    15,    63, 28809,   479,
         50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,     9,
           145,     5,   247,    18,   632,  7648,   479,     2,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100],
        [ 5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,    45,
             5,  7158,   486,     6,    40,  3094,     5,   403,   479, 50118,
          5341,    35,    83,  2470,    13,  1896,    18,   284,   161,    37,
          4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,   534,
         15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,    16,
            14,     9,    69,   979,   479, 50118, 30541,     6, 41571,  1896,
            18, 18631,  1843,    26,    14,    24,    69,   979,    18,  2236,
            15,     5,  7158,   486,   479,     2]], device='cuda:0')

Train

Define a Learner object and train the model:

text_gen_kwargs = hf_config$task_specific_params['summarization'][[1]]
text_gen_kwargs['max_length'] = 130L; text_gen_kwargs['min_length'] = 30L

text_gen_kwargs

model = HF_BaseModelWrapper(hf_model)
model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

learn = Learner(dls,
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=model_cb,
                splitter=partial(summarization_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn$create_opt()
learn$freeze()

learn %>% fit_one_cycle(1, lr_max=4e-5)

Conclusion

Predict a new text:

test_article = c("About 10 men armed with pistols and small machine guns raided a casino in Switzerland 
and made off into France with several hundred thousand Swiss francs in the early hours 
of Sunday morning, police said. The men, dressed in black clothes and black ski masks, 
split into two groups during the raid on the Grand Casino Basel, Chief Inspector Peter 
Gill told CNN. One group tried to break into the casino's vault on the lower level
but could not get in, but they did rob the cashier of the money that was not secured, 
he said. The second group of armed robbers entered the upper level where the roulette 
and blackjack tables are located and robbed the cashier there, he said. As the thieves 
were leaving the casino, a woman driving by and unaware of what was occurring unknowingly 
blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat
her, and took off for the French border. The other gunmen followed into France, which 
is only about 100 meters (yards) from the casino, Gill said. There were about 600 people 
in the casino at the time of the robbery. There were no serious injuries, although one 
guest on the Casino floor was kicked in the head by one of the robbers when he moved, 
the police officer said. Swiss authorities are working closely with French authorities,
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. 
CNN's Andreena Narayan contributed to this report.")
outputs = learn$blurr_summarize(test_article, num_return_sequences=3L)
cat(outputs)

 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .
Swiss authorities are working closely with French authorities, police chief says .


 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .
Swiss authorities are working closely with French authorities, an officer says.


 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .