Entity Embeddings

Andrew Fogarty

9/23/2020

# load python
library(reticulate)
use_condaenv("my_ml")
# Transformers
import torch
import torch.nn as nn
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from transformers import get_linear_schedule_with_warmup, AdamW
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
import time
import datetime
import random
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import re, os
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split
from collections import Counter
from transformers import BertModel, BertTokenizer, BertForSequenceClassification, DistilBertModel
import string
from torch.utils.data import Dataset, Subset
from sklearn.preprocessing import LabelEncoder
from torchvision import transforms

SEED = 15
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
## <torch._C.Generator object at 0x000000001F72E070>
torch.backends.cudnn.deterministic = True
torch.cuda.amp.autocast(enabled=True)

# tell pytorch to use cuda
## <torch.cuda.amp.autocast_mode.autocast object at 0x00000000338B53C8>
device = torch.device("cuda")

1 Introduction

In this guide, we will implement entity embeddings in two ways via PyTorch: (1) via nn.Embedding(), and (2) via transformers. We will also show how to load data in a more efficient manner through a custom PyTorch data set class. This style of data management is slightly more complicated to initialize, but is the precise way we want to load our data when dealing: (1) big data, or (2) a memory-conservative environment. Frankly, it is the way we should always load our data, but when learning PyTorch – there are of course other things to learn first besides interesting efficiencies.

Entity embeddings refers to the idea of transforming categorical variables into continuous embeddings to avoid one-hot encoding and sparse matrices. Embeddings, as we know, represent words as continuous vectors in a low dimensional space which capture lexical and semantic properties of words. Embeddings can be obtained from the internal representations from neural network models of text or by low rank approximation of co-occurrence statistics.

2 Entity Embeddings with nn.Embedding()

In this section, categorical embeddings via nn.Embedding() are incorporated into our transformer model. In the next section, we will use transformers to handle both.

2.1 Torch Data Set Class

We begin by creating our custom data set. What this does is it exemplifies lazy loading which means that PyTorch will only pull observations it needs and will do so on the fly while other computations are ongoing. Since this is the ideal way of loading and preparing data, let’s go over a few things. Tabling the embeddings for a moment, this class pulls one row of data from the specified CSV file at a time via the __getitem__() function. However, this class is a bit more complicated because it also uses nn.Embedding() to create continuous representations for the categorical variables in the data set. This happens over the course of multiple steps. First, the categorical data is encoded from string to numeric. Second, the categorical columns are scanned for the amount of different categories they have and then embeddings are made for each column. Lastly, the categorical data is injected into the embeddings and finally sent to dict for use later.

