Как мне увеличить данные после разделения набора данных для обучения на набор для обучения и проверки для CIFAR10 с помощью PyTorch?

При классификации CIFAR10 в PyTorch обычно используется 50 000 обучающих образцов и 10 000 тестовых образцов. Однако, если мне нужно создать набор для проверки, я могу сделать это, разделив обучающий набор на 40000 образцов поездов и 10000 образцов проверки. Я использовал следующие коды

train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

cifar_train_L = CIFAR10('./data',download=True, train= True, transform = train_transform)
cifar_test = CIFAR10('./data',download=True, train = False, transform= test_transform) 

train_size = int(0.8*len(cifar_training))
val_size = len(cifar_training) - train_size
cifar_train, cifar_val = torch.utils.data.random_split(cifar_train_L,[train_size,val_size])

train_dataloader = torch.utils.data.DataLoader(cifar_train, batch_size= BATCH_SIZE, shuffle= True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(cifar_test,batch_size= BATCH_SIZE, shuffle= True, num_workers= 2)
val_dataloader = torch.utils.data.DataLoader(cifar_val,batch_size= BATCH_SIZE, shuffle= True, num_workers= 2)

Обычно при расширении данных в PyTorch в функции transforms.Compose используются различные процессы дополнения (т.е. transforms.RandomHorizontalFlip ()). Однако, если я использую эти процессы дополнения перед разделением обучающего набора и набора проверки, расширенные данные также будут включены в набор проверки. Есть ли способ исправить эту проблему?

Короче говоря, я хочу вручную разделить обучающий набор данных на обучающий и проверочный набор, а также хочу использовать метод увеличения данных в новом обучающем наборе.


person Bloodstone Programmer    schedule 04.11.2019    source источник


Ответы (1)


Вы можете вручную переопределить transforms набора данных:

cifar_train, cifar_val = torch.utils.data.random_split(cifar_train_L,[train_size,val_size])
cifar_val.transforms = test_transform 
person Shai    schedule 04.11.2019