pytorch同样有这种特性
a = torch.Tensor([x for x in range(12)]).view(3, 4) index = torch.LongTensor([[0,1] for x in range(5)]).view(-1) print "a:\n", a print "index:\n", index a: 0 1 2 3 4 5 6 7 8 9 10 11 [torch.FloatTensor of size 3x4] index: 0 1 0 1 0 1 0 1 0 1 [torch.LongTensor of size 10] print a[index], a[index].size() 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 [torch.FloatTensor of size 10x4] torch.Size([10, 4])