# Create Dataset
class CSVDataset(Dataset):
    """Propaganda data set."""

    def __init__(self, csv_file, text_col, cat_cols, target, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            text_col (string): column containing the text for analysis.
            cat_cols (string): column(s) containing string categorical data.
            target (string): column containing the dependent variable.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        # initialize
        self.data_frame = pd.read_csv(csv_file)
        self.categorical_features = cat_cols
        self.text_features = text_col
        self.target = target
        self.transform = transform

        # encode categorical variables
        label_encoders = {}
        for cat_col in self.categorical_features:
            label_encoders[cat_col] = LabelEncoder()
            self.data_frame[cat_col] = label_encoders[cat_col].fit_transform(self.data_frame[cat_col])

        # encode outcome
        self.data_frame[target] = LabelEncoder().fit_transform(self.data_frame[target])

        # embedding info
        self.cat_dims = [int(self.data_frame[col].nunique()) for col in self.categorical_features]
        self.emb_dims = [(x, min(50, (x + 1) // 2)) for x in self.cat_dims]
        self.all_embeddings = nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in self.emb_dims])

        # get length of df
    def __len__(self):
        return len(self.data_frame)

        # get target
    def __get_target__(self):
        return self.data_frame.target

        # get df filtered by indices
    def __get_values__(self, indices):
        return self.data_frame.iloc[indices]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # pull a sample of data
        text = self.data_frame.iloc[idx][self.text_features]
        cats = self.data_frame.iloc[idx][self.categorical_features]
        cats = torch.tensor(cats).long()
        target = self.data_frame.iloc[idx][self.target]

        # create embeddings
        self.embeddings = []
        for i, emb in enumerate(self.all_embeddings):
            self.embeddings.append(emb(cats[i]))
        self.embeddings = torch.cat(self.embeddings, 0)

        # hold sample in a dict
        sample = {'text': text,
                  'cats': self.embeddings,
                  'target': target,
                  'idx': torch.tensor(idx)}

        if self.transform:
            sample = self.transform(sample)

        return sample

2.2 Torch Data Set Transforms

Next, a separate class is prepared to handle tokenization on the fly. It receives the dictionary from the class above, unpacks it, and sends the text through the usual tokenization process provided by huggingface via tokenizer.encode_plus().

class Tokenize_Transform():

    # retrieve sample and unpack it
    def __call__(self, sample):
        text, cats, target, idx = (sample['text']['body'],
                              sample['cats'],
                              sample['target'].values.astype(np.int64),
                              sample['idx'])

        # transform text to input ids and attn masks
        tokenizer_output = tokenizer.encode_plus(
                            text,  # document to encode.
                            add_special_tokens=True,  # add '[CLS]' and '[SEP]'
                            max_length=512,  # set max length
                            truncation=True,  # truncate longer messages
                            pad_to_max_length=True,  # add padding
                            return_attention_mask=True,  # create attn. masks
                            return_tensors='pt'  # return pytorch tensors
                       )
        input_ids, attn_mask = tokenizer_output['input_ids'], tokenizer_output['attention_mask']

        # yield another dict
        return {'input_ids': input_ids,
                'attn_mask': attn_mask,
                'cats': cats,
                'target': torch.from_numpy(target),
                'idx': idx}

2.3 Load Torch Data Sets

Now we are ready to instantiate our data sets and split them into train, valid, and test sets.

# load the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# instantiate the lazy data set
csv_dataset = CSVDataset(csv_file='C:\\Users\\Andrew\\Desktop\\test_export.csv',
                         text_col=['body'],
                         cat_cols=["sas_active", "peace_talks_active", "isisk_active", "administration"],
                         target=['target'],
                         transform=Tokenize_Transform())

# set train, valid, and test size
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\preprocessing\label.py:235: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
##   y = column_or_1d(y, warn=True)
train_size = int(0.8 * len(csv_dataset))
valid_size = int(0.1 * len(csv_dataset))

# use random split to create three data sets; +1 for odd number of data
train_ds, valid_ds, test_ds = torch.utils.data.random_split(csv_dataset, [train_size, valid_size, valid_size+1])

2.4 Prepare Custom Model to Accept Categorical Embeddings

Then, we create a custom transformer class that allows us to concat the custom categorical embeddings with the transformer-derived embeddings.

# create custom transformer that concats the text and categorical embeddings
class DistillBERT_FE(torch.nn.Module):
    def __init__(self):
        super(DistillBERT_FE, self).__init__()
        # load model
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        # pre-classifier layer
        self.pre_classifier = torch.nn.Linear(772, 772)  # 4 embed dim + 768
        # drop out
        self.dropout = torch.nn.Dropout(0.3)
        # final classification layer
        self.classifier = torch.nn.Linear(772, 2)  # 4 embed dim + 768

    def forward(self, input_ids, attention_mask):
        # generate outputs from BERT
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]  # last hidden layer
        pooled_output = hidden_state[:, 0]  # just the cls token

        # cat transformer embeddings with entity embeddings
        pooled_output = torch.cat([pooled_output, b_cats], dim=1)

        # send through pre-classifying linear layer
        pooled_output = self.pre_classifier(pooled_output)
        # relu
        pooled_output = torch.nn.ReLU()(pooled_output)
        # add dropout
        pooled_output = self.dropout(pooled_output)
        # final classifying layer to yield logits
        logits = self.classifier(pooled_output)

        return logits

2.5 Weighted Random Sampler

Since data is usually imbalanced, a weighted sampler is prepared to help provide balance to our data loaders.

# prepare weighted sampling for imbalanced classification
def create_sampler(train_ds, csv_dataset):
    # get indicies from train split
    train_indices = train_ds.indices
    # generate class distributions [y1, y2, etc...]
    bin_count = np.bincount(csv_dataset.__get_target__()[train_indices])
    # weight gen
    weight = 1. / bin_count.astype(np.float32)
    # produce weights for each observation in the data set
    samples_weight = torch.tensor([weight[t] for t in csv_dataset.__get_target__()[train_indices]])
    # prepare sampler
    sampler = torch.utils.data.WeightedRandomSampler(weights=samples_weight,
                                                     num_samples=len(samples_weight),
                                                     replacement=True)
    return sampler

# create sampler for the training ds
train_sampler = create_sampler(train_ds, csv_dataset)    

2.6 Training Functions

Now, a time helper function and the train, valid, and test workflows are prepared.

# time function
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def train(model, dataloader, optimizer):

    # capture time
    total_t0 = time.time()

    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
    print('Training...')

    # reset total loss for epoch
    train_total_loss = 0
    total_train_f1 = 0

    # put model into traning mode
    model.train()

    # for each batch of training data...
    for step, batch in enumerate(dataloader):

        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        # Unpack this training batch from our dataloader:
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU using
        # the `to` method.
        #
        b_input_ids = batch['input_ids'].squeeze(1).cuda()
        b_input_mask = batch['attn_mask'].squeeze(1).cuda()
        global b_cats
        b_cats = batch['cats'].cuda()
        b_labels = batch['target'].cuda().long()

        # clear previously calculated gradients
        optimizer.zero_grad()

        # runs the forward pass with autocasting.
        with autocast():
            # forward propagation (evaluate model on training batch)
            logits = model(input_ids=b_input_ids, attention_mask=b_input_mask)

            # loss
            loss = criterion(logits.view(-1, 2), b_labels.view(-1))
            # sum the training loss over all batches for average loss at end
            # loss is a tensor containing a single value
            train_total_loss += loss.item()

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # update the learning rate
        scheduler.step()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_train_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

    # calculate the average loss over all of the batches
    avg_train_loss = train_total_loss / len(dataloader)

    # calculate the average f1 over all of the batches
    avg_train_f1 = total_train_f1 / len(dataloader)

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'Train Loss': avg_train_loss,
            'Train F1': avg_train_f1
        }
    )

    # training time end
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | trn loss | trn f1 | trn time ")
    print(f"{epoch+1:5d} | {avg_train_loss:.5f} | {avg_train_f1:.5f} | {training_time:}")

    return None


