Classification: T5

Andrew Fogarty

7/21/2020

# load python
library(reticulate)
use_condaenv("my_ml")
# load packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AdamW, get_linear_schedule_with_warmup
import time
import datetime
import random
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
import re
import matplotlib.pyplot as plt
import seaborn as sns
import optuna
from optuna.pruners import SuccessiveHalvingPruner
from optuna.samplers import TPESampler


torch.cuda.amp.autocast(enabled=True)
## <torch.cuda.amp.autocast_mode.autocast object at 0x00000000348CA288>
SEED = 15
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
## <torch._C.Generator object at 0x000000001F53E050>
torch.backends.cudnn.deterministic = True

# tell pytorch to use cuda
device = torch.device("cuda")

1 Introduction

In this guide we use T5, a pre-trained and very large (e.g., roughly twice the size of BERT-base) encoder-decoder Transformer model for a classification task. T5, a model devised by Google, is an important advancement in the field of Transformers because it achieves near human-level performance on a variety of benchmarks like GLUE and SQuAD. Another important advancement is that it treats NLP as a text-to-text problem, whereby our inputs are text and our outputs are also text. In this universal framework, T5 can therefore handle any NLP task (in English). T5 was pre-trained on the C4 (Colossal Clean Crawled Corpus) corpus which amounts to roughly 750GB of clean English text. For comparative purpsoes, BERT was trained on roughly 13GB of text and XLNet was trained on roughly 126GB of text. For these reasons, T5 is the state of the art and its encoder-decoder architecture is likely the future of NLP models.

The guide proceeds by (1) preparing the data for text classification with T5 small – a small version of T5 base, and (2) training the data in PyTorch.

1.1 Data Preparation

Some unique pre-processing is required when using T5 for classification. Specifically, we need to add “” to the end of all of our input and target text that needs to be classified. T5’s tokenizer in the transformers library will handle the details from there.

1.2 Instantiate Tokenizer

Next, we instantiate the T5 tokenizer from transformers and check some special token IDs.

## 1
## 2
## 0

1.3 Tokenize the Corpus

Then, we proceed to tokenize our corpus like usual. Notice that we effectively do this process twice as we tokenize our corpus and also tokenize our targets.

## 2.772822299651568
## 3.0
## 3

1.4 Prepare and Split Data

Next, we split our data into train, validation, and test sets.

1.5 Instantiate Training Models

Now we are ready to prepare our training scripts which follow the other guides closely. T5ForConditionalGeneration asks that we supply four inputs into the model’s forward function: (1) corpus token ids, (2) corpus attention masks, and (3) our label ids, and (4) our label attention masks.

def train(model, dataloader, optimizer):

    # capture time
    total_t0 = time.time()

    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
    print('Training...')

    # reset total loss for epoch
    train_total_loss = 0
    total_train_f1 = 0

    # put model into traning mode
    model.train()

    # for each batch of training data...
    for step, batch in enumerate(dataloader):

        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        # Unpack this training batch from our dataloader:
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU using
        # the `to` method.
        #
        # `batch` contains three pytorch tensors:
        #   [0]: input tokens
        #   [1]: attention masks
        #   [2]: target tokens
        #   [3]: target attenion masks
        b_input_ids = batch[0].cuda()
        b_input_mask = batch[1].cuda()
        b_target_ids = batch[2].cuda()
        b_target_mask = batch[3].cuda()

        # clear previously calculated gradients
        optimizer.zero_grad()

        # runs the forward pass with autocasting.
        with autocast():
            # forward propagation (evaluate model on training batch)
            outputs = model(input_ids=b_input_ids,
                            attention_mask=b_input_mask,
                            labels=b_target_ids,
                            decoder_attention_mask=b_target_mask)

            loss, prediction_scores = outputs[:2]

            # sum the training loss over all batches for average loss at end
            # loss is a tensor containing a single value
            train_total_loss += loss.item()

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # update the learning rate
        scheduler.step()

    # calculate the average loss over all of the batches
    avg_train_loss = train_total_loss / len(dataloader)

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'Train Loss': avg_train_loss
        }
    )

    # training time end
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | trn loss | trn time ")
    print(f"{epoch+1:5d} | {avg_train_loss:.5f} | {training_time:}")

    return training_stats


