1. 参考
M3-Embedding
https://github.com/FlagOpen/FlagEmbedding
https://arxiv.org/pdf/2402.03216
https://huggingface.co/BAAI/bge-m3
2. Dense retrieval
import torch
import torch.nn as nnclass DenseRetrieval(nn.Module):def __init__(self, embedding_dim):super(DenseRetrieval, self).__init__()self.query_encoder = nn.Sequential(nn.Linear(embedding_dim, 128),nn.ReLU(),nn.Linear(128, 64))self.doc_encoder = nn.Sequential(nn.Linear(embedding_dim, 128),nn.ReLU(),nn.Linear(128, 64))def forward(self, query_embeddings, doc_embeddings):query_vectors = self.query_encoder(query_embeddings)doc_vectors = self.doc_encoder(doc_embeddings)# 计算余弦相似度或其他相似度scores = torch.cosine_similarity(query_vectors.unsqueeze(1), doc_vectors.unsqueeze(0), dim=2)return scores
2. Lexical Retrieval
import torch
import torch.nn as nnclass DenseRetrieval(nn.Module):def __init__(self, embedding_dim):super(DenseRetrieval, self).__init__()self.query_encoder = nn.Sequential(nn.Linear(embedding_dim, 128),nn.ReLU(),nn.Linear(128, 64))self.doc_encoder = nn.Sequential(nn.Linear(embedding_dim, 128),nn.ReLU(),nn.Linear(128, 64))def forward(self, query_embeddings, doc_embeddings):query_vectors = self.query_encoder(query_embeddings)doc_vectors = self.doc_encoder(doc_embeddings)# 计算余弦相似度或其他相似度scores = torch.cosine_similarity(query_vectors.unsqueeze(1), doc_vectors.unsqueeze(0), dim=2)return scores
4. Multi-Vector Retrieval
class MultiVectorRetrieval(nn.Module):def __init__(self, embedding_dim, num_vectors):super(MultiVectorRetrieval, self).__init__()self.num_vectors = num_vectorsself.projection = nn.Linear(embedding_dim, embedding_dim * num_vectors)def forward(self, query_embeddings, doc_embeddings):projected_query = self.projection(query_embeddings).view(-1, self.num_vectors, embedding_dim)projected_doc = self.projection(doc_embeddings).view(-1, self.num_vectors, embedding_dim)# 对每个向量计算相似度并取最大值similarities = torch.bmm(projected_query, projected_doc.transpose(1, 2))max_similarities, _ = torch.max(similarities, dim=-1)avg_similarity = torch.mean(max_similarities, dim=1)return avg_similarity