def validating(model, dataloader):

    # capture validation time
    total_t0 = time.time()

    # After the completion of each training epoch, measure our performance on
    # our validation set.
    print("")
    print("Running Validation...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_valid_accuracy = 0
    total_valid_loss = 0
    total_valid_f1 = 0
    total_valid_recall = 0
    total_valid_precision = 0

    # evaluate data for one epoch
    for batch in dataloader:

        # Unpack this training batch from our dataloader:
        b_input_ids = batch['input_ids'].squeeze(1).cuda()
        b_input_mask = batch['attn_mask'].squeeze(1).cuda()
        global b_cats
        b_cats = batch['cats'].cuda()
        b_labels = batch['target'].cuda().long()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            logits = model(input_ids=b_input_ids, attention_mask=b_input_mask)

            # loss
            loss = criterion(logits.view(-1, 2), b_labels.view(-1))

        # accumulate validation loss
        total_valid_loss += loss.item()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_valid_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

        # calculate accuracy
        total_valid_accuracy += accuracy_score(rounded_preds, y_true)

        # calculate precision
        total_valid_precision += precision_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

        # calculate recall
        total_valid_recall += recall_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

    # report final accuracy of validation run
    avg_accuracy = total_valid_accuracy / len(dataloader)

    # report final f1 of validation run
    global avg_val_f1
    avg_val_f1 = total_valid_f1 / len(dataloader)

    # report final f1 of validation run
    avg_precision = total_valid_precision / len(dataloader)

    # report final f1 of validation run
    avg_recall = total_valid_recall / len(dataloader)

    # calculate the average loss over all of the batches.
    global avg_val_loss
    avg_val_loss = total_valid_loss / len(dataloader)

    # Record all statistics from this epoch.
    valid_stats.append(
        {
            'Val Loss': avg_val_loss,
            'Val Accur.': avg_accuracy,
            'Val precision': avg_precision,
            'Val recall': avg_recall,
            'Val F1': avg_val_f1
        }
    )

    # capture end validation time
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | val loss | val f1 | val time")
    print(f"{epoch+1:5d} | {avg_val_loss:.5f} | {avg_val_f1:.5f} | {training_time:}")

    return None

2.7 Preparing to Train

Now we are ready to instantiate the model, the data loaders, and several other training helper objects.

# Load DistilBERT_FE
model = DistillBERT_FE().cuda()

# optimizer
## 
Downloading:   0%|          | 0.00/442 [00:00<?, ?B/s]
Downloading: 100%|##########| 442/442 [00:00<00:00, 442kB/s]
## 
Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]
Downloading:   0%|          | 489k/268M [00:00<00:54, 4.89MB/s]
Downloading:   1%|          | 1.76M/268M [00:00<00:44, 6.00MB/s]
Downloading:   1%|1         | 3.40M/268M [00:00<00:35, 7.40MB/s]
Downloading:   2%|1         | 5.16M/268M [00:00<00:29, 8.95MB/s]
Downloading:   3%|2         | 6.85M/268M [00:00<00:25, 10.4MB/s]
Downloading:   3%|3         | 8.51M/268M [00:00<00:22, 11.7MB/s]
Downloading:   4%|3         | 9.89M/268M [00:00<00:22, 11.5MB/s]
Downloading:   4%|4         | 11.9M/268M [00:00<00:19, 13.2MB/s]
Downloading:   5%|5         | 13.9M/268M [00:00<00:17, 14.7MB/s]
Downloading:   6%|6         | 16.2M/268M [00:01<00:15, 16.4MB/s]
Downloading:   7%|6         | 18.3M/268M [00:01<00:14, 17.6MB/s]
Downloading:   8%|7         | 20.4M/268M [00:01<00:13, 18.4MB/s]
Downloading:   8%|8         | 22.3M/268M [00:01<00:13, 18.7MB/s]
Downloading:   9%|9         | 24.4M/268M [00:01<00:12, 19.2MB/s]
Downloading:  10%|9         | 26.5M/268M [00:01<00:12, 19.7MB/s]
Downloading:  11%|#         | 28.5M/268M [00:01<00:12, 19.8MB/s]
Downloading:  11%|#1        | 30.7M/268M [00:01<00:11, 20.3MB/s]
Downloading:  12%|#2        | 32.9M/268M [00:01<00:11, 20.7MB/s]
Downloading:  13%|#3        | 35.0M/268M [00:01<00:11, 21.0MB/s]
Downloading:  14%|#3        | 37.2M/268M [00:02<00:11, 20.9MB/s]
Downloading:  15%|#4        | 39.3M/268M [00:02<00:10, 21.0MB/s]
Downloading:  16%|#5        | 41.5M/268M [00:02<00:10, 21.2MB/s]
Downloading:  16%|#6        | 43.8M/268M [00:02<00:10, 21.5MB/s]
Downloading:  17%|#7        | 45.9M/268M [00:02<00:10, 21.0MB/s]
Downloading:  18%|#7        | 48.1M/268M [00:02<00:10, 21.2MB/s]
Downloading:  19%|#8        | 50.2M/268M [00:02<00:10, 21.1MB/s]
Downloading:  20%|#9        | 52.5M/268M [00:02<00:10, 21.5MB/s]
Downloading:  20%|##        | 54.8M/268M [00:02<00:09, 22.0MB/s]
Downloading:  21%|##1       | 57.2M/268M [00:02<00:09, 22.2MB/s]
Downloading:  22%|##2       | 59.4M/268M [00:03<00:09, 21.6MB/s]
Downloading:  23%|##3       | 61.7M/268M [00:03<00:09, 21.9MB/s]
Downloading:  24%|##3       | 64.0M/268M [00:03<00:09, 22.1MB/s]
Downloading:  25%|##4       | 66.2M/268M [00:03<00:09, 22.2MB/s]
Downloading:  26%|##5       | 68.5M/268M [00:03<00:08, 22.4MB/s]
Downloading:  26%|##6       | 70.8M/268M [00:03<00:08, 22.5MB/s]
Downloading:  27%|##7       | 73.1M/268M [00:03<00:08, 22.6MB/s]
Downloading:  28%|##8       | 75.3M/268M [00:03<00:08, 22.4MB/s]
Downloading:  29%|##8       | 77.6M/268M [00:03<00:08, 22.5MB/s]
Downloading:  30%|##9       | 79.9M/268M [00:03<00:08, 22.6MB/s]
Downloading:  31%|###       | 82.3M/268M [00:04<00:08, 22.8MB/s]
Downloading:  32%|###1      | 84.5M/268M [00:04<00:08, 22.7MB/s]
Downloading:  32%|###2      | 86.9M/268M [00:04<00:07, 22.9MB/s]
Downloading:  33%|###3      | 89.2M/268M [00:04<00:07, 23.0MB/s]
Downloading:  34%|###4      | 91.5M/268M [00:04<00:07, 22.4MB/s]
Downloading:  35%|###4      | 93.8M/268M [00:04<00:07, 22.3MB/s]
Downloading:  36%|###5      | 96.0M/268M [00:04<00:07, 22.2MB/s]
Downloading:  37%|###6      | 98.2M/268M [00:04<00:07, 21.8MB/s]
Downloading:  37%|###7      | 100M/268M [00:04<00:07, 22.0MB/s] 
Downloading:  38%|###8      | 103M/268M [00:04<00:07, 22.1MB/s]
Downloading:  39%|###9      | 105M/268M [00:05<00:07, 22.0MB/s]
Downloading:  40%|###9      | 107M/268M [00:05<00:07, 22.2MB/s]
Downloading:  41%|####      | 109M/268M [00:05<00:07, 21.6MB/s]
Downloading:  42%|####1     | 112M/268M [00:05<00:07, 21.9MB/s]
Downloading:  43%|####2     | 114M/268M [00:05<00:07, 21.9MB/s]
Downloading:  43%|####3     | 116M/268M [00:05<00:06, 22.1MB/s]
Downloading:  44%|####4     | 118M/268M [00:05<00:06, 22.1MB/s]
Downloading:  45%|####5     | 121M/268M [00:05<00:06, 21.8MB/s]
Downloading:  46%|####5     | 123M/268M [00:05<00:06, 22.0MB/s]
Downloading:  47%|####6     | 125M/268M [00:06<00:06, 22.4MB/s]
Downloading:  48%|####7     | 128M/268M [00:06<00:06, 22.2MB/s]
Downloading:  48%|####8     | 130M/268M [00:06<00:06, 22.4MB/s]
Downloading:  49%|####9     | 132M/268M [00:06<00:06, 22.5MB/s]
Downloading:  50%|#####     | 134M/268M [00:06<00:06, 22.1MB/s]
Downloading:  51%|#####     | 137M/268M [00:06<00:05, 22.0MB/s]
Downloading:  52%|#####1    | 139M/268M [00:06<00:06, 21.4MB/s]
Downloading:  53%|#####2    | 141M/268M [00:06<00:05, 21.3MB/s]
Downloading:  53%|#####3    | 143M/268M [00:06<00:05, 21.9MB/s]
Downloading:  54%|#####4    | 145M/268M [00:06<00:05, 21.6MB/s]
Downloading:  55%|#####5    | 148M/268M [00:07<00:05, 21.7MB/s]
Downloading:  56%|#####5    | 150M/268M [00:07<00:05, 22.0MB/s]
Downloading:  57%|#####6    | 152M/268M [00:07<00:05, 21.9MB/s]
Downloading:  58%|#####7    | 154M/268M [00:07<00:05, 21.7MB/s]
Downloading:  58%|#####8    | 157M/268M [00:07<00:05, 21.9MB/s]
Downloading:  59%|#####9    | 159M/268M [00:07<00:04, 22.0MB/s]
Downloading:  60%|######    | 161M/268M [00:07<00:04, 22.1MB/s]
Downloading:  61%|######    | 163M/268M [00:07<00:04, 22.5MB/s]
Downloading:  62%|######1   | 166M/268M [00:07<00:04, 22.5MB/s]
Downloading:  63%|######2   | 168M/268M [00:07<00:04, 22.1MB/s]
Downloading:  64%|######3   | 170M/268M [00:08<00:04, 22.3MB/s]
Downloading:  64%|######4   | 173M/268M [00:08<00:04, 22.6MB/s]
Downloading:  65%|######5   | 175M/268M [00:08<00:04, 22.7MB/s]
Downloading:  66%|######6   | 177M/268M [00:08<00:03, 22.8MB/s]
Downloading:  67%|######6   | 179M/268M [00:08<00:03, 22.3MB/s]
Downloading:  68%|######7   | 182M/268M [00:08<00:03, 22.6MB/s]
Downloading:  69%|######8   | 184M/268M [00:08<00:03, 22.4MB/s]
Downloading:  70%|######9   | 186M/268M [00:08<00:03, 22.4MB/s]
Downloading:  70%|#######   | 189M/268M [00:08<00:03, 22.5MB/s]
Downloading:  71%|#######1  | 191M/268M [00:08<00:03, 22.2MB/s]
Downloading:  72%|#######2  | 193M/268M [00:09<00:03, 22.4MB/s]
Downloading:  73%|#######2  | 195M/268M [00:09<00:03, 22.6MB/s]
Downloading:  74%|#######3  | 198M/268M [00:09<00:03, 22.6MB/s]
Downloading:  75%|#######4  | 200M/268M [00:09<00:02, 22.7MB/s]
Downloading:  76%|#######5  | 202M/268M [00:09<00:02, 22.6MB/s]
Downloading:  76%|#######6  | 205M/268M [00:09<00:02, 22.2MB/s]
Downloading:  77%|#######7  | 207M/268M [00:09<00:02, 22.2MB/s]
Downloading:  78%|#######8  | 209M/268M [00:09<00:02, 22.1MB/s]
Downloading:  79%|#######8  | 211M/268M [00:09<00:02, 21.6MB/s]
Downloading:  80%|#######9  | 213M/268M [00:09<00:02, 21.7MB/s]
Downloading:  81%|########  | 216M/268M [00:10<00:02, 22.0MB/s]
Downloading:  81%|########1 | 218M/268M [00:10<00:03, 13.0MB/s]
Downloading:  82%|########2 | 220M/268M [00:10<00:03, 14.3MB/s]
Downloading:  83%|########2 | 222M/268M [00:10<00:02, 16.1MB/s]
Downloading:  84%|########3 | 224M/268M [00:10<00:02, 17.4MB/s]
Downloading:  85%|########4 | 226M/268M [00:10<00:02, 18.6MB/s]
Downloading:  85%|########5 | 229M/268M [00:10<00:01, 19.8MB/s]
Downloading:  86%|########6 | 231M/268M [00:11<00:01, 20.4MB/s]
Downloading:  87%|########7 | 233M/268M [00:11<00:01, 20.5MB/s]
Downloading:  88%|########7 | 235M/268M [00:11<00:01, 21.0MB/s]
Downloading:  89%|########8 | 238M/268M [00:11<00:01, 21.4MB/s]
Downloading:  90%|########9 | 240M/268M [00:11<00:01, 21.4MB/s]
Downloading:  90%|######### | 242M/268M [00:11<00:01, 21.2MB/s]
Downloading:  91%|#########1| 244M/268M [00:11<00:01, 21.3MB/s]
Downloading:  92%|#########1| 246M/268M [00:11<00:00, 21.6MB/s]
Downloading:  93%|#########2| 249M/268M [00:11<00:00, 22.1MB/s]
Downloading:  94%|#########3| 251M/268M [00:11<00:00, 22.2MB/s]
Downloading:  95%|#########4| 253M/268M [00:12<00:00, 22.1MB/s]
Downloading:  95%|#########5| 256M/268M [00:12<00:00, 22.4MB/s]
Downloading:  96%|#########6| 258M/268M [00:12<00:00, 22.4MB/s]
Downloading:  97%|#########7| 260M/268M [00:12<00:00, 22.4MB/s]
Downloading:  98%|#########7| 263M/268M [00:12<00:00, 22.5MB/s]
Downloading:  99%|#########8| 265M/268M [00:12<00:00, 22.2MB/s]
Downloading: 100%|#########9| 267M/268M [00:12<00:00, 22.2MB/s]
Downloading: 100%|##########| 268M/268M [00:12<00:00, 21.1MB/s]
optimizer = AdamW(model.parameters(),
                  lr=3.2696465645595003e-06,
                  weight_decay=1.0
                )

