torch版本:
import torchx1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int) x1.shape # torch.Size([2, 3]) # x2 x2 = torch.tensor([[12, 22, 32], [22, 32, 42]], dtype=torch.int) x2.shape # torch.Size([2, 3]) inputs = [x1, x2] #print(inputs) output = torch.cat(inputs, dim=0) print(output)
对应numpy版本:
import numpy as np x1 = np.array([[11, 21, 31], [21, 31, 41]], dtype=np.int32) x1.shape # torch.Size([2, 3]) # x2 x2 = np.array([[12, 22, 32], [22, 32, 42]], dtype=np.int32) x2.shape # torch.Size([2, 3]) inputs = [x1, x2] # print(inputs) output = np.concatenate(inputs, axis=0) print(output)
因此torch.cat函数和Numpy中concatenate对应,numpy里面是灭有没有cat函数的