Tengo dos redes, que necesito concatenar para mi modelo completo. Sin embargo, mi primer modelo está preentrenado y necesito que no se pueda entrenar cuando entreno el modelo completo. ¿Cómo puedo lograr esto en PyTorch?

Puedo concatenar dos modelos usando esta respuesta

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x1)
        return x2

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)

Básicamente aquí, quiero cargar modelA previamente entrenado y hacerlo no entrenable al entrenar el modelo Ensemble.

0
Nagabhushan S N 9 dic. 2020 a las 15:03

2 respuestas

La mejor respuesta

Puede congelar todos los parámetros del modelo que no desea entrenar, estableciendo requires_grad en falso. Me gusta esto:

for param in model.parameters():
    param.requires_grad = False

Esto debería funcionar para ti.

Otra forma es manejar esto en su bucle de tren:

modelA = MyModelA()
modelB = MyModelB()

criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)

for epoch in range(epochs):
    for samples, targets in dataloader:
        optimizerB.zero_grad()

        x = modelA.train()(samples)
        predictions = modelB.train()(samples)
    
        loss = criterionB(predictions, targets)
        loss.backward()
        optimizerB.step()

Así que pasa la salida de modelo a modelo pero optimiza solo modeloB.

1
Theodor Peifer 9 dic. 2020 a las 12:14

Una forma fácil de hacerlo es detach el tensor de salida del modelo que no desea actualizar y no retrocederá el gradiente al modelo conectado. En su caso, puede simplemente detach x2 tensor justo antes de concatinar con x1 en la función de avance del modelo MyEnsemble para mantener el peso de modelB sin cambios.

Entonces, la nueva función de avance debería ser la siguiente:

def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2.detach()), dim=1)  # Detaching x2, so modelB wont be updated
        x = self.classifier(F.relu(x))
        return x
1
Kaushik Roy 9 dic. 2020 a las 12:46
65216411