# set loss
criterion = nn.CrossEntropyLoss()


# set number of epochs
epochs = 5

# create DataLoaders with samplers
train_dataloader = DataLoader(train_ds,
                              batch_size=16,
                              sampler=train_sampler,
                              shuffle=False)

valid_dataloader = DataLoader(valid_ds,
                              batch_size=16,
                              shuffle=True)

test_dataloader = DataLoader(test_ds,
                              batch_size=16,
                              shuffle=True)
                              
# set LR scheduler
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0.1*total_steps,
                                            num_training_steps=total_steps)

Let’s check to see what a batch of data looks like and whether or not our weighted random sampler is functioning as intended.

# lets check class balance for each batch to see how the sampler is working
for i, batch in enumerate(train_dataloader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (batch['target'] == 0).sum(), (batch['target'] == 1).sum()))
    if i == 14:
        break

# lets have a look at a single batch of categorical embeddings
## batch index 0, 0/1: 9/7
## batch index 1, 0/1: 5/11
## batch index 2, 0/1: 6/10
## batch index 3, 0/1: 9/7
## batch index 4, 0/1: 5/11
## batch index 5, 0/1: 7/9
## batch index 6, 0/1: 7/9
## batch index 7, 0/1: 10/6
## batch index 8, 0/1: 6/10
## batch index 9, 0/1: 10/6
## batch index 10, 0/1: 6/10
## batch index 11, 0/1: 10/6
## batch index 12, 0/1: 10/6
## batch index 13, 0/1: 9/7
## batch index 14, 0/1: 8/8
batch['cats']
## tensor([[-0.7056, -0.5454,  1.0682, -0.1769],
##         [-0.7056, -0.5454,  1.0682, -0.1769],
##         [ 0.6741,  0.9107,  1.0682, -0.1769],
##         [ 0.6741,  0.9107,  1.0682, -1.2754],
##         [-0.7056, -0.5454,  1.0682, -0.1769],
##         [ 0.6741,  0.9107,  1.0682, -1.2754],
##         [ 0.6741,  0.9107,  0.1424, -1.2754],
##         [-0.7056,  0.9107,  1.0682, -0.1769],
##         [ 0.6741,  0.9107,  1.0682, -1.2754],
##         [ 0.6741,  0.9107,  0.1424, -1.2754],
##         [ 0.6741,  0.9107,  1.0682, -1.2754],
##         [ 0.6741,  0.9107,  0.1424, -1.2754],
##         [ 0.6741,  0.9107,  1.0682, -1.2754],
##         [ 0.6741,  0.9107,  1.0682, -0.1769],
##         [ 0.6741,  0.9107,  1.0682, -1.2754],
##         [-0.7056,  0.9107,  1.0682, -0.1769]], grad_fn=<StackBackward>)

