torch.where 会根据条件去选择元素,返回一个tensor。¶
torch.where(condition, input, other, *, out=None) → Tensor
condition 是一个BoolTensor
input 和 other 可以常量,也可以是张量,
返回的张量形状与condition相同
In [2]:
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"a = torch.randint(0,10,[2,3])
torch.where(a > 5, 6,7)
Out[2]:
tensor([[6, 7, 6],[6, 7, 7]])
当input为张量的情况
In [9]:
a = torch.randint(0,10,[2,3])
a
b = torch.rand([1,3]) # b张量会被广播成(2,3)的形状
b
torch.where(a > 5, b,7)
Out[9]:
tensor([[7, 4, 1],[1, 6, 3]])
Out[9]:
tensor([[0.3639, 0.3413, 0.0434]])
Out[9]:
tensor([[0.3639, 7.0000, 7.0000],[7.0000, 0.3413, 7.0000]])
当input 和 other 都为张量的时候
In [11]:
a = torch.randint(0,10,[2,3])
a
b = torch.rand([1,3]) # b张量会被广播成(2,3)的形状
b
c = torch.randint(10,20,[1,3])
c
torch.where(a > 5, b,c)
Out[11]:
tensor([[7, 1, 5],[1, 7, 4]])
Out[11]:
tensor([[0.4111, 0.3227, 0.3590]])
Out[11]:
tensor([[19, 16, 19]])
Out[11]:
tensor([[ 0.4111, 16.0000, 19.0000],[19.0000, 0.3227, 19.0000]])
当不传入 input 和 other¶
torch.where(condition) 会返回满足 condition 为非零元素的坐标.
In [26]:
a = torch.randint(0,5,[2,3])
a
b,c = torch.where(a) # b 和 c 的长度相等,表示非零元素的数量。
b #行索引
c #列索引
for b_index in range(len(b)):a[b[b_index], c[b_index]]
Out[26]:
tensor([[4, 4, 3],[0, 1, 0]])
Out[26]:
tensor([0, 0, 0, 1])
Out[26]:
tensor([0, 1, 2, 1])
Out[26]:
tensor(4)
Out[26]:
tensor(4)
Out[26]:
tensor(3)
Out[26]:
tensor(1)