def validating(model, dataloader):

    # capture validation time
    total_t0 = time.time()

    # After the completion of each training epoch, measure our performance on
    # our validation set.
    print("")
    print("Running Validation...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_valid_loss = 0

    # evaluate data for one epoch
    for batch in dataloader:

        # Unpack this training batch from our dataloader:
        # `batch` contains three pytorch tensors:
        #   [0]: input tokens
        #   [1]: attention masks
        #   [2]: target tokens
        #   [3]: target attenion masks
        b_input_ids = batch[0].cuda()
        b_input_mask = batch[1].cuda()
        b_target_ids = batch[2].cuda()
        b_target_mask = batch[3].cuda()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            # forward propagation (evaluate model on training batch)
            outputs = model(input_ids=b_input_ids,
                            attention_mask=b_input_mask,
                            labels=b_target_ids,
                            decoder_attention_mask=b_target_mask)

            loss, prediction_scores = outputs[:2]

            # sum the training loss over all batches for average loss at end
            # loss is a tensor containing a single value
            total_valid_loss += loss.item()

    # calculate the average loss over all of the batches.
    global avg_val_loss
    avg_val_loss = total_valid_loss / len(dataloader)

    # Record all statistics from this epoch.
    valid_stats.append(
        {
            'Val Loss': avg_val_loss,
            'Val PPL.': np.exp(avg_val_loss)
        }
    )

    # capture end validation time
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | val loss | val ppl | val time")
    print(f"{epoch+1:5d} | {avg_val_loss:.5f} | {np.exp(avg_val_loss):.3f} | {training_time:}")

    return valid_stats


def testing(model, dataloader):

    print("")
    print("Running Testing...")

    # measure training time
    t0 = time.time()

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_test_loss = 0
    total_test_acc = 0
    total_test_f1 = 0
    predictions = []
    actuals = []

    # evaluate data for one epoch
    for step, batch in enumerate(dataloader):
        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(dataloader), elapsed))

        # Unpack this training batch from our dataloader:
        # `batch` contains three pytorch tensors:
        #   [0]: input tokens
        #   [1]: attention masks
        #   [2]: target tokens
        #   [3]: target attenion masks
        b_input_ids = batch[0].cuda()
        b_input_mask = batch[1].cuda()
        b_target_ids = batch[2].cuda()
        b_target_mask = batch[3].cuda()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            # forward propagation (evaluate model on training batch)
            outputs = model(input_ids=b_input_ids,
                            attention_mask=b_input_mask,
                            labels=b_target_ids,
                            decoder_attention_mask=b_target_mask)

            loss, prediction_scores = outputs[:2]

            total_test_loss += loss.item()

            generated_ids = model.generate(
                    input_ids=b_input_ids,
                    attention_mask=b_input_mask,
                    max_length=3
                    )

            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in b_target_ids]

            total_test_acc += accuracy_score(target, preds)
            total_test_f1 += f1_score(preds, target,
                                       average='weighted',
                                       labels=np.unique(preds))
            predictions.extend(preds)
            actuals.extend(target)

    # calculate the average loss over all of the batches.
    avg_test_loss = total_test_loss / len(dataloader)

    avg_test_acc = total_test_acc / len(test_dataloader)

    avg_test_f1 = total_test_f1 / len(test_dataloader)

    # Record all statistics from this epoch.
    test_stats.append(
        {
            'Test Loss': avg_test_loss,
            'Test PPL.': np.exp(avg_test_loss),
            'Test Acc.': avg_test_acc,
            'Test F1': avg_test_f1
        }
    )
    global df2
    temp_data = pd.DataFrame({'predicted': predictions, 'actual': actuals})
    df2 = df2.append(temp_data)

    return test_stats
    

# time function
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

1.6 Dealing with Imbalanced Classifcation: Data Loaders

Since my corpus is imbalanced, we produce weighted samplers to help balance the distribution of data as it is fed via the data loaders.

# helper function to count target distribution inside tensor data sets
def target_count(tensor_dataset):
    # set empty count containers
    count0 = 0
    count1 = 0
    # set total container to turn into torch tensor
    total = []
    for i in tensor_dataset:
        # for kabul tensor
        if torch.all(torch.eq(i[2], torch.tensor([20716, 83, 1]))):
            count0 += 1
        # for us tensor
        elif torch.all(torch.eq(i[2], torch.tensor([837, 1, 0]))):
            count1 += 1
    total.append(count0)
    total.append(count1)
    return torch.tensor(total)


# prepare weighted sampling for imbalanced classification
def create_sampler(target_tensor, tensor_dataset):
    # generate class distributions [x, y]
    class_sample_count = target_count(tensor_dataset)
    # weight
    weight = 1. / class_sample_count.float()
    # produce weights for each observation in the data set
    new_batch = []
    # for each obs
    for i in tensor_dataset:
        # if i is equal to kabul
        if torch.all(torch.eq(i[2], torch.tensor([20716, 83, 1]))):
            # append 0
            new_batch.append(0)
            # elif equal to US
        elif torch.all(torch.eq(i[2], torch.tensor([837, 1, 0]))):
            # append 1
            new_batch.append(1)
    samples_weight = torch.tensor([weight[t] for t in new_batch])
    # prepare sampler
    sampler = torch.utils.data.WeightedRandomSampler(weights=samples_weight,
                                                     num_samples=len(samples_weight),
                                                     replacement=True)
    return sampler


