torch.gather(torch.gather)

函数定义

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

对于一个3-D的张量,输出按照以下公式被指定为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

函数参数

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
  • out (Tensor, optional) – the destination tensor

函数参数说明

  • 参数input和参数index必须拥有相同数量的维度,并且要求index.size(d) <= input.size(d)对于所有的维度d != dim。
  • out将会拥有和index一样的形状。
  • 参数input和参数index不能彼此进行广播

例子

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])
————————

Function definition

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor, the output is specified as follows:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Function parameters

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
  • out (Tensor, optional) – the destination tensor

Function parameter description

  • Parameter input and parameter index must have the same number of dimensions, and index. Size (d) < = input. Size (d) for all dimensions D= dim。
  • Out will have the same shape as index.
  • Parameter input and parameter index cannot broadcast to each other

example

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])