2.8 Training

Now, we are ready to train.

# create gradient scaler for mixed precision
scaler = GradScaler()

# create training result storage
training_stats = []
valid_stats = []
best_valid_loss = float('inf')

# for each epoch
for epoch in range(epochs):
    # train
    train(model, train_dataloader, optimizer)
    # validate
    validating(model, valid_dataloader)
    # check validation loss
    if valid_stats[epoch]['Val Loss'] < best_valid_loss:
        best_valid_loss = valid_stats[epoch]['Val Loss']
        # save best model for use later
        torch.save(model.state_dict(), 'bert-model1.pt')  # torch save
## 
## ======== Epoch 1 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     1 | 0.46083 | 0.81059 | 0:02:36
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     1 | 0.31722 | 0.85896 | 0:00:13
## 
## ======== Epoch 2 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     2 | 0.31257 | 0.87632 | 0:02:34
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     2 | 0.30928 | 0.85402 | 0:00:13
## 
## ======== Epoch 3 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     3 | 0.27499 | 0.89340 | 0:02:32
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     3 | 0.27319 | 0.87726 | 0:00:12
## 
## ======== Epoch 4 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     4 | 0.24393 | 0.90583 | 0:02:31
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     4 | 0.27130 | 0.87891 | 0:00:13
## 
## ======== Epoch 5 / 5 ========
## Training...
##   Batch    40  of    503.
##   Batch    80  of    503.
##   Batch   120  of    503.
##   Batch   160  of    503.
##   Batch   200  of    503.
##   Batch   240  of    503.
##   Batch   280  of    503.
##   Batch   320  of    503.
##   Batch   360  of    503.
##   Batch   400  of    503.
##   Batch   440  of    503.
##   Batch   480  of    503.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     5 | 0.26196 | 0.89868 | 0:02:33
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     5 | 0.28235 | 0.88039 | 0:00:13
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)

Lastly, the other added consequence of using the torch data set is that it allows us to more easily conduct error analysis by subsetting the batch indices to the original data frame like so:

# for error analysis
batch_idx = np.array(batch['idx'])
csv_dataset.__get_values__(batch_idx)
##             date  ... word_count
## 1244   9/10/2018  ...        173
## 291    1/21/2020  ...        132
## 2411    5/4/2017  ...        317
## 7044    2/3/2015  ...         60
## 828    3/20/2019  ...        179
## 3494   5/10/2016  ...         95
## 8943   5/19/2014  ...         55
## 1949  10/26/2017  ...        552
## 4529    1/2/2016  ...         61
## 9416   3/21/2014  ...         64
## 2777    1/6/2017  ...         89
## 7854   9/22/2014  ...         49
## 2898  10/20/2016  ...        356
## 2382   5/13/2017  ...         75
## 2976    9/4/2016  ...        399
## 1332   7/13/2018  ...        318
## 
## [16 rows x 17 columns]

3 Transformer-derived Categorical Embeddings

In this section, the code will be provided with less discussion, as it mostly follows that of the section above. The logic here is that we will use transformers to generate contextual embeddings for our categorical variables by joining them and treating them as a long sentence. Then, by concatenating the categorical embeddings with that of the text embeddings, we will rely on our transformer architecture to eventually learn that they are joined together and to then conduct feature engineering for us.

