RE2文本匹配实战

引言

今天我们来实现RE2进行文本匹配,模型实现参考了官方代码https://github.com/alibaba-edu/simple-effective-text-matching-pytorch。

模型实现

202231008143

RE2模型架构如上图所示。它的输入是两个文本片段,所有组件参数除了预测层和对齐层外都是共享的。上图虚线框出来的为一个Block,堆叠了N个block,文本片段之间的block内部通过对齐层进行交互。block之间通过增加的残差层进行连接。

下面我们从底向上依次实现,实现过程中参考了官方实现。

Embedding

嵌入层很简单没有使用字符嵌入,就是简单的单词嵌入。

class Embedding(nn.Module):def __init__(self, vocab_size: int, embedding_dim: int, dropout: float) -> None:super().__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)self.dropout = nn.Dropout(dropout)def forward(self, x: Tensor) -> Tensor:"""Args:x (Tensor): (batch_size, seq_len)Returns:Tensor: (batch_size, seq_len, embedding_dim)"""return self.dropout(self.embedding(x))

Encoder

GeLU

首先实现GeLU,它是RELU的变种,后来被用到BERT中。其函数图像如下所示:

../_images/GELU.png

class GeLU(nn.Module):def forward(self, x: Tensor) -> Tensor:return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))

Linear

重写了线性层,activations开启GeLU激活。

class Linear(nn.Module):def __init__(self, in_features: int, out_features: int, activations: bool = True) -> None:super().__init__()linear = nn.Linear(in_features, out_features)modules = [weight_norm(linear)]if activations:modules.append(GeLU())self.model = nn.Sequential(*modules)self.reset_parameters(activations)def reset_parameters(self, activations: bool) -> None:linear = self.model[0]nn.init.normal_(linear.weight,std=math.sqrt((2.0 if activations else 1.0) / linear.in_features),)nn.init.zeros_(linear.bias)def forward(self, x):return self.model(x)

nn.Conv1d

我们在比较聚合模型的实现中详细了解了torch.nn.Conv2d的实现以及CNN的一些基础概念。

这里我们通过torch.nn.Conv1d来实现论文中的多层卷积网络,本小结来详细了解Conv1d实现。

torch.nn.Conv1din_channels: 输入的通道数,文本中为嵌入维度out_channels: 一个卷积核产生一个输出通道kernel_size: 卷积核的大小stride: 卷积步长,默认为1padding: 填充,默认为0bias(bool): 是否添加偏置,默认为True

我们以一个例子来说明它的计算过程,假设对于输入"W B G 是 冠 军",随机得到的嵌入为:

希望今天下午S13 WBG可以战胜T1。

import numpy as np
import torch.nn as nn
import torchbatch_size = 1
seq_len = 6
embed_size = 3input_tensor = torch.rand(batch_size, seq_len, embed_size)
print(input_tensor)
print(input_tensor.shape)
tensor([[[0.9291, 0.8333, 0.5160],[0.0543, 0.8149, 0.5704],[0.7831, 0.2263, 0.9279],[0.0898, 0.0758, 0.4401],[0.4321, 0.2098, 0.6666],[0.6183, 0.0609, 0.2330]]])
torch.Size([1, 6, 3])

此时每个字符对应一个3维的嵌入向量,分别为:

W — [0.9291, 0.8333, 0.5160]
B — [0.0543, 0.8149, 0.5704]
G — [0.7831, 0.2263, 0.9279]
是 — [0.0898, 0.0758, 0.4401]
冠 — [0.4321, 0.2098, 0.6666]
军 — [0.6183, 0.0609, 0.2330]

但是Conv1d需要in_channels即嵌入维度为仅在batch_size后第一个位置,由[1, 6, 3]变成[1, 3, 6]

input_tensor = input_tensor.permute(0, 2, 1)
# (batch_size, embed_size, seq_len)

图示如下:

image-20231118141540674

文章还没发,结果被3:0了。

然后我们定义一个一维卷积:

input_channels = embed_size # 等于embed_size
output_channels = 2
kernel_size = 2 # kernel_sizeconv1d = nn.Conv1d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size)

我们可以打印出来filter权重矩阵:

print(conv1d.weight)
print(conv1d.weight.shape)
Parameter containing:
tensor([[[ 0.0025,  0.3353],[ 0.0620, -0.3916],[-0.3458, -0.0610]],[[-0.1731, -0.0787],[-0.0419, -0.2555],[-0.1429,  0.1656]]], requires_grad=True)
torch.Size([2, 3, 2])

filter 权重的大小为 (2,3,2) shape[0]=2是filter个数;shape[1]=3是输入嵌入大小;shape[2]=2是filter大小。

默认是添加了偏置,一个filter一个偏置:

Parameter containing:
tensor([ 0.3760, -0.2881], requires_grad=True)
torch.Size([2])

我们这里有两个filter,所以有两个偏置。因为这里kernel_size=2,且步长stride=1,所以一个filter是如下的方式框住了两个字符嵌入,并且每次向右移动一格:

image-20231118142337624

此时第一个filter的卷积操作计算为:

