Classification: Hierarchical Attention Networks

Andrew Fogarty

7/19/2020

# load python
library(reticulate)
use_condaenv("my_ml")
# load packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler
from transformers import BertModel, BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import time, datetime, random, re
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import optuna
from optuna.pruners import SuccessiveHalvingPruner
from optuna.samplers import TPESampler
from collections import Counter
import nltk
from keras.preprocessing.text import Tokenizer, text_to_word_sequence
## Using TensorFlow backend.
SEED = 15
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
## <torch._C.Generator object at 0x000000002175E090>
torch.backends.cudnn.deterministic = True
torch.cuda.amp.autocast(enabled=True)


# tell pytorch to use cuda
## <torch.cuda.amp.autocast_mode.autocast object at 0x00000000360F0B48>
device = torch.device("cuda")

1 Introduction

Hierarchical Attention Networks (HAN), as its name suggests, have a hierarchical structure that reflects the hierarchical nature of documents. It has two levels of attention mechanisms that are applied at the word and sentence level which afford it the differential ability to capture more and less important content which evaluating documents. In other words it:

  1. Understands that not every word in a sentence nor every sentence in a document are equally important, and
  2. Words are context dependent and need to be taken into consideration

1.1 Architecture:

HAN mirrors the hierarchical structure of documents through its representation of sentences and documents. Sentence representations are built by first encoding the word of a sentence and then aggregating those words to yield a sentence vector. Next, document representations are built similarly by taking each sentence vector of each sentence in the document as an input. On a general level, the model is comprised of an encoder, which returns relevant context, and an attention mechanism, which yields importance weights of the contexts in a single vector.

1.1.1 Word Encoder

Recurrent Neural Networks have memory (states) that ensure they can predict the next word given previous words. A Gated Recurrent Unit (GRU) has hidden states which are information transferring memory cells. Two gates choose whether or not to keep (and update) or forget information. The GRU intends to extract relevant context, or annotations, of every word.

1.1.2 Sentence Encoder

Similar to the word encoder level, contexts of sentences are summarized with a bi-directional GRU that goes through the document forwards and backwards.

2 Modeling

In this guide, we recreate the model from the paper, Hierarchical Attention Networks for Document Classification in PyTorch, and walk through how to create and structure data appropriately to feed a HAN model.

2.1 Preparing the Data

# prepare and load data
def prepare_df(pkl_location):
    # read pkl as pandas
    df = pd.read_pickle(pkl_location)
    # just keep us/kabul labels
    df = df.loc[(df['target'] == 'US') | (df['target'] == 'Kabul')]
    # mask DV to recode
    us = df['target'] == 'US'
    kabul = df['target'] == 'Kabul'
    # apply mask
    df.loc[us, 'target'] = 1
    df.loc[kabul, 'target'] = 0
    # reset index
    df = df.reset_index(drop=True)
    return df


df = prepare_df('C:\\Users\\Andrew\\Desktop\\df.pkl')


# prepare data
def clean_df(df):
    # strip dash but keep a space
    df['body'] = df['body'].str.replace('-', ' ')
    # lower case the data
    df['body'] = df['body'].apply(lambda x: x.lower())
    # remove excess spaces near punctuation
    df['body'] = df['body'].apply(lambda x: re.sub(r'\s([?.!"](?:\s|$))', r'\1', x))
    # generate a word count
    df['word_count'] = df['body'].apply(lambda x: len(x.split()))
    # remove excess white spaces
    df['body'] = df['body'].apply(lambda x: " ".join(x.split()))
    return df


df = clean_df(df)


# lets remove rare words
def remove_rare_words(df):
    # get counts of each word -- necessary for vocab
    counts = Counter(" ".join(df['body'].values.tolist()).split(" "))
    # remove low counts -- keep those above 2
    counts = {key: value for key, value in counts.items() if value > 2}

    # remove rare words from corpus
    def remove_rare(x):
        return ' '.join(list(filter(lambda x: x in counts.keys(), x.split())))

    # apply funx
    df['body'] = df['body'].apply(remove_rare)
    return df


df = remove_rare_words(df)


# whats the length of the vocab?
counts = Counter(" ".join(df['body'].values.tolist()).split(" "))
vocab = sorted(counts, key=counts.get, reverse=True)
print(len(vocab))
## 23522

2.2 GLoVe Embeddings

