食品包裝設(shè)計(jì)價(jià)格seo崗位工資
前言
在大模型的生成過(guò)程中,部分原生的大語(yǔ)言模型未經(jīng)過(guò)特殊的對(duì)齊訓(xùn)練,往往會(huì)“胡說(shuō)八道”的生成一些敏感詞語(yǔ)等用戶不想生成的詞語(yǔ),最簡(jiǎn)單粗暴的方式就是在大模型生成的文本之后,添加敏感詞庫(kù)等規(guī)則手段進(jìn)行敏感詞過(guò)濾,但是在生成過(guò)程中,生成敏感詞仍然耗費(fèi)了時(shí)間和算力成本。
本文以chatglm2-6B為例,通過(guò)自定義LogitsProcessor,實(shí)踐大模型在生成過(guò)程中控制一些詞語(yǔ)的生成。
LogitsProcessor
從下面代碼可以看到,LogitsProcessor的作用就是在生成過(guò)程中修改score,改變模型輸出的概率分布的工具。
class LogitsProcessor:"""Abstract base class for all logit processors that can be applied during generation."""@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:raise NotImplementedError(f"{self.__class__} is an abstract class. Only classes inheriting this class can be called.")class LogitsProcessorList(list):"""This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each[`LogitsProcessor`] or [`LogitsWarper`] to the inputs."""def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:r"""Args:input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):Prediction scores of a language modeling head. These can be logits for each vocabulary when not usingbeam search or log softmax for each vocabulary token when using beam searchkwargs (`Dict[str, Any]`, *optional*):Additional kwargs that are specific to a logits processor.Return:`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:The processed prediction scores."""for processor in self:function_args = inspect.signature(processor.__call__).parametersif len(function_args) > 2:if not all(arg in kwargs for arg in list(function_args.keys())[2:]):raise ValueError(f"Make sure that all the required parameters: {list(function_args.keys())} for "f"{processor.__class__} are passed to the logits processor.")scores = processor(input_ids, scores, **kwargs)else:scores = processor(input_ids, scores)return scores
自定義LogitsProcessor實(shí)踐
回到正題,如何自定義LogitsProcessor控制大模型生成的過(guò)程呢?下面直接上實(shí)踐代碼:
class new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scores
forbid_token_id_list是不讓模型生成詞語(yǔ)的id映射列表,對(duì)于這些抑制生成的詞語(yǔ),在自定義logits_processor時(shí)將其概率推向負(fù)無(wú)窮大即可。
chatglm2-6B詳細(xì)實(shí)踐代碼:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from typing import List
import torchclass new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scoresmodel_path = "THUDM/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True).to('mps')def add_forbid_words():'''添加需要抑制的詞語(yǔ),這里簡(jiǎn)單添加了數(shù)字和幾個(gè)詞語(yǔ)進(jìn)行對(duì)比:return:list'''forbid_words = []for i in range(10):forbid_words.append(tokenizer.convert_tokens_to_ids(str(i)))forbid_words.append(tokenizer.convert_tokens_to_ids("首先"))forbid_words.append(tokenizer.convert_tokens_to_ids("積極"))forbid_words.append(tokenizer.convert_tokens_to_ids("回答"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇敢"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇氣"))return forbid_wordslogits_processor = LogitsProcessorList()
logits_processor.append(new_logits_processor(add_forbid_words()))streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)input = "列舉出10個(gè)積極的詞語(yǔ):"outputs = model.generate(tokenizer(input, return_tensors='pt').input_ids.to("mps"),max_new_tokens=1024,logits_processor=logits_processor, # 不開(kāi)啟注釋即可streamer=streamer
)
decode_text = tokenizer.batch_decode(outputs, streamer=streamer)[0]
print(decode_text)
抑制前輸出:
1. 勇敢
2. 快樂(lè)
3. 成功
4. 努力
5. 積極
6. 樂(lè)觀
7. 自信
8. 開(kāi)朗
9. 團(tuán)結(jié)
10. 奮斗
抑制后輸出:
- 積極主動(dòng)
- 樂(lè)觀向上
- 自信
- 自律
- 誠(chéng)實(shí)守信
- 樂(lè)于助人
- 勇于嘗試
- 堅(jiān)韌不拔
- 樂(lè)觀開(kāi)朗
- 團(tuán)結(jié)一心
小結(jié)
本文通過(guò)自定義LogitsProcessor,簡(jiǎn)單的實(shí)踐了大語(yǔ)言模型在生成過(guò)程中屏蔽生成用戶自定義詞語(yǔ)的trick。在現(xiàn)實(shí)場(chǎng)景中,根據(jù)特定場(chǎng)景探索如何靈活的利用LogitsProcessor進(jìn)行有針對(duì)性的控制生成模型的生成過(guò)程非常重要。
參考文獻(xiàn)
【1】https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/generation/logits_process.py