Я пытаюсь использовать функцию BucketIterator.splits в torchtext для загрузки данных из CSV-файлов для использования в CNN. Все работает нормально, если у меня нет пакета, в котором самое длинное предложение короче, чем самый большой размер фильтра.
В моем примере у меня есть фильтры размеров 3, 4 и 5, поэтому, если в самом длинном предложении нет хотя бы 5 слов, я получаю сообщение об ошибке. Есть ли способ позволить BucketIterator динамически устанавливать заполнение для пакетов, а также устанавливать минимальную длину заполнения?
Это мой код, который я использую для своего BucketIterator:
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
Я надеюсь, есть способ установить минимальную длину для sort_key или что-то в этом роде?
Я пробовал это, но это не работает:
FILTER_SIZES = [3,4,5]
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device)