Text Classification: (Distil)BERT

Andrew Fogarty

7/15/2020

# load python
library(reticulate)
use_condaenv("my_ml")
# load packages
import torch
import torch.nn as nn
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from transformers import get_linear_schedule_with_warmup, AdamW
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
import time, datetime, random, re
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import optuna
from optuna.pruners import SuccessiveHalvingPruner
from optuna.samplers import TPESampler
from torch.cuda.amp import autocast, GradScaler

SEED = 15
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
## <torch._C.Generator object at 0x000000001F4FD070>
torch.backends.cudnn.deterministic = True
torch.cuda.amp.autocast(enabled=True)

# tell pytorch to use cuda
## <torch.cuda.amp.autocast_mode.autocast object at 0x00000000348A9F88>
device = torch.device("cuda")

1 Introduction

Text or sequence classification aims to label a sentence or document based on its content. In this post, we use Transformers to classify a novel data set that I created based on insurgent propaganda messages.

2 Encoder Transformer Architecture: BERT

Bidirectional Encoder Representations from Transformers (BERT) is a encoder only transformer architecture that takes input of up to 512 word piece tokens, is comprised of 12 encoder layers (24 for BERT large) and 12 attention heads (16 for BERT large), and outputs a 768 dimensional vector (or 1024 for BERT large). BERT is pre-trained, which means that it has been trained on tasks (e.g., masked language model and next sentence prediction) and corpora (e.g., books corpus and Wikipedia). BERT offers cased, uncased, base, large, and multi-lingual models that we will sample from using the transformers (huggingface) library in PyTorch.

Among BERT’s most novel shifts, aside from using transformers, was to move away from the traditional form of predicting the next word task to taking an entire sentence, corruipting parts of it, and then predicting the corrupted parts. In turn, this helps the model work with entire sentences at a time.

To get value out of BERT, we need to adapt it to our task at hand which amounts to placing the minimal amount of additional structure on top of the model. So for our classification tasks, a CLS token begins each sentence which is used to predict whether one sentence follows another. So with a single sentence input, we take the output vector from the CLS position, create a feed forward network with affine and sigmoid to classify, and then fine tune to learn specific classification labels.

The guide proceeds by (1) preparing the data for text classification with DistilBERT – a distilled version of BERT base, and (2) analyzing the data in PyTorch.

2.1 Load Our Data Set

While little data pre-processing is needed when using BERT, owing to its word piece tokenization, some minor cleanup is applied for clarity.

2.2 Tokenize the Data

Next, we instantiate the DistilBERT tokenizer from transformers and tokenize our entire corpus.

## 1103
## ['space', 'La', 'directed', 'smile', 'episode', 'hours', 'whole', '##de', '##less', 'Why']
## 28996

2.3 Prepare Train, Validation, and Testing Functions

def train(model, dataloader, optimizer):

    # capture time
    total_t0 = time.time()

    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
    print('Training...')

    # reset total loss for epoch
    train_total_loss = 0
    total_train_f1 = 0

    # put model into traning mode
    model.train()

    # for each batch of training data...
    for step, batch in enumerate(dataloader):

        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        # Unpack this training batch from our dataloader:
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU using
        # the `to` method.
        #
        # `batch` contains three pytorch tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: labels
        b_input_ids = batch[0].cuda()
        b_input_mask = batch[1].cuda()
        b_labels = batch[2].cuda().long()

        # clear previously calculated gradients
        optimizer.zero_grad()

        # runs the forward pass with autocasting.
        with autocast():
            # forward propagation (evaluate model on training batch)
            loss, logits = model(input_ids=b_input_ids,
                                 attention_mask=b_input_mask,
                                 labels=b_labels)
            # sum the training loss over all batches for average loss at end
            # loss is a tensor containing a single value
            train_total_loss += loss.item()

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # update the learning rate
        scheduler.step()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_train_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

    # calculate the average loss over all of the batches
    avg_train_loss = train_total_loss / len(dataloader)

    # calculate the average f1 over all of the batches
    avg_train_f1 = total_train_f1 / len(dataloader)

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'Train Loss': avg_train_loss,
            'Train F1': avg_train_f1
        }
    )

    # training time end
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | trn loss | trn f1 | trn time ")
    print(f"{epoch+1:5d} | {avg_train_loss:.5f} | {avg_train_f1:.5f} | {training_time:}")

    torch.cuda.empty_cache()

    return None


