Как изменить DataLoader в PyTorch для чтения одного изображения для предсказания?

В настоящее время у меня есть предварительно обученная модель, которая использует DataLoader для чтения пакета изображений для обучения модели.

self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, 
   num_workers=1, pin_memory=True)

...

model.eval()
for step, inputs in enumerate(test_loader.data_loader):
   outputs = model(torch.cat([inputs], 1))

...

Я хочу обрабатывать (делать прогнозы) изображения по мере их поступления из очереди. Он должен быть похож на код, который считывает одно изображение и запускает модель, чтобы делать на нем прогнозы. Что-то вроде следующего:

from PIL import Image

new_input = Image.open(image_path)
model.eval()
outputs = model(torch.cat([new_input ], 1))

Мне было интересно, не могли бы вы подсказать мне, как это сделать, и применить те же преобразования в DataLoader.


person Hamid R. Darabi    schedule 02.04.2020    source источник
comment
Это действительно будет зависеть от того, как работает ваш dataset. Детали которых в вопросе не приводятся.   -  person jodag    schedule 03.04.2020
comment
@jodag Я отредактировал вопрос, чтобы предоставить больше контекста. Я ценю вашу помощь.   -  person Hamid R. Darabi    schedule 03.04.2020


Ответы (2)


Это можно сделать с помощью IterableDataset:

from torch.utils.data import IterableDataset

class MyDataset(IterableDataset):
    def __init__(self, image_queue):
      self.queue = image_queue

    def read_next_image(self):
        while self.queue.qsize() > 0:
            # you can add transform here
            yield self.queue.get()
        return None

    def __iter__(self):
        return self.read_next_image()

и batch_size = 1:

import queue
import torchvision.transforms.functional as TF

buffer = queue.Queue()
new_input = Image.open(image_path)
buffer.put(TF.to_tensor(new_input)) 
# ... Populate queue here

dataset = MyDataset(buffer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
for data in dataloader:
   model(data) # data is one-image batch of size [1,3,H,W] where 3 - number of color channels
person Anton Ganichev    schedule 04.04.2020
comment
Спасибо @Anton Ganichev - person Hamid R. Darabi; 04.04.2020

Я не знаю о dataLoader, но вы можете загрузить одно изображение, используя следующую функцию:

def safe_pil_loader(path, from_memory=False):
try:
    if from_memory:
        img = Image.open(path)
        res = img.convert('RGB')
    else:
        with open(path, 'rb') as f:
            img = Image.open(f)
            res = img.convert('RGB')
except:
    res = Image.new('RGB', (227, 227), color=0)
return res

А для применения трансформации вы можете сделать следующее:

trans = transforms.Compose([
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            normalize,
        ])
img=trans(img)
person Marzi Heidari    schedule 03.04.2020
comment
Спасибо, @Marzieh Heidari! - person Hamid R. Darabi; 03.04.2020