# torch.gather(torch.gather)-torch

## torch.gather(torch.gather)

### 函数定义

``````torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
``````

Gathers values along an axis specified by dim.

``````out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
``````

### 函数参数

• input (Tensor) – the source tensor
• dim (int) – the axis along which to index
• index (LongTensor) – the indices of elements to gather
• sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
• out (Tensor, optional) – the destination tensor

### 函数参数说明

• 参数input和参数index必须拥有相同数量的维度，并且要求index.size(d) <= input.size(d)对于所有的维度d != dim。
• out将会拥有和index一样的形状。
• 参数input和参数index不能彼此进行广播

### 例子

``````>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
[ 4,  3]])``````
————————

### Function definition

``````torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
``````

Gathers values along an axis specified by dim.

For a 3-D tensor, the output is specified as follows:

``````out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
``````

### Function parameters

• input (Tensor) – the source tensor
• dim (int) – the axis along which to index
• index (LongTensor) – the indices of elements to gather
• sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
• out (Tensor, optional) – the destination tensor

### Function parameter description

• Parameter input and parameter index must have the same number of dimensions, and index. Size (d) < = input. Size (d) for all dimensions D= dim。
• Out will have the same shape as index.
• Parameter input and parameter index cannot broadcast to each other

### example

``````>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
[ 4,  3]])``````