目录
- 概
- 主要内容
Wang N., Choi J., Brand D., Chen C. and Gopalakrishnan K. Training deep neural networks with 8-bit floating point numbers. NeurIPS, 2018.
概
本文提出了一种 8-bit 的训练方式.
主要内容
-
本文想要实现 8-bit 的训练, 作者认为主要挑战是两个向量的点击 (各元素相乘再相加) 的过程会产生严重的 swamping 问题, 导致 smaller number 会被直接裁剪掉.
-
于是本文的做法是, 乘法依旧在 8-bit 上做, 然后 accumulation 过程在加法上实现. 主要通过两个手段实现这一目的:
- Chunk-based Accumulation: 两个长向量的点积分解为数个 chunk 的点积再相加的过程, 这个好处是每个 chunk 的点积不会产生很严重的 swamping 问题的, 然后会产生严重 swamping 情况的问题因为实际上使用的 FP16 格式进行, 所以能够规避;
- Stochastic Rounding: 随机 rounding 是一种有效的规避 swamping 的方法. 特别地, 这里采用的 floating point stochastic rounding: 对于一个浮点数 \(x = s \cdot 2^e \cdot (1 + m)\) (\((s, e, m)\) 分别表示 (符号, 指数, 尾数))
\[\text{Round}(x) =\left \{\begin{array}{ll}s \cdot 2^e \cdot (1 + \lfloor m \rfloor + \epsilon)& \text{with probability } \frac{m - \lfloor m \rfloor}{\epsilon}, \\s \cdot 2^e \cdot (1 + \lfloor m \rfloor)& \text{with probability } 1 - \frac{m - \lfloor m \rfloor}{\epsilon},\end{array} \right. \]其中 \(\epsilon=2^{-k}\).
注: 浮点数表示个数包括 (sign, exponent, mantissa), 本文所讨论的 FP8 为 (1, 5, 2) 格式, FP16 为 (1, 6, 9) 格式.