# need to make them numeric now
train_sampler = create_sampler(target_count(train_dataset), train_dataset)


# check balancer
train_dataloader = DataLoader(train_dataset,
                              batch_size=24,
                              sampler=train_sampler,
                              shuffle=False)

# lets check class balance for each batch to see how the sampler is working
for i, (input_ids, input_masks, target_ids, target_masks) in enumerate(train_dataloader):
    count_kabul = 0
    count_us = 0
    if i in range(0, 10):
        for j in target_ids:
            if (torch.all(torch.eq(j, torch.tensor([20716, 83, 1])))):
                count_kabul += 1
            else:
                count_us += 1
        print("batch index {}, 0/1: {}/{}".format(i, count_kabul, count_us))
## batch index 0, 0/1: 12/12
## batch index 1, 0/1: 13/11
## batch index 2, 0/1: 9/15
## batch index 3, 0/1: 15/9
## batch index 4, 0/1: 17/7
## batch index 5, 0/1: 9/15
## batch index 6, 0/1: 4/20
## batch index 7, 0/1: 14/10
## batch index 8, 0/1: 10/14
## batch index 9, 0/1: 12/12

Before training, several prepatory objects are instantiated like the model, data loaders, and the optimizer.

1.7 Prepare for Training

## Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight']
## You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

1.8 Train and Validate

Finally we are ready to train. Two containers are created to store the results of each training and validation epoch

