Как я могу возобновить тренировку pl.Trainer после перерыва?

У меня есть объекты Model и Trainer pytorch-lightning, которые инициализируются следующим образом:

checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join('experiments', experiment_name, '{epoch}-{avg_valid_iou:.4f}'),
    save_top_k=1,
    verbose=True,
    monitor='avg_valid_iou',
    mode='max',
    prefix=''
)
model = nn.DataParallel (FaultNetPL(batch_size=5)).cuda()
trainer = Trainer( checkpoint_callback=checkpoint_callback, 
                  max_epochs=500,gpus=1,
                  logger=logger)

Затем я начинаю тренироваться, используя:

trainer.fit(model)

Но обучение было прервано, и теперь я хотел бы возобновить его, используя контрольную точку с N-й итерации. Поэтому я попытался инициализировать модель и тренер как:

model = FaultNetPL.load_from_checkpoint('experiments/VNET/epoch=77-avg_valid_iou=0.7604.ckpt',batch_size=5)
trainer = Trainer(resume_from_checkpoint = 'epoch=77-avg_valid_iou=0.7604.ckpt', 
                  checkpoint_callback=checkpoint_callback, 
                  max_epochs=500,gpus=1,
                  logger=logger)

Но снова и снова обучение с нуля (с нулевой эпохи и огромная ошибка). Что я пропустил?


person Андрей Севостьянов    schedule 01.03.2021    source источник


Ответы (1)


Вы должны сохранить не только состояние модели, но также состояние оптимизатора и значение начальной эпохи. Например:

state({
       'epoch': epoch + 1,
       'state_dict': model.module.state_dict(),
       'optimizer': optimizer.state_dict(),
      })

После сохранения контрольной точки вы можете возобновить обучение с помощью следующих команд:

checkpoint = torch.load(state_file)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_val = checkpoint['epoch']

for epoch in range(start_val, max_val):
   ...
   ...
person Shmn    schedule 01.03.2021
comment
Спасибо за ответ. Но что на самом деле должно быть в курсе .... ..... Также я использую молнию, а не питатель - person Андрей Севостьянов; 02.03.2021
comment
Я также упомянул в коде способ запуска ModelCheckpoint. - person Андрей Севостьянов; 02.03.2021