sum([[0.9291, 0.0543],           [[0.0025,  0.3353],[0.8333, 0.8149],     *       [0.0620, -0.3916],      +    0.3760(bias)[0.5160, 0.5704]]             [-0.3458, -0.0610])

第一个filter权重和这两个嵌入进行逐位置乘法产生一个标量(sum),最后加上第一个filter的偏置。

通过代码实现为:

# 开始计算卷积
# 前两个嵌入与卷积核权重逐元素乘法
result = input_tensor[:,:,:2]*conv1d.weight 
print(result)
# 结果求和再加上偏置
print(torch.sum(result[0]) + conv1d.bias[0])
print(torch.sum(result[1]) + conv1d.bias[1])
tensor([[[ 0.0024,  0.0182],[ 0.0517, -0.3191],[-0.1784, -0.0348]],[[-0.1608, -0.0043],[-0.0349, -0.2082],[-0.0737,  0.0944]]], grad_fn=<MulBackward0>)tensor(-0.0841, grad_fn=<AddBackward0>) # 第一个filter的结果
tensor(-0.6756, grad_fn=<AddBackward0>) # 第二个filter的结果

这是第一次卷积的结果,第二次卷积把红框向右移动一格,又会有一个结果。

image-20231118143821057

最终移动到输入的最后一个位置计算完毕:

image-20231118143932207

共需要计算5次,因此最终一个filter会输出5个标量,共有2个filter,批大小为1。

如果用代码实现的话:

output = conv1d(input_tensor)
print(output)
print(output.shape)
tensor([[[-0.0841,  0.3468,  0.0447,  0.2508,  0.3288],[-0.6756, -0.3790, -0.5193, -0.3470, -0.4926]]],grad_fn=<ConvolutionBackward0>)
torch.Size([1, 2, 5])

可以看到output的形状为[1, 2, 5],第一列的计算结果和我们上面的一致。

shape[0]=1是批次内样本个数;``shape[1]=2是filter个数,也是想要输出的channel数;shape[2]=5`是卷积后的维度。

这里(忽略dilation)卷积后的维度大小由卷积核大小kernel_size、步长stride、填充padding以及输入序列长度seq_len决定:
⌊ seq_len + 2 × padding − kernel_size stride + 1 ⌋ \left\lfloor \frac{\text{seq\_len} + 2 \times\text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor strideseq_len+2×paddingkernel_size+1

其中 ⌊ ∗ ⌋ \lfloor * \rfloor 表示向下取整。

我们可以代入上面的参数验证一下:
6 + 2 × 0 − 2 1 + 1 = 6 + 0 − 2 + 1 = 5 \frac{6 + 2\times 0 - 2}{1} + 1 = 6+0-2+1=5 16+2×02+1=6+02+1=5
结果是5和输出一致。

Conv1d

下面实现RE2的多层卷积网络,首先是一个改写的Conv1d,用weight_norm进行权重归一化,采用GeLU激活函数。

class Conv1d(nn.Module):def __init__(self, in_channels: int, out_channels: int, kernel_sizes: list[int]) -> None:"""Args:in_channels (int): the embedding_dimout_channels (int): number of filterskernel_sizes (list[int]): the size of kernel"""super().__init__()out_channels = out_channels // len(kernel_sizes)convs = []# L_in is seq_len, L_out is output_dim of conv# L_out = (L_in + 2 * padding - kernel_size + 1)# and padding=(kernel_size - 1) // 2# L_out = (L_in + kernel_size - 1 - kernel_size + 1) = L_infor kernel_size in kernel_sizes:conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)convs.append(nn.Sequential(weight_norm(conv), GeLU()))# output shape of each conv is (batch_size, out_channels(new), seq_len)self.model = nn.ModuleList(convs)self.reset_parameters()def reset_parameters(self) -> None:for seq in self.model:conv = seq[0]nn.init.normal_(conv.weight,std=math.sqrt(2.0 / (conv.in_channels * conv.kernel_size[0])),)nn.init.zeros_(conv.bias)def forward(self, x: Tensor) -> Tensor:"""Args:x (Tensor): shape (batch_size, embedding_dim, seq_len)Returns:Tensor:"""# back to (batch_size, out_channels, seq_len)return torch.cat([encoder(x) for encoder in self.model], dim=1)

out_channels // len(kernel_sizes)将输出大小拆分,最后用torch.cat将它们拼接回out_channels

padding=(kernel_size - 1) // 2目的是使得卷积后的维度大小和输入的seq_len一致,这里需要kernel_size 是奇数,因为padding只能接收整数。

weight_norm将权重分解为大小和方向,可以加速训练过程并提高模型的泛化能力。保留原先的权重方向,大小由权重归一化层自己学习:
w = g v ∣ ∣ v ∣ ∣ \pmb w = g\frac{\pmb v}{||\pmb v||} w=g∣∣v∣∣v

Encoder实现

class Encoder(nn.Module):def __init__(self,input_size: int,hidden_size: int,kernel_sizes: list[int],encoder_layers: int,dropout: float,) -> None:"""_summary_Args:input_size (int): embedding_dim or embedding_dim + hidden_sizehidden_size (int): hidden sizekernel_sizes (list[int]): the size of kernelsencoder_layers (int): number of conv layersdropout (float): dropout ratio"""super().__init__()self.encoders = nn.ModuleList([Conv1d(in_channels=input_size if i == 0 else hidden_size,out_channels=hidden_size,kernel_sizes=kernel_sizes,)for i in range(encoder_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x: Tensor, mask: Tensor) -> Tensor:"""forward in encoderArgs:x (Tensor): (batch_size, seq_len, input_size)mask (Tensor): (batch_size, seq_len, 1)Returns:Tensor: _description_"""# x (batch_size, input_size, seq_len)x = x.transpose(1, 2)# mask (batch_size, 1, seq_len)mask = mask.transpose(1, 2)for i, encoder in enumerate(self.encoders):# fills elements of x with 0.0 where mask is Falsex.masked_fill_(~mask, 0.0)# using dropoutif i > 0:x = self.dropout(x)# returned x (batch_size, hidden_size, seq_len)x = encoder(x)# apply dropoutx = self.dropout(x)# (batch_size, seq_len, hidden_size)return x.transpose(1, 2)

这里用多层Conv1d作为编码器,要注意第0层和其他层的区别,第0层的嵌入维度是input_size即``embedding_size,经过第0层的Conv1d后维度变成两hidden_size,所以后续层参数in_channelshidden_size`。

这里用x.masked_fill_(~mask, 0.0)设置mask矩阵中的填充位为0。

不采用RNN作为编码器,作者认为RNN速度慢且没有带来性能上的提升。

Alignment

然后实现对齐层,所谓的对齐就是让两个序列进行交互,这里采用基于注意力交互的方式。

class Alignment(nn.Module):def __init__(self, input_size: int, hidden_size: int, dropout: float, project_func: str) -> None:"""Args:input_size (int): embedding_dim  + hidden_size  or embedding_dim  + hidden_size * 2hidden_size (int): hidden sizedropout (float): dropout ratioproject_func (str): identity or linear"""super().__init__()self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(hidden_size)))if project_func != "identity":self.projection = nn.Sequential(nn.Dropout(dropout), Linear(input_size, hidden_size))else:self.projection = nn.Identity()def forward(self, a: Tensor, b: Tensor, mask_a: Tensor, mask_b: Tensor) -> Tensor:"""Args:a (Tensor): (batch_size, seq_len, input_size)b (Tensor): (batch_size, seq_len, input_size)mask_a (Tensor):  (batch_size, seq_len, 1)mask_b (Tensor):  (batch_size, seq_len, 1)Returns:Tensor: _description_"""# if projection == 'linear' : self.projection(*) -> (batch_size, seq_len,  hidden_size) -> transpose(*) -> (batch_size, hidden_size,  seq_len)# if projection == 'identity' : self.projection(*) -> (batch_size, seq_len, input_size) -> transpose(*) -> (batch_size, input_size,  seq_len)# attn (batch_size, seq_len_a,  seq_len_b)attn = (torch.matmul(self.projection(a), self.projection(b).transpose(1, 2))* self.temperature)# mask (batch_size, seq_len_a, seq_len_b)mask = torch.matmul(mask_a.float(), mask_b.transpose(1, 2).float())mask = mask.bool()# fills elements of x with 0.0(after exp) where mask is Falseattn.masked_fill_(~mask, -1e7)# attn_a (batch_size, seq_len_a,  seq_len_b)attn_a = F.softmax(attn, dim=1)# attn_b (batch_size, seq_len_a,  seq_len_b)attn_b = F.softmax(attn, dim=2)# feature_b  (batch_size, seq_len_b,  seq_len_a) x (batch_size, seq_len_a, input_size)# -> (batch_size, seq_len_b,  input_size)feature_b = torch.matmul(attn_a.transpose(1, 2), a)# feature_a  (batch_size, seq_len_a,  seq_len_b) x (batch_size, seq_len_b, input_size)# -> (batch_size, seq_len_a,  input_size)feature_a = torch.matmul(attn_b, b)return feature_a, feature_b

