torchtext BucketIterator минимальное заполнение

Я пытаюсь использовать функцию BucketIterator.splits в torchtext для загрузки данных из CSV-файлов для использования в CNN. Все работает нормально, если у меня нет пакета, в котором самое длинное предложение короче, чем самый большой размер фильтра.

В моем примере у меня есть фильтры размеров 3, 4 и 5, поэтому, если в самом длинном предложении нет хотя бы 5 слов, я получаю сообщение об ошибке. Есть ли способ позволить BucketIterator динамически устанавливать заполнение для пакетов, а также устанавливать минимальную длину заполнения?

Это мой код, который я использую для своего BucketIterator:

train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)

Я надеюсь, есть способ установить минимальную длину для sort_key или что-то в этом роде?

Я пробовал это, но это не работает:

FILTER_SIZES = [3,4,5]
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device) 

person paul41    schedule 09.07.2018    source источник


Ответы (2)


Я просмотрел исходный код torchtext, чтобы лучше понять, что делает sort_key, и понял, почему моя первоначальная идея не сработала.

Я не уверен, лучшее это решение или нет, но я нашел решение, которое работает. Я создал функцию токенизатора, которая дополняет текст, если он короче самой длинной длины фильтра, а затем создает оттуда BucketIterator.

FILTER_SIZES = [3,4,5]
spacy_en = spacy.load('en')

def tokenizer(text):
    token = [t.text for t in spacy_en.tokenizer(text)]
    if len(token) < FILTER_SIZES[-1]:
        for i in range(0, FILTER_SIZES[-1] - len(token)):
            token.append('<PAD>')
    return token

TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, tensor_type=torch.cuda.LongTensor)

train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
person paul41    schedule 09.07.2018

Хотя подход @ paul41 работает, это несколько неправильное использование. Правильный способ сделать это - использовать preprocessing или postprocessing (до или после нумерации соответственно). Вот пример postprocessing:

def get_pad_to_min_len_fn(min_length):
    def pad_to_min_len(batch, vocab, min_length=min_length):
        pad_idx = vocab.stoi['<pad>']
        for idx, ex in enumerate(batch):
            if len(ex) < min_length:
                batch[idx] = ex + [pad_idx] * (min_length - len(ex))
        return batch
    return pad_to_min_len

FILTER_SIZES = [3,4,5]
min_len_padding = get_pad_to_min_len_fn(min_length=max(FILTER_SIZES))

TEXT = Field(sequential=True, use_vocab=True, lower=True, batch_first=True, 
             postprocessing=min_len_padding)

Вложенные функции, необходимые для передачи параметров внутренней функции, если они определены в основном цикле (например, min_length = max(FILTER_SIZES)), но параметры могут быть жестко запрограммированы внутри функции, если это работает.

person Nikita    schedule 28.04.2020