Как объединить 2 модели pytorch и сделать первую не обучаемой в PyTorch

У меня две сети, которые мне нужно объединить для моей полной модели. Однако моя первая модель предварительно обучена, и мне нужно сделать ее необучаемой при обучении полной модели. Как я могу добиться этого в PyTorch.

Я могу объединить две модели, используя этот ответ

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x1)
        return x2

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)

В основном здесь я хочу загрузить предварительно обученный modelA и сделать его необучаемым при обучении модели Ensemble.


person Nagabhushan S N    schedule 09.12.2020    source источник


Ответы (2)


Вы можете заморозить все параметры модели, которую не хотите тренировать, установив для requires_grad значение false. Нравится:

for param in model.parameters():
    param.requires_grad = False

Это должно сработать для вас.

Другой способ - справиться с этим в вашем цикле поезда:

modelA = MyModelA()
modelB = MyModelB()

criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)

for epoch in range(epochs):
    for samples, targets in dataloader:
        optimizerB.zero_grad()

        x = modelA.train()(samples)
        predictions = modelB.train()(samples)
    
        loss = criterionB(predictions, targets)
        loss.backward()
        optimizerB.step()

Таким образом, вы передаете вывод modelA в modelB, но оптимизируете только modelB.

person Theodor Peifer    schedule 09.12.2020

Один из простых способов сделать это - detach выходной тензор модели, который вы не хотите обновлять, и он не будет распространять градиент обратно на подключенную модель. В вашем случае вы можете просто detach x2 тензор непосредственно перед объединением с x1 в прямой функции модели MyEnsemble, чтобы сохранить вес modelB неизменным.

Итак, новая функция пересылки должна выглядеть следующим образом:

def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2.detach()), dim=1)  # Detaching x2, so modelB wont be updated
        x = self.classifier(F.relu(x))
        return x
person Kaushik Roy    schedule 09.12.2020