TensorFlow 2.0 RC1

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Multiply

import numpy as np

Salida esperada:

Multiply()([np.array([1,2,3,4,4,4]).reshape(2,3), np.array([1,0])])

enter image description here

Problema:

input_1 = Input(shape=(None,3))
mask_1 = Input(shape=(None,))

net = Multiply()([input_1, mask_1])
net = Model(inputs=[input_1, mask_1], outputs=net)

net.predict([np.array([1,2,3,4,4,4]).reshape(1,2,3), np.array([1,0]).reshape(1,2)]) # 1 = batch size

enter image description here

¿Cómo solucionar este problema?

0
Alexey Golyshev 30 sep. 2019 a las 14:16

3 respuestas

La mejor respuesta

El número de dimensiones debe coincidir, modificando la forma de entrada de la segunda entrada a (None, 1) y agregando una dimensión adicional a la matriz [1, 0]

import numpy as np
from tensorflow.keras.layers import Multiply
from tensorflow.keras import Model, Input

input_1 = Input(shape=(2,3))
mask_1 = Input(shape=(2,1))

net = Multiply()([input_1, mask_1])
net = Model(inputs=[input_1, mask_1], outputs=net)

net.summary()

print(net.predict([np.array([1,2,3,4,4,4]).reshape((1,2,3)), np.array([1,0]).reshape((1,2,1))]))
2
Raphael Meudec 30 sep. 2019 a las 12:03

Cambie la forma de la segunda matriz en la última línea de código como np.array([1,0]).reshape(-1)

net.predict([np.array([1,2,3,4,4,4]).reshape(1,2,3), np.array([1,0]).reshape(-1)]) # 1 = batch size
3
stephen_mugisha 30 sep. 2019 a las 11:30

Depende de cómo se especifique la forma de entrada. En el ejemplo Multiply () (multiplicación por elementos), el tamaño del lote es 2 y el tamaño de la característica es 3 para Input y 1 para mask. Entonces, al especificar la forma de entrada en Keras, solo se debe especificar el tamaño de la característica.

input_1 = Input(shape=(3,))
mask_1 = Input(shape=(1,))
net = Multiply()([input_1, mask_1])
net = Model(inputs=[input_1, mask_1], outputs=net)
output = net.predict([np.array([1,2,3,4,4,4]).reshape(2,3), np.array([1,0])])
print(output)

[[1. 2. 3.] [0. 0. 0.]]

3
Manoj Mohan 30 sep. 2019 a las 12:08
58166818