# Create Dataset
class CSVDataset(Dataset):
    """Propaganda data set."""

    def __init__(self, csv_file, text_col, cat_cols, target, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            text_col (string): column containing the text for analysis.
            cat_cols (string): column(s) containing string categorical data.
            target (string): column containing the dependent variable.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        # initialize
        self.data_frame = pd.read_csv(csv_file)
        self.categorical_features = cat_cols
        self.text_features = text_col
        self.target = target
        self.transform = transform
        
        # encode outcome
        self.data_frame[target] = LabelEncoder().fit_transform(self.data_frame[target])

    def __len__(self):
        return len(self.data_frame)

    def __get_target__(self):
        return self.data_frame.target

    def __get_values__(self, indices):
        return self.data_frame.iloc[indices]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        text = self.data_frame.iloc[idx][self.text_features]
        cats = self.data_frame.iloc[idx][self.categorical_features]
        target = self.data_frame.iloc[idx][self.target]

        sample = {'text': text, 'cats': cats.values, 'target': target, 'idx': torch.tensor(idx)}

        if self.transform:
            sample = self.transform(sample)

        return sample


class Tokenize_Transform():

    def __call__(self, sample):
        text, cats, target, idx = (sample['text']['body'],
                              sample['cats'],
                              sample['target'].values.astype(np.int64),
                              sample['idx'])

        # transform text to input ids and attn masks
        tokenizer_output = tokenizer.encode_plus(
                            text,  # document to encode.
                            add_special_tokens=True,  # add '[CLS]' and '[SEP]'
                            max_length=512,  # set max length
                            truncation=True,  # truncate longer messages
                            pad_to_max_length=True,  # add padding
                            return_attention_mask=True,  # create attn. masks
                            return_tensors='pt'  # return pytorch tensors
                       )
        input_ids, attn_mask = tokenizer_output['input_ids'], tokenizer_output['attention_mask']

        return {'input_ids': input_ids,
                'attn_mask': attn_mask,
                'cats': cats,
                'target': target,
                'idx': idx}


class Tokenize_Cats():

    def __call__(self, sample):
        text, cats, target, idx = ((sample['input_ids'], sample['attn_mask']),
                              sample['cats'],
                              sample['target'],
                              sample['idx'])

        # transform text to input ids and attn masks
        cat_input_ids = []
        cat_attn_mask = []
        encoded_dict = tokenizer.encode_plus(
                                ' '.join(cats),                      # Sentence to encode.
                                add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                                max_length = 10,           # Pad & truncate all sentences.
                                truncation = True,
                                pad_to_max_length = True,
                                return_attention_mask = True,   # Construct attn. masks.
                                return_tensors = 'pt',     # Return pytorch tensors.
                           )
        cat_input_ids.append(encoded_dict['input_ids'])
        cat_attn_mask.append(encoded_dict['attention_mask'])

        # Convert the lists into tensors.
        cat_input_ids = torch.cat(cat_input_ids, dim=1)
        cat_attn_mask = torch.cat(cat_attn_mask, dim=1)

        return {'input_ids': text[0],
                'attn_mask': text[1],
                'cats_ids': cat_input_ids,
                'cats_mask': cat_attn_mask,
                'target': torch.from_numpy(target),
                'idx': idx}
csv_dataset = CSVDataset(csv_file='C:\\Users\\Andrew\\Desktop\\test_export.csv',
                         text_col=['body'],
                         cat_cols=["sas_active", "peace_talks_active", "isisk_active", "administration"],
                         target=['target'],
                         transform=transforms.Compose([Tokenize_Transform(), Tokenize_Cats()]))
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\preprocessing\label.py:235: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
##   y = column_or_1d(y, warn=True)
train_size = int(0.8 * len(csv_dataset))
valid_size = int(0.1 * len(csv_dataset))

train_ds, valid_ds, test_ds = torch.utils.data.random_split(csv_dataset, [train_size, valid_size, valid_size+1])
class DistillBERT_FE(torch.nn.Module):
    def __init__(self):
        super(DistillBERT_FE, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(1536, 1536)  # 4 embed dim + 768
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(1536, 2)  # 4 embed dim + 768

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        output_2 = self.l1(input_ids=cats_ids, attention_mask=cats_mask)
        hidden_state = output_1[0]  # last hidden layer
        pooled_output = hidden_state[:, 0]  # just the cls token

        hidden_state2 = output_2[0]  # last hidden layer
        pooled_output2 = hidden_state2[:, 0]  # just the cls token

        # cat transformer embeddings with entity embeddings
        pooled_output = torch.cat([pooled_output, pooled_output2], dim=1)

        # send through pre-classifying linear layer
        pooled_output = self.pre_classifier(pooled_output)
        # relu
        pooled_output = torch.nn.ReLU()(pooled_output)
        # add dropout
        pooled_output = self.dropout(pooled_output)
        # final classifying layer to yield logits
        logits = self.classifier(pooled_output)
        return logits

# prepare weighted sampling for imbalanced classification
def create_sampler(train_ds):
    # get indicies from train split
    train_indices = train_ds.indices
    # generate class distributions [x, y]
    bin_count = np.bincount(csv_dataset.__get_target__()[train_indices])
    # weight gen
    weight = 1. / bin_count.astype(np.float32)
    # produce weights for each observation in the data set
    samples_weight = torch.tensor([weight[t] for t in csv_dataset.__get_target__()[train_indices]])
    # prepare sampler
    sampler = torch.utils.data.WeightedRandomSampler(weights=samples_weight,
                                                     num_samples=len(samples_weight),
                                                     replacement=True)
    return sampler

# create sampler for the training ds
train_sampler = create_sampler(train_ds)

# set loss
criterion = nn.CrossEntropyLoss()

# time function
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def train(model, dataloader, optimizer):

    # capture time
    total_t0 = time.time()

    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
    print('Training...')

    # reset total loss for epoch
    train_total_loss = 0
    total_train_f1 = 0

    # put model into traning mode
    model.train()

    # for each batch of training data...
    for step, batch in enumerate(dataloader):

        # progress update every 40 batches.
        if step % 40 == 0 and not step == 0:

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(dataloader)))

        # Unpack this training batch from our dataloader:
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU using
        # the `to` method.
        #
        b_input_ids = batch['input_ids'].squeeze(1).cuda()
        b_input_mask = batch['attn_mask'].squeeze(1).cuda()
        global cats_ids, cats_mask
        cats_ids = batch['cats_ids'].squeeze(1).cuda()
        cats_mask = batch['cats_mask'].squeeze(1).cuda()
        b_labels = batch['target'].cuda().long()

        # clear previously calculated gradients
        optimizer.zero_grad()

        # runs the forward pass with autocasting.
        with autocast():
            # forward propagation (evaluate model on training batch)
            logits = model(input_ids=b_input_ids, attention_mask=b_input_mask)

            # loss
            loss = criterion(logits.view(-1, 2), b_labels.view(-1))
            # sum the training loss over all batches for average loss at end
            # loss is a tensor containing a single value
            train_total_loss += loss.item()

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # update the learning rate
        scheduler.step()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_train_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

    # calculate the average loss over all of the batches
    avg_train_loss = train_total_loss / len(dataloader)

    # calculate the average f1 over all of the batches
    avg_train_f1 = total_train_f1 / len(dataloader)

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'Train Loss': avg_train_loss,
            'Train F1': avg_train_f1
        }
    )

    # training time end
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | trn loss | trn f1 | trn time ")
    print(f"{epoch+1:5d} | {avg_train_loss:.5f} | {avg_train_f1:.5f} | {training_time:}")

    return None