增强残差连接

image-20231119145423766

class AugmentedResidualConnection(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: Tensor, res: Tensor, i: int) -> Tensor:"""Args:x (Tensor): the output of pre block (batch_size, seq_len, hidden_size)res (Tensor): (batch_size, seq_len, embedding_size) or (batch_size, seq_len, embedding_size + hidden_size)res[:,:,hidden_size:] is the output of Embedding layerres[:,:,:hidden_size] is the output of previous two blocki (int): layer indexReturns:Tensor: (batch_size, seq_len,  hidden_size  + embedding_size)"""if i == 1:# (batch_size, seq_len,  hidden_size  + embedding_size)return torch.cat([x, res], dim=-1)hidden_size = x.size(-1)# (res[:, :, :hidden_size] + x) is the summation of the output of previous two blocks# x (batch_size, seq_len, hidden_size)x = (res[:, :, :hidden_size] + x) * math.sqrt(0.5)# (batch_size, seq_len,  hidden_size  + embedding_size)return torch.cat([x, res[:, :, hidden_size:]], dim=-1)

为了给对齐处理提供更丰富的特征,RE2采用了一个增强版的残差连接,用于每个块之间。

对于一个长为 l l l的序列,标记第 n n n个块的输入和输出分别为 x ( n ) = ( x 1 ( n ) , x 2 ( n ) , ⋯ , x l ( n ) ) x^{(n)} = (x^{(n)}_1,x^{(n)}_2,\cdots,x^{(n)}_l) x(n)=(x1(n),x2(n),,xl(n)) o ( n ) = ( o 1 ( n ) , o 2 ( n ) , ⋯ , o l ( n ) ) o^{(n)} = (o^{(n)}_1,o^{(n)}_2,\cdots,o^{(n)}_l) o(n)=(o1(n),o2(n),,ol(n)) o ( 0 ) o^{(0)} o(0)表示零向量序列。

第一个块的输入 x ( 1 ) x^{(1)} x(1)是嵌入层的输出,由图1中的空心矩形表示;第 n ( n ≥ 2 ) n(n\geq 2) n(n2)块的输入 x ( n ) x^{(n)} x(n)是第一块的输入 x ( 1 ) x^{(1)} x(1)和前面两块的输出进行求和后的拼接(图中的对角斜线矩形):
x i ( n ) = [ x i ( 1 ) ; o i ( n − 1 ) + o i ( n − 2 ) ] x^{(n)}_i =[x^{(1)}_i;o^{(n-1)}_i + o^{(n-2)}_i ] xi(n)=[xi(1);oi(n1)+oi(n2)]
公式更加清楚一点,第 n n n块的输入是由两个向量拼接而来,第一个向量是第一块的输入,第二个向量是第 n n n块前面两块的输出进行(元素级)累加。这个就是增强的残差连接

融合层

class Fusion(nn.Module):def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:"""Args:input_size (int): embedding_dim  + hidden_size  or embedding_dim  + hidden_size * 2hidden_size (int): hidden sizedropout (float): dropout ratio"""super().__init__()self.dropout = nn.Dropout(dropout)self.fusion1 = Linear(input_size * 2, hidden_size, activations=True)self.fusion2 = Linear(input_size * 2, hidden_size, activations=True)self.fusion3 = Linear(input_size * 2, hidden_size, activations=True)self.fusion = Linear(hidden_size * 3, hidden_size, activations=True)def forward(self, x: Tensor, align: Tensor) -> Tensor:"""Args:x (Tensor): input (batch_size, seq_len, input_size)align (Tensor): output of Alignment (batch_size, seq_len,  input_size)Returns:Tensor: (batch_size, seq_len, hidden_size)"""# x1 (batch_size, seq_len, hidden_size)x1 = self.fusion1(torch.cat([x, align], dim=-1))# x2 (batch_size, seq_len, hidden_size)x2 = self.fusion1(torch.cat([x, x - align], dim=-1))# x3 (batch_size, seq_len, hidden_size)x3 = self.fusion1(torch.cat([x, x * align], dim=-1))# x (batch_size, seq_len, hidden_size * 3)x = torch.cat([x1, x2, x3], dim=-1)x = self.dropout(x)# (batch_size, seq_len, hidden_size)return self.fusion(x)

融合层通过三个方面比较了局部和对齐表示(分别为对齐层的输入和输出),然后对它们进行融合。