Now we are ready to load GloVe embeddings. The code below:

  1. Loads GloVe
  2. It creates one nearly empty vector for the padding tokens.
  3. It creates a randomly initialized vector between \([-0.14, 0.14]\) (mimicking the variance of 200d GLoVe) to account for unknown tokens.
  4. It extends the GloVe dimensions from 200 to 201 to account for this variation.
## array([ 1.2289e-01,  5.8037e-01, -6.9635e-02, -5.0288e-01,  1.0503e-01,
##         3.9945e-01, -3.8635e-01, -8.4279e-02,  1.2219e-01,  8.0312e-02,
##         3.2337e-01,  4.7579e-01, -3.8375e-02, -7.0900e-03,  4.1524e-01,
##         3.2121e-01, -2.1185e-01,  3.6144e-01, -5.5623e-02, -3.0512e-02,
##         4.2854e-01,  2.8547e+00, -1.4623e-01, -1.7557e-01,  3.1197e-01,
##        -1.3118e-01,  3.3298e-02,  1.3093e-01,  8.9889e-02, -1.2417e-01,
##         2.3396e-03, -6.8954e-02, -1.0754e-01, -1.1551e-01, -3.1052e-01,
##        -1.2097e-01, -4.6691e-01, -8.3600e-02, -3.7664e-02, -7.1779e-02,
##        -1.1899e-01, -2.0381e-01, -1.2424e-01,  4.6339e-01, -1.9828e-01,
##        -8.0365e-03,  5.3718e-01,  3.1739e-02,  3.4331e-01,  7.9704e-03,
##         4.8744e-03,  3.0592e-02, -1.7615e-01,  8.2342e-01, -1.3793e-01,
##        -1.0075e-01, -1.2686e-01,  7.4735e-02, -8.8719e-02, -4.2719e-02,
##         7.6624e-02,  8.9263e-02,  6.4445e-02, -3.1958e-02,  1.5254e-01,
##        -1.0384e-01,  7.6604e-02,  3.4099e-01,  2.4331e-01, -1.0452e-01,
##         4.0714e-01, -1.8260e-01, -4.0667e-02,  5.0878e-01,  8.0760e-02,
##         2.2759e-01, -4.2162e-02, -1.8171e-01, -9.5025e-02,  3.0334e-02,
##         8.8202e-02, -3.9843e-06, -3.9877e-03,  1.5724e-01,  3.3167e-01,
##         8.4710e-02, -2.5919e-01, -4.1384e-01,  2.9920e-01, -5.4255e-01,
##         3.2129e-02,  1.0030e-01,  4.4202e-01,  4.4682e-02, -9.0681e-02,
##        -1.0481e-01, -1.1860e-01, -3.1972e-01, -2.0790e-01, -4.0203e-02,
##        -2.2988e-02,  2.2824e-01,  5.5238e-03,  1.2568e-01, -1.4640e-01,
##        -1.4904e-01, -1.1561e-01,  1.0517e+00, -1.9498e-01,  8.3958e-02,
##         4.4812e-02, -1.2965e-01, -9.3468e-02,  2.1237e-01, -8.8332e-02,
##        -1.8680e-01,  2.6521e-01,  1.3097e-01, -4.8102e-02, -2.2467e-01,
##         2.8412e-01,  3.4907e-01,  3.4833e-01,  1.7877e-02,  3.0504e-01,
##        -8.3453e-01,  4.8856e-02, -1.9330e-01,  2.0764e-01, -4.9701e-01,
##        -1.8747e-01, -7.6801e-02,  1.5558e-01, -4.6844e-01,  4.0944e-01,
##         2.1386e-01,  8.2392e-02, -2.6491e-01, -2.1224e-01, -1.3293e-01,
##         1.4738e-01, -1.4192e-01,  1.8994e-01, -1.5587e-01,  1.0738e+00,
##         4.0789e-01, -2.7452e-01, -1.8431e-01,  6.8679e-04, -8.7115e-02,
##         1.9672e-01,  4.0918e-01, -3.5462e-01, -6.3260e-02,  4.4920e-01,
##        -6.0568e-02, -4.1636e-02,  2.0531e-01,  1.7025e-02, -5.8448e-01,
##         7.5441e-02,  8.2116e-02, -4.6008e-01,  1.2393e-02, -2.5310e-02,
##         1.4177e-01, -9.2192e-02,  3.4505e-01, -5.2136e-01,  5.7304e-01,
##         1.1973e-02,  3.3196e-02,  2.9672e-01, -2.7899e-01,  1.9979e-01,
##         2.5666e-01,  8.2079e-02, -7.8436e-02,  9.3719e-02,  2.4202e-01,
##         1.3495e+00, -3.0434e-01, -3.0936e-01,  4.2047e-01, -7.9068e-02,
##        -1.4819e-01, -8.9404e-02,  6.6800e-02,  2.2405e-01,  2.7226e-01,
##        -3.5236e-02,  1.7688e-01, -5.3600e-02,  7.0031e-03, -3.3006e-02,
##        -8.0021e-02, -2.4451e-01, -3.9174e-02, -1.6236e-01, -9.6652e-02],
##       dtype=float32)
## 2
## torch.Size([400002, 201])

