Estoy tratando de agregar un tensor 2d a cada tensor 2d de un tensor 3d. Digamos que tengo un tensor a
con forma (2,3,2)
y un tensor b
con forma (2,2)
.
a = [[[1,2],
[1,2],
[1,2]],
[[3,4],
[3,4],
[3,4]]]
b = [[1,2], [3,4]]
#the result i want to get
a[:, 0, :] + b
a[:, 1, :] + b
a[:, 2, :] + b
Quiero saber si hay un método en pytorch que pueda hacer esto.
3 respuestas
La forma más eficiente de hacer esto sería agregar una segunda dimensión adicional a b
y usar la transmisión para agregar:
a = torch.Tensor([[[1,2],[1,2],[1,2]],[[3,4],[3,4],[3,4]]])
b = torch.Tensor([[1,2],[3,4]])
a += b.unsqueeze(1)
La solución propuesta por @SinaAfrooze es correcta pero no es la más rápida.
TL; DR: torch.add(b.unsqueeze(1), a)
es más rápido.
Puntos de referencia:
import torch
a = torch.Tensor([[[1,2],[1,2],[1,2]],[[3,4],[3,4],[3,4]]])
b = torch.Tensor([[1,2],[3,4]])
z = a + b.unsqueeze(1)
%timeit k = torch.add(b.unsqueeze(1), a)
4.08 µs ± 25.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit z = a + b.unsqueeze(1)
4.14 µs ± 29 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
torch.equal(k, z)
True
Que quieres hacer:
a = [[[1,2],
[1,2],
[1,2]],
[[3,4],
[3,4],
[3,4]]]
b = [[1,2], [3,4]]
a = torch.LongTensor(a)
b = torch.LongTensor(b)
a[:, 0, :] += b
a[:, 1, :] += b
a[:, 2, :] += b
print(a)
Salida:
tensor([[[2, 4],
[2, 4],
[2, 4]],
[[6, 8],
[6, 8],
[6, 8]]])
Puede hacer lo mismo de la siguiente manera.
a = (a.transpose(0, 1) + b).transpose(0, 1)
print(a) # prints the same tensor
Preguntas relacionadas
Nuevas preguntas
python
Python es un lenguaje de programación multipropósito, de tipificación dinámica y de múltiples paradigmas. Está diseñado para ser rápido de aprender, comprender y usar, y hacer cumplir una sintaxis limpia y uniforme. Tenga en cuenta que Python 2 está oficialmente fuera de soporte a partir del 01-01-2020. Aún así, para preguntas de Python específicas de la versión, agregue la etiqueta [python-2.7] o [python-3.x]. Cuando utilice una variante de Python (por ejemplo, Jython, PyPy) o una biblioteca (por ejemplo, Pandas y NumPy), inclúyala en las etiquetas.