对于第一个序列来说,融合层的输出 a ˉ \bar a aˉ计算为:
a ˉ i 1 = G 1 ( [ a i ; a i ′ ] ) , a ˉ i 2 = G 2 ( [ a i ; a i − a i ′ ] ) , a ˉ i 3 = G 3 ( [ a i ; a i ∘ a i ′ ] ) , a ˉ i = G ( [ a ˉ i 1 ; a ˉ i 2 ; a ˉ i 3 ] ) , \begin{aligned} \bar a_i^1 &= G_1([a_i;a_i^\prime]), \\ \bar a_i^2 &= G_2([a_i;a_i - a_i^\prime]), \\ \bar a_i^3 &= G_3([a_i;a_i \circ a_i^\prime]), \\ \bar a_i &= G([\bar a_i^1;\bar a_i^2;\bar a_i^3]), \end{aligned} aˉi1aˉi2aˉi3aˉi=G1([ai;ai]),=G2([ai;aiai]),=G3([ai;aiai]),=G([aˉi1;aˉi2;aˉi3]),
这里 G 1 , G 2 , G 3 G_1,G_2,G_3 G1,G2,G3 G G G都是参数独立的单层前馈网络; ∘ \circ 表示元素级乘法。

差操作( − - )强调了两个向量的区别,而乘操作强调了它们的相似。对于另一个序列 b ˉ \bar b bˉ的计算是类似的。

这些操作和ESIM有点类似,增加了一个前馈网络。

完了之后通过一个池化层得到定长向量。

池化层

class Pooling(nn.Module):def forward(self, x: Tensor, mask: Tensor) -> Tensor:"""Args:x (Tensor): (batch_size, seq_len, hidden_size)mask (Tensor): (batch_size, seq_len, 1)Returns:Tensor: (batch_size, hidden_size)"""# max returns a namedtuple (values, indices), we only need valuesreturn x.masked_fill(~mask, -float("inf")).max(dim=1)[0]

池化层取时间步维度上的最大值。

预测层

class Prediction(nn.Module):def __init__(self, hidden_size: int, num_classes: int, dropout: float) -> None:super().__init__()self.dense = nn.Sequential(nn.Dropout(dropout),Linear(hidden_size * 4, hidden_size, activations=True),nn.Dropout(dropout),Linear(hidden_size, num_classes),)def forward(self, a: Tensor, b: Tensor) -> Tensor:"""Args:a (Tensor): (batch_size, hidden_size)b (Tensor): (batch_size, hidden_size)Returns:Tensor: (batch_size, num_classes)"""return self.dense(torch.cat([a, b, a - b, a * b], dim=-1))

预测层比较简单,再次对输入向量进行了一个融合:
y ^ = H ( [ v 1 ; v 2 ; v 1 − v 2 ; v 1 ∘ v 2 ] ) \hat {\pmb y} = H([v_1;v_2;v_1-v_2;v1 \circ v_2]) y^=H([v1;v2;v1v2;v1v2])

RE2实现

RE2的实现时上述层的堆叠:

class RE2(nn.Module):def __init__(self, args) -> None:super().__init__()self.embedding = Embedding(args.vocab_size, args.embedding_dim, args.dropout)self.connection = AugmentedResidualConnection()self.blocks = nn.ModuleList([nn.ModuleDict({"encoder": Encoder(args.embedding_dimif i == 0else args.embedding_dim + args.hidden_size,args.hidden_size,args.kernel_sizes,args.encoder_layers,args.dropout,),"alignment": Alignment(args.embedding_dim + args.hidden_sizeif i == 0else args.embedding_dim + args.hidden_size * 2,args.hidden_size,args.dropout,args.project_func,),"fusion": Fusion(args.embedding_dim + args.hidden_sizeif i == 0else args.embedding_dim + args.hidden_size * 2,args.hidden_size,args.dropout,),})for i in range(args.num_blocks)])self.pooling = Pooling()self.prediction = Prediction(args.hidden_size, args.num_classes, args.dropout)def forward(self, a: Tensor, b: Tensor, mask_a: Tensor, mask_b: Tensor) -> Tensor:"""Args:a (Tensor): (batch_size, seq_len)b (Tensor): (batch_size, seq_len)mask_a (Tensor): (batch_size, seq_len, 1)mask_b (Tensor): (batch_size, seq_len, 1)Returns:Tensor: (batch_size, num_classes)"""# a (batch_size, seq_len, embedding_dim)a = self.embedding(a)# b (batch_size, seq_len, embedding_dim)b = self.embedding(b)res_a, res_b = a, bfor i, block in enumerate(self.blocks):if i > 0:# a (batch_size, seq_len, embedding_dim + hidden_size)a = self.connection(a, res_a, i)# b (batch_size, seq_len, embedding_dim + hidden_size)b = self.connection(b, res_b, i)# now embeddings saved to res_a[:,:,hidden_size:]res_a, res_b = a, b# a_enc (batch_size, seq_len, hidden_size)a_enc = block["encoder"](a, mask_a)# b_enc (batch_size, seq_len, hidden_size)b_enc = block["encoder"](b, mask_b)# concating the input and output of encoder# a (batch_size, seq_len, embedding_dim + hidden_size or embedding_dim + hidden_size * 2)a = torch.cat([a, a_enc], dim=-1)# b (batch_size, seq_len, embedding_dim + hidden_size or embedding_dim + hidden_size * 2)b = torch.cat([b, b_enc], dim=-1)# align_a (batch_size, seq_len,  embedding_dim + hidden_size or embedding_dim + hidden_size * 2)# align_b (batch_size, seq_len,  embedding_dim + hidden_size or embedding_dim + hidden_size * 2)align_a, align_b = block["alignment"](a, b, mask_a, mask_b)# a (batch_size, seq_len,  hidden_size)a = block["fusion"](a, align_a)# b (batch_size, seq_len,  hidden_size)b = block["fusion"](b, align_b)# a (batch_size, hidden_size)a = self.pooling(a, mask_a)# b (batch_size, hidden_size)b = self.pooling(b, mask_b)# (batch_size, num_classes)return self.prediction(a, b)

注意不同块之间输入维度的区别。

数据准备

在→文章←中数据准备这部分内容有详细的解释。

