En Tensorflow Federated (TFF), puede pasar a tff.learning.build_federated_averaging_process un broadcast_process y un aggregation_process, que pueden incorporar codificadores personalizados, p. Ej. para aplicar compresiones personalizadas.

Llegando al punto de mi pregunta, estoy tratando de implementar un codificador para dispersar las actualizaciones / pesos del modelo.

Estoy tratando de construir un codificador de este tipo implementando el EncodingStageInterface, de tensorflow_model_optimization.python.core.internal. Sin embargo, estoy luchando por implementar un estado (local) para acumular las coordenadas cerradas de las actualizaciones del modelo / pesos del modelo ronda por ronda. Tenga en cuenta que este estado no debe comunicarse y solo debe mantenerse localmente (por lo que AdaptiveEncodingStageInterface no debería ser útil). En general, la pregunta es cómo mantener un estado local dentro de un codificador para luego pasarlo al proceso fedavg.

Adjunto el código de la implementación de mi codificador (que, además del estado que me gustaría agregar, funciona bien como sin estado como se esperaba). Luego adjunto el extracto de mi código donde uso la implementación del codificador. Si descommento las partes comentadas en stateful_encoding_stage_topk.py , el código no funciona: no puedo entender cómo administrar el estado (que es un tensor) en el modo TF no ansioso.

stateful_encoding_stage_topk.py

import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):

  ENCODED_VALUES_KEY = 'stateful_topk_values'
  INDICES_KEY = 'indices'
  
  
  def __init__(self):
    super().__init__()
    # Here I would like to init my state
    #self.A = tf.zeros([800], dtype=tf.float32)

  @property
  def name(self):
    """See base class."""
    return 'stateful_topk'

  @property
  def compressible_tensors_keys(self):
    """See base class."""
    return [self.ENCODED_VALUES_KEY]

  @property
  def commutes_with_sum(self):
    """See base class."""
    return True

  @property
  def decode_needs_input_shape(self):
    """See base class."""
    return True

  def get_params(self):
    """See base class."""
    return {}, {}

  def encode(self, x, encode_params):
    """See base class."""
    del encode_params  # Unused.

    dW = tf.reshape(x, [-1])
    # Here I would like to retrieve the state
    A = tf.zeros([800], dtype=tf.float32)
    #A = self.residual
    
    dW_and_A = tf.math.add(A, dW)

    percentage = tf.constant(0.4, dtype=tf.float32)
    k_float = tf.multiply(percentage, tf.cast(tf.size(dW), tf.float32))
    k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)

    values, indices = tf.math.top_k(tf.math.abs(dW_and_A), k = k_int, sorted = False)
    indices = tf.expand_dims(indices, 1)
    sparse_dW = tf.scatter_nd(indices, values, tf.shape(dW_and_A))
    
    # Here I would like to update the state
    A_updated = tf.math.subtract(dW_and_A, sparse_dW)
    #self.A = A_updated
    
    encoded_x = {self.ENCODED_VALUES_KEY: values,
                 self.INDICES_KEY: indices}

    return encoded_x

  def decode(self,
             encoded_tensors,
             decode_params,
             num_summands=None,
             shape=None):
    """See base class."""
    del decode_params, num_summands  # Unused.
    
    indices = encoded_tensors[self.INDICES_KEY]
    values = encoded_tensors[self.ENCODED_VALUES_KEY]
    tensor = tf.fill([800], 0.0)
    decoded_values = tf.tensor_scatter_nd_update(tensor, indices, values)
    
    return tf.reshape(decoded_values, shape)



def sparse_quantizing_encoder():
  encoder = te.core.EncoderComposer(
      StatefulTopKEncodingStage() )  
  return encoder.make()

fedavg_with_sparsification.py

[...]

def sparsification_broadcast_encoder_fn(value):
  spec = tf.TensorSpec(value.shape, value.dtype)
  return te.encoders.as_simple_encoder(te.encoders.identity(), spec)

def sparsification_mean_encoder_fn(value):
  spec = tf.TensorSpec(value.shape, value.dtype)
  
  if value.shape.num_elements() == 800:
    return te.encoders.as_gather_encoder(
        stateful_encoding_stage_topk.sparse_quantizing_encoder(), spec)

  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
  
encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        model_fn, sparsification_broadcast_encoder_fn))

encoded_mean_process = (
    tff.learning.framework.build_encoded_mean_process_from_model(
        model_fn, sparsification_mean_encoder_fn))


iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),
    client_weight_fn=lambda _: tf.constant(1.0),
    broadcast_process=encoded_broadcast_process,
    aggregation_process=encoded_mean_process)

[...]

Estoy usando:

  • tensorflow 2.4.0
  • tensorflow-federated 0.17.0
3
Alessio Mora 21 ene. 2021 a las 18:17

1 respuesta

La mejor respuesta

Intentaré responder en dos partes; (1) codificador top_k sin estado y (2) darse cuenta de la idea con estado que parece querer en TFF.

(1)

Para que el TopKEncodingStage funcione sin estado, veo algunos detalles para cambiar.

La propiedad commutes_with_sum debe establecerse en False. En pseudocódigo, su significado es si sum_x(decode(encode(x))) == decode(sum_x(encode(x))). Esto no es cierto para la representación que devuelve su método encode; la suma de indices no funcionaría bien. Creo que la implementación del método decode se puede simplificar para

return tf.scatter_nd(
    indices=encoded_tensors[self.INDICES_KEY],
    updates=encoded_tensors[self.ENCODED_VALUES_KEY],
    shape=shape)

(2)

A lo que te refieres no se puede lograr de esta manera usando tff.learning.build_federated_averaging_process. El proceso devuelto por este método no tiene ningún mecanismo para mantener el estado local / cliente. Cualquiera que sea el estado expresado en su StatefulTopKEncodingStage terminaría siendo el estado del servidor, no el estado local.

Para trabajar con el cliente / estado local, es posible que deba escribir más código personalizado. Para empezar, consulte examples/stateful_clients que puede adaptar para almacenar el estado al que se refiere.

Tenga en cuenta que en TFF, esto deberá representarse como transformaciones funcionales. Almacenar valores en atributos de una clase y usarlos en otro lugar puede llevar a errores sorprendentes.

3
Jakub Konecny 22 ene. 2021 a las 08:39