2.3 Prepare to Create 3D Data

## Average number of words in each sentence:  29
## Average number of sentences in each document:  10

2.4 Create 3D Data Set

## (10045, 15, 35)
## array([    19,   9945,    249,      3, 400001,    547,    118, 400001,
##           959,  16119,      5, 400001,    104,    112,     33,     51,
##          1315,     53,    303,      4,   7255, 400001,    284,    226,
##        400001,   3485,      3,    400,     21, 400001,    104,    112,
##             5,     47,    205])

2.5 Create Torch Tensor Data Sets

With the corpus tokenized, we now proceed to prepare our data for analysis in PyTorch. The code below creates a TensorDataset comprised of our features and our labels. It then proceeds to spit the data sets into train, validation, and test sets.

2.6 Data Loaders and Helper Functions

Since my corpus is imbalanced, we produce weighted samplers to help balance the distribution of data as it is fed via the data loaders.

# helper function to count target distribution inside tensor data sets
def target_count(tensor_dataset):
    # set empty count containers
    count0 = 0
    count1 = 0
    # set total container to turn into torch tensor
    total = []
    # for every item in the tensor data set
    for i in tensor_dataset:
        # if the target is equal to 0
        if i[1].item() == 0:
            count0 += 1
        # if the target is equal to 1
        elif i[1].item() == 1:
            count1 += 1
    total.append(count0)
    total.append(count1)
    return torch.tensor(total)


# prepare weighted sampling for imbalanced classification
def create_sampler(target_tensor, tensor_dataset):
    # generate class distributions [x, y]
    class_sample_count = target_count(tensor_dataset)
    # weight
    weight = 1. / class_sample_count.float()
    # produce weights for each observation in the data set
    samples_weight = torch.tensor([weight[t[1]] for t in tensor_dataset])
    # prepare sampler
    sampler = torch.utils.data.WeightedRandomSampler(weights=samples_weight,
                                                     num_samples=len(samples_weight),
                                                     replacement=True)
    return sampler


# create samplers for each data set
train_sampler = create_sampler(target_count(train_dataset), train_dataset)


# prepare data loaders
def data_loading(batch_size, data_set, **kwargs):
    # instantiate sampler and don't use last batch if uneven
    temp_dataloader = DataLoader(data_set,
                                 batch_size=batch_size,
                                 drop_last=True)
    return temp_dataloader


# create DataLoaders with samplers
train_dataloader = data_loading(batch_size=80,
                                data_set=train_dataset,
                                sampler=train_sampler,
                                shuffle=False)

valid_dataloader = data_loading(batch_size=80,
                                data_set=val_dataset,
                                shuffle=True)

test_dataloader = data_loading(batch_size=80,
                               data_set=test_dataset,
                               shuffle=True)
                               
# time function
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

2.7 Create Hierarchical Attention Network

2.7.2 HAN: Create Word Level RNN