from collections import defaultdict
from tqdm import tqdm
import numpy as np
import json
from torch.utils.data import Dataset
import pandas as pd
from typing import TupleUNK_TOKEN = "<UNK>"
PAD_TOKEN = "<PAD>"class Vocabulary:"""Class to process text and extract vocabulary for mapping"""def __init__(self, token_to_idx: dict = None, tokens: list[str] = None) -> None:"""Args:token_to_idx (dict, optional): a pre-existing map of tokens to indices. Defaults to None.tokens (list[str], optional): a list of unique tokens with no duplicates. Defaults to None."""assert any([tokens, token_to_idx]), "At least one of these parameters should be set as not None."if token_to_idx:self._token_to_idx = token_to_idxelse:self._token_to_idx = {}if PAD_TOKEN not in tokens:tokens = [PAD_TOKEN] + tokensfor idx, token in enumerate(tokens):self._token_to_idx[token] = idxself._idx_to_token = {idx: token for token, idx in self._token_to_idx.items()}self.unk_index = self._token_to_idx[UNK_TOKEN]self.pad_index = self._token_to_idx[PAD_TOKEN]@classmethoddef build(cls,sentences: list[list[str]],min_freq: int = 2,reserved_tokens: list[str] = None,) -> "Vocabulary":"""Construct the Vocabulary from sentencesArgs:sentences (list[list[str]]): a list of tokenized sequencesmin_freq (int, optional): the minimum word frequency to be saved. Defaults to 2.reserved_tokens (list[str], optional): the reserved tokens to add into the Vocabulary. Defaults to None.Returns:Vocabulary: a Vocubulary instane"""token_freqs = defaultdict(int)for sentence in tqdm(sentences):for token in sentence:token_freqs[token] += 1unique_tokens = (reserved_tokens if reserved_tokens else []) + [UNK_TOKEN]unique_tokens += [tokenfor token, freq in token_freqs.items()if freq >= min_freq and token != UNK_TOKEN]return cls(tokens=unique_tokens)def __len__(self) -> int:return len(self._idx_to_token)def __getitem__(self, tokens: list[str] | str) -> list[int] | int:"""Retrieve the indices associated with the tokens or the index with the single tokenArgs:tokens (list[str] | str): a list of tokens or single tokenReturns:list[int] | int: the indices or the single index"""if not isinstance(tokens, (list, tuple)):return self._token_to_idx.get(tokens, self.unk_index)return [self.__getitem__(token) for token in tokens]def lookup_token(self, indices: list[int] | int) -> list[str] | str:"""Retrive the tokens associated with the indices or the token with the single indexArgs:indices (list[int] | int): a list of index or single indexReturns:list[str] | str: the corresponding tokens (or token)"""if not isinstance(indices, (list, tuple)):return self._idx_to_token[indices]return [self._idx_to_token[index] for index in indices]def to_serializable(self) -> dict:"""Returns a dictionary that can be serialized"""return {"token_to_idx": self._token_to_idx}@classmethoddef from_serializable(cls, contents: dict) -> "Vocabulary":"""Instantiates the Vocabulary from a serialized dictionaryArgs:contents (dict): a dictionary generated by `to_serializable`Returns:Vocabulary: the Vocabulary instance"""return cls(**contents)def __repr__(self):return f"<Vocabulary(size={len(self)})>"class TMVectorizer:"""The Vectorizer which vectorizes the Vocabulary"""def __init__(self, vocab: Vocabulary, max_len: int) -> None:"""Args:vocab (Vocabulary): maps characters to integersmax_len (int): the max length of the sequence in the dataset"""self.vocab = vocabself.max_len = max_lenself.padding_index = vocab.pad_indexdef _vectorize(self, indices: list[int], vector_length: int = -1) -> np.ndarray:"""Vectorize the provided indicesArgs:indices (list[int]): a list of integers that represent a sequencevector_length (int, optional): an arugment for forcing the length of index vector. Defaults to -1.Returns:np.ndarray: the vectorized index array"""if vector_length <= 0:vector_length = len(indices)vector = np.zeros(vector_length, dtype=np.int64)if len(indices) > vector_length:vector[:] = indices[:vector_length]else:vector[: len(indices)] = indicesvector[len(indices) :] = self.padding_indexreturn vectordef _get_indices(self, sentence: list[str]) -> list[int]:"""Return the vectorized sentenceArgs:sentence (list[str]): list of tokensReturns:indices (list[int]): list of integers representing the sentence"""return [self.vocab[token] for token in sentence]def vectorize(self, sentence: list[str], use_dataset_max_length: bool = True) -> np.ndarray:"""Return the vectorized sequenceArgs:sentence (list[str]): raw sentence from the datasetuse_dataset_max_length (bool): whether to use the global max vector lengthReturns:the vectorized sequence with padding"""vector_length = -1if use_dataset_max_length:vector_length = self.max_lenindices = self._get_indices(sentence)vector = self._vectorize(indices, vector_length=vector_length)return vector@classmethoddef from_serializable(cls, contents: dict) -> "TMVectorizer":"""Instantiates the TMVectorizer from a serialized dictionaryArgs:contents (dict): a dictionary generated by `to_serializable`Returns:TMVectorizer:"""vocab = Vocabulary.from_serializable(contents["vocab"])max_len = contents["max_len"]return cls(vocab=vocab, max_len=max_len)def to_serializable(self) -> dict:"""Returns a dictionary that can be serializedReturns:dict: a dict contains Vocabulary instance and max_len attribute"""return {"vocab": self.vocab.to_serializable(), "max_len": self.max_len}def save_vectorizer(self, filepath: str) -> None:"""Dump this TMVectorizer instance to fileArgs:filepath (str): the path to store the file"""with open(filepath, "w") as f:json.dump(self.to_serializable(), f)@classmethoddef load_vectorizer(cls, filepath: str) -> "TMVectorizer":"""Load TMVectorizer from a fileArgs:filepath (str): the path stored the fileReturns:TMVectorizer:"""with open(filepath) as f:return TMVectorizer.from_serializable(json.load(f))class TMDataset(Dataset):"""Dataset for text matching"""def __init__(self, text_df: pd.DataFrame, vectorizer: TMVectorizer) -> None:"""Args:text_df (pd.DataFrame): a DataFrame which contains the processed data examplesvectorizer (TMVectorizer): a TMVectorizer instance"""self.text_df = text_dfself._vectorizer = vectorizerdef __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:row = self.text_df.iloc[index]vector1 = self._vectorizer.vectorize(row.sentence1)vector2 = self._vectorizer.vectorize(row.sentence2)mask1 = vector1 != self._vectorizer.padding_indexmask2 = vector2 != self._vectorizer.padding_indexreturn (vector1, vector2, mask1, mask2, row.label)def get_vectorizer(self) -> TMVectorizer:return self._vectorizerdef __len__(self) -> int:return len(self.text_df)

和之前的文章差不多,唯一的区别增加了填充位置的mask。