## 
## ======== Epoch 1 / 6 ========
## Training...
##   Batch    40  of    335.
##   Batch    80  of    335.
##   Batch   120  of    335.
##   Batch   160  of    335.
##   Batch   200  of    335.
##   Batch   240  of    335.
##   Batch   280  of    335.
##   Batch   320  of    335.
## 
## summary results
## epoch | trn loss | trn time 
##     1 | 1.86750 | 0:01:28
## [{'Train Loss': 1.867504431307316}]
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val ppl | val time
##     1 | 0.14962 | 1.161 | 0:00:05
## [{'Val Loss': 0.14961695591253893, 'Val PPL.': 1.1613892942138855}]
## ('./model_save/t5-classification/spiece.model', './model_save/t5-classification/special_tokens_map.json', './model_save/t5-classification/added_tokens.json')
## 
## ======== Epoch 2 / 6 ========
## Training...
##   Batch    40  of    335.
##   Batch    80  of    335.
##   Batch   120  of    335.
##   Batch   160  of    335.
##   Batch   200  of    335.
##   Batch   240  of    335.
##   Batch   280  of    335.
##   Batch   320  of    335.
## 
## summary results
## epoch | trn loss | trn time 
##     2 | 0.18869 | 0:01:40
## [{'Train Loss': 1.867504431307316}, {'Train Loss': 0.18868607740793655}]
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val ppl | val time
##     2 | 0.13648 | 1.146 | 0:00:05
## [{'Val Loss': 0.14961695591253893, 'Val PPL.': 1.1613892942138855}, {'Val Loss': 0.1364755311182567, 'Val PPL.': 1.146226830543859}]
## ('./model_save/t5-classification/spiece.model', './model_save/t5-classification/special_tokens_map.json', './model_save/t5-classification/added_tokens.json')
## 
## ======== Epoch 3 / 6 ========
## Training...
##   Batch    40  of    335.
##   Batch    80  of    335.
##   Batch   120  of    335.
##   Batch   160  of    335.
##   Batch   200  of    335.
##   Batch   240  of    335.
##   Batch   280  of    335.
##   Batch   320  of    335.
## 
## summary results
## epoch | trn loss | trn time 
##     3 | 0.15885 | 0:01:27
## [{'Train Loss': 1.867504431307316}, {'Train Loss': 0.18868607740793655}, {'Train Loss': 0.15885204529361938}]
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val ppl | val time
##     3 | 0.12382 | 1.132 | 0:00:05
## [{'Val Loss': 0.14961695591253893, 'Val PPL.': 1.1613892942138855}, {'Val Loss': 0.1364755311182567, 'Val PPL.': 1.146226830543859}, {'Val Loss': 0.12382005155086517, 'Val PPL.': 1.131812184825848}]
## ('./model_save/t5-classification/spiece.model', './model_save/t5-classification/special_tokens_map.json', './model_save/t5-classification/added_tokens.json')
## 
## ======== Epoch 4 / 6 ========
## Training...
##   Batch    40  of    335.
##   Batch    80  of    335.
##   Batch   120  of    335.
##   Batch   160  of    335.
##   Batch   200  of    335.
##   Batch   240  of    335.
##   Batch   280  of    335.
##   Batch   320  of    335.
## 
## summary results
## epoch | trn loss | trn time 
##     4 | 0.15560 | 0:01:33
## [{'Train Loss': 1.867504431307316}, {'Train Loss': 0.18868607740793655}, {'Train Loss': 0.15885204529361938}, {'Train Loss': 0.15560255393163483}]
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val ppl | val time
##     4 | 0.11525 | 1.122 | 0:00:05
## [{'Val Loss': 0.14961695591253893, 'Val PPL.': 1.1613892942138855}, {'Val Loss': 0.1364755311182567, 'Val PPL.': 1.146226830543859}, {'Val Loss': 0.12382005155086517, 'Val PPL.': 1.131812184825848}, {'Val Loss': 0.1152467570666756, 'Val PPL.': 1.1221503019282884}]
## ('./model_save/t5-classification/spiece.model', './model_save/t5-classification/special_tokens_map.json', './model_save/t5-classification/added_tokens.json')
## 
## ======== Epoch 5 / 6 ========
## Training...
##   Batch    40  of    335.
##   Batch    80  of    335.
##   Batch   120  of    335.
##   Batch   160  of    335.
##   Batch   200  of    335.
##   Batch   240  of    335.
##   Batch   280  of    335.
##   Batch   320  of    335.
## 
## summary results
## epoch | trn loss | trn time 
##     5 | 0.14363 | 0:01:34
## [{'Train Loss': 1.867504431307316}, {'Train Loss': 0.18868607740793655}, {'Train Loss': 0.15885204529361938}, {'Train Loss': 0.15560255393163483}, {'Train Loss': 0.14363485894986053}]
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val ppl | val time
##     5 | 0.10829 | 1.114 | 0:00:05
## [{'Val Loss': 0.14961695591253893, 'Val PPL.': 1.1613892942138855}, {'Val Loss': 0.1364755311182567, 'Val PPL.': 1.146226830543859}, {'Val Loss': 0.12382005155086517, 'Val PPL.': 1.131812184825848}, {'Val Loss': 0.1152467570666756, 'Val PPL.': 1.1221503019282884}, {'Val Loss': 0.10829435022813934, 'Val PPL.': 1.1143757138609098}]
## ('./model_save/t5-classification/spiece.model', './model_save/t5-classification/special_tokens_map.json', './model_save/t5-classification/added_tokens.json')
## 
## ======== Epoch 6 / 6 ========
## Training...
##   Batch    40  of    335.
##   Batch    80  of    335.
##   Batch   120  of    335.
##   Batch   160  of    335.
##   Batch   200  of    335.
##   Batch   240  of    335.
##   Batch   280  of    335.
##   Batch   320  of    335.
## 
## summary results
## epoch | trn loss | trn time 
##     6 | 0.13688 | 0:01:33
## [{'Train Loss': 1.867504431307316}, {'Train Loss': 0.18868607740793655}, {'Train Loss': 0.15885204529361938}, {'Train Loss': 0.15560255393163483}, {'Train Loss': 0.14363485894986053}, {'Train Loss': 0.13688372016064268}]
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val ppl | val time
##     6 | 0.11283 | 1.119 | 0:00:05
## [{'Val Loss': 0.14961695591253893, 'Val PPL.': 1.1613892942138855}, {'Val Loss': 0.1364755311182567, 'Val PPL.': 1.146226830543859}, {'Val Loss': 0.12382005155086517, 'Val PPL.': 1.131812184825848}, {'Val Loss': 0.1152467570666756, 'Val PPL.': 1.1221503019282884}, {'Val Loss': 0.10829435022813934, 'Val PPL.': 1.1143757138609098}, {'Val Loss': 0.11282750214671805, 'Val PPL.': 1.1194388155003379}]
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\torch\optim\lr_scheduler.py:123: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
##   "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

Now we can present our training results nicely in a data frame.

##        Train Loss  Val Loss  Val PPL.
## Epoch                                
## 1           1.868     0.150     1.161
## 2           0.189     0.136     1.146
## 3           0.159     0.124     1.132
## 4           0.156     0.115     1.122
## 5           0.144     0.108     1.114
## 6           0.137     0.113     1.119

1.9 Test and Generate Targets

While also checking our loss in our held-out test data, we also generate predicted targets to compare with our actual targets for final evaluation metrics.

## <All keys matched successfully>
## 
## Running Testing...
##   Batch    40  of     42.    Elapsed: 0:00:15.
## [{'Test Loss': 0.13512379383402212, 'Test PPL.': 1.144678479718354, 'Test Acc.': 0.8438208616780045, 'Test F1': 0.8333258198269518}]
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)
##    Test Loss  Test PPL.  Test Acc.  Test F1
## 0      0.135      1.145      0.844    0.833

And below is an example of T5’s text-to-text output generation:

##   predicted actual
## 0     Kabul  Kabul
## 1     Kabul  Kabul
## 2     Kabul     US
## 3     Kabul  Kabul
## 4     Kabul  Kabul

2 Sources