廣州建設(shè)網(wǎng)站公司企業(yè)線上培訓(xùn)平臺
NLP文本匹配任務(wù)Text Matching [有監(jiān)督訓(xùn)練]:PointWise(單塔)、DSSM(雙塔)、Sentence BERT(雙塔)項目實踐
0 背景介紹以及相關(guān)概念
本項目對3種常用的文本匹配的方法進(jìn)行實現(xiàn):PointWise(單塔)、DSSM(雙塔)、Sentence BERT(雙塔)。
文本匹配(Text Matching)是 NLP 下的一個分支,通常用于計算兩個句子之間的相似程度,在推薦、推理等場景下都有著重要的作用。
舉例來講,今天我們有一堆評論數(shù)據(jù),我們想要找到指定類別的評論數(shù)據(jù),例如:
1. 為什么是開過的洗發(fā)水都流出來了,是用過的嗎?是這樣子包裝的嗎?
2. 喜歡折疊手機(jī)的我對這款手機(jī)情有獨鐘,簡潔的外觀設(shè)計非常符合當(dāng)代年輕人的口味,給攜帶增添了一份愉悅。
3. 物流很快,但是到貨時有的水果已經(jīng)不新鮮了,壞掉了,不滿意本次購物。
...
在這一堆評論中我們想找到跟「水果」相關(guān)的評論,那么第 3 條評論就應(yīng)該被召回。這個問題可以被建模為文本分類對吧,通過訓(xùn)練一個文本分類模型也能達(dá)到同樣的目的。
但,分類模型的主要問題是:分類標(biāo)簽是固定的。假如在訓(xùn)練的時候標(biāo)簽集合是「洗浴,電腦,水果」,今天再來一條「服飾」的評論,那么模型依舊只能在原有的標(biāo)簽集合里面進(jìn)行推理,無論推到哪個都是錯誤的。因此,我們需要一個能夠有一定自適應(yīng)能力的模型,在加入一些新標(biāo)簽后不用重訓(xùn)模型也能很好的完成任務(wù)。
文本匹配目前比較常用的有兩種結(jié)構(gòu):
-
單塔模型:準(zhǔn)確率高,但計算速度慢。
-
雙塔模型:計算速度快,準(zhǔn)確率相對低一些。
下面我們對這兩種方法分別進(jìn)行介紹。
0.1 單塔模型
單塔模型顧名思義,是指在整個過程中只進(jìn)行一次模型計算。這里的「塔」指的是進(jìn)行「幾次模型計算」,而不一定是「模型個數(shù)」,這個我們會放到雙塔部分解釋。在單塔模型下,我們需要把兩句文本通過 [SEP] 進(jìn)行拼接,將拼接好的數(shù)據(jù)喂給模型,通過 output 中的[CLS] token 做一個二分類任務(wù)。
單塔模型的 forward 部分長這樣,完整源碼在文末:
def __init__(self, encoder, dropout=None):"""init func.Args:encoder (transformers.AutoModel): backbone, 默認(rèn)使用 ernie 3.0dropout (float): dropout 比例"""super().__init__()self.encoder = encoderhidden_size = 768self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)self.classifier = nn.Linear(768, 2)def forward(self,input_ids,token_type_ids,position_ids=None,attention_mask=None) -> torch.tensor:"""Foward 函數(shù),輸入匹配好的pair對,返回二維向量(相似/不相似)。Args:input_ids (torch.LongTensor): (batch, seq_len)token_type_ids (torch.LongTensor): (batch, seq_len)position_ids (torch.LongTensor): (batch, seq_len)attention_mask (torch.LongTensor): (batch, seq_len)Returns:torch.tensor: (batch, 2)"""pooled_embedding = self.encoder(input_ids=input_ids,token_type_ids=token_type_ids,position_ids=position_ids,attention_mask=attention_mask)["pooler_output"] # (batch, hidden_size)pooled_embedding = self.dropout(pooled_embedding) # (batch, hidden_size)logits = self.classifier(pooled_embedding) # (batch, 2)return logits
單塔模型的優(yōu)勢在于準(zhǔn)確率較高,但缺點在于:計算慢。
- 為什么慢呢?
舉例來講,如果今天我們有三個類別「電腦、水果、洗浴」,那我們就需要將一句話跟每個類別都做一次拼接,并喂給模型去做推理:
水果[SEP]蘋果不是很新鮮,不滿意這次購物[SEP]
電腦[SEP]蘋果不是很新鮮,不滿意這次購物[SEP]
洗浴[SEP]蘋果不是很新鮮,不滿意這次購物[SEP]
那如果類別數(shù)目到達(dá)成百上千時,就需要拼接上千次,為了判斷一個樣本就需要過上次模型,而大模型的計算通常來講是非常耗時的,這就導(dǎo)致了在類別數(shù)目很大的情況下,單塔模型的效率往往無法滿足人們的需求。
0.2 雙塔模型
單塔模型的劣勢很明顯,有多少類別就需要算多少次。但事實上,這些類別都是不會變的,唯一變的只有新的評論數(shù)據(jù)。所以我們能不能實現(xiàn)將這些不會變的「類別信息」「提前計算」存下來,只計算那些沒有見過的「評論數(shù)據(jù)」呢?這就是雙塔模型的思想。雙塔模型的「雙塔」含義就是:兩次模型計算。即,類別特征計算一次,評論特征計算一次。
通過上圖可以看到,「類別」和「評論」不再是拼接在一起喂給模型,而是單獨喂給模型,并分別得到各自的 embedding 向量,再進(jìn)行后續(xù)的計算。而上圖中左右兩個兩個模型可以使用同一個模型(這種方式叫:同構(gòu)模型),也可以用兩個不同的模型(這種方式叫:異構(gòu)模型)。因此「雙塔」并不一定代表存在兩個模型,而是代表需要需要進(jìn)行兩次模型計算。
0.2.1 DSSM(Deep Structured Semantic Models,深度結(jié)構(gòu)化語義模型)
Paper Reference: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf
DSSM 是一篇比較早期的 paper,我們主要借鑒其通過 embedding 之間的余弦相似度進(jìn)行召回排序的思想。我們分別將「類別」和「評論」文本過一遍模型,并得到兩段文本的 embedding。將匹配的 pair 之間的余弦相似度 label 置為 1,不匹配的 pair 之間余弦相似度 label 置為 0。
Note: 余弦相似度的取值范圍是 [-1, 1],但為了方便我將 label 置為 0 并用 MSE 去訓(xùn)練,也能取得不錯的效果。
訓(xùn)練數(shù)據(jù)集長這樣,匹配 pair 標(biāo)簽為 1,不匹配為 0:
蒙牛 不錯還好挺不錯 0
蒙牛 我喜歡demom制造的蒙牛奶 1
衣服 褲子太差了,剛穿一次屁股就起毛了。 1
...
實現(xiàn)中有兩個關(guān)鍵函數(shù):獲得句子的 embedding 函數(shù)(用于推理)、獲得句子對的余弦相似度(用于訓(xùn)練):
def forward(self,input_ids: torch.tensor,token_type_ids: torch.tensor,attention_mask: torch.tensor) -> torch.tensor:"""forward 函數(shù),輸入單句子,獲得單句子的embedding。Args:input_ids (torch.LongTensor): (batch, seq_len)token_type_ids (torch.LongTensor): (batch, seq_len)attention_mask (torch.LongTensor): (batch, seq_len)Returns:torch.tensor: embedding -> (batch, hidden_size)"""embedding = self.encoder(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask)["pooler_output"] # (batch, hidden_size)return embeddingdef get_similarity(self,query_input_ids: torch.tensor,query_token_type_ids: torch.tensor,query_attention_mask: torch.tensor,doc_input_ids: torch.tensor,doc_token_type_ids: torch.tensor,doc_attention_mask: torch.tensor) -> torch.tensor:"""輸入query和doc的向量,返回query和doc兩個向量的余弦相似度。Args:query_input_ids (torch.LongTensor): (batch, seq_len)query_token_type_ids (torch.LongTensor): (batch, seq_len)query_attention_mask (torch.LongTensor): (batch, seq_len)doc_input_ids (torch.LongTensor): (batch, seq_len)doc_token_type_ids (torch.LongTensor): (batch, seq_len)doc_attention_mask (torch.LongTensor): (batch, seq_len)Returns:torch.tensor: (batch, 1)"""query_embedding = self.encoder(input_ids=query_input_ids,token_type_ids=query_token_type_ids,attention_mask=query_attention_mask)["pooler_output"] # (batch, hidden_size)query_embedding = self.dropout(query_embedding)doc_embedding = self.encoder(input_ids=doc_input_ids,token_type_ids=doc_token_type_ids,attention_mask=doc_attention_mask)["pooler_output"] # (batch, hidden_size)doc_embedding = self.dropout(doc_embedding)similarity = nn.functional.cosine_similarity(query_embedding, doc_embedding)return similarity
0.2.2 Sentence Transformer
Paper Reference:https://arxiv.org/pdf/1908.10084.pdf
Sentence Transformer 也是一個雙塔模型,只是在訓(xùn)練時不直接對兩個句子的 embedding 算余弦相似度,而是將這兩個 embedding 和 embedding 之間的差向量進(jìn)行拼接,將這三個向量拼好后喂給一個判別層做二分類任務(wù)。
原 paper 中在 inference 的時候不再使用訓(xùn)練架構(gòu),而是采用了余弦相似度的方法做召回。但我在實現(xiàn)的時候在推理部分仍然沿用了訓(xùn)練的模型架構(gòu),原因是想抹除結(jié)構(gòu)不一致的 gap,并且訓(xùn)練層也只是多了一層 Linear 層,在推理的時候也不至于消耗過多的時間。Sentence Transformer 在推理時需要同時傳入「當(dāng)前評論信息」以及事先計算好的「所有類別 embedding」,如下:
def forward(self,query_input_ids: torch.tensor,query_token_type_ids: torch.tensor,query_attention_mask: torch.tensor,doc_embeddings: torch.tensor,) -> torch.tensor:"""forward 函數(shù),輸入query句子和doc_embedding向量,將query句子過一遍模型得到query embedding再和doc_embedding做二分類。Args:input_ids (torch.LongTensor): (batch, seq_len)token_type_ids (torch.LongTensor): (batch, seq_len)attention_mask (torch.LongTensor): (batch, seq_len)doc_embedding (torch.LongTensor): 所有需要匹配的doc_embedding -> (batch, doc_embedding_numbers, hidden_size)Returns:torch.tensor: embedding_match_logits -> (batch, doc_embedding_numbers, 2)"""query_embedding = self.encoder(input_ids=query_input_ids,token_type_ids=query_token_type_ids,attention_mask=query_attention_mask)["last_hidden_state"] # (batch, seq_len, hidden_size)query_attention_mask = torch.unsqueeze(query_attention_mask, dim=-1) # (batch, seq_len, 1)query_embedding = query_embedding * query_attention_mask # (batch, seq_len, hidden_size)query_sum_embedding = torch.sum(query_embedding, dim=1) # (batch, hidden_size)query_sum_mask = torch.sum(query_attention_mask, dim=1) # (batch, 1)query_mean = query_sum_embedding / query_sum_mask # (batch, hidden_size)query_mean = query_mean.unsqueeze(dim=1).repeat(1, doc_embeddings.size()[1], 1) # (batch, doc_embedding_numbers, hidden_size)sub = torch.abs(torch.subtract(query_mean, doc_embeddings)) # (batch, doc_embedding_numbers, hidden_size)concat = torch.cat([query_mean, doc_embeddings, sub], dim=-1) # (batch, doc_embedding_numbers, hidden_size * 3)logits = self.classifier(concat) # (batch, doc_embedding_numbers, 2)return logits
1. 環(huán)境安裝
本項目基于 pytorch
+ transformers
實現(xiàn),運行前請安裝相關(guān)依賴包:
pip install -r ../../requirements.txttorch
transformers==4.22.1
datasets==2.4.0
evaluate==0.2.2
matplotlib==3.6.0
rich==12.5.1
scikit-learn==1.1.2
requests==2.28.1
2. 數(shù)據(jù)集準(zhǔn)備
項目中提供了一部分示例數(shù)據(jù),我們使用「商品評論」和「商品類別」來進(jìn)行文本匹配任務(wù),數(shù)據(jù)在 data/comment_classify
。
若想使用自定義數(shù)據(jù)
訓(xùn)練,只需要仿照示例數(shù)據(jù)構(gòu)建數(shù)據(jù)集即可:
衣服:指穿在身上遮體御寒并起美化作用的物品。 為什么是開過的洗發(fā)水都流出來了、是用過的嗎?是這樣子包裝的嗎? 0
衣服:指穿在身上遮體御寒并起美化作用的物品。 開始買回來大很多 后來換了回來又小了 號碼區(qū)別太不正規(guī) 建議各位謹(jǐn)慎 1
...
每一行用 \t
分隔符分開,第一部分部分為商品類型(text1)
,中間部分為商品評論(text2)
,最后一部分為商品評論和商品類型是否一致(label)
。
3. 有監(jiān)督-模型訓(xùn)練
3.1 PointWise(單塔)
3.1.1 模型訓(xùn)練
修改訓(xùn)練腳本 train_pointwise.sh
里的對應(yīng)參數(shù), 開啟模型訓(xùn)練:
python train_pointwise.py \--model "nghuyong/ernie-3.0-base-zh" \ # backbone--train_path "data/comment_classify/train.txt" \ # 訓(xùn)練集--dev_path "data/comment_classify/dev.txt" \ #驗證集--save_dir "checkpoints/comment_classify" \ # 訓(xùn)練模型存放地址--img_log_dir "logs/comment_classify" \ # loss曲線圖保存位置--img_log_name "ERNIE-PointWise" \ # loss曲線圖保存文件夾--batch_size 8 \--max_seq_len 128 \--valid_steps 50 \--logging_steps 10 \--num_train_epochs 10 \--device "cuda:0"
正確開啟訓(xùn)練后,終端會打印以下信息:
...
global step 10, epoch: 1, loss: 0.77517, speed: 3.43 step/s
global step 20, epoch: 1, loss: 0.67356, speed: 4.15 step/s
global step 30, epoch: 1, loss: 0.53567, speed: 4.15 step/s
global step 40, epoch: 1, loss: 0.47579, speed: 4.15 step/s
global step 50, epoch: 2, loss: 0.43162, speed: 4.41 step/s
Evaluation precision: 0.88571, recall: 0.87736, F1: 0.88152
best F1 performence has been updated: 0.00000 --> 0.88152
global step 60, epoch: 2, loss: 0.40301, speed: 4.08 step/s
global step 70, epoch: 2, loss: 0.37792, speed: 4.03 step/s
global step 80, epoch: 2, loss: 0.35343, speed: 4.04 step/s
global step 90, epoch: 2, loss: 0.33623, speed: 4.23 step/s
global step 100, epoch: 3, loss: 0.31319, speed: 4.01 step/s
Evaluation precision: 0.96970, recall: 0.90566, F1: 0.93659
best F1 performence has been updated: 0.88152 --> 0.93659
...
在 logs/comment_classify
文件下將會保存訓(xùn)練曲線圖:
3.1.2 模型推理
完成模型訓(xùn)練后,運行 inference_pointwise.py
以加載訓(xùn)練好的模型并應(yīng)用:
...test_inference('手機(jī):一種可以在較廣范圍內(nèi)使用的便攜式電話終端。', # 第一句話'味道非常好,京東送貨速度也非???#xff0c;特別滿意。', # 第二句話max_seq_len=128)
...
運行推理程序:
python inference_pointwise.py
得到以下推理結(jié)果:
tensor([[ 1.8477, -2.0484]], device='cuda:0') # 兩句話不相似(0)的概率更大
3.2 DSSM(雙塔)
3.2.1 模型訓(xùn)練
修改訓(xùn)練腳本 train_dssm.sh
里的對應(yīng)參數(shù), 開啟模型訓(xùn)練:
python train_dssm.py \--model "nghuyong/ernie-3.0-base-zh" \--train_path "data/comment_classify/train.txt" \--dev_path "data/comment_classify/dev.txt" \--save_dir "checkpoints/comment_classify/dssm" \--img_log_dir "logs/comment_classify" \--img_log_name "ERNIE-DSSM" \--batch_size 8 \--max_seq_len 256 \--valid_steps 50 \--logging_steps 10 \--num_train_epochs 10 \--device "cuda:0"
正確開啟訓(xùn)練后,終端會打印以下信息:
...
global step 0, epoch: 1, loss: 0.62319, speed: 15.16 step/s
Evaluation precision: 0.29912, recall: 0.96226, F1: 0.45638
best F1 performence has been updated: 0.00000 --> 0.45638
global step 10, epoch: 1, loss: 0.40931, speed: 3.64 step/s
global step 20, epoch: 1, loss: 0.36969, speed: 3.69 step/s
global step 30, epoch: 1, loss: 0.33927, speed: 3.69 step/s
global step 40, epoch: 1, loss: 0.31732, speed: 3.70 step/s
global step 50, epoch: 1, loss: 0.30996, speed: 3.68 step/s
...
在 logs/comment_classify
文件下將會保存訓(xùn)練曲線圖:
3.2.2 模型推理
和單塔模型不一樣的是,雙塔模型可以事先計算所有候選類別的Embedding,當(dāng)新來一個句子時,只需計算新句子的Embedding,并通過余弦相似度找到最優(yōu)解即可。
因此,在推理之前,我們需要提前計算所有類別的Embedding并保存。
類別Embedding計算
運行 get_embedding.py
文件以計算對應(yīng)類別embedding并存放到本地:
...
text_file = 'data/comment_classify/types_desc.txt' # 候選文本存放地址
output_file = 'embeddings/comment_classify/dssm_type_embeddings.json' # embedding存放地址device = 'cuda:0' # 指定GPU設(shè)備
model_type = 'dssm' # 使用DSSM還是Sentence Transformer
saved_model_path = './checkpoints/comment_classify/dssm/model_best/' # 訓(xùn)練模型存放地址
tokenizer = AutoTokenizer.from_pretrained(saved_model_path)
model = torch.load(os.path.join(saved_model_path, 'model.pt'))
model.to(device).eval()
...
其中,所有需要預(yù)先計算的內(nèi)容都存放在 types_desc.txt
文件中。
文件用 \t
分隔,分別代表 類別id
、類別名稱
、類別描述
:
0 水果 指多汁且主要味覺為甜味和酸味,可食用的植物果實。
1 洗浴 洗浴用品。
2 平板 也叫便攜式電腦,是一種小型、方便攜帶的個人電腦,以觸摸屏作為基本的輸入設(shè)備。
...
執(zhí)行 python get_embeddings.py
命令后,會在代碼中設(shè)置的embedding存放地址中找到對應(yīng)的embedding文件:
{"0": {"label": "水果", "text": "水果:指多汁且主要味覺為甜味和酸味,可食用的植物果實。", "embedding": [0.3363891839981079, -0.8757723569869995, -0.4140555262565613, 0.8288457989692688, -0.8255823850631714, 0.9906797409057617, -0.9985526204109192, 0.9907819032669067, -0.9326567649841309, -0.9372553825378418, 0.11966298520565033, -0.7452883720397949,...]},"1": ...,...
}
模型推理
完成預(yù)計算后,接下來就可以開始推理了。
我們構(gòu)建一條新評論:這個破筆記本卡的不要不要的,差評
。
運行 python inference_dssm.py
,得到下面結(jié)果:
[('平板', 0.9515482187271118),('電腦', 0.8216977119445801),('洗浴', 0.12220608443021774),('衣服', 0.1199738010764122),('手機(jī)', 0.07764233648777008),('酒店', 0.044791921973228455),('水果', -0.050112202763557434),('電器', -0.07554933428764343),('書籍', -0.08481660485267639),('蒙牛', -0.16164332628250122)
]
函數(shù)將輸出(類別,余弦相似度)的二元組,并按照相似度做倒排(相似度取值范圍:[-1, 1])。
3.3 Sentence Transformer(雙塔)
3.3.1 模型訓(xùn)練
修改訓(xùn)練腳本 train_sentence_transformer.sh
里的對應(yīng)參數(shù), 開啟模型訓(xùn)練:
python train_sentence_transformer.py \--model "nghuyong/ernie-3.0-base-zh" \--train_path "data/comment_classify/train.txt" \--dev_path "data/comment_classify/dev.txt" \--save_dir "checkpoints/comment_classify/sentence_transformer" \--img_log_dir "logs/comment_classify" \--img_log_name "Sentence-Ernie" \--batch_size 8 \--max_seq_len 256 \--valid_steps 50 \--logging_steps 10 \--num_train_epochs 10 \--device "cuda:0"
正確開啟訓(xùn)練后,終端會打印以下信息:
...
Evaluation precision: 0.81928, recall: 0.64151, F1: 0.71958
best F1 performence has been updated: 0.46120 --> 0.71958
global step 260, epoch: 2, loss: 0.58730, speed: 3.53 step/s
global step 270, epoch: 2, loss: 0.58171, speed: 3.55 step/s
global step 280, epoch: 2, loss: 0.57529, speed: 3.48 step/s
global step 290, epoch: 2, loss: 0.56687, speed: 3.55 step/s
global step 300, epoch: 2, loss: 0.56033, speed: 3.55 step/s
...
在 logs/comment_classify
文件下將會保存訓(xùn)練曲線圖:
3.2.2 模型推理
Sentence Transformer 同樣也是雙塔模型,因此我們需要事先計算所有候選文本的embedding值。
類別Embedding計算
運行 get_embedding.py
文件以計算對應(yīng)類別embedding并存放到本地:
...
text_file = 'data/comment_classify/types_desc.txt' # 候選文本存放地址
output_file = 'embeddings/comment_classify/sentence_transformer_type_embeddings.json' # embedding存放地址device = 'cuda:0' # 指定GPU設(shè)備
model_type = 'sentence_transformer' # 使用DSSM還是Sentence Transformer
saved_model_path = './checkpoints/comment_classify/sentence_transformer/model_best/' # 訓(xùn)練模型存放地址
tokenizer = AutoTokenizer.from_pretrained(saved_model_path)
model = torch.load(os.path.join(saved_model_path, 'model.pt'))
model.to(device).eval()
...
其中,所有需要預(yù)先計算的內(nèi)容都存放在 types_desc.txt
文件中。
文件用 \t
分隔,分別代表 類別id
、類別名稱
、類別描述
:
0 水果 指多汁且主要味覺為甜味和酸味,可食用的植物果實。
1 洗浴 洗浴用品。
2 平板 也叫便攜式電腦,是一種小型、方便攜帶的個人電腦,以觸摸屏作為基本的輸入設(shè)備。
...
執(zhí)行 python get_embeddings.py
命令后,會在代碼中設(shè)置的embedding存放地址中找到對應(yīng)的embedding文件:
{"0": {"label": "水果", "text": "水果:指多汁且主要味覺為甜味和酸味,可食用的植物果實。", "embedding": [0.32447007298469543, -1.0908259153366089, -0.14340722560882568, 0.058471400290727615, -0.33798110485076904, -0.050156619399785995, 0.041511114686727524, 0.671889066696167, 0.2313404232263565, 1.3200652599334717, -1.10829496383667, 0.4710233509540558, -0.08577515184879303, -0.41730815172195435, -0.1956728845834732, 0.05548520386219025, ...]}"1": ...,...
}
模型推理
完成預(yù)計算后,接下來就可以開始推理了。
我們構(gòu)建一條新評論:這個破筆記本卡的不要不要的,差評
。
運行 python inference_sentence_transformer.py
,函數(shù)會輸出所有類別里「匹配通過」的類別及其匹配值,得到下面結(jié)果:
Used 0.5233056545257568s.
[('平板', 1.136274814605713), ('電腦', 0.8851938247680664)
]
函數(shù)將輸出(匹配通過的類別,匹配值)的二元組,并按照匹配值(越大則越匹配)做倒排。
參考鏈接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/text_matching/supervised
github無法連接的可以在:https://download.csdn.net/download/sinat_39620217/88214437 下載
更多優(yōu)質(zhì)內(nèi)容請關(guān)注公號:汀丶人工智能;會提供一些相關(guān)的資源和優(yōu)質(zhì)文章,免費獲取閱讀。