模型训练

learning_rate=1e-3,
batch_size=256,
num_epochs=10,
max_len=50,
embedding_dim=300,
hidden_size=150,
encoder_layers=2,
num_blocks=2,
kernel_sizes=[3],
dropout=0.2,
min_freq=2,
project_func="linear",
grad_clipping=2.0,
print_every=300,
num_classes=2,

经过几次实验,表现最好的配置如上所示,学习率为0.2;梯度裁剪为2.0。

如论文所述,增加了梯度裁剪,学习率指数衰减通过用AdamW替换。

验证和训练函数为:

def evaluate(data_iter: DataLoader, model: nn.Module
) -> Tuple[float, float, float, float]:y_list, y_pred_list = [], []model.eval()for x1, x2, mask1, mask2, y in tqdm(data_iter):x1 = x1.to(device).long()x2 = x2.to(device).long()mask1 = mask1.to(device).bool().unsqueeze(2)mask2 = mask2.to(device).bool().unsqueeze(2)y = y.float().to(device)output = model(x1, x2, mask1, mask2)pred = torch.argmax(output, dim=1).long()y_pred_list.append(pred)y_list.append(y)y_pred = torch.cat(y_pred_list, 0)y = torch.cat(y_list, 0)acc, p, r, f1 = metrics(y, y_pred)return acc, p, r, f1def train(data_iter: DataLoader,model: nn.Module,criterion: nn.CrossEntropyLoss,optimizer: torch.optim.Optimizer,grad_clipping: float,print_every: int = 500,verbose=True,
) -> None:model.train()for step, (x1, x2, mask1, mask2, y) in enumerate(tqdm(data_iter)):x1 = x1.to(device).long()x2 = x2.to(device).long()mask1 = mask1.to(device).bool().unsqueeze(2)mask2 = mask2.to(device).bool().unsqueeze(2)y = torch.LongTensor(y).to(device)output = model(x1, x2, mask1, mask2)loss = criterion(output, y)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping)optimizer.step()if verbose and (step + 1) % print_every == 0:pred = torch.argmax(output, dim=1).long()acc, p, r, f1 = metrics(y, pred)print(f" TRAIN iter={step+1} loss={loss.item():.6f} accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")

