# load python
library(reticulate)
use_condaenv("my_ml")
# load packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import get_linear_schedule_with_warmup, AdamW
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
import time, datetime, random, optuna, re, string
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
from optuna.pruners import SuccessiveHalvingPruner
from optuna.samplers import TPESampler
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split
from collections import Counter
from transformers import BertModel, BertTokenizer
= 15
SEED
random.seed(SEED)
np.random.seed(SEED) torch.manual_seed(SEED)
## <torch._C.Generator object at 0x000000002189E070>
= True
torch.backends.cudnn.deterministic =True)
torch.cuda.amp.autocast(enabled
# tell pytorch to use cuda
## <torch.cuda.amp.autocast_mode.autocast object at 0x00000000347738C8>
= torch.device("cuda") device
In this guide, we prepare a BERT-CNN ensemble which takes the embeddings generated by the BERT base model and feeds them into a CNN. The general logic from this guide can be used to replace the CNN with any other NN of your choice. Future guides will explore other models like Bi-Directional LSTMs and the use of self-attention in embedding layer aggregation.
Like other guides, this walk through provides a complete treatment of the data preparation and training of the BERT-CNN ensemble in PyTorch.
We begin by loading and lightly editing our data prior to tokenization.
# prepare and load data
def prepare_df(pkl_location):
# read pkl as pandas
= pd.read_pickle(pkl_location)
df # just keep us/kabul labels
= df.loc[(df['target'] == 'US') | (df['target'] == 'Kabul')]
df # mask DV to recode
= df['target'] == 'US'
us = df['target'] == 'Kabul'
kabul # apply mask
'target'] = 1
df.loc[us, 'target'] = 0
df.loc[kabul, # reset index
= df.reset_index(drop=True)
df return df
= prepare_df('C:\\Users\\Andrew\\Desktop\\df.pkl')
df
# prepare data
def clean_df(df):
# strip dash but keep a space
'body'] = df['body'].str.replace('-', ' ')
df[# prepare keys for punctuation removal
= str.maketrans(dict.fromkeys(string.punctuation))
translator # lower case the data
'body'] = df['body'].apply(lambda x: x.lower())
df[# remove excess spaces near punctuation
'body'] = df['body'].apply(lambda x: re.sub(r'\s([?.!"](?:\s|$))', r'\1', x))
df[# remove punctuation -- f1 improves by .05 by disabling this
#df['body'] = df['body'].apply(lambda x: x.translate(translator))
# generate a word count
'word_count'] = df['body'].apply(lambda x: len(x.split()))
df[# remove excess white spaces
'body'] = df['body'].apply(lambda x: " ".join(x.split()))
df[
return df
= clean_df(df)
df
# lets remove rare words
def remove_rare_words(df):
# get counts of each word -- necessary for vocab
= Counter(" ".join(df['body'].values.tolist()).split(" "))
counts # remove low counts -- keep those above 2
= {key: value for key, value in counts.items() if value > 2}
counts
# remove rare words from corpus
def remove_rare(x):
return ' '.join(list(filter(lambda x: x in counts.keys(), x.split())))
# apply funx
'body'] = df['body'].apply(remove_rare)
df[return df
= remove_rare_words(df)
df
# remove transliterated words that GloVe can't find
= np.load('C:\\Users\\Andrew\\translit_no_match.npy')
no_matches = dict(zip(set(no_matches), range(len(set(no_matches)))))
no_matches
# remove transliterated words from corpus
'body'] = df['body'].apply(lambda x: ' '.join(list(filter(lambda x: x not in no_matches.keys(), x.split())))) df[
Next, we instantiate the BERT tokenizer from transformers
and tokenize our entire corpus.
# instantiate BERT tokenizer with upper + lower case
= BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer
# a look at some of the BERT vocab
##
Downloading: 0%| | 0.00/232k [00:00<?, ?B/s]
Downloading: 100%|##########| 232k/232k [00:00<00:00, 7.99MB/s]
= dict(zip(tokenizer.vocab.keys(), range(len(tokenizer))))
word_map 'the') # find index value word_map.get(
## 1996
list(tokenizer.vocab.keys())[2000:2010]
## ['to', 'was', 'he', 'is', 'as', 'for', 'on', 'with', 'that', 'it']
len(tokenizer)
## 30522
# tokenize corpus using BERT
def tokenize_corpus(df, tokenizer, max_len):
# token ID storage
= []
input_ids # attension mask storage
= []
attention_masks # max len -- 512 is max
= max_len
max_len # for every document:
for doc in df:
# `encode_plus` will:
# (1) Tokenize the sentence.
# (2) Prepend the `[CLS]` token to the start.
# (3) Append the `[SEP]` token to the end.
# (4) Map tokens to their IDs.
# (5) Pad or truncate the sentence to `max_length`
# (6) Create attention masks for [PAD] tokens.
= tokenizer.encode_plus(
encoded_dict # document to encode.
doc, =True, # add '[CLS]' and '[SEP]'
add_special_tokens=max_len, # set max length
max_length=True, # truncate longer messages
truncation=True, # add padding
pad_to_max_length=True, # create attn. masks
return_attention_mask='pt' # return pytorch tensors
return_tensors
)
# add the tokenized sentence to the list
'input_ids'])
input_ids.append(encoded_dict[
# and its attention mask (differentiates padding from non-padding)
'attention_mask'])
attention_masks.append(encoded_dict[
return torch.cat(input_ids, dim=0), torch.cat(attention_masks, dim=0)
# create tokenized data
= tokenize_corpus(df['body'].values, tokenizer, 512)
input_ids, attention_masks
# convert the labels into tensors.
= torch.tensor(df['target'].values.astype(np.float32)) labels
With the corpus tokenized, we now proceed to prepare our data for analysis in PyTorch. The code below creates a TensorDataset
comprised of our features, attention masks, and our labels. It then proceeds to spit the data sets into train, validation, and test sets.
For our purposes, we will not actually use the labels as we are simply using the BERT transformer without any specific head on top. The CNN will be our head that we place on-top of the network.
# prepare tensor data sets
def prepare_dataset(padded_tokens, attention_masks, target):
# prepare target into np array
= np.array(target.values, dtype=np.int64).reshape(-1, 1)
target # create tensor data sets
= TensorDataset(padded_tokens, attention_masks, torch.from_numpy(target))
tensor_df # 80% of df
= int(0.8 * len(df))
train_size # 20% of df
= len(df) - train_size
val_size # 50% of validation
= int(val_size - 0.5*val_size)
test_size # divide the dataset by randomly selecting samples
= random_split(tensor_df, [train_size, val_size])
train_dataset, val_dataset # divide validation by randomly selecting samples
= random_split(val_dataset, [test_size, test_size+1])
val_dataset, test_dataset
return train_dataset, val_dataset, test_dataset
# create tensor data sets
= prepare_dataset(input_ids,
train_dataset, val_dataset, test_dataset
attention_masks,'target']) df[
Since my corpus is imbalanced, we produce weighted samplers to help balance the distribution of data as it is fed via the data loaders.
# helper function to count target distribution inside tensor data sets
def target_count(tensor_dataset):
# set empty count containers
= 0
count0 = 0
count1 # set total container to turn into torch tensor
= []
total # for every item in the tensor data set
for i in tensor_dataset:
# if the target is equal to 0
if i[2].item() == 0:
+= 1
count0 # if the target is equal to 1
elif i[2].item() == 1:
+= 1
count1
total.append(count0)
total.append(count1)return torch.tensor(total)
# prepare weighted sampling for imbalanced classification
def create_sampler(target_tensor, tensor_dataset):
# generate class distributions [x, y]
= target_count(tensor_dataset)
class_sample_count # weight
= 1. / class_sample_count.float()
weight # produce weights for each observation in the data set
= torch.tensor([weight[t[2]] for t in tensor_dataset])
samples_weight # prepare sampler
= torch.utils.data.WeightedRandomSampler(weights=samples_weight,
sampler =len(samples_weight),
num_samples=True)
replacementreturn sampler
# create samplers for just the training set
= create_sampler(target_count(train_dataset), train_dataset)
train_sampler
# time function
def format_time(elapsed):
'''
Takes a time in seconds and returns a string hh:mm:ss
'''
# round to the nearest second.
= int(round((elapsed)))
elapsed_rounded # format as hh:mm:ss
return str(datetime.timedelta(seconds=elapsed_rounded))
Now we instantiate the data loaders.
# create DataLoaders with samplers
= DataLoader(train_dataset,
train_dataloader =8,
batch_size=train_sampler,
sampler=False)
shuffle
= DataLoader(val_dataset,
valid_dataloader =8,
batch_size=True)
shuffle
= DataLoader(test_dataset,
test_dataloader =8,
batch_size=True) shuffle
Here, we modify the previously used CNN class. We strip out the nn.Embedding
layers as we are no longer providing a look-up table for embedding vectors. Instead, we are injecting the embedding vectors directly into the CNN from BERT.
# Build Kim Yoon CNN
class KimCNN(nn.Module):
def __init__(self, config):
super().__init__()
= config.output_channel # number of kernels
output_channel = config.num_classes # number of targets to predict
num_classes = config.dropout # dropout value
dropout = config.embedding_dim # length of embedding dim
embedding_dim
= 3 # three conv nets here
ks
# input_channel = word embeddings at a value of 1; 3 for RGB images
= 4 # for single embedding, input_channel = 1
input_channel
# [3, 4, 5] = window height
# padding = padding to account for height of search window
# 3 convolutional nets
self.conv1 = nn.Conv2d(input_channel, output_channel, (3, embedding_dim), padding=(2, 0), groups=4)
self.conv2 = nn.Conv2d(input_channel, output_channel, (4, embedding_dim), padding=(3, 0), groups=4)
self.conv3 = nn.Conv2d(input_channel, output_channel, (5, embedding_dim), padding=(4, 0), groups=4)
# apply dropout
self.dropout = nn.Dropout(dropout)
# fully connected layer for classification
# 3x conv nets * output channel
self.fc1 = nn.Linear(ks * output_channel, num_classes)
def forward(self, x, **kwargs):
#x = x.unsqueeze(1) # get another dimension at first index pos
# squeeze to get size; (batch, channel_output, ~=sent_len) * ks
= [F.relu(self.conv1(x)).squeeze(3), F.relu(self.conv2(x)).squeeze(3), F.relu(self.conv3(x)).squeeze(3)]
x # max-over-time pooling; # (batch, channel_output) * ks
= [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
x # concat results; (batch, channel_output * ks)
= torch.cat(x, 1)
x # add dropout
= self.dropout(x)
x # generate logits (batch, target_size)
= self.fc1(x)
logit return logit
Now, we prepare functions to train, validate, and test our data.
def train(model, dataloader, optimizer):
# capture time
= time.time()
total_t0
# Perform one full pass over the training set.
print("")
print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
print('Training...')
# reset total loss for epoch
= 0
train_total_loss = 0
total_train_f1
# put both models into traning mode
model.train()
kim_model.train()
# for each batch of training data...
for step, batch in enumerate(dataloader):
# progress update every 40 batches.
if step % 40 == 0 and not step == 0:
# Report progress.
print(' Batch {:>5,} of {:>5,}.'.format(step, len(dataloader)))
# Unpack this training batch from our dataloader:
#
# As we unpack the batch, we'll also copy each tensor to the GPU
#
# `batch` contains three pytorch tensors:
# [0]: input ids
# [1]: attention masks
# [2]: labels
= batch[0].cuda()
b_input_ids = batch[1].cuda()
b_input_mask = batch[2].cuda().long()
b_labels
# clear previously calculated gradients
optimizer.zero_grad()
# runs the forward pass with autocasting.
with autocast():
# forward propagation (evaluate model on training batch)
= model(input_ids=b_input_ids, attention_mask=b_input_mask)
outputs
= outputs[2] # get hidden layers
hidden_layers
= torch.stack(hidden_layers, dim=1) # stack the layers
hidden_layers
= hidden_layers[:, -4:] # get the last 4 layers
hidden_layers
= kim_model(hidden_layers)
logits
= criterion(logits.view(-1, 2), b_labels.view(-1))
loss
# sum the training loss over all batches for average loss at end
# loss is a tensor containing a single value
+= loss.item()
train_total_loss
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
# Update the scheduler
scheduler.step()
# calculate preds
= torch.max(logits, 1)
_, predicted
# move logits and labels to CPU
= predicted.detach().cpu().numpy()
predicted = b_labels.detach().cpu().numpy()
y_true
# calculate f1
+= f1_score(predicted, y_true,
total_train_f1 ='weighted',
average=np.unique(predicted))
labels
# calculate the average loss over all of the batches
= train_total_loss / len(dataloader)
avg_train_loss
# calculate the average f1 over all of the batches
= total_train_f1 / len(dataloader)
avg_train_f1
# training time end
= format_time(time.time() - total_t0)
training_time
# Record all statistics from this epoch.
training_stats.append(
{'Train Loss': avg_train_loss,
'Train F1': avg_train_f1,
'Train Time': training_time
}
)
# print result summaries
print("")
print("summary results")
print("epoch | trn loss | trn f1 | trn time ")
print(f"{epoch+1:5d} | {avg_train_loss:.5f} | {avg_train_f1:.5f} | {training_time:}")
#torch.cuda.empty_cache()
return None
def validating(model, dataloader):
# capture validation time
= time.time()
total_t0
# After the completion of each training epoch, measure our performance on
# our validation set.
print("")
print("Running Validation...")
# put both models in evaluation mode
eval()
model.eval()
kim_model.
# track variables
= 0
total_valid_accuracy = 0
total_valid_loss = 0
total_valid_f1 = 0
total_valid_recall = 0
total_valid_precision = 0
total_bert_valid_loss
# evaluate data for one epoch
for batch in dataloader:
# Unpack this training batch from our dataloader:
# `batch` contains three pytorch tensors:
# [0]: input ids
# [1]: attention masks
# [2]: labels
= batch[0].cuda()
b_input_ids = batch[1].cuda()
b_input_mask = batch[2].cuda().long()
b_labels
# tell pytorch not to bother calculating gradients
with torch.no_grad():
# forward propagation (evaluate model on training batch)
= model(input_ids=b_input_ids, attention_mask=b_input_mask)
outputs
= outputs[2] # get hidden layers
hidden_layers
= torch.stack(hidden_layers, dim=1) # stack the layers
hidden_layers
= hidden_layers[:, -4:] # get the last 4 layers
hidden_layers
= kim_model(hidden_layers)
logits
= criterion(logits.view(-1, 2), b_labels.view(-1))
loss
# accumulate validation loss
+= loss.item()
total_valid_loss
# calculate preds
= torch.max(logits, 1)
_, predicted
# move logits and labels to CPU
= predicted.detach().cpu().numpy()
predicted = b_labels.detach().cpu().numpy()
y_true
# calculate f1
+= f1_score(predicted, y_true,
total_valid_f1 ='weighted',
average=np.unique(predicted))
labels
# calculate accuracy
+= accuracy_score(predicted, y_true)
total_valid_accuracy
# calculate precision
+= precision_score(predicted, y_true,
total_valid_precision ='weighted',
average=np.unique(predicted))
labels
# calculate recall
+= recall_score(predicted, y_true,
total_valid_recall ='weighted',
average=np.unique(predicted))
labels
# report final accuracy of validation run
= total_valid_accuracy / len(dataloader)
avg_accuracy
# report final f1 of validation run
global avg_val_f1
= total_valid_f1 / len(dataloader)
avg_val_f1
# report final f1 of validation run
= total_valid_precision / len(dataloader)
avg_precision
# report final f1 of validation run
= total_valid_recall / len(dataloader)
avg_recall
# calculate the average loss over all of the batches.
global avg_val_loss
= total_valid_loss / len(dataloader)
avg_val_loss
# capture end validation time
= format_time(time.time() - total_t0)
training_time
# Record all statistics from this epoch.
valid_stats.append(
{'Val Loss': avg_val_loss,
'Val Accur.': avg_accuracy,
'Val precision': avg_precision,
'Val recall': avg_recall,
'Val F1': avg_val_f1,
'Val Time': training_time
}
)
# print result summaries
print("")
print("summary results")
print("epoch | val loss | val f1 | val time")
print(f"{epoch+1:5d} | {avg_val_loss:.5f} | {avg_val_f1:.5f} | {training_time:}")
return None
def testing(model, dataloader):
print("")
print("Running Testing...")
# capture test time
= time.time()
total_t0
# put both models in evaluation mode
eval()
model.eval()
kim_model.
# track variables
= 0
total_test_accuracy = 0
total_test_loss = 0
total_test_f1 = 0
total_test_recall = 0
total_test_precision
# evaluate data for one epoch
for batch in dataloader:
# Unpack this training batch from our dataloader:
# `batch` contains three pytorch tensors:
# [0]: input ids
# [1]: attention masks
# [2]: labels
= batch[0].cuda()
b_input_ids = batch[1].cuda()
b_input_mask = batch[2].cuda().long()
b_labels
# tell pytorch not to bother calculating gradients
with torch.no_grad():
# forward propagation (evaluate model on training batch)
= model(input_ids=b_input_ids, attention_mask=b_input_mask)
outputs
= outputs[2] # get hidden layers
hidden_layers
= torch.stack(hidden_layers, dim=1) # stack the layers
hidden_layers
= hidden_layers[:, -4:] # get the last 4 layers
hidden_layers
= kim_model(hidden_layers)
logits
= criterion(logits.view(-1, 2), b_labels.view(-1))
loss
# accumulate validation loss
+= loss.item()
total_test_loss
# calculate preds
= torch.max(logits, 1)
_, predicted
# move logits and labels to CPU
= predicted.detach().cpu().numpy()
predicted = b_labels.detach().cpu().numpy()
y_true
# calculate f1
+= f1_score(predicted, y_true,
total_test_f1 ='weighted',
average=np.unique(predicted))
labels
# calculate accuracy
+= accuracy_score(predicted, y_true)
total_test_accuracy
# calculate precision
+= precision_score(predicted, y_true,
total_test_precision ='weighted',
average=np.unique(predicted))
labels
# calculate recall
+= recall_score(predicted, y_true,
total_test_recall ='weighted',
average=np.unique(predicted))
labels
# report final accuracy of test run
= total_test_accuracy / len(dataloader)
avg_accuracy
# report final f1 of test run
= total_test_f1 / len(dataloader)
avg_test_f1
# report final f1 of test run
= total_test_precision / len(dataloader)
avg_precision
# report final f1 of test run
= total_test_recall / len(dataloader)
avg_recall
# calculate the average loss over all of the batches.
= total_test_loss / len(dataloader)
avg_test_loss
# capture end testing time
= format_time(time.time() - total_t0)
training_time
# Record all statistics from this epoch.
test_stats.append(
{'Test Loss': avg_test_loss,
'Test Accur.': avg_accuracy,
'Test precision': avg_precision,
'Test recall': avg_recall,
'Test F1': avg_test_f1,
'Test Time': training_time
}
)# print result summaries
print("")
print("summary results")
print("epoch | test loss | test f1 | test time")
print(f"{epoch+1:5d} | {avg_test_loss:.5f} | {avg_test_f1:.5f} | {training_time:}")
return None
Now we instantiate our models and attach them to the GPU. A few other preparatory objects are created like the loss criteria, epochs, the optimizer, and our optimizer scheduler.
# instantiate BERT model with hidden states
= BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True).cuda()
model
# instantiate CNN config
##
Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]
Downloading: 100%|##########| 433/433 [00:00<00:00, 432kB/s]
##
Downloading: 0%| | 0.00/440M [00:00<?, ?B/s]
Downloading: 0%| | 688k/440M [00:00<01:04, 6.87MB/s]
Downloading: 1%| | 2.25M/440M [00:00<00:53, 8.25MB/s]
Downloading: 1%| | 3.80M/440M [00:00<00:45, 9.59MB/s]
Downloading: 1%|1 | 5.20M/440M [00:00<00:41, 10.6MB/s]
Downloading: 1%|1 | 6.53M/440M [00:00<00:38, 11.3MB/s]
Downloading: 2%|1 | 7.99M/440M [00:00<00:35, 12.1MB/s]
Downloading: 2%|2 | 9.25M/440M [00:00<00:35, 12.2MB/s]
Downloading: 2%|2 | 10.7M/440M [00:00<00:33, 12.8MB/s]
Downloading: 3%|2 | 12.1M/440M [00:00<00:32, 13.0MB/s]
Downloading: 3%|3 | 13.5M/440M [00:01<00:32, 13.3MB/s]
Downloading: 3%|3 | 14.9M/440M [00:01<00:31, 13.5MB/s]
Downloading: 4%|3 | 16.4M/440M [00:01<00:30, 13.7MB/s]
Downloading: 4%|4 | 17.8M/440M [00:01<00:30, 13.9MB/s]
Downloading: 4%|4 | 19.2M/440M [00:01<00:30, 13.7MB/s]
Downloading: 5%|4 | 20.6M/440M [00:01<00:30, 13.8MB/s]
Downloading: 5%|4 | 22.0M/440M [00:01<00:30, 13.7MB/s]
Downloading: 5%|5 | 23.4M/440M [00:01<00:30, 13.6MB/s]
Downloading: 6%|5 | 24.8M/440M [00:01<00:29, 13.9MB/s]
Downloading: 6%|5 | 26.4M/440M [00:01<00:29, 14.2MB/s]
Downloading: 6%|6 | 27.8M/440M [00:02<00:32, 12.8MB/s]
Downloading: 7%|6 | 29.1M/440M [00:02<00:35, 11.4MB/s]
Downloading: 7%|6 | 30.4M/440M [00:02<00:34, 11.9MB/s]
Downloading: 7%|7 | 31.9M/440M [00:02<00:32, 12.6MB/s]
Downloading: 8%|7 | 33.3M/440M [00:02<00:30, 13.2MB/s]
Downloading: 8%|7 | 34.7M/440M [00:02<00:30, 13.3MB/s]
Downloading: 8%|8 | 36.1M/440M [00:02<00:30, 13.1MB/s]
Downloading: 8%|8 | 37.4M/440M [00:02<00:33, 11.9MB/s]
Downloading: 9%|8 | 38.8M/440M [00:02<00:32, 12.4MB/s]
Downloading: 9%|9 | 40.2M/440M [00:03<00:31, 12.9MB/s]
Downloading: 9%|9 | 41.7M/440M [00:03<00:30, 13.3MB/s]
Downloading: 10%|9 | 43.1M/440M [00:03<00:29, 13.5MB/s]
Downloading: 10%|# | 44.6M/440M [00:03<00:28, 13.9MB/s]
Downloading: 10%|# | 46.0M/440M [00:03<00:28, 13.9MB/s]
Downloading: 11%|# | 47.4M/440M [00:03<00:28, 14.0MB/s]
Downloading: 11%|#1 | 48.9M/440M [00:03<00:27, 14.2MB/s]
Downloading: 11%|#1 | 50.3M/440M [00:03<00:27, 14.1MB/s]
Downloading: 12%|#1 | 51.7M/440M [00:03<00:28, 13.7MB/s]
Downloading: 12%|#2 | 53.1M/440M [00:03<00:28, 13.8MB/s]
Downloading: 12%|#2 | 54.5M/440M [00:04<00:28, 13.7MB/s]
Downloading: 13%|#2 | 55.9M/440M [00:04<00:37, 10.3MB/s]
Downloading: 13%|#2 | 57.3M/440M [00:04<00:34, 11.1MB/s]
Downloading: 13%|#3 | 58.7M/440M [00:04<00:32, 11.9MB/s]
Downloading: 14%|#3 | 60.0M/440M [00:04<00:32, 11.6MB/s]
Downloading: 14%|#3 | 61.4M/440M [00:04<00:31, 12.2MB/s]
Downloading: 14%|#4 | 62.6M/440M [00:04<00:30, 12.3MB/s]
Downloading: 15%|#4 | 64.0M/440M [00:04<00:30, 12.5MB/s]
Downloading: 15%|#4 | 65.4M/440M [00:05<00:29, 12.7MB/s]
Downloading: 15%|#5 | 66.8M/440M [00:05<00:28, 13.0MB/s]
Downloading: 16%|#5 | 68.3M/440M [00:05<00:27, 13.5MB/s]
Downloading: 16%|#5 | 69.7M/440M [00:05<00:27, 13.6MB/s]
Downloading: 16%|#6 | 71.1M/440M [00:05<00:26, 13.8MB/s]
Downloading: 16%|#6 | 72.6M/440M [00:05<00:26, 14.1MB/s]
Downloading: 17%|#6 | 74.0M/440M [00:05<00:26, 14.1MB/s]
Downloading: 17%|#7 | 75.4M/440M [00:05<00:26, 13.9MB/s]
Downloading: 17%|#7 | 76.8M/440M [00:05<00:31, 11.5MB/s]
Downloading: 18%|#7 | 78.1M/440M [00:06<00:32, 11.0MB/s]
Downloading: 18%|#8 | 79.6M/440M [00:06<00:30, 11.9MB/s]
Downloading: 18%|#8 | 81.0M/440M [00:06<00:28, 12.4MB/s]
Downloading: 19%|#8 | 82.3M/440M [00:06<00:27, 12.8MB/s]
Downloading: 19%|#9 | 83.9M/440M [00:06<00:26, 13.5MB/s]
Downloading: 19%|#9 | 85.3M/440M [00:06<00:26, 13.4MB/s]
Downloading: 20%|#9 | 86.8M/440M [00:06<00:25, 13.8MB/s]
Downloading: 20%|## | 88.2M/440M [00:06<00:25, 13.6MB/s]
Downloading: 20%|## | 89.7M/440M [00:06<00:24, 14.1MB/s]
Downloading: 21%|## | 91.1M/440M [00:06<00:28, 12.4MB/s]
Downloading: 21%|##1 | 92.6M/440M [00:07<00:26, 13.0MB/s]
Downloading: 21%|##1 | 94.0M/440M [00:07<00:26, 13.2MB/s]
Downloading: 22%|##1 | 95.4M/440M [00:07<00:25, 13.5MB/s]
Downloading: 22%|##1 | 96.9M/440M [00:07<00:24, 13.9MB/s]
Downloading: 22%|##2 | 98.3M/440M [00:07<00:24, 13.9MB/s]
Downloading: 23%|##2 | 99.7M/440M [00:07<00:24, 13.7MB/s]
Downloading: 23%|##2 | 101M/440M [00:07<00:26, 13.0MB/s]
Downloading: 23%|##3 | 102M/440M [00:07<00:25, 13.2MB/s]
Downloading: 24%|##3 | 104M/440M [00:07<00:24, 13.5MB/s]
Downloading: 24%|##3 | 105M/440M [00:08<00:24, 13.5MB/s]
Downloading: 24%|##4 | 107M/440M [00:08<00:24, 13.4MB/s]
Downloading: 25%|##4 | 108M/440M [00:08<00:25, 13.1MB/s]
Downloading: 25%|##4 | 109M/440M [00:08<00:25, 13.1MB/s]
Downloading: 25%|##5 | 111M/440M [00:08<00:26, 12.7MB/s]
Downloading: 25%|##5 | 112M/440M [00:08<00:24, 13.3MB/s]
Downloading: 26%|##5 | 114M/440M [00:08<00:24, 13.3MB/s]
Downloading: 26%|##6 | 115M/440M [00:08<00:23, 13.6MB/s]
Downloading: 26%|##6 | 116M/440M [00:08<00:23, 13.9MB/s]
Downloading: 27%|##6 | 118M/440M [00:08<00:23, 14.0MB/s]
Downloading: 27%|##7 | 119M/440M [00:09<00:22, 14.0MB/s]
Downloading: 27%|##7 | 121M/440M [00:09<00:22, 14.0MB/s]
Downloading: 28%|##7 | 122M/440M [00:09<00:23, 13.8MB/s]
Downloading: 28%|##8 | 124M/440M [00:09<00:22, 13.9MB/s]
Downloading: 28%|##8 | 125M/440M [00:09<00:23, 13.7MB/s]
Downloading: 29%|##8 | 126M/440M [00:09<00:22, 14.0MB/s]
Downloading: 29%|##9 | 128M/440M [00:09<00:22, 14.1MB/s]
Downloading: 29%|##9 | 129M/440M [00:09<00:22, 13.9MB/s]
Downloading: 30%|##9 | 131M/440M [00:09<00:22, 14.1MB/s]
Downloading: 30%|### | 132M/440M [00:09<00:21, 14.2MB/s]
Downloading: 30%|### | 134M/440M [00:10<00:21, 14.3MB/s]
Downloading: 31%|### | 135M/440M [00:10<00:20, 14.7MB/s]
Downloading: 31%|###1 | 137M/440M [00:10<00:21, 14.0MB/s]
Downloading: 31%|###1 | 138M/440M [00:10<00:21, 13.8MB/s]
Downloading: 32%|###1 | 140M/440M [00:10<00:21, 14.0MB/s]
Downloading: 32%|###2 | 141M/440M [00:10<00:21, 13.8MB/s]
Downloading: 32%|###2 | 142M/440M [00:10<00:21, 13.7MB/s]
Downloading: 33%|###2 | 144M/440M [00:10<00:21, 13.9MB/s]
Downloading: 33%|###2 | 145M/440M [00:10<00:21, 14.0MB/s]
Downloading: 33%|###3 | 147M/440M [00:10<00:21, 13.8MB/s]
Downloading: 34%|###3 | 148M/440M [00:11<00:20, 14.1MB/s]
Downloading: 34%|###3 | 150M/440M [00:11<00:20, 14.1MB/s]
Downloading: 34%|###4 | 151M/440M [00:11<00:20, 13.8MB/s]
Downloading: 35%|###4 | 152M/440M [00:11<00:20, 13.9MB/s]
Downloading: 35%|###4 | 154M/440M [00:11<00:20, 13.8MB/s]
Downloading: 35%|###5 | 155M/440M [00:11<00:20, 13.7MB/s]
Downloading: 36%|###5 | 157M/440M [00:11<00:20, 13.6MB/s]
Downloading: 36%|###5 | 158M/440M [00:11<00:20, 13.7MB/s]
Downloading: 36%|###6 | 159M/440M [00:11<00:20, 13.6MB/s]
Downloading: 37%|###6 | 161M/440M [00:12<00:20, 13.7MB/s]
Downloading: 37%|###6 | 162M/440M [00:12<00:19, 14.1MB/s]
Downloading: 37%|###7 | 164M/440M [00:12<00:20, 13.7MB/s]
Downloading: 38%|###7 | 165M/440M [00:12<00:19, 14.2MB/s]
Downloading: 38%|###7 | 167M/440M [00:12<00:19, 14.2MB/s]
Downloading: 38%|###8 | 168M/440M [00:12<00:18, 14.4MB/s]
Downloading: 39%|###8 | 170M/440M [00:12<00:18, 14.4MB/s]
Downloading: 39%|###8 | 171M/440M [00:12<00:18, 14.2MB/s]
Downloading: 39%|###9 | 173M/440M [00:12<00:18, 14.3MB/s]
Downloading: 40%|###9 | 174M/440M [00:12<00:18, 14.2MB/s]
Downloading: 40%|###9 | 175M/440M [00:13<00:18, 14.1MB/s]
Downloading: 40%|#### | 177M/440M [00:13<00:30, 8.50MB/s]
Downloading: 40%|#### | 178M/440M [00:13<00:26, 9.72MB/s]
Downloading: 41%|#### | 180M/440M [00:13<00:24, 10.8MB/s]
Downloading: 41%|####1 | 181M/440M [00:13<00:22, 11.6MB/s]
Downloading: 41%|####1 | 183M/440M [00:13<00:20, 12.3MB/s]
Downloading: 42%|####1 | 184M/440M [00:13<00:19, 12.8MB/s]
Downloading: 42%|####2 | 185M/440M [00:13<00:19, 13.1MB/s]
Downloading: 42%|####2 | 187M/440M [00:14<00:18, 13.4MB/s]
Downloading: 43%|####2 | 188M/440M [00:14<00:18, 13.4MB/s]
Downloading: 43%|####3 | 190M/440M [00:14<00:18, 13.4MB/s]
Downloading: 43%|####3 | 191M/440M [00:14<00:18, 13.4MB/s]
Downloading: 44%|####3 | 192M/440M [00:14<00:18, 13.5MB/s]
Downloading: 44%|####3 | 194M/440M [00:14<00:18, 13.4MB/s]
Downloading: 44%|####4 | 195M/440M [00:14<00:17, 13.7MB/s]
Downloading: 45%|####4 | 197M/440M [00:14<00:17, 13.7MB/s]
Downloading: 45%|####4 | 198M/440M [00:14<00:17, 14.0MB/s]
Downloading: 45%|####5 | 199M/440M [00:14<00:17, 14.0MB/s]
Downloading: 46%|####5 | 201M/440M [00:15<00:17, 13.9MB/s]
Downloading: 46%|####5 | 202M/440M [00:15<00:16, 14.2MB/s]
Downloading: 46%|####6 | 204M/440M [00:15<00:17, 13.6MB/s]
Downloading: 47%|####6 | 205M/440M [00:15<00:16, 13.9MB/s]
Downloading: 47%|####6 | 207M/440M [00:15<00:17, 13.7MB/s]
Downloading: 47%|####7 | 208M/440M [00:15<00:17, 13.6MB/s]
Downloading: 48%|####7 | 209M/440M [00:15<00:17, 13.6MB/s]
Downloading: 48%|####7 | 211M/440M [00:15<00:16, 13.7MB/s]
Downloading: 48%|####8 | 212M/440M [00:15<00:16, 13.8MB/s]
Downloading: 49%|####8 | 214M/440M [00:16<00:16, 14.0MB/s]
Downloading: 49%|####8 | 215M/440M [00:16<00:15, 14.1MB/s]
Downloading: 49%|####9 | 217M/440M [00:16<00:16, 13.8MB/s]
Downloading: 50%|####9 | 218M/440M [00:16<00:15, 14.1MB/s]
Downloading: 50%|####9 | 220M/440M [00:16<00:15, 14.0MB/s]
Downloading: 50%|##### | 221M/440M [00:16<00:15, 14.3MB/s]
Downloading: 51%|##### | 223M/440M [00:16<00:15, 14.3MB/s]
Downloading: 51%|##### | 224M/440M [00:16<00:15, 14.1MB/s]
Downloading: 51%|#####1 | 225M/440M [00:16<00:15, 14.0MB/s]
Downloading: 51%|#####1 | 227M/440M [00:16<00:15, 13.9MB/s]
Downloading: 52%|#####1 | 228M/440M [00:17<00:15, 13.7MB/s]
Downloading: 52%|#####2 | 230M/440M [00:17<00:15, 13.8MB/s]
Downloading: 52%|#####2 | 231M/440M [00:17<00:15, 13.9MB/s]
Downloading: 53%|#####2 | 232M/440M [00:17<00:14, 14.0MB/s]
Downloading: 53%|#####3 | 234M/440M [00:17<00:14, 14.1MB/s]
Downloading: 53%|#####3 | 235M/440M [00:17<00:14, 14.2MB/s]
Downloading: 54%|#####3 | 237M/440M [00:17<00:14, 14.0MB/s]
Downloading: 54%|#####4 | 238M/440M [00:17<00:14, 13.5MB/s]
Downloading: 54%|#####4 | 240M/440M [00:17<00:14, 13.9MB/s]
Downloading: 55%|#####4 | 241M/440M [00:17<00:14, 13.6MB/s]
Downloading: 55%|#####5 | 242M/440M [00:18<00:14, 13.5MB/s]
Downloading: 55%|#####5 | 244M/440M [00:18<00:14, 13.5MB/s]
Downloading: 56%|#####5 | 245M/440M [00:18<00:13, 13.9MB/s]
Downloading: 56%|#####6 | 247M/440M [00:18<00:13, 14.0MB/s]
Downloading: 56%|#####6 | 248M/440M [00:18<00:13, 14.1MB/s]
Downloading: 57%|#####6 | 250M/440M [00:18<00:13, 14.0MB/s]
Downloading: 57%|#####7 | 251M/440M [00:18<00:13, 13.8MB/s]
Downloading: 57%|#####7 | 253M/440M [00:18<00:13, 13.9MB/s]
Downloading: 58%|#####7 | 254M/440M [00:18<00:13, 14.1MB/s]
Downloading: 58%|#####7 | 255M/440M [00:19<00:13, 14.1MB/s]
Downloading: 58%|#####8 | 257M/440M [00:19<00:12, 14.2MB/s]
Downloading: 59%|#####8 | 258M/440M [00:19<00:12, 14.1MB/s]
Downloading: 59%|#####8 | 260M/440M [00:19<00:12, 13.9MB/s]
Downloading: 59%|#####9 | 261M/440M [00:19<00:12, 14.0MB/s]
Downloading: 60%|#####9 | 263M/440M [00:19<00:13, 13.4MB/s]
Downloading: 60%|#####9 | 264M/440M [00:19<00:13, 13.4MB/s]
Downloading: 60%|###### | 265M/440M [00:19<00:12, 13.5MB/s]
Downloading: 61%|###### | 267M/440M [00:19<00:12, 13.5MB/s]
Downloading: 61%|###### | 268M/440M [00:19<00:12, 13.4MB/s]
Downloading: 61%|######1 | 270M/440M [00:20<00:12, 13.7MB/s]
Downloading: 62%|######1 | 271M/440M [00:20<00:12, 13.1MB/s]
Downloading: 62%|######1 | 272M/440M [00:20<00:18, 9.17MB/s]
Downloading: 62%|######2 | 273M/440M [00:20<00:19, 8.78MB/s]
Downloading: 62%|######2 | 274M/440M [00:20<00:19, 8.55MB/s]
Downloading: 63%|######2 | 276M/440M [00:20<00:17, 9.65MB/s]
Downloading: 63%|######2 | 277M/440M [00:20<00:15, 10.7MB/s]
Downloading: 63%|######3 | 279M/440M [00:20<00:13, 11.7MB/s]
Downloading: 64%|######3 | 280M/440M [00:21<00:13, 12.0MB/s]
Downloading: 64%|######3 | 281M/440M [00:21<00:13, 12.2MB/s]
Downloading: 64%|######4 | 283M/440M [00:21<00:11, 13.2MB/s]
Downloading: 65%|######4 | 285M/440M [00:21<00:11, 13.9MB/s]
Downloading: 65%|######4 | 286M/440M [00:21<00:11, 14.0MB/s]
Downloading: 65%|######5 | 287M/440M [00:21<00:11, 13.8MB/s]
Downloading: 66%|######5 | 289M/440M [00:21<00:11, 13.7MB/s]
Downloading: 66%|######5 | 290M/440M [00:21<00:10, 13.8MB/s]
Downloading: 66%|######6 | 292M/440M [00:21<00:10, 13.6MB/s]
Downloading: 67%|######6 | 293M/440M [00:21<00:10, 13.8MB/s]
Downloading: 67%|######6 | 295M/440M [00:22<00:10, 13.7MB/s]
Downloading: 67%|######7 | 296M/440M [00:22<00:10, 13.8MB/s]
Downloading: 67%|######7 | 297M/440M [00:22<00:10, 13.7MB/s]
Downloading: 68%|######7 | 299M/440M [00:22<00:10, 13.7MB/s]
Downloading: 68%|######8 | 300M/440M [00:22<00:10, 13.8MB/s]
Downloading: 68%|######8 | 302M/440M [00:22<00:09, 14.0MB/s]
Downloading: 69%|######8 | 303M/440M [00:22<00:09, 13.9MB/s]
Downloading: 69%|######9 | 304M/440M [00:22<00:09, 14.0MB/s]
Downloading: 69%|######9 | 306M/440M [00:22<00:09, 13.6MB/s]
Downloading: 70%|######9 | 307M/440M [00:23<00:09, 13.4MB/s]
Downloading: 70%|####### | 309M/440M [00:23<00:09, 13.7MB/s]
Downloading: 70%|####### | 310M/440M [00:23<00:09, 13.6MB/s]
Downloading: 71%|####### | 311M/440M [00:23<00:09, 13.4MB/s]
Downloading: 71%|#######1 | 313M/440M [00:23<00:09, 13.6MB/s]
Downloading: 71%|#######1 | 314M/440M [00:23<00:09, 13.4MB/s]
Downloading: 72%|#######1 | 315M/440M [00:23<00:09, 13.4MB/s]
Downloading: 72%|#######1 | 317M/440M [00:23<00:09, 13.7MB/s]
Downloading: 72%|#######2 | 318M/440M [00:23<00:08, 14.1MB/s]
Downloading: 73%|#######2 | 320M/440M [00:23<00:08, 13.9MB/s]
Downloading: 73%|#######2 | 321M/440M [00:24<00:08, 13.9MB/s]
Downloading: 73%|#######3 | 323M/440M [00:24<00:08, 14.0MB/s]
Downloading: 74%|#######3 | 324M/440M [00:24<00:08, 14.2MB/s]
Downloading: 74%|#######3 | 326M/440M [00:24<00:08, 13.7MB/s]
Downloading: 74%|#######4 | 327M/440M [00:24<00:08, 13.9MB/s]
Downloading: 75%|#######4 | 329M/440M [00:24<00:08, 13.5MB/s]
Downloading: 75%|#######4 | 330M/440M [00:24<00:08, 13.5MB/s]
Downloading: 75%|#######5 | 331M/440M [00:24<00:07, 13.7MB/s]
Downloading: 76%|#######5 | 333M/440M [00:24<00:07, 13.7MB/s]
Downloading: 76%|#######5 | 334M/440M [00:24<00:08, 13.0MB/s]
Downloading: 76%|#######6 | 335M/440M [00:25<00:08, 12.3MB/s]
Downloading: 76%|#######6 | 337M/440M [00:25<00:10, 10.2MB/s]
Downloading: 77%|#######6 | 338M/440M [00:25<00:09, 11.2MB/s]
Downloading: 77%|#######7 | 339M/440M [00:25<00:08, 11.5MB/s]
Downloading: 77%|#######7 | 341M/440M [00:25<00:08, 12.2MB/s]
Downloading: 78%|#######7 | 342M/440M [00:25<00:07, 13.0MB/s]
Downloading: 78%|#######8 | 344M/440M [00:25<00:07, 13.4MB/s]
Downloading: 78%|#######8 | 345M/440M [00:25<00:07, 13.5MB/s]
Downloading: 79%|#######8 | 347M/440M [00:25<00:06, 13.4MB/s]
Downloading: 79%|#######9 | 348M/440M [00:26<00:06, 13.5MB/s]
Downloading: 79%|#######9 | 349M/440M [00:26<00:06, 13.5MB/s]
Downloading: 80%|#######9 | 351M/440M [00:26<00:06, 13.6MB/s]
Downloading: 80%|#######9 | 352M/440M [00:26<00:06, 13.6MB/s]
Downloading: 80%|######## | 354M/440M [00:26<00:06, 13.9MB/s]
Downloading: 81%|######## | 355M/440M [00:26<00:06, 13.8MB/s]
Downloading: 81%|######## | 356M/440M [00:26<00:06, 13.5MB/s]
Downloading: 81%|########1 | 358M/440M [00:26<00:05, 13.8MB/s]
Downloading: 82%|########1 | 359M/440M [00:26<00:05, 14.0MB/s]
Downloading: 82%|########1 | 361M/440M [00:27<00:05, 14.2MB/s]
Downloading: 82%|########2 | 362M/440M [00:27<00:05, 14.5MB/s]
Downloading: 83%|########2 | 364M/440M [00:27<00:05, 14.5MB/s]
Downloading: 83%|########2 | 365M/440M [00:27<00:05, 13.9MB/s]
Downloading: 83%|########3 | 367M/440M [00:27<00:05, 14.3MB/s]
Downloading: 84%|########3 | 368M/440M [00:27<00:04, 14.5MB/s]
Downloading: 84%|########3 | 370M/440M [00:27<00:04, 14.6MB/s]
Downloading: 84%|########4 | 371M/440M [00:27<00:04, 14.6MB/s]
Downloading: 85%|########4 | 373M/440M [00:27<00:04, 14.6MB/s]
Downloading: 85%|########4 | 374M/440M [00:27<00:04, 14.2MB/s]
Downloading: 85%|########5 | 376M/440M [00:28<00:04, 14.2MB/s]
Downloading: 86%|########5 | 377M/440M [00:28<00:04, 14.4MB/s]
Downloading: 86%|########5 | 379M/440M [00:28<00:04, 14.2MB/s]
Downloading: 86%|########6 | 380M/440M [00:28<00:04, 14.2MB/s]
Downloading: 87%|########6 | 381M/440M [00:28<00:04, 14.2MB/s]
Downloading: 87%|########6 | 383M/440M [00:28<00:03, 14.4MB/s]
Downloading: 87%|########7 | 384M/440M [00:28<00:03, 14.4MB/s]
Downloading: 88%|########7 | 386M/440M [00:28<00:03, 14.2MB/s]
Downloading: 88%|########7 | 387M/440M [00:28<00:03, 14.2MB/s]
Downloading: 88%|########8 | 389M/440M [00:28<00:03, 14.1MB/s]
Downloading: 89%|########8 | 390M/440M [00:29<00:03, 14.0MB/s]
Downloading: 89%|########8 | 392M/440M [00:29<00:03, 14.0MB/s]
Downloading: 89%|########9 | 393M/440M [00:29<00:03, 14.2MB/s]
Downloading: 90%|########9 | 394M/440M [00:29<00:03, 14.2MB/s]
Downloading: 90%|########9 | 396M/440M [00:29<00:03, 14.2MB/s]
Downloading: 90%|######### | 397M/440M [00:29<00:03, 14.3MB/s]
Downloading: 91%|######### | 399M/440M [00:29<00:02, 14.0MB/s]
Downloading: 91%|######### | 400M/440M [00:29<00:02, 13.8MB/s]
Downloading: 91%|#########1| 402M/440M [00:29<00:02, 14.0MB/s]
Downloading: 92%|#########1| 403M/440M [00:29<00:02, 13.9MB/s]
Downloading: 92%|#########1| 404M/440M [00:30<00:02, 13.9MB/s]
Downloading: 92%|#########2| 406M/440M [00:30<00:02, 14.2MB/s]
Downloading: 92%|#########2| 407M/440M [00:30<00:02, 14.0MB/s]
Downloading: 93%|#########2| 409M/440M [00:30<00:02, 14.0MB/s]
Downloading: 93%|#########3| 410M/440M [00:30<00:02, 14.2MB/s]
Downloading: 93%|#########3| 412M/440M [00:30<00:02, 14.0MB/s]
Downloading: 94%|#########3| 413M/440M [00:30<00:02, 13.4MB/s]
Downloading: 94%|#########4| 415M/440M [00:30<00:01, 13.7MB/s]
Downloading: 94%|#########4| 416M/440M [00:30<00:01, 14.0MB/s]
Downloading: 95%|#########4| 418M/440M [00:31<00:01, 14.5MB/s]
Downloading: 95%|#########5| 419M/440M [00:31<00:01, 14.4MB/s]
Downloading: 95%|#########5| 421M/440M [00:31<00:01, 14.6MB/s]
Downloading: 96%|#########5| 422M/440M [00:31<00:01, 14.1MB/s]
Downloading: 96%|#########6| 424M/440M [00:31<00:01, 14.3MB/s]
Downloading: 97%|#########6| 425M/440M [00:31<00:01, 14.5MB/s]
Downloading: 97%|#########6| 427M/440M [00:31<00:00, 14.2MB/s]
Downloading: 97%|#########7| 428M/440M [00:31<00:00, 14.0MB/s]
Downloading: 97%|#########7| 429M/440M [00:31<00:00, 14.1MB/s]
Downloading: 98%|#########7| 431M/440M [00:31<00:00, 14.0MB/s]
Downloading: 98%|#########8| 432M/440M [00:32<00:00, 14.0MB/s]
Downloading: 98%|#########8| 434M/440M [00:32<00:00, 13.9MB/s]
Downloading: 99%|#########8| 435M/440M [00:32<00:00, 13.7MB/s]
Downloading: 99%|#########9| 436M/440M [00:32<00:00, 13.6MB/s]
Downloading: 99%|#########9| 438M/440M [00:32<00:00, 13.7MB/s]
Downloading: 100%|#########9| 439M/440M [00:32<00:00, 13.4MB/s]
Downloading: 100%|##########| 440M/440M [00:32<00:00, 13.5MB/s]
class config:
def __init__(self):
= 2 # binary
config.num_classes = 16 # number of kernels
config.output_channel = 768 # embed dimension
config.embedding_dim = 0.4 # dropout value
config.dropout return None
# create config
= config()
config1
# instantiate CNN
= KimCNN(config1).cuda()
kim_model
# set loss
= nn.CrossEntropyLoss()
criterion
# set number of epochs
= 4
epochs
# only train the last 4 layers; saves ~600mb of GPU mem and 30s of compute
= []
BERT_parameters = [11, 10, 9, 8]
allowed_layers
for name, param in model.named_parameters():
for layer_num in allowed_layers:
= str(layer_num)
layer_num if ".{}.".format(layer_num) in name:
BERT_parameters.append(param)
# set optimizer
= AdamW([{'params': BERT_parameters, 'lr': 2e-5}], weight_decay=1.0)
optimizer
# set LR scheduler
= len(train_dataloader) * epochs
total_steps = get_linear_schedule_with_warmup(optimizer,
scheduler =0,
num_warmup_steps=total_steps)
num_training_steps
# create gradient scaler for mixed precision
= GradScaler() scaler
Finally we are ready to train. Two containers are created to store the results of each training and validation epoch
# create training result storage
= []
training_stats = []
valid_stats = float('inf')
best_valid_loss
# for each epoch
for epoch in range(epochs):
# train
train(model, train_dataloader, optimizer)# validate
validating(model, valid_dataloader)# check validation loss
if valid_stats[epoch]['Val Loss'] < best_valid_loss:
= valid_stats[epoch]['Val Loss']
best_valid_loss # save best model for use later
'bert-cnn-model1.pt') # torch save
torch.save(model.state_dict(), = model.module if hasattr(model, 'module') else model
model_to_save './model_save/bert-cnn/') # transformers save
model_to_save.save_pretrained('./model_save/bert-cnn/') # transformers save tokenizer.save_pretrained(
##
## ======== Epoch 1 / 4 ========
## Training...
## Batch 40 of 1,005.
## Batch 80 of 1,005.
## Batch 120 of 1,005.
## Batch 160 of 1,005.
## Batch 200 of 1,005.
## Batch 240 of 1,005.
## Batch 280 of 1,005.
## Batch 320 of 1,005.
## Batch 360 of 1,005.
## Batch 400 of 1,005.
## Batch 440 of 1,005.
## Batch 480 of 1,005.
## Batch 520 of 1,005.
## Batch 560 of 1,005.
## Batch 600 of 1,005.
## Batch 640 of 1,005.
## Batch 680 of 1,005.
## Batch 720 of 1,005.
## Batch 760 of 1,005.
## Batch 800 of 1,005.
## Batch 840 of 1,005.
## Batch 880 of 1,005.
## Batch 920 of 1,005.
## Batch 960 of 1,005.
## Batch 1,000 of 1,005.
##
## summary results
## epoch | trn loss | trn f1 | trn time
## 1 | 0.39960 | 0.83102 | 0:12:21
##
## Running Validation...
##
## summary results
## epoch | val loss | val f1 | val time
## 1 | 0.27511 | 0.84472 | 0:00:17
## ('./model_save/bert-cnn/vocab.txt', './model_save/bert-cnn/special_tokens_map.json', './model_save/bert-cnn/added_tokens.json')
##
## ======== Epoch 2 / 4 ========
## Training...
## Batch 40 of 1,005.
## Batch 80 of 1,005.
## Batch 120 of 1,005.
## Batch 160 of 1,005.
## Batch 200 of 1,005.
## Batch 240 of 1,005.
## Batch 280 of 1,005.
## Batch 320 of 1,005.
## Batch 360 of 1,005.
## Batch 400 of 1,005.
## Batch 440 of 1,005.
## Batch 480 of 1,005.
## Batch 520 of 1,005.
## Batch 560 of 1,005.
## Batch 600 of 1,005.
## Batch 640 of 1,005.
## Batch 680 of 1,005.
## Batch 720 of 1,005.
## Batch 760 of 1,005.
## Batch 800 of 1,005.
## Batch 840 of 1,005.
## Batch 880 of 1,005.
## Batch 920 of 1,005.
## Batch 960 of 1,005.
## Batch 1,000 of 1,005.
##
## summary results
## epoch | trn loss | trn f1 | trn time
## 2 | 0.29237 | 0.88316 | 0:12:24
##
## Running Validation...
##
## summary results
## epoch | val loss | val f1 | val time
## 2 | 0.30805 | 0.82737 | 0:00:17
##
## ======== Epoch 3 / 4 ========
## Training...
## Batch 40 of 1,005.
## Batch 80 of 1,005.
## Batch 120 of 1,005.
## Batch 160 of 1,005.
## Batch 200 of 1,005.
## Batch 240 of 1,005.
## Batch 280 of 1,005.
## Batch 320 of 1,005.
## Batch 360 of 1,005.
## Batch 400 of 1,005.
## Batch 440 of 1,005.
## Batch 480 of 1,005.
## Batch 520 of 1,005.
## Batch 560 of 1,005.
## Batch 600 of 1,005.
## Batch 640 of 1,005.
## Batch 680 of 1,005.
## Batch 720 of 1,005.
## Batch 760 of 1,005.
## Batch 800 of 1,005.
## Batch 840 of 1,005.
## Batch 880 of 1,005.
## Batch 920 of 1,005.
## Batch 960 of 1,005.
## Batch 1,000 of 1,005.
##
## summary results
## epoch | trn loss | trn f1 | trn time
## 3 | 0.27105 | 0.89014 | 0:12:11
##
## Running Validation...
##
## summary results
## epoch | val loss | val f1 | val time
## 3 | 0.30127 | 0.84481 | 0:00:17
##
## ======== Epoch 4 / 4 ========
## Training...
## Batch 40 of 1,005.
## Batch 80 of 1,005.
## Batch 120 of 1,005.
## Batch 160 of 1,005.
## Batch 200 of 1,005.
## Batch 240 of 1,005.
## Batch 280 of 1,005.
## Batch 320 of 1,005.
## Batch 360 of 1,005.
## Batch 400 of 1,005.
## Batch 440 of 1,005.
## Batch 480 of 1,005.
## Batch 520 of 1,005.
## Batch 560 of 1,005.
## Batch 600 of 1,005.
## Batch 640 of 1,005.
## Batch 680 of 1,005.
## Batch 720 of 1,005.
## Batch 760 of 1,005.
## Batch 800 of 1,005.
## Batch 840 of 1,005.
## Batch 880 of 1,005.
## Batch 920 of 1,005.
## Batch 960 of 1,005.
## Batch 1,000 of 1,005.
##
## summary results
## epoch | trn loss | trn f1 | trn time
## 4 | 0.25257 | 0.89834 | 0:11:23
##
## Running Validation...
##
## summary results
## epoch | val loss | val f1 | val time
## 4 | 0.27479 | 0.85058 | 0:00:16
## ('./model_save/bert-cnn/vocab.txt', './model_save/bert-cnn/special_tokens_map.json', './model_save/bert-cnn/added_tokens.json')
##
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
## 'precision', 'predicted', average, warn_for)
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
## 'precision', 'predicted', average, warn_for)
After training, we organize the results nicely in pandas
.
# organize results
'precision', 3)
pd.set_option(= pd.DataFrame(data=training_stats)
df_train_stats = pd.DataFrame(data=valid_stats)
df_valid_stats = pd.concat([df_train_stats, df_valid_stats], axis=1)
df_stats 0, 'Epoch', range(1, len(df_stats)+1))
df_stats.insert(= df_stats.set_index('Epoch')
df_stats df_stats
## Train Loss Train F1 Train Time ... Val recall Val F1 Val Time
## Epoch ...
## 1 0.400 0.831 0:12:21 ... 0.861 0.845 0:00:17
## 2 0.292 0.883 0:12:24 ... 0.850 0.827 0:00:17
## 3 0.271 0.890 0:12:11 ... 0.862 0.845 0:00:17
## 4 0.253 0.898 0:11:23 ... 0.867 0.851 0:00:16
##
## [4 rows x 9 columns]
And lastly we run our final test:
# test the model
= []
test_stats 'bert-cnn-model1.pt')) model.load_state_dict(torch.load(
## <All keys matched successfully>
testing(model, test_dataloader)
##
## Running Testing...
##
## summary results
## epoch | test loss | test f1 | test time
## 4 | 0.31259 | 0.83967 | 0:00:16
##
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
## 'precision', 'predicted', average, warn_for)
## C:\Users\Andrew\Anaconda3\envs\my_ml\lib\site-packages\sklearn\metrics\classification.py:1437: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples.
## 'precision', 'predicted', average, warn_for)
= pd.DataFrame(data=test_stats)
df_test_stats df_test_stats
## Test Loss Test Accur. Test precision Test recall Test F1 Test Time
## 0 0.313 0.854 0.862 0.854 0.84 0:00:16
The results show a slight improvement over our standard BERT model at the cost of 4-5x the training time.
Kim, Yoon. “Convolutional neural networks for sentence classification.” arXiv preprint arXiv:1408.5882 (2014).
Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018).