def validating(model, dataloader):

    # capture validation time
    total_t0 = time.time()

    # After the completion of each training epoch, measure our performance on
    # our validation set.
    print("")
    print("Running Validation...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_valid_accuracy = 0
    total_valid_loss = 0
    total_valid_f1 = 0
    total_valid_recall = 0
    total_valid_precision = 0

    # evaluate data for one epoch
    for batch in dataloader:

        # Unpack this training batch from our dataloader:
        # `batch` contains three pytorch tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: labels
        b_input_ids = batch[0].cuda()
        b_input_mask = batch[1].cuda()
        b_labels = batch[2].cuda().long()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            # forward propagation (evaluate model on training batch)
            loss, logits = model(input_ids=b_input_ids,
                                 attention_mask=b_input_mask,
                                 labels=b_labels)

        # accumulate validation loss
        total_valid_loss += loss.item()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_valid_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

        # calculate accuracy
        total_valid_accuracy += accuracy_score(rounded_preds, y_true)

        # calculate precision
        total_valid_precision += precision_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

        # calculate recall
        total_valid_recall += recall_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

    # report final accuracy of validation run
    avg_accuracy = total_valid_accuracy / len(dataloader)

    # report final f1 of validation run
    global avg_val_f1
    avg_val_f1 = total_valid_f1 / len(dataloader)

    # report final f1 of validation run
    avg_precision = total_valid_precision / len(dataloader)

    # report final f1 of validation run
    avg_recall = total_valid_recall / len(dataloader)

    # calculate the average loss over all of the batches.
    global avg_val_loss
    avg_val_loss = total_valid_loss / len(dataloader)

    # Record all statistics from this epoch.
    valid_stats.append(
        {
            'Val Loss': avg_val_loss,
            'Val Accur.': avg_accuracy,
            'Val precision': avg_precision,
            'Val recall': avg_recall,
            'Val F1': avg_val_f1
        }
    )

    # capture end validation time
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | val loss | val f1 | val time")
    print(f"{epoch+1:5d} | {avg_val_loss:.5f} | {avg_val_f1:.5f} | {training_time:}")

    return None


def testing(model, dataloader):

    print("")
    print("Running Testing...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_test_accuracy = 0
    total_test_loss = 0
    total_test_f1 = 0
    total_test_recall = 0
    total_test_precision = 0

    # evaluate data for one epoch
    for batch in dataloader:

        # Unpack this training batch from our dataloader:
        # `batch` contains three pytorch tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: labels
        b_input_ids = batch[0].cuda()
        b_input_mask = batch[1].cuda()
        b_labels = batch[2].cuda().long()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            # forward propagation (evaluate model on training batch)
            loss, logits = model(input_ids=b_input_ids,
                                 attention_mask=b_input_mask,
                                 labels=b_labels)

        # accumulate validation loss
        total_test_loss += loss.item()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_test_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

        # calculate accuracy
        total_test_accuracy += accuracy_score(rounded_preds, y_true)

        # calculate precision
        total_test_precision += precision_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

        # calculate recall
        total_test_recall += recall_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

    # report final accuracy of validation run
    avg_accuracy = total_test_accuracy / len(dataloader)

    # report final f1 of validation run
    avg_test_f1 = total_test_f1 / len(dataloader)

    # report final f1 of validation run
    avg_precision = total_test_precision / len(dataloader)

    # report final f1 of validation run
    avg_recall = total_test_recall / len(dataloader)

    # calculate the average loss over all of the batches.
    avg_test_loss = total_test_loss / len(dataloader)

    # Record all statistics from this epoch.
    test_stats.append(
        {
            'Test Loss': avg_test_loss,
            'Test Accur.': avg_accuracy,
            'Test precision': avg_precision,
            'Test recall': avg_recall,
            'Test F1': avg_test_f1
        }
    )
    return None

2.4 Prepare Tensor Data Sets

With the corpus work out of the way, we now proceed to prepare our data for analysis in PyTorch. The code below creates a TensorDataset comprised of our features, attention masks, and our labels. It then proceeds to spit the data sets into train, validation, and test sets.

Since my corpus is imbalanced, I produce weighted samplers to help balance the distribution of data as it is fed outside of my data loaders.

As you might have guessed, preparing data loaders for each of our train, dev, and test data sets is our next task.

Below we instantiate some helper functions for time keeping.

Next, we load our DistilBERT model and tweak its hyperparameters.

## Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
## - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
## - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
## Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']
## You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Now we are almost ready to train. A few other preparatory objects are created like the loss criteria, epochs, the optimizer, and our optimizer scheduler.

## DistilBertForSequenceClassification(
##   (distilbert): DistilBertModel(
##     (embeddings): Embeddings(
##       (word_embeddings): Embedding(28996, 768, padding_idx=0)
##       (position_embeddings): Embedding(512, 768)
##       (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##       (dropout): Dropout(p=0.1, inplace=False)
##     )
##     (transformer): Transformer(
##       (layer): ModuleList(
##         (0): TransformerBlock(
##           (attention): MultiHeadSelfAttention(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (q_lin): Linear(in_features=768, out_features=768, bias=True)
##             (k_lin): Linear(in_features=768, out_features=768, bias=True)
##             (v_lin): Linear(in_features=768, out_features=768, bias=True)
##             (out_lin): Linear(in_features=768, out_features=768, bias=True)
##           )
##           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##           (ffn): FFN(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (lin1): Linear(in_features=768, out_features=3072, bias=True)
##             (lin2): Linear(in_features=3072, out_features=768, bias=True)
##           )
##           (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##         )
##         (1): TransformerBlock(
##           (attention): MultiHeadSelfAttention(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (q_lin): Linear(in_features=768, out_features=768, bias=True)
##             (k_lin): Linear(in_features=768, out_features=768, bias=True)
##             (v_lin): Linear(in_features=768, out_features=768, bias=True)
##             (out_lin): Linear(in_features=768, out_features=768, bias=True)
##           )
##           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##           (ffn): FFN(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (lin1): Linear(in_features=768, out_features=3072, bias=True)
##             (lin2): Linear(in_features=3072, out_features=768, bias=True)
##           )
##           (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##         )
##         (2): TransformerBlock(
##           (attention): MultiHeadSelfAttention(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (q_lin): Linear(in_features=768, out_features=768, bias=True)
##             (k_lin): Linear(in_features=768, out_features=768, bias=True)
##             (v_lin): Linear(in_features=768, out_features=768, bias=True)
##             (out_lin): Linear(in_features=768, out_features=768, bias=True)
##           )
##           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##           (ffn): FFN(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (lin1): Linear(in_features=768, out_features=3072, bias=True)
##             (lin2): Linear(in_features=3072, out_features=768, bias=True)
##           )
##           (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##         )
##         (3): TransformerBlock(
##           (attention): MultiHeadSelfAttention(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (q_lin): Linear(in_features=768, out_features=768, bias=True)
##             (k_lin): Linear(in_features=768, out_features=768, bias=True)
##             (v_lin): Linear(in_features=768, out_features=768, bias=True)
##             (out_lin): Linear(in_features=768, out_features=768, bias=True)
##           )
##           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##           (ffn): FFN(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (lin1): Linear(in_features=768, out_features=3072, bias=True)
##             (lin2): Linear(in_features=3072, out_features=768, bias=True)
##           )
##           (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##         )
##         (4): TransformerBlock(
##           (attention): MultiHeadSelfAttention(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (q_lin): Linear(in_features=768, out_features=768, bias=True)
##             (k_lin): Linear(in_features=768, out_features=768, bias=True)
##             (v_lin): Linear(in_features=768, out_features=768, bias=True)
##             (out_lin): Linear(in_features=768, out_features=768, bias=True)
##           )
##           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##           (ffn): FFN(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (lin1): Linear(in_features=768, out_features=3072, bias=True)
##             (lin2): Linear(in_features=3072, out_features=768, bias=True)
##           )
##           (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##         )
##         (5): TransformerBlock(
##           (attention): MultiHeadSelfAttention(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (q_lin): Linear(in_features=768, out_features=768, bias=True)
##             (k_lin): Linear(in_features=768, out_features=768, bias=True)
##             (v_lin): Linear(in_features=768, out_features=768, bias=True)
##             (out_lin): Linear(in_features=768, out_features=768, bias=True)
##           )
##           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##           (ffn): FFN(
##             (dropout): Dropout(p=0.1, inplace=False)
##             (lin1): Linear(in_features=768, out_features=3072, bias=True)
##             (lin2): Linear(in_features=3072, out_features=768, bias=True)
##           )
##           (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
##         )
##       )
##     )
##   )
##   (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
##   (classifier): Linear(in_features=768, out_features=2, bias=True)
##   (dropout): Dropout(p=0.2, inplace=False)
## )

Finally we are ready to train. Two containers are created to store the results of each training and validation epoch

## 
## ======== Epoch 1 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     1 | 0.40460 | 0.84251 | 0:01:18
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     1 | 0.34351 | 0.82082 | 0:00:05
## ('./model_save/vocab.txt', './model_save/special_tokens_map.json', './model_save/added_tokens.json')
## 
## ======== Epoch 2 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     2 | 0.30051 | 0.88085 | 0:01:19
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     2 | 0.30222 | 0.83206 | 0:00:05
## ('./model_save/vocab.txt', './model_save/special_tokens_map.json', './model_save/added_tokens.json')
## 
## ======== Epoch 3 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     3 | 0.27190 | 0.88991 | 0:01:18
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     3 | 0.31899 | 0.83935 | 0:00:05
## 
## ======== Epoch 4 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     4 | 0.25377 | 0.90093 | 0:01:19
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     4 | 0.29434 | 0.85140 | 0:00:05
## ('./model_save/vocab.txt', './model_save/special_tokens_map.json', './model_save/added_tokens.json')
## 
## ======== Epoch 5 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     5 | 0.24774 | 0.90097 | 0:01:20
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     5 | 0.29090 | 0.84772 | 0:00:06
## ('./model_save/vocab.txt', './model_save/special_tokens_map.json', './model_save/added_tokens.json')
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)

After training, we organize the results nicely in pandas.

##        Train Loss  Train F1  Val Loss  ...  Val precision  Val recall  Val F1
## Epoch                                  ...                                   
## 1           0.405     0.843     0.344  ...          0.852       0.840   0.821
## 2           0.301     0.881     0.302  ...          0.867       0.851   0.832
## 3           0.272     0.890     0.319  ...          0.862       0.854   0.839
## 4           0.254     0.901     0.294  ...          0.876       0.864   0.851
## 5           0.248     0.901     0.291  ...          0.864       0.863   0.848
## 
## [5 rows x 7 columns]
Training Results

Training Results

And lastly run our final test:

## <All keys matched successfully>
## 
## Running Testing...
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)
##    Test Loss  Test Accur.  Test precision  Test recall  Test F1
## 0      0.315        0.865            0.88        0.865    0.857

