Skip to content
This repository was archived by the owner on Feb 1, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import json
import numpy as np
import torch.nn.functional as F

class Dataset(torch.utils.data.Dataset):
def __init__(self,
Expand Down Expand Up @@ -70,28 +71,45 @@ def sort_data(self):
def get_example(self, index):
return self.data[index]

def encode_passages(batch_text_passages, tokenizer, max_length):
def get_padded_tensor(ten_list, value = 0):
max_len = max([x.shape[1] for x in ten_list])
padded_list = []
for tensor in ten_list:
if(tensor.shape[1] < max_len):
tensor = F.pad(input=tensor, pad=(0, max_len - tensor.shape[1], 0, 0), mode='constant', value=value)
padded_list.append(tensor)
return padded_list

def encode_passages(batch_text_passages, tokenizer, max_length, pad_to_max_length):
# if padding to the max length, no padding to max passage length, and vice versa.
padding = not pad_to_max_length

passage_ids, passage_masks = [], []
for k, text_passages in enumerate(batch_text_passages):
p = tokenizer.batch_encode_plus(
text_passages,
max_length=max_length,
pad_to_max_length=True,
pad_to_max_length=pad_to_max_length,
padding=padding,
return_tensors='pt',
truncation=True
)
passage_ids.append(p['input_ids'][None])
passage_masks.append(p['attention_mask'][None])
passage_ids.append(p['input_ids'])
passage_masks.append(p['attention_mask'])

passage_ids = get_padded_tensor(passage_ids, value=0)
passage_masks = get_padded_tensor(passage_masks, value=0)

passage_ids = torch.cat(passage_ids, dim=0)
passage_masks = torch.cat(passage_masks, dim=0)
passage_ids = torch.stack(passage_ids)
passage_masks = torch.stack(passage_masks)
return passage_ids, passage_masks.bool()

class Collator(object):
def __init__(self, text_maxlength, tokenizer, answer_maxlength=20):
def __init__(self, text_maxlength, tokenizer, answer_maxlength=20, pad_to_max_length=False):
self.tokenizer = tokenizer
self.text_maxlength = text_maxlength
self.answer_maxlength = answer_maxlength
self.pad_to_max_length = pad_to_max_length

def __call__(self, batch):
assert(batch[0]['target'] != None)
Expand All @@ -115,7 +133,8 @@ def append_question(example):
text_passages = [append_question(example) for example in batch]
passage_ids, passage_masks = encode_passages(text_passages,
self.tokenizer,
self.text_maxlength)
self.text_maxlength,
self.pad_to_max_length)

return (index, target_ids, target_mask, passage_ids, passage_masks)

Expand Down
1 change: 1 addition & 0 deletions src/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def initialize_parser(self):
help='save model every <save_freq> steps during training')
self.parser.add_argument('--eval_print_freq', type=int, default=1000,
help='print intermdiate results of evaluation every <eval_print_freq> steps')
self.parser.add_argument('--pad_to_max_length', action='store_true', help='Should the passages all be padded to the max length provided (True value here), or to the max length of the longest passage? (False value here).')


def print_options(self, opt):
Expand Down
4 changes: 2 additions & 2 deletions test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def evaluate(model, dataset, dataloader, tokenizer, opt):
with torch.no_grad():
for i, batch in enumerate(dataloader):
(idx, _, _, context_ids, context_mask) = batch

if opt.write_crossattention_scores:
model.reset_score_storage()

Expand Down Expand Up @@ -101,7 +101,7 @@ def evaluate(model, dataset, dataloader, tokenizer, opt):

tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)

collator_function = src.data.Collator(opt.text_maxlength, tokenizer)
collator_function = src.data.Collator(opt.text_maxlength, tokenizer, pad_to_max_length=opt.pad_to_max_length)
eval_examples = src.data.load_data(
opt.eval_data,
global_rank=opt.global_rank, #use the global rank and world size attibutes to split the eval set on multiple gpus
Expand Down
2 changes: 1 addition & 1 deletion train_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def evaluate(model, dataset, tokenizer, collator, opt):

#load data
tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
collator = src.data.Collator(opt.text_maxlength, tokenizer, answer_maxlength=opt.answer_maxlength)
collator = src.data.Collator(opt.text_maxlength, tokenizer, answer_maxlength=opt.answer_maxlength, pad_to_max_length=opt.pad_to_max_length)

# use golbal rank and world size to split the eval set on multiple gpus
train_examples = src.data.load_data(
Expand Down