DIST 实现方法

import torch.nn as nn


def cosine_similarity(a, b, eps=1e-8):
    return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)


def pearson_correlation(a, b, eps=1e-8):
    return cosine_similarity(a - a.mean(1).unsqueeze(1),
                             b - b.mean(1).unsqueeze(1), eps)


def inter_class_relation(soft_student_outputs, soft_teacher_outputs):
    return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()


def intra_class_relation(soft_student_outputs, soft_teacher_outputs):
    return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))


class DIST(nn.Module):
    def __init__(self, beta=1.0, gamma=1.0, temp=1.0):
        super(DIST, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = (student_preds / self.temp).softmax(dim=1)
        soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)
        inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)
        intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)
        kd_loss = self.beta * inter_loss + self.gamma * intra_loss
        return kd_loss

KLDiv方法

import torch.nn as nn
import torch.nn.functional as F

# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):
    def __init__(self, temp=1.0):
        super(KLDiv, self).__init__()
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
        soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
        kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
        kd_loss *= self.temp ** 2
        return kd_loss

关于知识蒸馏文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法介绍

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

更新

原文地址:https://blog.csdn.net/qq_42864343/article/details/134768003

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任

如若转载,请注明出处:http://www.7code.cn/show_34380.html

如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱suwngjj01@126.com进行投诉反馈,一经查实,立即删除

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注