Estoy usando las API de conjunto de datos en TF 2.4. Actualmente tengo una pieza de código que funciona como

def map_func(a:int, b:int) -> typing.Tuple[typing.List[float],typing.List[int]]:
    # some complex logics here, for example, protobuf message deserialization
    return [0.0],[0] if some_condition() else [1.0],[1]

some_dataset \
  .map(lambda a, b: tf.numpy_function(map_func, inp=[a,b], Tout=(tf.float32, tf.int32))) \
  .filter(lambda features, labels: any(labels)) \ # filter out results whose labels are all zeros, regardless whatever features are
  .some_other_apis()

La función map_func definida anteriormente devuelve una tupla de (características, etiquetas), donde las etiquetas pueden contener ceros o no ceros. Al encadenar una llamada filter, filtro las muestras cuyas etiquetas son todas ceros.

Cuál es el problema
Me pregunto si es posible "integrar" la filter lógica dentro de map_func, porque la implementación actual parece algo fea y redundante. Traté de devolver una tupla de ([], []) o (Ninguno, Ninguno) cuando quiero abandonar los resultados, pero TF se queja de que los tipos de retorno no coinciden.

0
Li Xiaoming 22 ene. 2021 a las 12:27

1 respuesta

La mejor respuesta

Puede utilizar tf.where y tf.gather:

import tensorflow as tf
import numpy as np

def map_func(a) :
    return tf.gather_nd(a, tf.where(a > 0.5))

inputs = np.random.rand(10, 5)

np.round(inputs, 3)
array([[0.952, 0.329, 0.786, 0.714, 0.819],
       [0.048, 0.98 , 0.363, 0.03 , 0.078],
       [0.779, 0.833, 0.368, 0.216, 0.669],
       [0.807, 0.332, 0.217, 0.594, 0.254],
       [0.787, 0.453, 0.943, 0.915, 0.76 ],
       [0.047, 0.014, 0.555, 0.57 , 0.422],
       [0.195, 0.167, 0.077, 0.562, 0.586],
       [0.693, 0.434, 0.055, 0.213, 0.021],
       [0.459, 0.34 , 0.785, 0.938, 0.979],
       [0.08 , 0.667, 0.781, 0.092, 0.644]])
ds = tf.data.Dataset.from_tensor_slices(inputs)

ds = ds.map(map_func) 

for i in ds:
    print(np.round(i.numpy(), 3))
[0.952 0.786 0.714 0.819]
[0.98]
[0.779 0.833 0.669]
[0.807 0.594]
[0.787 0.943 0.915 0.76 ]
[0.555 0.57 ]
[0.562 0.586]
[0.693]
[0.785 0.938 0.979]
[0.667 0.781 0.644]
0
Nicolas Gervais 22 ene. 2021 a las 14:03