我们在做深度学习时经常不免需要处理和保留topk个元素,如何操作比较顺利呢,我来教大家两种方法。

方法一:使用topk+full_like+scatter填充

x = torch.randn(2, 3, 4)
value, index = x.topk(2, dim=-1)
mask = torch.full_like(x, float('-inf'))
mask = mask.scatter(-1, index, value)
print(mask)

方法二:使用索引

x = torch.randn(2, 3, 4)
value, index = x.topk(2, dim=-1)
x[x < value[:, :, -1].unsqueeze(2)] = float('-inf'))