Как я могу использовать PyTorch DataLoader для обучения с подкреплением?

Я пытаюсь настроить обобщенную структуру обучения с подкреплением в PyTorch, чтобы воспользоваться преимуществами всех утилит высокого уровня, которые используют PyTorch DataSet и DataLoader, например Ignite или FastAI, но я столкнулся с блокировщиком с динамической природой Данные обучения с подкреплением:

  • Элементы данных генерируются из кода, а не считываются из файла, и они зависят от предыдущих действий и результатов модели, поэтому для каждого вызова nextItem требуется доступ к состоянию модели.
  • Учебные эпизоды не имеют фиксированной длины, поэтому мне нужен динамический размер пакета, а также динамический общий размер набора данных. Я бы предпочел использовать функцию завершающего условия вместо числа. Я мог бы «возможно» сделать это с помощью дополнений, как при обработке предложений НЛП, но это настоящий взлом.

Мои поиски в Google и StackOverflow пока не дали никаких результатов. Кто-нибудь знает о существующих решениях или обходных путях использования DataLoader или DataSet с обучением с подкреплением? Я ненавижу терять доступ ко всем существующим библиотекам, которые зависят от них.


person Ken Otwell    schedule 29.07.2019    source источник


Ответы (1)


Вот один фреймворк на основе PyTorch и вот что-то из Facebook.

Что касается вашего вопроса (и, без сомнения, благородного квеста):

Вы можете легко создать torch.utils.data.Dataset, зависящий от чего-либо, включая модель, что-то вроде этого (простите за слабую абстракцию, это просто для доказательства):

import typing

import torch
from torch.utils.data import Dataset


class Environment(Dataset):
    def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
        self.current_state = initial_state
        self.actor: torch.nn.Module = actor
        self.max_interactions: int = max_interactions

    # Just ignore the index
    def __getitem__(self, _):
        self.current_state = self.actor.update(self.current_state)
        return self.current_state.get_data()

    def __len__(self):
        return self.max_interactions

Предположим, что torch.nn.Module-подобная сеть имеет какое-то update изменяющееся состояние окружающей среды. В общем, это просто структура Python, и с ее помощью можно моделировать множество вещей.

Вы можете указать max_interactions как почти infinite или изменить его на лету, если необходимо, с помощью некоторых обратных вызовов во время обучения (поскольку __len__, вероятно, будет вызываться несколько раз по всему коду). Кроме того, среда может предоставлять batches вместо образцов.

torch.utils.data.DataLoader имеет аргумент batch_sampler, там вы можете создавать партии разной длины. Поскольку сеть не зависит от первого измерения, вы также можете вернуть любой размер пакета оттуда.

КСТАТИ. Заполнение следует использовать, если каждый образец будет иметь разную длину, разный размер партии не имеет к этому никакого отношения.

person Szymon Maszke    schedule 29.07.2019
comment
Спасибо, Шимон - это достойный подход. Это своего рода взлом, учитывая, что мы на самом деле не знаем, как и где вызывается len (находится ли он в диапазоне for?). Но это, вероятно, лучшее, что мы можем сделать. Но особенно спасибо за ссылку SLM - похоже, это действительно хорошая работа. Я собираюсь потратить некоторое время на то, чтобы убедиться, что не изобретаю велосипед заново. - person Ken Otwell; 30.07.2019