¿Dónde hay una conexión explícita entre optimizer y loss?

¿Cómo sabe el optimizador dónde obtener los gradientes de la pérdida sin una llamada como esto optimizer.step(loss)?

-Más contexto-

Cuando minimizo la pérdida, no tuve que pasar los gradientes al optimizador.

loss.backward() # Back Propagation
optimizer.step() # Gardient Descent
26
Aerin 30 dic. 2018 a las 09:30

3 respuestas

La mejor respuesta

Sin profundizar demasiado en las partes internas de pytorch, puedo ofrecer una respuesta simplista:

Recuerde que cuando inicializa optimizer explícitamente le dice qué parámetros (tensores) del modelo debe actualizar. Los propios tensores "almacenan" los gradientes (tienen un {{X1 }} y un requires_grad atributos) una vez que llame backward() en la pérdida. Después de calcular los gradientes para todos los tensores en el modelo, llamar a optimizer.step() hace que el optimizador repita todos los parámetros (tensores), se supone que debe actualizar y utilizar sus grad almacenados internamente para actualizar sus valores.

25
Shai 30 dic. 2018 a las 06:39

Digamos que definimos un modelo: model, y la función de pérdida: criterion y tenemos la siguiente secuencia de pasos:

pred = model(input)
loss = criterion(pred, true_labels)
loss.backward()

pred tendrá un atributo grad_fn, que hace referencia a una función que lo creó y lo vincula al modelo. Por lo tanto, loss.backward() tendrá información sobre el modelo con el que está trabajando.

Intente eliminar el atributo grad_fn, por ejemplo con:

pred = pred.clone().detach()

Entonces los pesos del modelo no se actualizarán.

Y el optimizador está vinculado al modelo porque pasamos model.parameters() cuando creamos el optimizador.

1
Akavall 25 may. 2020 a las 23:49

Cuando llama a loss.backward(), todo lo que hace es calcular el gradiente de pérdida con todos los parámetros en pérdida que tienen requires_grad = True y almacenarlos en el atributo parameter.grad para cada parámetro.

optimizer.step() actualiza todos los parámetros basados en parameter.grad

11
Morteza Jalambadani 27 feb. 2019 a las 14:32