def validating(model, dataloader):

    # capture validation time
    total_t0 = time.time()

    # After the completion of each training epoch, measure our performance on
    # our validation set.
    print("")
    print("Running Validation...")

    # put the model in evaluation mode
    model.eval()

    # track variables
    total_valid_accuracy = 0
    total_valid_loss = 0
    total_valid_f1 = 0
    total_valid_recall = 0
    total_valid_precision = 0

    # evaluate data for one epoch
    for batch in dataloader:

        # Unpack this training batch from our dataloader:
        b_input_ids = batch['input_ids'].squeeze(1).cuda()
        b_input_mask = batch['attn_mask'].squeeze(1).cuda()
        global cats_ids, cats_mask
        cats_ids = batch['cats_ids'].squeeze(1).cuda()
        cats_mask = batch['cats_mask'].squeeze(1).cuda()
        b_labels = batch['target'].cuda().long()

        # tell pytorch not to bother calculating gradients
        # as its only necessary for training
        with torch.no_grad():

            logits = model(input_ids=b_input_ids, attention_mask=b_input_mask)

            # loss
            loss = criterion(logits.view(-1, 2), b_labels.view(-1))

        # accumulate validation loss
        total_valid_loss += loss.item()

        # move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        y_true = b_labels.detach().cpu().numpy()

        # calculate preds
        rounded_preds = np.argmax(logits, axis=1).flatten()

        # calculate f1
        total_valid_f1 += f1_score(rounded_preds, y_true,
                                   average='weighted',
                                   labels=np.unique(rounded_preds))

        # calculate accuracy
        total_valid_accuracy += accuracy_score(rounded_preds, y_true)

        # calculate precision
        total_valid_precision += precision_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

        # calculate recall
        total_valid_recall += recall_score(rounded_preds, y_true,
                                                 average='weighted',
                                                 labels=np.unique(rounded_preds))

    # report final accuracy of validation run
    avg_accuracy = total_valid_accuracy / len(dataloader)

    # report final f1 of validation run
    global avg_val_f1
    avg_val_f1 = total_valid_f1 / len(dataloader)

    # report final f1 of validation run
    avg_precision = total_valid_precision / len(dataloader)

    # report final f1 of validation run
    avg_recall = total_valid_recall / len(dataloader)

    # calculate the average loss over all of the batches.
    global avg_val_loss
    avg_val_loss = total_valid_loss / len(dataloader)

    # Record all statistics from this epoch.
    valid_stats.append(
        {
            'Val Loss': avg_val_loss,
            'Val Accur.': avg_accuracy,
            'Val precision': avg_precision,
            'Val recall': avg_recall,
            'Val F1': avg_val_f1
        }
    )

    # capture end validation time
    training_time = format_time(time.time() - total_t0)

    # print result summaries
    print("")
    print("summary results")
    print("epoch | val loss | val f1 | val time")
    print(f"{epoch+1:5d} | {avg_val_loss:.5f} | {avg_val_f1:.5f} | {training_time:}")

    return None


# Load DistilBERT_FE
model = DistillBERT_FE().cuda()

# optimizer
optimizer = AdamW(model.parameters(),
                  lr=3.2696465645595003e-06,
                  weight_decay=1.0
                )


# set number of epochs
epochs = 5

# create DataLoaders with samplers
train_dataloader = DataLoader(train_ds,
                              batch_size=8,
                              sampler=train_sampler,
                              shuffle=False)

valid_dataloader = DataLoader(valid_ds,
                              batch_size=8,
                              shuffle=True)

test_dataloader = DataLoader(test_ds,
                              batch_size=8,
                              shuffle=True)


# set LR scheduler
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=total_steps*0.1,
                                            num_training_steps=total_steps)


# lets check class balance for each batch to see how the sampler is working
for i, batch in enumerate(train_dataloader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (batch['target'] == 0).sum(), (batch['target'] == 1).sum()))
    if i == 14:
        break
## batch index 0, 0/1: 5/3
## batch index 1, 0/1: 4/4
## batch index 2, 0/1: 3/5
## batch index 3, 0/1: 5/3
## batch index 4, 0/1: 2/6
## batch index 5, 0/1: 2/6
## batch index 6, 0/1: 3/5
## batch index 7, 0/1: 6/2
## batch index 8, 0/1: 4/4
## batch index 9, 0/1: 3/5
## batch index 10, 0/1: 3/5
## batch index 11, 0/1: 2/6
## batch index 12, 0/1: 4/4
## batch index 13, 0/1: 3/5
## batch index 14, 0/1: 3/5

# create gradient scaler for mixed precision
scaler = GradScaler()

# create training result storage
training_stats = []
valid_stats = []
best_valid_loss = float('inf')

