diff --git a/src/data.py b/src/data.py index 4230fd2..112ef89 100644 --- a/src/data.py +++ b/src/data.py @@ -8,6 +8,7 @@ import random import json import numpy as np +import torch.nn.functional as F class Dataset(torch.utils.data.Dataset): def __init__(self, @@ -70,28 +71,45 @@ def sort_data(self): def get_example(self, index): return self.data[index] -def encode_passages(batch_text_passages, tokenizer, max_length): +def get_padded_tensor(ten_list, value = 0): + max_len = max([x.shape[1] for x in ten_list]) + padded_list = [] + for tensor in ten_list: + if(tensor.shape[1] < max_len): + tensor = F.pad(input=tensor, pad=(0, max_len - tensor.shape[1], 0, 0), mode='constant', value=value) + padded_list.append(tensor) + return padded_list + +def encode_passages(batch_text_passages, tokenizer, max_length, pad_to_max_length): + # if padding to the max length, no padding to max passage length, and vice versa. + padding = not pad_to_max_length + passage_ids, passage_masks = [], [] for k, text_passages in enumerate(batch_text_passages): p = tokenizer.batch_encode_plus( text_passages, max_length=max_length, - pad_to_max_length=True, + pad_to_max_length=pad_to_max_length, + padding=padding, return_tensors='pt', truncation=True ) - passage_ids.append(p['input_ids'][None]) - passage_masks.append(p['attention_mask'][None]) + passage_ids.append(p['input_ids']) + passage_masks.append(p['attention_mask']) + + passage_ids = get_padded_tensor(passage_ids, value=0) + passage_masks = get_padded_tensor(passage_masks, value=0) - passage_ids = torch.cat(passage_ids, dim=0) - passage_masks = torch.cat(passage_masks, dim=0) + passage_ids = torch.stack(passage_ids) + passage_masks = torch.stack(passage_masks) return passage_ids, passage_masks.bool() class Collator(object): - def __init__(self, text_maxlength, tokenizer, answer_maxlength=20): + def __init__(self, text_maxlength, tokenizer, answer_maxlength=20, pad_to_max_length=False): self.tokenizer = tokenizer self.text_maxlength = text_maxlength self.answer_maxlength = answer_maxlength + self.pad_to_max_length = pad_to_max_length def __call__(self, batch): assert(batch[0]['target'] != None) @@ -115,7 +133,8 @@ def append_question(example): text_passages = [append_question(example) for example in batch] passage_ids, passage_masks = encode_passages(text_passages, self.tokenizer, - self.text_maxlength) + self.text_maxlength, + self.pad_to_max_length) return (index, target_ids, target_mask, passage_ids, passage_masks) diff --git a/src/options.py b/src/options.py index ce03207..3757071 100644 --- a/src/options.py +++ b/src/options.py @@ -89,6 +89,7 @@ def initialize_parser(self): help='save model every steps during training') self.parser.add_argument('--eval_print_freq', type=int, default=1000, help='print intermdiate results of evaluation every steps') + self.parser.add_argument('--pad_to_max_length', action='store_true', help='Should the passages all be padded to the max length provided (True value here), or to the max length of the longest passage? (False value here).') def print_options(self, opt): diff --git a/test_reader.py b/test_reader.py index e53434e..25f8cbc 100644 --- a/test_reader.py +++ b/test_reader.py @@ -35,7 +35,7 @@ def evaluate(model, dataset, dataloader, tokenizer, opt): with torch.no_grad(): for i, batch in enumerate(dataloader): (idx, _, _, context_ids, context_mask) = batch - + if opt.write_crossattention_scores: model.reset_score_storage() @@ -101,7 +101,7 @@ def evaluate(model, dataset, dataloader, tokenizer, opt): tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False) - collator_function = src.data.Collator(opt.text_maxlength, tokenizer) + collator_function = src.data.Collator(opt.text_maxlength, tokenizer, pad_to_max_length=opt.pad_to_max_length) eval_examples = src.data.load_data( opt.eval_data, global_rank=opt.global_rank, #use the global rank and world size attibutes to split the eval set on multiple gpus diff --git a/train_reader.py b/train_reader.py index f760992..3f288e6 100644 --- a/train_reader.py +++ b/train_reader.py @@ -154,7 +154,7 @@ def evaluate(model, dataset, tokenizer, collator, opt): #load data tokenizer = transformers.T5Tokenizer.from_pretrained(model_name) - collator = src.data.Collator(opt.text_maxlength, tokenizer, answer_maxlength=opt.answer_maxlength) + collator = src.data.Collator(opt.text_maxlength, tokenizer, answer_maxlength=opt.answer_maxlength, pad_to_max_length=opt.pad_to_max_length) # use golbal rank and world size to split the eval set on multiple gpus train_examples = src.data.load_data(