核心训练代码为:

    
model = RE2(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
criterion = nn.CrossEntropyLoss()print(f"Model: {model}")for epoch in range(args.num_epochs):train(train_data_loader,model,criterion,optimizer,args.grad_clipping,print_every=args.print_every,verbose=args.verbose,)print("Begin evalute on dev set.")with torch.no_grad():acc, p, r, f1 = evaluate(dev_data_loader, model)print(f"EVALUATE [{epoch+1}/{args.num_epochs}]  accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")model.eval()acc, p, r, f1 = evaluate(test_data_loader, model)
print(f"TEST accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")
model = RE2(args)print(f"Model: {model}")model_saved_path = os.path.join(args.save_dir, args.model_state_file)if args.reload_model and os.path.exists(model_saved_path):model.load_state_dict(torch.load(args.model_saved_path))print("Reloaded model")else:print("New model")model = model.to(device)model_save_path = os.path.join(args.save_dir, f"{datetime.now().strftime('%Y%m%d%H%M%S')}-model.pth")train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_data_loader = DataLoader(dev_dataset, batch_size=args.batch_size)test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size)optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)criterion = nn.CrossEntropyLoss()for epoch in range(args.num_epochs):train(train_data_loader,model,criterion,optimizer,args.grad_clipping,print_every=args.print_every,verbose=args.verbose,)print("Begin evalute on dev set.")with torch.no_grad():acc, p, r, f1 = evaluate(dev_data_loader, model)print(f"EVALUATE [{epoch+1}/{args.num_epochs}]  accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")model.eval()acc, p, r, f1 = evaluate(test_data_loader, model)print(f"TEST accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")
Arguments : Namespace(dataset_csv='text_matching/data/lcqmc/{}.txt', vectorizer_file='vectorizer.json', model_state_file='model.pth', pandas_file='dataframe.{}.pkl', save_dir='D:\\workspace\\nlp-in-action\\text_matching\\re2\\model_storage', reload_model=False, cuda=True, learning_rate=0.001, batch_size=256, num_epochs=10, max_len=50, embedding_dim=300, hidden_size=150, encoder_layers=2, num_blocks=2, kernel_sizes=[3], dropout=0.2, min_freq=2, project_func='linear', grad_clipping=2.0, print_every=300, lr_decay_rate=0.95, num_classes=2, verbose=True)
Using device: cuda:0.
Loads cached dataframes.
Loads vectorizer file.
Model: RE2((embedding): Embedding((embedding): Embedding(35925, 300, padding_idx=0)(dropout): Dropout(p=0.2, inplace=False))(connection): AugmentedResidualConnection()(blocks): ModuleList((0): ModuleDict((encoder): Encoder((encoders): ModuleList((0): Conv1d((model): ModuleList((0): Sequential((0): Conv1d(300, 150, kernel_size=(3,), stride=(1,), padding=(1,))(1): GeLU())))(1): Conv1d((model): ModuleList((0): Sequential((0): Conv1d(150, 150, kernel_size=(3,), stride=(1,), padding=(1,))(1): GeLU()))))(dropout): Dropout(p=0.2, inplace=False))(alignment): Alignment((projection): Sequential((0): Dropout(p=0.2, inplace=False)(1): Linear((model): Sequential((0): Linear(in_features=450, out_features=150, bias=True)(1): GeLU()))))(fusion): Fusion((dropout): Dropout(p=0.2, inplace=False)(fusion1): Linear((model): Sequential((0): Linear(in_features=900, out_features=150, bias=True)(1): GeLU()))(fusion2): Linear((model): Sequential((0): Linear(in_features=900, out_features=150, bias=True)(1): GeLU()))(fusion3): Linear((model): Sequential((0): Linear(in_features=900, out_features=150, bias=True)(1): GeLU()))(fusion): Linear((model): Sequential((0): Linear(in_features=450, out_features=150, bias=True)(1): GeLU()))))(1): ModuleDict((encoder): Encoder((encoders): ModuleList((0): Conv1d((model): ModuleList((0): Sequential((0): Conv1d(450, 150, kernel_size=(3,), stride=(1,), padding=(1,))(1): GeLU())))(1): Conv1d((model): ModuleList((0): Sequential((0): Conv1d(150, 150, kernel_size=(3,), stride=(1,), padding=(1,))(1): GeLU()))))(dropout): Dropout(p=0.2, inplace=False))(alignment): Alignment((projection): Sequential((0): Dropout(p=0.2, inplace=False)(1): Linear((model): Sequential((0): Linear(in_features=600, out_features=150, bias=True)(1): GeLU()))))(fusion): Fusion((dropout): Dropout(p=0.2, inplace=False)(fusion1): Linear((model): Sequential((0): Linear(in_features=1200, out_features=150, bias=True)(1): GeLU()))(fusion2): Linear((model): Sequential((0): Linear(in_features=1200, out_features=150, bias=True)(1): GeLU()))(fusion3): Linear((model): Sequential((0): Linear(in_features=1200, out_features=150, bias=True)(1): GeLU()))(fusion): Linear((model): Sequential((0): Linear(in_features=450, out_features=150, bias=True)(1): GeLU())))))(pooling): Pooling()(prediction): Prediction((dense): Sequential((0): Dropout(p=0.2, inplace=False)(1): Linear((model): Sequential((0): Linear(in_features=600, out_features=150, bias=True)(1): GeLU()))(2): Dropout(p=0.2, inplace=False)(3): Linear((model): Sequential((0): Linear(in_features=150, out_features=2, bias=True)(1): GeLU()))))
)
New model32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:16<02:38,  4.00it/s] 
TRAIN iter=300 loss=0.273509 accuracy=0.887 precision=0.885 recal=0.926 f1 score=0.904964%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:31<01:23,  3.99it/s] 
TRAIN iter=600 loss=0.296151 accuracy=0.859 precision=0.897 recal=0.861 f1 score=0.878496%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:46<00:08,  4.00it/s] 
TRAIN iter=900 loss=0.262893 accuracy=0.875 precision=0.887 recal=0.887 f1 score=0.8873
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:54<00:00,  3.98it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.60it/s] 
EVALUATE [1/10]  accuracy=0.752 precision=0.737 recal=0.783 f1 score=0.759232%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:37,  4.03it/s] 
TRAIN iter=300 loss=0.272779 accuracy=0.898 precision=0.919 recal=0.907 f1 score=0.913364%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:23,  3.98it/s] 
TRAIN iter=600 loss=0.238999 accuracy=0.898 precision=0.907 recal=0.930 f1 score=0.918796%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:44<00:08,  4.00it/s] 
TRAIN iter=900 loss=0.225822 accuracy=0.910 precision=0.929 recal=0.909 f1 score=0.9187
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:52<00:00,  4.01it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.59it/s] 
EVALUATE [2/10]  accuracy=0.787 precision=0.763 recal=0.831 f1 score=0.795632%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:37,  4.03it/s] 
TRAIN iter=300 loss=0.260889 accuracy=0.902 precision=0.929 recal=0.912 f1 score=0.920664%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:22,  4.03it/s] 
TRAIN iter=600 loss=0.216830 accuracy=0.910 precision=0.929 recal=0.923 f1 score=0.925696%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:43<00:08,  4.06it/s] 
TRAIN iter=900 loss=0.162659 accuracy=0.945 precision=0.944 recal=0.958 f1 score=0.9510
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:51<00:00,  4.02it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
EVALUATE [3/10]  accuracy=0.816 precision=0.809 recal=0.827 f1 score=0.817932%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:36,  4.06it/s] 
TRAIN iter=300 loss=0.228807 accuracy=0.906 precision=0.909 recal=0.922 f1 score=0.915564%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.05it/s] 
TRAIN iter=600 loss=0.186292 accuracy=0.926 precision=0.932 recal=0.938 f1 score=0.934796%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.06it/s] 
TRAIN iter=900 loss=0.160805 accuracy=0.953 precision=0.957 recal=0.957 f1 score=0.9568
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
EVALUATE [4/10]  accuracy=0.814 precision=0.804 recal=0.832 f1 score=0.817632%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:13<02:36,  4.06it/s] 
TRAIN iter=300 loss=0.190363 accuracy=0.910 precision=0.926 recal=0.919 f1 score=0.922664%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.04it/s] 
TRAIN iter=600 loss=0.190028 accuracy=0.918 precision=0.901 recal=0.967 f1 score=0.932596%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.05it/s] 
TRAIN iter=900 loss=0.170661 accuracy=0.930 precision=0.957 recal=0.918 f1 score=0.9375
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
EVALUATE [5/10]  accuracy=0.810 precision=0.775 recal=0.873 f1 score=0.821232%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:40,  3.95it/s] 
TRAIN iter=300 loss=0.125980 accuracy=0.965 precision=0.974 recal=0.968 f1 score=0.970964%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.05it/s] 
TRAIN iter=600 loss=0.160912 accuracy=0.930 precision=0.928 recal=0.953 f1 score=0.940496%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.05it/s] 
TRAIN iter=900 loss=0.159766 accuracy=0.930 precision=0.922 recal=0.959 f1 score=0.9400
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s] 
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.74it/s] 
EVALUATE [6/10]  accuracy=0.815 precision=0.777 recal=0.885 f1 score=0.827132%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:13<02:36,  4.04it/s] 
TRAIN iter=300 loss=0.144144 accuracy=0.941 precision=0.973 recal=0.929 f1 score=0.950864%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.06it/s] 
TRAIN iter=600 loss=0.149635 accuracy=0.934 precision=0.922 recal=0.975 f1 score=0.947796%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.06it/s] 
TRAIN iter=900 loss=0.151699 accuracy=0.938 precision=0.926 recal=0.974 f1 score=0.9497
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s] 
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
EVALUATE [7/10]  accuracy=0.831 precision=0.806 recal=0.874 f1 score=0.838332%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:36,  4.04it/s] 
TRAIN iter=300 loss=0.191586 accuracy=0.922 precision=0.908 recal=0.967 f1 score=0.936764%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:23,  3.98it/s] 
TRAIN iter=600 loss=0.188188 accuracy=0.930 precision=0.947 recal=0.935 f1 score=0.941296%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:44<00:08,  4.03it/s] 
TRAIN iter=900 loss=0.196099 accuracy=0.910 precision=0.939 recal=0.892 f1 score=0.9151
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:53<00:00,  4.00it/s] 
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.66it/s] 
EVALUATE [8/10]  accuracy=0.838 precision=0.817 recal=0.870 f1 score=0.842632%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:15<02:36,  4.04it/s] 
TRAIN iter=300 loss=0.136444 accuracy=0.953 precision=0.986 recal=0.934 f1 score=0.959264%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:30<01:22,  4.05it/s] 
TRAIN iter=600 loss=0.137828 accuracy=0.949 precision=0.953 recal=0.959 f1 score=0.955996%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:45<00:08,  3.98it/s] 
TRAIN iter=900 loss=0.148434 accuracy=0.934 precision=0.947 recal=0.941 f1 score=0.9439
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:53<00:00,  3.99it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.39it/s] 
EVALUATE [9/10]  accuracy=0.840 precision=0.814 recal=0.883 f1 score=0.847132%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:15<02:38,  4.01it/s] 
TRAIN iter=300 loss=0.223042 accuracy=0.918 precision=0.904 recal=0.968 f1 score=0.935064%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:23,  4.02it/s] 
TRAIN iter=600 loss=0.105175 accuracy=0.965 precision=0.971 recal=0.964 f1 score=0.967796%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:45<00:08,  4.04it/s] 
TRAIN iter=900 loss=0.110603 accuracy=0.953 precision=0.934 recal=0.986 f1 score=0.9592
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:53<00:00,  4.00it/s]
Begin evalute on dev set.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.66it/s] 
EVALUATE [10/10]  accuracy=0.836 precision=0.819 recal=0.863 f1 score=0.8406
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 14.59it/s] 
TEST accuracy=0.822 precision=0.762 recal=0.936 f1 score=0.8403

