本文介绍 pytorch里面的torch.gather操作

pytorch里面的torch.gather操作

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

torch.gather 只是一个引子,别看它简单,但能引出很多问题。我们先来看看,它是如何工作的。假如我们有一个矩阵:

[[34, 4, 6],
[45, 6, 7]]

我们想要对它每一个位置的点进行重新排列应该怎么做呢?比如我要得到这么一个矩阵:

[[4, 6, 34],
[6, 7, 45]]

可以看到,我把每一行的(此时的axis=1)位置进行了变换。具体来说,用torch.gather可以做这个事情:

r = torch.gather(a, 1, torch.tensor([[1, 2, 0], [1, 2, 0]]))

说变了,就是用一个矩阵来对它进行重排。那么到底在什么场合我们会用到这个函数呢?

其实一个很明显的作用就是在分类问题中,通过gather方法可以从一个矩阵里面挑选出最大值来完成分类任务。

最近遇到一个onnx2trt的问题,但是本质上并不是由于它造成的,跟gather没有太大的关系。