个人的简单理解:
repeat
可以理解为多次复制张量后在指定维度上concate
上去,即x.repeat(n,dim=k)
等价成torch.cat([x for _ in range(n)],dim=k)
repeat_interleave
实际上等价于repeat
在高一维的基础上运算后再view
,即x.repeat_interleave(n,dim=k)
等价成x.repeat(n,dim=k+1).view(N0, N1, ..., n*Nk, Nk+1, ...)
,其中N0
,N1
, Nk
, Nk+1
分别指x
的第0
,1
,k
,k+1
维的长度。当k
是最后一维时自动unsqueeze(-1)