转自:https://www.cnblogs.com/catnofishing/p/13287322.html
-
detach到底有什么作用呢
首先要明确一个意识:pytorch是动态计算图,每次backward后,本次计算图自动销毁,但是计算图中的节点都还保留。
方向传播直到叶子节点为止,否者一直传播,直到找到叶子节点
我的答案是有用,但根本不是为了防止梯度开销过大(注释真的害人不浅啊),detach的真正作用是梯度节流,防止反向传播传播到隐藏状态时,因为上次小批量方向传播计算图的销毁导致继续向下传播而引起报错。啥意思呢,我以连续两次小批量迭代举例:
第一次小批量迭代,H0 是叶子节点,因为他没经过任何计算。剩余H1是非叶子节点。在第一次方向传播后,第一次的计算图已经销毁,但是节点数据仍然存在。
第二次小批量迭代,第一次批量迭代的最后时间节点的隐藏状态H2 成为第二批次小的初始隐藏状态( H0(第二次) = H2(第一次) ),这样第二次在方向传播时,当传播到H0时,发现H0 是 分支节点(grad_fn+requires_grad) ,就会继续向下传播直到找到叶子节点为止,但是可惜的是H0 之后的计算图(即第一次小批量的计算图)已经销毁,传播发生中断,因此就会导致出错。而使用detach之后,H0 自然与上次的计算图没有任何关系,H0自身变为叶子节点,这样传播到H0时自然就结束了。
好了,验证我所说的吧。
- 首先,不使用detach,会导致传播报错
将detach 操作删除
运行结果:
看到没,第二次在方向传播时出错了吧
-
使用detach,防止出错,并使H0 变为叶子节点
代码更改如下:
结果:全是true
综上:detach在这里作用,大家明白不,喜欢点个赞!!!!
至于书中为什么将detach的作用注释成那样呢,我想作者在翻译成torch的时候,忽略了MAXNET框架(原书是maxnet框架)与pytorch的区别。 MaxNet是支持静态图的,所以对于MaxNet ,detach的作用是与注释相同的,但是pytorch是动态图,所以作用在这里就不同了!!!