¿En qué train_on_batch() es diferente de fit()? ¿En qué casos deberíamos usar train_on_batch()?

27
Dhairya Verma 5 mar. 2018 a las 00:13

4 respuestas

La mejor respuesta

Para esta pregunta, es una respuesta simple del autor principal:

Con fit_generator, puede usar un generador para los datos de validación como bien. En general, recomendaría usar fit_generator, pero usar train_on_batch también funciona bien. Estos métodos solo existen por el bien de conveniencia en diferentes casos de uso, no existe un método "correcto".

train_on_batch le permite actualizar expresamente los pesos en función de una colección de muestras que proporcione, independientemente de cualquier tamaño de lote fijo. Utilizaría esto en los casos en que eso es lo que desea: entrenar en una colección explícita de muestras. Podría usar ese enfoque para mantener su propia iteración en varios lotes de un conjunto de entrenamiento tradicional, pero permitir que fit o fit_generator repita los lotes para usted probablemente sea más simple.

Un caso en el que podría ser bueno usar train_on_batch es para actualizar un modelo previamente entrenado en un nuevo lote de muestras. Supongamos que ya ha entrenado y desplegado un modelo, y en algún momento más tarde ha recibido un nuevo conjunto de muestras de entrenamiento que nunca antes se habían utilizado. Puede usar train_on_batch para actualizar directamente el modelo existente solo en esas muestras. Otros métodos también pueden hacer esto, pero es bastante explícito usar train_on_batch para este caso.

Además de casos especiales como este (ya sea que tenga alguna razón pedagógica para mantener su propio cursor en diferentes lotes de entrenamiento, o bien para algún tipo de actualización de entrenamiento semi-en línea en un lote especial), probablemente sea mejor usar siempre { {X0}} (para datos que caben en la memoria) o fit_generator (para transmitir lotes de datos como generador).

38
nbro 8 oct. 2019 a las 22:33

De hecho, la respuesta de @nbro ayuda, solo para agregar algunos escenarios más, digamos que está entrenando algún modelo de secuencia a secuencia o una red grande con uno o más codificadores. Podemos crear bucles de entrenamiento personalizados usando train_on_batch y usar una parte de nuestros datos para validar en el codificador directamente sin usar devoluciones de llamada. Escribir devoluciones de llamada para un proceso de validación complejo podría ser difícil. Hay varios casos en los que deseamos entrenar en lote.

Saludos, Karthick

0
karthick raja 7 feb. 2020 a las 09:18

train_on_batch() le brinda un mayor control del estado del LSTM, por ejemplo, cuando se utiliza un LSTM con estado y se necesitan llamadas de control a model.reset_states(). Es posible que tenga datos de varias series y necesite restablecer el estado después de cada serie, lo que puede hacer con train_on_batch(), pero si utilizó .fit(), la red se capacitaría en todas las series de datos sin restableciendo el estado. No hay correcto o incorrecto, depende de los datos que esté utilizando y de cómo desee que se comporte la red.

12
BigBadMe 2 ago. 2018 a las 08:45

Train_on_batch también verá un aumento del rendimiento sobre el generador de ajuste y ajuste si está utilizando conjuntos de datos grandes y no tiene datos fácilmente serializables (como matrices numpy de alto rango), para escribir en tfrecords.

En este caso, puede guardar los arreglos como archivos numpy y cargar subconjuntos más pequeños de ellos (traina.npy, trainb.npy, etc.) en la memoria, cuando todo el conjunto no cabe en la memoria. Luego puede usar tf.data.Dataset.from_tensor_slices y luego usar train_on_batch con su subdataset, luego cargar otro conjunto de datos y llamar a train en lote nuevamente, etc., ahora ha entrenado en todo su conjunto y puede controlar exactamente cuánto y qué de su conjunto de datos entrena su modelo. Luego, puede definir sus propias épocas, tamaños de lote, etc. con simples bucles y funciones para tomar de su conjunto de datos.

1
Adam Collins 1 ago. 2019 a las 13:20