class WordLevelRNN(nn.Module):

    def __init__(self, config):
        super().__init__()
        pre_embed = config.pre_embed  # embeddings
        word_num_hidden = config.word_num_hidden
        words_num = config.vocab_size
        words_dim = config.words_dim
        self.mode = config.mode
        if self.mode == 'rand':
            rand_embed_init = torch.Tensor(vocab_size, words_dim).uniform(-0.25, 0.25)
            self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
        elif self.mode == 'static':
            self.static_embed = nn.Embedding.from_pretrained(pre_embed, freeze=True)
        elif self.mode == 'non-static':
            self.non_static_embed = nn.Embedding.from_pretrained(pre_embed, freeze=False)
        else:
            print("Unsupported order")
            raise Exception
        self.word_context_weights = nn.Parameter(torch.rand(2 * word_num_hidden, 1))
        self.GRU = nn.GRU(words_dim, word_num_hidden, bidirectional=True)
        self.linear = nn.Linear(2 * word_num_hidden, 2 * word_num_hidden, bias=True)
        self.word_context_weights.data.uniform_(-0.25, 0.25)
        self.soft_word = nn.Softmax(dim=1)

    def forward(self, x):
        # x expected to be of dimensions--> (num_words, batch_size)
        if self.mode == 'rand':
            x = self.embed(x)
        elif self.mode == 'static':
            x = self.static_embed(x)
        elif self.mode == 'non-static':
            x = self.non_static_embed(x)
        else :
            print("Unsupported mode")
            raise Exception
        h, _ = self.GRU(x)
        x = torch.tanh(self.linear(h))
        x = torch.matmul(x, self.word_context_weights)
        x = x.squeeze(dim=2)
        x = self.soft_word(x.transpose(1, 0))
        x = torch.mul(h.permute(2, 0, 1), x.transpose(1, 0))
        x = torch.sum(x, dim=1).transpose(1, 0).unsqueeze(0)
        return x

2.8 Create Training Functions

Now, we prepare functions to train, validate, and test our data.

def train(model, dataloader, optimizer, criterion):

    # capture time
    total_t0 = time.time()

    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
    print('Training...')

    # reset total loss for epoch
    train_total_loss = 0
    total_train_f1 = 0

    # put model into traning mode
    model.train()

    # for each batch of training data...
    for step, batch in enumerate(dataloader):

        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        # Unpack this training batch from our dataloader:
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU
        #
        # `batch` contains two pytorch tensors:
        #   [0]: input ids
        #   [1]: labels
        b_input_ids = batch[0].cuda().long()
        b_labels = batch[1].cuda().long()

        # clear previously calculated gradients
        optimizer.zero_grad()

        with autocast():
            # forward propagation (evaluate model on training batch)
            logits = model(b_input_ids)

        # calculate cross entropy loss
        loss = criterion(logits.view(-1, 2), b_labels.view(-1))

        # sum the training loss over all batches for average loss at end
        # loss is a tensor containing a single value
        train_total_loss += loss.item()

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # update the learning rate
        scheduler.step()

        # get preds
        _, predicted = torch.max(logits, 1)

        # move logits and labels to CPU
        predicted = predicted.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate f1
        total_train_f1 += f1_score(predicted, y_true,
                                   average='weighted',
                                   labels=np.unique(predicted))

    # calculate the average loss over all of the batches
    avg_train_loss = train_total_loss / len(dataloader)

    # calculate the average f1 over all of the batches
    avg_train_f1 = total_train_f1 / len(dataloader)

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'Train Loss': avg_train_loss,
            'Train F1': avg_train_f1
        }
    )

    # training time end
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | trn loss | trn f1 | trn time ")
    print(f"{epoch+1:5d} | {avg_train_loss:.5f} | {avg_train_f1:.5f} | {training_time:}")

    torch.cuda.empty_cache()

    return None


