题目
题目链接
循环神经网络(RNN)是一种能够处理序列数据的神经网络,其特点是能够将前一时刻的输出作为下一时刻的输入。
BPTT是循环神经网络的一种训练方法,其数学推导可以参考相关资料。大体的更新步骤与BP神经网络类似,但是不同的是需要考虑时间步长的影响。
具体原理可以参考相关文献,这里不做赘述。
在本题中,用到的计算公式如下:
\[h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)
\]
\[y_t = W_{hy} h_t + b_y
\]
权重更新方式如下
\[W_{xh} = W_{xh} - \eta \frac{\partial L}{\partial W_{xh}}
\]
\[W_{hh} = W_{hh} - \eta \frac{\partial L}{\partial W_{hh}}
\]
\[W_{hy} = W_{hy} - \eta \frac{\partial L}{\partial W_{hy}}
\]
\[b_h = b_h - \eta \frac{\partial L}{\partial b_h}
\]
\[b_y = b_y - \eta \frac{\partial L}{\partial b_y}
\]
学习率\(\eta\)在本题中为0.01这个固定值。
标准代码如下
class SimpleRNN:def __init__(self, input_size, hidden_size, output_size):self.hidden_size = hidden_sizeself.W_xh = np.random.randn(hidden_size, input_size) * 0.01self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01self.W_hy = np.random.randn(output_size, hidden_size) * 0.01self.b_h = np.zeros((hidden_size, 1))self.b_y = np.zeros((output_size, 1))def forward(self, x):h = np.zeros((self.hidden_size, 1)) # Initialize hidden stateoutputs = []self.last_inputs = []self.last_hiddens = [h]for t in range(len(x)):self.last_inputs.append(x[t].reshape(-1, 1))h = np.tanh(np.dot(self.W_xh, self.last_inputs[t]) + np.dot(self.W_hh, h) + self.b_h)y = np.dot(self.W_hy, h) + self.b_youtputs.append(y)self.last_hiddens.append(h)self.last_outputs = outputsreturn np.array(outputs)def backward(self, x, y, learning_rate):dW_xh = np.zeros_like(self.W_xh)dW_hh = np.zeros_like(self.W_hh)dW_hy = np.zeros_like(self.W_hy)db_h = np.zeros_like(self.b_h)db_y = np.zeros_like(self.b_y)dh_next = np.zeros((self.hidden_size, 1))for t in reversed(range(len(x))):dy = self.last_outputs[t] - y[t].reshape(-1, 1) # (Predicted - Actual)dW_hy += np.dot(dy, self.last_hiddens[t+1].T)db_y += dydh = np.dot(self.W_hy.T, dy) + dh_nextdh_raw = (1 - self.last_hiddens[t+1] ** 2) * dh # Derivative of tanhdW_xh += np.dot(dh_raw, self.last_inputs[t].T)dW_hh += np.dot(dh_raw, self.last_hiddens[t].T)db_h += dh_rawdh_next = np.dot(self.W_hh.T, dh_raw)# Update weights and biasesself.W_xh -= learning_rate * dW_xhself.W_hh -= learning_rate * dW_hhself.W_hy -= learning_rate * dW_hyself.b_h -= learning_rate * db_hself.b_y -= learning_rate * db_y