本文介绍如何使用LogitsProcessor避免大模型生成过程中出现重复的问题

1. 准备工作

首先实例一个模型,以GLM2为例:

import re
import os
import json
import random
from typing import *
from copy import deepcopy

import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList, MaxNewTokensCriteria, StoppingCriteria

创建模型

tokenizer = AutoTokenizer.from_pretrained(".../ChatGLM2/", trust_remote_code=True)
model = AutoModel.from_pretrained(".../ChatGLM2/", trust_remote_code=True).half()
model.to('cuda:0')

2. 问题分析

接下来思考一下,如何防止模型不停的重复呢?重复分为几种情况,一个字符循环出现,或者多个字符循环出现,例如:

'abcdeeeeee'
'abcdededede'

生成过程考虑,防止模型生成重复的内容,第一步自然是要判断模型陷入了重复,第二步就是打断它重复的过程,也就是将重复的token,在当前step生成的时候,将其概率设置为-inf,那么重复的过程自然就停止了。

3. 创建processor

3.1 防止重复生成的processor

先来解决如何判定重复。这里直接去leetcode上找一个题,获取个字符串中最大的重复片段,解法如下

def longest_dup_substring(s: str) -> str:
    # 生成两个进制
    a1, a2 = random.randint(26, 100), random.randint(26, 100)
    # 生成两个
    mod1, mod2 = random.randint(10**9+7, 2**31-1), random.randint(10**9+7, 2**31-1)
    n = len(s)
    # 先对所有字符进行编码
    arr = [ord(c)-ord('a') for c in s]
    # 二分查找范围是[1, n-1]
    l, r = 1, n-1
    length, start = 0, -1
    while l <= r:
        m = l + (r - l + 1) // 2
        idx = check(arr, m, a1, a2, mod1, mod2)
        # 有重复子串移动左边界
        if idx != -1:
            l = m + 1
            length = m
            start = idx
        # 无重复子串移动右边界
        else:
            r = m - 1
    return s[start:start+length] if start != -1 else ""

def check(arr, m, a1, a2, mod1, mod2):
    n = len(arr)
    aL1, aL2 = pow(a1, m, mod1), pow(a2, m, mod2)
    h1, h2 = 0, 0
    for i in range(m):
        h1 = (h1 * a1 + arr[i]) % mod1
        h2 = (h2 * a2 + arr[i]) % mod2
    # 存储一个编码组合是否出现过
    seen = {(h1, h2)}
    for start in range(1, n - m + 1):
        h1 = (h1 * a1 - arr[start - 1] * aL1 + arr[start + m - 1]) % mod1
        h2 = (h2 * a2 - arr[start - 1] * aL2 + arr[start + m - 1]) % mod2
        # 如果重复,则返回重复串的起点
        if (h1, h2) in seen:
            return start
        seen.add((h1, h2))
    # 没有重复,则返回-1
    return -1

效果如下

longestDupSubstring('埃尔多安经济学可以重振经济,土耳其土耳其')
# '土耳其'

那么我们就可以写一个processor,在每一个step即将生成的时候,判定一下,是否之前已经生成的结果中,出现了重复。以及,如果出现了重复,则禁止重复部分的第一个token(例如上面例子中,土耳其的土字),在当前step被生成。

针对实际使用中由这个processor引发的一些其他的问题,我又对这个processor增加了一点规则限制,一个比较好用版本如下

其中的参数threshold判断重复多少的情况算作循环,例如将threshold设置为10,那么如果重复部分的长度是3,重复了3次,3×3=9,则不被判定为陷入了循环,而如果重复了4次,3×4=12,则被判定为循环,此时processor将发挥效果了。

class ForbidDuplicationProcessor(LogitsProcessor):
    """
    防止生成的内容陷入循环。
    当循环内容与循环次数之乘积大于指定次数
    则在生成下一个token时将循环内容的第一个token概率设置为0
    ---------------
    ver: 2023-08-17
    by: changhongyu
    """
    def __init__(self, tokenizer, threshold: int = 10):
        self.tokenizer = tokenizer
        self.threshold = threshold
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        current_sequence = self.tokenizer.decode(input_ids[0][current_token_len: ])
        current_dup_str = longest_dup_substring(current_sequence)
        if len(current_dup_str):
            # 如果存在重复子序列,则根据其长度与重复次数判断是否禁止循环
            if len(current_dup_str) > 1 or (len(current_dup_str) == 1 and current_dup_str * self.threshold in current_sequence):
                if len(current_dup_str) * current_sequence.count(current_dup_str) >= self.threshold:
                    token_ids = self.tokenizer.encode(current_dup_str)
                    # 获取截止目前的上一个token
                    last_token = input_ids[0][-1].detach().cpu().numpy().tolist()
                    if len(token_ids) and last_token == token_ids[-1]:
                        # 如果截止目前的上一个token,与重复部分的最后一个token一致
                        # 说明即将触发重复, 先把重复部分的第一个token禁掉
                        scores[:, token_ids[0]] = 0
                        # 然后按出现比率判断是否重复部分内还有其他重复
                        for token_id in token_ids:
                            if token_ids.count(token_id) * len(token_ids) > 1.2:
                                scores[:, token_id] = 0

        return scores

需要注意的是,为了获取当前序列已经生成的长度需要processor的外部,也就是与model.generate同级结构处,定义一个全局变量current_token_len

global current_token_len

3.2 防止数字规则循环的processor

出了上述的情况,还有一种常见的循环,无法利用上面的规则解决,即数字无规则循环的情况。针对这个场景创建另一个processor,只要连续出现的数字出现次数,大于一定的阈值,则禁止当前step再次生成数字

class MaxConsecutiveProcessor(LogitsProcessor):
    """
    给定一个集合集合中的字符最多连续若干次
    下一次生成时不能再出现该集合中的字符
    ---------------
    ver: 2023-08-17
    by: changhongyu
    ---------------
    修复bug
    ver: 2023-09-11
    """
    def __init__(self, consecutive_token_ids, max_num: int = 10):
        self.consecutive_token_ids = consecutive_token_ids
        self.max_num = max_num
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        input_ids_list = input_ids.squeeze(0).detach().cpu().numpy().tolist()
        cur_num = 0
        for token in input_ids_list[::-1]:
            if token in self.consecutive_token_ids:
                cur_num += 1
            else:
                break
                
        if cur_num >= self.max_num:
            # 如果连续次数超过阈值,那集合中的所有token在下一个step都不可以再出现
            for token_id in self.consecutive_token_ids:
                scores[..., token_id] = 0
        return scores

4. 使用

使用方法非常简单,首先创建processor容器。对processor不熟悉的同学,可以去看之前的文章,有非常详细的介绍

logits_processor = LogitsProcessorList()

然后对于ChatGLM而言,需要先添加默认processor:

logits_processor.append(InvalidScoreLogitsProcessor())

接下来,再添加防止陷入循环的两个processor:

number_tokens = [str(i) for i in range(10)] + ['.', '-']
number_token_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in number_tokens]
logits_processor.append(ForbidDuplicationProcessor(tokenizer))
logits_processor.append(MaxConsecutiveProcessor(number_token_ids))

最后调用generate的时候,把logits_processor作为参数传进去就可以了。

以上便是使用logits_processor来防止大模型在生成过程中陷入循环的方法。经过我的反复调整,基本可以覆盖大多数情景,如果在使用中遇到了bug,也欢迎指出。

原文地址:https://blog.csdn.net/weixin_44826203/article/details/132837387

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任

如若转载,请注明出处:http://www.7code.cn/show_15899.html

如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱suwngjj01@126.com进行投诉反馈,一经查实,立即删除

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注