def validating(model, dataloader, criterion):

    # capture validation time
    total_t0 = time.time()

    # After the completion of each training epoch, measure our performance on
    # our validation set.
    print("")
    print("Running Validation...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_valid_accuracy = 0
    total_valid_loss = 0
    total_valid_f1 = 0
    total_valid_recall = 0
    total_valid_precision = 0

    # evaluate data for one epoch
    for batch in dataloader:

        # unpack batch from dataloader
        b_input_ids = batch[0].cuda().long()
        b_labels = batch[1].cuda().long()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            # forward propagation (evaluate model on training batch)
            logits = model(b_input_ids)

            # calculate BCEWithLogitsLoss
            loss = criterion(logits.view(-1, 2), b_labels.view(-1))

            # calculate preds
            _, predicted = torch.max(logits, 1)

        # accumulate validation loss
        total_valid_loss += loss.item()

        # move logits and labels to CPU
        predicted = predicted.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate f1
        total_valid_f1 += f1_score(predicted, y_true,
                                   average='weighted',
                                   labels=np.unique(predicted))

        # calculate accuracy
        total_valid_accuracy += accuracy_score(predicted, y_true)

        # calculate precision
        total_valid_precision += precision_score(predicted, y_true,
                                                 average='weighted',
                                                 labels=np.unique(predicted))

        # calculate recall
        total_valid_recall += recall_score(predicted, y_true,
                                                 average='weighted',
                                                 labels=np.unique(predicted))

    # report final accuracy of validation run
    avg_accuracy = total_valid_accuracy / len(dataloader)

    # report final f1 of validation run
    global avg_val_f1
    avg_val_f1 = total_valid_f1 / len(dataloader)

    # report final f1 of validation run
    avg_precision = total_valid_precision / len(dataloader)

    # report final f1 of validation run
    avg_recall = total_valid_recall / len(dataloader)

    # calculate the average loss over all of the batches.
    avg_val_loss = total_valid_loss / len(dataloader)

    # Record all statistics from this epoch.
    valid_stats.append(
        {
            'Val Loss': avg_val_loss,
            'Val Accur.': avg_accuracy,
            'Val precision': avg_precision,
            'Val recall': avg_recall,
            'Val F1': avg_val_f1
        }
    )

    # capture end validation time
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | val loss | val f1 | val time")
    print(f"{epoch+1:5d} | {avg_val_loss:.5f} | {avg_val_f1:.5f} | {training_time:}")

    return None


def testing(model, dataloader, criterion):

    print("")
    print("Running Testing...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_test_accuracy = 0
    total_test_loss = 0
    total_test_f1 = 0
    total_test_recall = 0
    total_test_precision = 0

    # evaluate data for one epoch
    for step, batch in enumerate(dataloader):
        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        # unpack batch from dataloader
        b_input_ids = batch[0].cuda().long()
        b_labels = batch[1].cuda().long()

        # tell pytorch not to bother calculating gradients
        # only necessary for training
        with torch.no_grad():

            # forward propagation (evaluate model on training batch)
            logits = model(b_input_ids)

            # calculate cross entropy loss
            loss = criterion(logits.view(-1, 2), b_labels.view(-1))

            # calculate preds
            _, predicted = torch.max(logits, 1)

            # accumulate validation loss
            total_test_loss += loss.item()

        # move logits and labels to CPU
        predicted = predicted.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate f1
        total_test_f1 += f1_score(predicted, y_true,
                                   average='weighted',
                                   labels=np.unique(predicted))

        # calculate accuracy
        total_test_accuracy += accuracy_score(predicted, y_true)

        # calculate precision
        total_test_precision += precision_score(predicted, y_true,
                                                 average='weighted',
                                                 labels=np.unique(predicted))

        # calculate recall
        total_test_recall += recall_score(predicted, y_true,
                                                 average='weighted',
                                                 labels=np.unique(predicted))

    # report final accuracy of validation run
    avg_accuracy = total_test_accuracy / len(dataloader)

    # report final f1 of validation run
    avg_test_f1 = total_test_f1 / len(dataloader)

    # report final f1 of validation run
    avg_precision = total_test_precision / len(dataloader)

    # report final f1 of validation run
    avg_recall = total_test_recall / len(dataloader)

    # calculate the average loss over all of the batches.
    avg_test_loss = total_test_loss / len(dataloader)

    # Record all statistics from this epoch.
    test_stats.append(
        {
            'Test Loss': avg_test_loss,
            'Test Accur.': avg_accuracy,
            'Test precision': avg_precision,
            'Test recall': avg_recall,
            'Test F1': avg_test_f1
        }
    )
    return None

2.9 Create Models and Training Necessities

In order to use our HAN, we need to specify a config class that sets a number of hyperparameters that the class is expecting. The rest of the objects are our usual optimizer, scheduler, epochs, and loss function.

2.10 Train

Finally we are ready to train. Two containers are created to store the results of each training and validation epoch

## 
## ======== Epoch 1 / 5 ========
## Training...
##   Batch    40  of    100.
##   Batch    80  of    100.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     1 | 0.35020 | 0.84468 | 0:00:07
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     1 | 0.25396 | 0.89940 | 0:00:00
## 
## ======== Epoch 2 / 5 ========
## Training...
##   Batch    40  of    100.
##   Batch    80  of    100.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     2 | 0.28292 | 0.87081 | 0:00:07
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     2 | 0.24733 | 0.88715 | 0:00:00
## 
## ======== Epoch 3 / 5 ========
## Training...
##   Batch    40  of    100.
##   Batch    80  of    100.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     3 | 0.26662 | 0.88033 | 0:00:07
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     3 | 0.24396 | 0.89047 | 0:00:00
## 
## ======== Epoch 4 / 5 ========
## Training...
##   Batch    40  of    100.
##   Batch    80  of    100.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     4 | 0.25348 | 0.88681 | 0:00:07
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     4 | 0.24355 | 0.89599 | 0:00:00
## 
## ======== Epoch 5 / 5 ========
## Training...
##   Batch    40  of    100.
##   Batch    80  of    100.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     5 | 0.24332 | 0.89265 | 0:00:07
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     5 | 0.24402 | 0.89838 | 0:00:00
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\torch\optim\lr_scheduler.py:123: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
##   "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

2.11 Show Results

After training, we organize the results nicely in pandas.

##        Train Loss  Train F1  Val Loss  ...  Val precision  Val recall  Val F1
## Epoch                                  ...                                   
## 1           0.350     0.845     0.254  ...          0.905       0.898   0.899
## 2           0.283     0.871     0.247  ...          0.891       0.890   0.887
## 3           0.267     0.880     0.244  ...          0.893       0.893   0.890
## 4           0.253     0.887     0.244  ...          0.898       0.898   0.896
## 5           0.243     0.893     0.244  ...          0.901       0.900   0.898
## 
## [5 rows x 7 columns]

2.12 Plot Results

Then we plot our results like so:

Training Results

Training Results

2.13 Test the Model

And lastly run our final test:

## <All keys matched successfully>
## 
## Running Testing...
##    Test Loss  Test Accur.  Test precision  Test recall  Test F1
## 0       0.31        0.858           0.859        0.858    0.856

3 HAN: Hyperband and ASHA Hyperparameter Search with Optuna

# optuna -- tune hyperparameters
# create gradient scaler for mixed precision
scaler = GradScaler()

training_stats = []
valid_stats = []
epochs = 5
def objective(trial):

    # alter hyperparameters
    sent_num_hidden = trial.suggest_int('sentence_num_hidden', low=25, high=175, step=5)
    word_num_hidden = trial.suggest_int('word_num_hidden', low=25, high=175, step=5)
    learning_rate = trial.suggest_loguniform('lr', 1e-5, 1e-3)
    weight_decay = trial.suggest_float('weight_decay', low=0.5, high=1, step=0.05)
    configHAN = HANconfig()
    configHAN.sent_num_hidden = sent_num_hidden
    configHAN.word_num_hidden = word_num_hidden

    # data loaders
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=80,
                                  sampler=train_sampler,
                                  shuffle=False)

    valid_dataloader = DataLoader(val_dataset,
                                  batch_size=80,
                                  shuffle=True)

    # instantiate model
    model = HAN(configHAN).cuda()

    # set optimizer
    optimizer = AdamW(model.parameters(),
                      lr=learning_rate,
                      weight_decay=weight_decay
                    )

    criterion = nn.BCEWithLogitsLoss()

    # set LR scheduler
    total_steps = len(train_dataloader) * epochs
    global scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=total_steps)

    global epoch
    for epoch in range(epochs):
        # set containers
        train_total_loss = 0
        total_train_f1 = 0

        # put model into traning mode
        model.train()

        # for each batch of training data...
        for step, batch in enumerate(train_dataloader):
            b_input_ids = batch[0].cuda().long()
            b_labels = batch[1].cuda().type(torch.cuda.FloatTensor)

            optimizer.zero_grad()

            with autocast():
                logits = model(b_input_ids)
                loss = criterion(logits, b_labels)

            train_total_loss += loss.item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        # validation
        model.eval()

        total_valid_loss = 0
        total_valid_f1 = 0

        # evaluate data for one epoch
        for batch in valid_dataloader:

            b_input_ids = batch[0].cuda().long()
            b_labels = batch[1].cuda().type(torch.cuda.FloatTensor)

            with torch.no_grad():
                logits = model(b_input_ids)
                loss = criterion(logits, b_labels)

            total_valid_loss += loss.item()

        # generate predictions
        rounded_preds = torch.round(torch.sigmoid(logits))

        # move logits and labels to CPU
        rounded_preds = rounded_preds.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate f1
        total_valid_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

        avg_val_f1 = total_valid_f1 / len(valid_dataloader)

        avg_val_loss = total_valid_loss / len(valid_dataloader)

    trial.report(avg_val_loss, epoch)

    # Handle pruning based on the intermediate value.
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()

    return avg_val_loss



study = optuna.create_study(direction="minimize",
                            pruner=optuna.pruners.HyperbandPruner(min_resource=1,
                                                                  max_resource=7,
                                                                  reduction_factor=3,
                                                                  ))
study.optimize(objective, n_trials=35)


pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

4 Sources