Soy un principiante de tensorflow, tratando de leer matrices numpy almacenadas en el disco en TF usando TextLineReader. Pero cuando leo las matrices en TF, veo valores diferentes de la matriz original. ¿Podría alguien señalar el error que estoy cometiendo aquí? Por favor, vea un código de muestra a continuación. Gracias

import tensorflow as tf
import numpy as np
import csv

#Write two numpy arrays to disk 
a = np.arange(15).reshape(3, 5)
np.save("a.npy",a,allow_pickle=False)

b = np.arange(30).reshape(5, 6)
np.save("b.npy",b,allow_pickle=False)

with open('files.csv', 'w') as csvfile:
    filewriter = csv.writer(csvfile, delimiter=',')
    filewriter.writerow(['a.npy', 'b.npy'])


# Load a csv with the two array filenames

csv_filename = "files.csv"
filename_queue = tf.train.string_input_producer([csv_filename])

reader = tf.TextLineReader()
_, csv_filename_tf = reader.read(filename_queue)


record_defaults = [tf.constant([], dtype=tf.string), tf.constant([], dtype=tf.string)]
filename_i,filename_j = tf.decode_csv(
    csv_filename_tf, record_defaults=record_defaults)

file_contents_i = tf.read_file(filename_i)
file_contents_j = tf.read_file(filename_j)

bytes_i = tf.decode_raw(file_contents_i, tf.int16)
array_i = tf.reshape(tf.cast(tf.slice(bytes_i, [0], [3*5]), tf.int16), [3, 5])

bytes_j = tf.decode_raw(file_contents_j, tf.int16)
array_j = tf.reshape(tf.cast(tf.slice(bytes_j, [0], [5*6]), tf.int16), [5, 6])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    a_out, b_out = (sess.run([array_i, array_j]))

    print(a)
    print(a_out)

    coord.request_stop()
    coord.join(threads)

Aquí está el resultado que obtengo:

Producto esperado (a)

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]

Salida recibida: (a_out)

[[20115 19797 22864     1   118]
 [10107 25956 25459 10098  8250]
 [15399 14441 11303 10016 28518]]
0
ictguy1 1 mar. 2018 a las 16:01

3 respuestas

La mejor respuesta

No creo que tensorflow decode_raw y numpy's np.save sean compatibles.

0
Alexandre Passos 1 mar. 2018 a las 23:33

Para saber qué estaba pasando imprimí bytes_i en lugar de array_i.

a_out, b_out = (sess.run([bytes_i, bytes_j]))
print(a_out)

Y obtuve la siguiente lista:

[20115 19797 22864     1   118 10107 25956 25459 10098  8250 15399 14441
 11303 10016 28518 29810 24946 24430 29295 25956 10098  8250 24902 29548
 11365 10016 26739 28769 10085  8250 13096  8236 10549  8236  8317  8224
  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224
  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224
  8224  8224  8224  2592     0     0     0     0     1     0     0     0
     2     0     0     0     3     0     0     0     4     0     0     0
     5     0     0     0     6     0     0     0     7     0     0     0
     8     0     0     0     9     0     0     0    10     0     0     0
    11     0     0     0    12     0     0     0    13     0     0     0
    14     0     0     0]

Parece que hay un encabezado delante de los datos almacenados en el archivo numpy. Además, parece que los valores de datos se guardan como int64 y no como int16.

Solución

Primero especifique el tipo de valores en la matriz:

a = np.arange(15).reshape(3, 5).astype(np.int16)
b = np.arange(30).reshape(5, 6).astype(np.int16)

Luego lea los últimos bytes del archivo:

array_i = tf.reshape(tf.cast(tf.slice(bytes_i,
                                      begin=[tf.size(bytes_i) - (3*5)],
                                      size=[3*5]), tf.int16), [3, 5])
array_j = tf.reshape(tf.cast(tf.slice(bytes_j,
                                      begin=[tf.size(bytes_j) - (5*6)],
                                      size=[5*6]), tf.int16), [5, 6])
0
Mathumeo 24 oct. 2018 a las 08:35

Use los archivos zip de numpy .npz

Para guardar las variables a b:

weights = {w.name : sess.run(w) for w in [a, b]}
np.savez(path, **weights)

Cargar:

weights = [a, b]
npz_weights = np.load(path)
for i,k in enumerate([w.name for w in weights]):
    sess.run(weights[i].assign(npz_weights[k]))
0
Nimrod Morag 6 may. 2018 a las 05:58