3 DistilBERT: Hyperband and ASHA Hyperparameter Search with Optuna

The code below shows how we can use state-of-the-art pruning and search algorithms to improve our model’s performance through hyperparameter selection.

training_stats = []
valid_stats = []
epochs = 5

# create gradient scaler for mixed precision
scaler = GradScaler()
def objective(trial):

    model = DistilBertForSequenceClassification.from_pretrained(
        "distilbert-base-cased",
        num_labels=2)

    # instantiate model - attach to GPU
    model.cuda()

    # alter hyperparams
    dropout = trial.suggest_float('dropout', low=0.1, high=0.4, step=0.05)
    learning_rate = trial.suggest_loguniform('lr', 1e-7, 1e-4)
    weight_decay = trial.suggest_float('weight_decay', low=0.5, high=1, step=0.05)
    model.config.__dict__['dropout'] = dropout

    # gen data loaders
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=24,
                                  sampler=train_sampler,
                                  shuffle=False)

    valid_dataloader = DataLoader(val_dataset,
                                  batch_size=24,
                                  shuffle=True)

    # optimizer
    optimizer = AdamW(model.parameters(),
                      lr=learning_rate,
                      weight_decay=weight_decay)

    # set LR scheduler
    total_steps = len(train_dataloader) * epochs
    global scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=total_steps)

    global epoch
    for epoch in range(epochs):
        # containers for metrics
        train_total_loss = 0
        total_train_f1 = 0

        # put model into traning mode
        model.train()

        # for each batch of training data...
        for step, batch in enumerate(train_dataloader):
            b_input_ids = batch[0].cuda()
            b_input_mask = batch[1].cuda()
            b_labels = batch[2].cuda().long()

            optimizer.zero_grad()

            with autocast():
                loss, logits = model(input_ids=b_input_ids,
                                     attention_mask=b_input_mask,
                                     labels=b_labels)

                train_total_loss += loss.item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        # validation
        model.eval()

        total_valid_loss = 0
        total_valid_f1 = 0

        # evaluate data for one epoch
        for batch in valid_dataloader:

            b_input_ids = batch[0].cuda()
            b_input_mask = batch[1].cuda()
            b_labels = batch[2].cuda().long()

            with torch.no_grad():
                loss, logits = model(input_ids=b_input_ids,
                                     attention_mask=b_input_mask,
                                     labels=b_labels)

            total_valid_loss += loss.item()

            logits = logits.detach().cpu().numpy()
            y_true = b_labels.detach().cpu().numpy()

            rounded_preds = np.argmax(logits, axis=1).flatten()

            total_valid_f1 += f1_score(rounded_preds, y_true,
                                       average='weighted',
                                       labels=np.unique(rounded_preds))

        global avg_val_f1
        avg_val_f1 = total_valid_f1 / len(valid_dataloader)

        # calculate the average loss over all of the batches.
        global avg_val_loss
        avg_val_loss = total_valid_loss / len(valid_dataloader)

    trial.report(avg_val_loss, epoch)

    # Handle pruning based on the intermediate value.
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()

    return avg_val_loss


study = optuna.create_study(direction="minimize",
                            pruner=optuna.pruners.HyperbandPruner(min_resource=1,
                                                                  max_resource=5,
                                                                  reduction_factor=3,
                                                                  ))
study.optimize(objective, n_trials=50)


pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

4 Sources