寧波網(wǎng)站建設(shè)哪里有今天新聞?wù)畻l
一.論文
1.1 P-tuning
區(qū)別于之前的工作,這篇工作認(rèn)為promote可以在句子中的任意位置起到作用,可以將它們插入上下文或目標(biāo)中
上圖中,左圖是不使用任何操作,右圖是選擇在居首和目標(biāo)前插入promote的embedding,插入promote的過程可以表示為
其中x代表一系列離散的輸入令牌,y代表目標(biāo)(可以理解為希望模型想要給你的回答),e()表示對應(yīng)的embedding,其實(shí)就是將其參數(shù)化映射成為偽tokens,即
通過最小化這些參數(shù)
1.2 promote生成
嵌入的promote實(shí)際上可以理解為不一定離散且不相互關(guān)聯(lián)的,而實(shí)際上的promote其實(shí)應(yīng)該是高度離散的且具有關(guān)聯(lián)性的,因此作者選擇使用雙向長短期記憶網(wǎng)絡(luò)(LSTM),激活函數(shù)和MLP來建模這種關(guān)系
在推理中,我們只需要輸出嵌入h,并且可以丟棄LSTM頭
二.代碼
本質(zhì)上是使用一個PromptEncoder來生成偽的embedding添加到原先的embedding中
2.1 訓(xùn)練
訓(xùn)練過程只更新promote_encoder中的參數(shù)
?2.1.1 PromptEncoder
在PTuneForLAMA中實(shí)例化了PromptEncoder
?PromptEncoder本質(zhì)上是一個(嵌入 + LSTM + MLP)
import torch
import torch.nn as nnclass PromptEncoder(torch.nn.Module):def __init__(self, template, hidden_size, tokenizer, device, args):super().__init__()self.device = deviceself.spell_length = sum(template)self.hidden_size = hidden_sizeself.tokenizer = tokenizerself.args = args# ent embeddingself.cloze_length = templateself.cloze_mask = [[1] * self.cloze_length[0] # first cloze+ [1] * self.cloze_length[1] # second cloze+ [1] * self.cloze_length[2] # third cloze]self.cloze_mask = torch.LongTensor(self.cloze_mask).bool().to(self.device)self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0])))).to(self.device)# embeddingself.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size).to(self.device)# LSTMself.lstm_head = torch.nn.LSTM(input_size=self.hidden_size,hidden_size=self.hidden_size // 2,num_layers=2,dropout=self.args.lstm_dropout,bidirectional=True,batch_first=True)self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size),nn.ReLU(),nn.Linear(self.hidden_size, self.hidden_size))print("init prompt encoder...")def forward(self):input_embeds = self.embedding(self.seq_indices).unsqueeze(0)output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()return output_embeds
2.1.2 調(diào)用
在PTuneForLAMA的forward函數(shù)中調(diào)用了embed_input來實(shí)現(xiàn)