# for each epoch
for epoch in range(epochs):
    # train
    train(model, train_dataloader, optimizer)
    # validate
    validating(model, valid_dataloader)
    # check validation loss
    if valid_stats[epoch]['Val Loss'] < best_valid_loss:
        best_valid_loss = valid_stats[epoch]['Val Loss']
        # save best model for use later
        torch.save(model.state_dict(), 'bert-model1.pt')  # torch save
## 
## ======== Epoch 1 / 5 ========
## Training...
##   Batch    40  of  1,005.
##   Batch    80  of  1,005.
##   Batch   120  of  1,005.
##   Batch   160  of  1,005.
##   Batch   200  of  1,005.
##   Batch   240  of  1,005.
##   Batch   280  of  1,005.
##   Batch   320  of  1,005.
##   Batch   360  of  1,005.
##   Batch   400  of  1,005.
##   Batch   440  of  1,005.
##   Batch   480  of  1,005.
##   Batch   520  of  1,005.
##   Batch   560  of  1,005.
##   Batch   600  of  1,005.
##   Batch   640  of  1,005.
##   Batch   680  of  1,005.
##   Batch   720  of  1,005.
##   Batch   760  of  1,005.
##   Batch   800  of  1,005.
##   Batch   840  of  1,005.
##   Batch   880  of  1,005.
##   Batch   920  of  1,005.
##   Batch   960  of  1,005.
##   Batch 1,000  of  1,005.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     1 | 0.43771 | 0.80962 | 0:03:02
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     1 | 0.36443 | 0.81212 | 0:00:13
## 
## ======== Epoch 2 / 5 ========
## Training...
##   Batch    40  of  1,005.
##   Batch    80  of  1,005.
##   Batch   120  of  1,005.
##   Batch   160  of  1,005.
##   Batch   200  of  1,005.
##   Batch   240  of  1,005.
##   Batch   280  of  1,005.
##   Batch   320  of  1,005.
##   Batch   360  of  1,005.
##   Batch   400  of  1,005.
##   Batch   440  of  1,005.
##   Batch   480  of  1,005.
##   Batch   520  of  1,005.
##   Batch   560  of  1,005.
##   Batch   600  of  1,005.
##   Batch   640  of  1,005.
##   Batch   680  of  1,005.
##   Batch   720  of  1,005.
##   Batch   760  of  1,005.
##   Batch   800  of  1,005.
##   Batch   840  of  1,005.
##   Batch   880  of  1,005.
##   Batch   920  of  1,005.
##   Batch   960  of  1,005.
##   Batch 1,000  of  1,005.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     2 | 0.29222 | 0.88269 | 0:03:05
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     2 | 0.33665 | 0.83473 | 0:00:13
## 
## ======== Epoch 3 / 5 ========
## Training...
##   Batch    40  of  1,005.
##   Batch    80  of  1,005.
##   Batch   120  of  1,005.
##   Batch   160  of  1,005.
##   Batch   200  of  1,005.
##   Batch   240  of  1,005.
##   Batch   280  of  1,005.
##   Batch   320  of  1,005.
##   Batch   360  of  1,005.
##   Batch   400  of  1,005.
##   Batch   440  of  1,005.
##   Batch   480  of  1,005.
##   Batch   520  of  1,005.
##   Batch   560  of  1,005.
##   Batch   600  of  1,005.
##   Batch   640  of  1,005.
##   Batch   680  of  1,005.
##   Batch   720  of  1,005.
##   Batch   760  of  1,005.
##   Batch   800  of  1,005.
##   Batch   840  of  1,005.
##   Batch   880  of  1,005.
##   Batch   920  of  1,005.
##   Batch   960  of  1,005.
##   Batch 1,000  of  1,005.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     3 | 0.26064 | 0.89359 | 0:03:05
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     3 | 0.33270 | 0.83736 | 0:00:13
## 
## ======== Epoch 4 / 5 ========
## Training...
##   Batch    40  of  1,005.
##   Batch    80  of  1,005.
##   Batch   120  of  1,005.
##   Batch   160  of  1,005.
##   Batch   200  of  1,005.
##   Batch   240  of  1,005.
##   Batch   280  of  1,005.
##   Batch   320  of  1,005.
##   Batch   360  of  1,005.
##   Batch   400  of  1,005.
##   Batch   440  of  1,005.
##   Batch   480  of  1,005.
##   Batch   520  of  1,005.
##   Batch   560  of  1,005.
##   Batch   600  of  1,005.
##   Batch   640  of  1,005.
##   Batch   680  of  1,005.
##   Batch   720  of  1,005.
##   Batch   760  of  1,005.
##   Batch   800  of  1,005.
##   Batch   840  of  1,005.
##   Batch   880  of  1,005.
##   Batch   920  of  1,005.
##   Batch   960  of  1,005.
##   Batch 1,000  of  1,005.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     4 | 0.23057 | 0.91137 | 0:03:03
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     4 | 0.33422 | 0.84537 | 0:00:13
## 
## ======== Epoch 5 / 5 ========
## Training...
##   Batch    40  of  1,005.
##   Batch    80  of  1,005.
##   Batch   120  of  1,005.
##   Batch   160  of  1,005.
##   Batch   200  of  1,005.
##   Batch   240  of  1,005.
##   Batch   280  of  1,005.
##   Batch   320  of  1,005.
##   Batch   360  of  1,005.
##   Batch   400  of  1,005.
##   Batch   440  of  1,005.
##   Batch   480  of  1,005.
##   Batch   520  of  1,005.
##   Batch   560  of  1,005.
##   Batch   600  of  1,005.
##   Batch   640  of  1,005.
##   Batch   680  of  1,005.
##   Batch   720  of  1,005.
##   Batch   760  of  1,005.
##   Batch   800  of  1,005.
##   Batch   840  of  1,005.
##   Batch   880  of  1,005.
##   Batch   920  of  1,005.
##   Batch   960  of  1,005.
##   Batch 1,000  of  1,005.
## 
## summary results
## epoch | trn loss | trn f1 | trn time 
##     5 | 0.21859 | 0.91334 | 0:03:07
## 
## Running Validation...
## 
## summary results
## epoch | val loss | val f1 | val time
##     5 | 0.33059 | 0.84985 | 0:00:13
## 
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
##   'precision', 'predicted', average, warn_for)