本文介绍: 在这个掩码中,对角线以下和对角线上的元素设置为负无穷和零,以确保在自注意力机制中,模型只能关注当前位置之前的信息。是一个三角矩阵,其中对角线及其以下的元素为负无穷,而对角线以上的元素为0。这样的矩阵在自注意力机制中被用作掩码,确保模型生成每个位置时只关注之前的位置,而不会使用未来的信息。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将保持为。通常用于在自注意力机制中,确保模型生成序列时只能注意当前位置之前的信息,而。的上三角矩阵,其中上三角元素为1,下三角元素为0。

nopeek掩码通常用于在自注意力机制中,确保模型生成序列时只能注意当前位置之前的信息,而不能“窥视”未来的信息

def gen_nopeek_mask(length):    
    mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask.to(device)
  1. torch.triu(torch.ones(length, length)) == 1: 创建一个大小(length, length) 的上三角矩阵,其中上三角的元素为1,下三角的元素为0。

  2. .transpose(0, 1): 将矩阵进行转置,得到对角线上方的区域

  3. mask = mask.float(): 将布尔类型的矩阵转换浮点数类型

  4. .masked_fill(mask == 0, float('-inf')): 将矩阵中值为0的位置用负无穷(-∞)填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将趋近于零,表示模型在这些位置应该关注

  5. .masked_fill(mask == 1, float(0.0)): 将矩阵中值为1的位置用0填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将保持为1,表示模型在这些位置应该关注

最终,mask一个三角矩阵,其中对角线及其以下的元素为负无穷,而对角线以上的元素为0。这样的矩阵在自注意力机制中被用作掩码,确保模型生成每个位置时只关注之前的位置,而不会使用未来的信息。

我们使用一个具体的长度演示 gen_nopeek_mask 函数比如 length = 4。以下是运行这个函数示例

import torch

def gen_nopeek_mask(length):
    mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# 生成长度为 4 的 nopeek mask
mask_example = gen_nopeek_mask(4)
print(mask_example)
运行这个示例,将得到一个 4x4 的矩阵,其中包含了上三角区域以及对角线以下的部分
tensor([[ 0., -inf, -inf, -inf],
        [ 0.,  0., -inf, -inf],
        [ 0.,  0.,  0., -inf],
        [ 0.,  0.,  0.,  0.]])

这个矩阵是一个示例的 “nopeek” 掩码。在这个掩码中,对角线以下和对角线上的元素被设置为负无穷和零,以确保在自注意力机制中,模型只能关注当前位置之前的信息。这种掩码通常在 Transformer 模型中的解码器中使用。

将矩阵中值为0的位置用无穷(-∞)填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将趋近于0表示模型在这些位置应该关注

将矩阵中值为1的位置用0填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将保持为1表示模型在这些位置应该关注

原文地址:https://blog.csdn.net/qq_42536162/article/details/134675618

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任

如若转载,请注明出处:http://www.7code.cn/show_46330.html

如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱suwngjj01@126.com进行投诉反馈,一经查实,立即删除

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注