python - 在python中,不使用for循环,如何加速切片

我尝试加速以下python代码:


import torch


import numpy as np



A = torch.zeros(11, 16, 64)


B = torch.randn(11, 9, 64)



indices = np.random.randint(0,9,(11,16))



for i in range(len(A)):


 A[i,:,:] = B[i,indices[i],:]



有没有很好的方法不使用for循环?它很慢,特别是在处理大数据时,谢谢!

时间:

举个例子


# create a (11, 1) range array that broadcasts with indices which is (11, 16)


indices0 = np.expand_dims(np.arange(indices.shape[0]), 1)


A = B[indices0, indices, :]



或者如果indicestorch.LongTensor


indices0 = torch.arange(indices.shape[0]).unsqueeze(1)


A = B[indices0, indices, :]



...