这是在没有使用预训练的词向量前提下达到的准确率,后面机会自己训练一个word2vec词向量然后结合起来用看下效果。

完整代码

https://github.com/nlp-greyfoss/nlp-in-action-public/blob/master/text_matching/re2/model.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/195112.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从零开始:Rust环境搭建指南

大家好&#xff01;我是lincyang。 今天&#xff0c;我们将一起探讨如何从零开始搭建Rust开发环境。 Rust环境搭建概览 Rust是一种系统编程语言&#xff0c;以其安全性、并发性和性能闻名。搭建Rust环境是学习和使用这一语言的第一步。 第一步&#xff1a;安装Rust Rust的…

nginx后端服务器在负载均衡调度中的状态

状态说明 down 状态说明当前的sever暂时不参与负载均衡

ACWSpring1.3

首先,前端写ajax写上我们的访问路径(就在我们前端的源代码里面),我们建了两个包pkController用于前端页面url映射过来一层一层找到我们的RestController返回bot1里面有键值,返回的这就是一个session对象bot1这个map.前端拿到我们bot1里的两个值给到我们前端显示出来 1准备页面:…

Java概述

接触Java后会发现它的体系有一个特点&#xff0c;就是非常喜欢用“J”字母开头的缩写&#xff0c;比如JCP, JSR, JMS, JPA, JSP, JAX-RS......它们有些是规范&#xff0c;有些是组织的名称&#xff0c;表意多样&#xff0c;对第一次接触的人来说很可能会觉得混乱&#xff0c;本…

吾爱破解置顶的“太极”,太好用了吧!

日常工作和娱乐&#xff0c;都需要用到不同类型的软件&#xff0c;哪怕软件体积不大&#xff0c;也必须安装&#xff0c;否则到用时找不到就非常麻烦了。 其实&#xff0c;很多软件不一定一样不剩地全部安装一遍&#xff0c;一方面原因是用的不多&#xff0c;另一方面多少有点…

spring常见面试题总结

1、spring是什么 Spring&#xff1a;是一个轻量级的IOC和AOP的java开发框架&#xff0c;为了简化企业级开发而生。核心就是控制反转和面向切面编程。 IOC&#xff1a;控制反转&#xff08;Inverse of Control&#xff09;&#xff0c;以前项目都是在哪儿用到对象 在哪儿new&a…

【算法基础】动态规划

背包问题 01背包 每个物品只能放一次 2. 01背包问题 - AcWing题库 二维dp #include<bits/stdc.h> const int N1010; int f[N][N]; int v[N],w[N]; signed main() {int n,m;std::cin>>n>>m; for(int i1;i<n;i) std::cin>>v[i]>>w[i];for…

8.1 Windows驱动开发:内核文件读写系列函数

在应用层下的文件操作只需要调用微软应用层下的API函数及C库标准函数即可&#xff0c;而如果在内核中读写文件则应用层的API显然是无法被使用的&#xff0c;内核层需要使用内核专有API&#xff0c;某些应用层下的API只需要增加Zw开头即可在内核中使用&#xff0c;例如本章要讲解…

读书笔记--从一到无穷大的关键金句和阅读感悟

借着休假&#xff0c;重新研读了十多年前读过的乔治.伽莫夫所著图书《从一到无穷大--ONE TWO THREE...INFINITY》&#xff0c;该书作为20世纪最经典的科普类图书之一&#xff0c;当时读的懵懵懂懂&#xff0c;现在重新阅读又有了不同的感受&#xff0c;再结合过去的科研工作&am…

【win32_000】视频截图

PPT 编译器不会自己添加unicode定义 v 函数 WinMain int __clrcall WinMain([in] HINSTANCE hInstance ,//应用程序的当前实例的句柄。[in, optional] HINSTANCE hPrevInstance ,//应用程序上一个实例的句柄。 此参数始终为 NULL。[in] …

华为流量统计的2种配置方法

第一种就是用实打实的配置去设置&#xff0c;配置比较复杂 比如华为防火墙流量统计&#xff1a; acl 3000 rule permit ip source 192.168.0.1 0.0.0.0 dest 10.0.0.1 0.0.0.0 diagnose firewall statistic acl 3000 enable dis firewall statistics acl //流量查看 另一种最…

Vulhub靶场-KIOPTRIX: LEVEL 1.1

目录 环境配置 端口扫描 漏送发现 漏送利用 提权&#xff08;内核漏洞提权&#xff09; 环境配置 环境配置的过程同主页该专栏第一个靶场&#xff0c;不在赘述。 端口扫描 首先通过arp-scan并根据靶机的mac地址确定靶机的IP地址 然后对靶机进行一个扫描 首先扫描到OpenS…