Условный GAN - одинаковое перемешивание / разделение двух наборов данных

Я пытаюсь использовать DCGAN для раскрашивания некоторых изображений. При этом я настраиваю свой GAN на версии изображений в оттенках серого. Затем я хочу обучить свой GAN / дискриминатор сначала с партией реальных изображений, а затем с партией поддельных изображений. Время от времени я хочу сравнивать цветную версию изображений, версию в градациях серого и истинную версию изображения. Поэтому мне нужно, чтобы партии реальных / серых изображений разделялись одинаково. Использую питторч. Глядя на код, который я включил, они должны дать одинаковые пакеты. Однако они этого не делают.

Пробовал без worker_init_fn. Я также пробовал разные случайные вызовы функций и безрезультатно передавал их worker_init_fn.

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))

dataloader_gray = torch.utils.data.DataLoader(dataset_gray, batch_size=batch_size,
                                          shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))

for i, (data, data_gray) in enumerate(zip(dataloader, dataloader_gray)):
    doStuff()

person forTheRoad    schedule 22.04.2019    source источник
comment
Вы загружаете цветные изображения и серые изображения в отдельные загрузчики данных, как убедиться, что они правильно спарены?   -  person Haran Rajkumar    schedule 22.04.2019
comment
Харан, возможно, я плохо сформулировал свой вопрос. Мой вопрос в том, как правильно сочетать изображения в партиях. Разделив партии с использованием одного и того же ГСЧ / начального числа, изображения должны быть правильно соединены. Однако я не могу это понять ..   -  person forTheRoad    schedule 22.04.2019
comment
Попробуйте вместо этого использовать torch.manual_seed(). Вы также проверили, какие первые изображения загружает загрузчик? Соответствуют ли цвет и серый цвет? При другом запуске загрузчики загружают изображения в том же порядке?   -  person Haran Rajkumar    schedule 22.04.2019
comment
Я пробовал worker_init_fn = torch.manual_seed. Ничего не меняет. Результаты показаны здесь imgur.com/a/uCSDvmB. Серый цвет должен быть просто серыми версиями изображений, а не полностью другими.   -  person forTheRoad    schedule 22.04.2019
comment
torch.manual_seed(<some number>) должен быть установлен в начале вашего скрипта.   -  person Haran Rajkumar    schedule 22.04.2019
comment
Установка torch.manual_seed (12) в начале моего скрипта, к сожалению, не решает проблему.   -  person forTheRoad    schedule 22.04.2019
comment
Получаете ли вы одинаковые результаты при нескольких запусках?   -  person Haran Rajkumar    schedule 22.04.2019
comment
Я бы рекомендовал создать новый класс, наследующий набор данных, который загружает оба изображения последовательно. Это было бы более надежным решением по сравнению с игрой с начальными значениями.   -  person Haran Rajkumar    schedule 22.04.2019
comment
да. Если я использую torch.manual_seed (12), я получаю результаты изображения. Если я удалю посев вручную, я этого не сделаю. Поскольку мы выполняем итерацию по двум загрузчикам данных, я думаю, что они оба изменяют состояние RNG, что дает нам разные серые / цветные изображения.   -  person forTheRoad    schedule 22.04.2019


Ответы (1)


Как указывал в комментариях Харан Раджкумар, гораздо лучшим решением было бы объединить оба набора данных заранее и применить torch.utils.DataLoader после этого (при условии оба torch.utils.Dataset объекта содержат изображения в том же порядке в начале).

Обратите внимание, что для выполнения этой операции не нужно создавать отдельный класс, torch.utils.data.ConcatDataset предоставляет эту функциональность из коробки.

Не уверен в вашем точном коде, но этого должно быть достаточно (или, по крайней мере, достаточно, чтобы вы в правильном направлении):

import torch

dataloader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset(dataset, dataset_gray),
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

for i, (data, data_gray) in enumerate(dataloader):
    doStuff()

Как видите, он намного читабельнее и проще в использовании.

person Szymon Maszke    schedule 22.04.2019
comment
Спасибо. Я последовал совету @HaranRajkumar и создал новый класс. Оно работает! - person forTheRoad; 22.04.2019
comment
ИМО, этот подход немного короче и удобочитаем, но рад, что он сработал для вас. - person Szymon Maszke; 22.04.2019