【python代码技巧2】数组索引扩增技巧

xiaoxiao2021-02-28  122

import numpy as np import torch a = np.array([x for x in range(12)]).reshape(3, 4) index = np.array([[0,1] for x in range(5)]).flatten() print "a:\n", a, a.shape print "index:\n", index, index.shape a: [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] (3, 4) index: [0 1 0 1 0 1 0 1 0 1] (10,) print a[index], a[index].shape [[0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7]] (10, 4)

pytorch同样有这种特性

a = torch.Tensor([x for x in range(12)]).view(3, 4) index = torch.LongTensor([[0,1] for x in range(5)]).view(-1) print "a:\n", a print "index:\n", index a: 0 1 2 3 4 5 6 7 8 9 10 11 [torch.FloatTensor of size 3x4] index: 0 1 0 1 0 1 0 1 0 1 [torch.LongTensor of size 10] print a[index], a[index].size() 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 [torch.FloatTensor of size 10x4] torch.Size([10, 4])
转载请注明原文地址: https://www.6miu.com/read-27855.html

最新回复(0)