У меня есть объекты Model и Trainer pytorch-lightning, которые инициализируются следующим образом:
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join('experiments', experiment_name, '{epoch}-{avg_valid_iou:.4f}'),
save_top_k=1,
verbose=True,
monitor='avg_valid_iou',
mode='max',
prefix=''
)
model = nn.DataParallel (FaultNetPL(batch_size=5)).cuda()
trainer = Trainer( checkpoint_callback=checkpoint_callback,
max_epochs=500,gpus=1,
logger=logger)
Затем я начинаю тренироваться, используя:
trainer.fit(model)
Но обучение было прервано, и теперь я хотел бы возобновить его, используя контрольную точку с N-й итерации. Поэтому я попытался инициализировать модель и тренер как:
model = FaultNetPL.load_from_checkpoint('experiments/VNET/epoch=77-avg_valid_iou=0.7604.ckpt',batch_size=5)
trainer = Trainer(resume_from_checkpoint = 'epoch=77-avg_valid_iou=0.7604.ckpt',
checkpoint_callback=checkpoint_callback,
max_epochs=500,gpus=1,
logger=logger)
Но снова и снова обучение с нуля (с нулевой эпохи и огромная ошибка). Что я пропустил?