DIST 实现方法
import torch.nn as nn
def cosine_similarity(a, b, eps=1e-8):
return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)
def pearson_correlation(a, b, eps=1e-8):
return cosine_similarity(a - a.mean(1).unsqueeze(1),
b - b.mean(1).unsqueeze(1), eps)
def inter_class_relation(soft_student_outputs, soft_teacher_outputs):
return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()
def intra_class_relation(soft_student_outputs, soft_teacher_outputs):
return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))
class DIST(nn.Module):
def __init__(self, beta=1.0, gamma=1.0, temp=1.0):
super(DIST, self).__init__()
self.beta = beta
self.gamma = gamma
self.temp = temp
def forward(self, student_preds, teacher_preds, **kwargs):
soft_student_outputs = (student_preds / self.temp).softmax(dim=1)
soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)
inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)
intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)
kd_loss = self.beta * inter_loss + self.gamma * intra_loss
return kd_loss
KLDiv方法
import torch.nn as nn
import torch.nn.functional as F
# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):
def __init__(self, temp=1.0):
super(KLDiv, self).__init__()
self.temp = temp
def forward(self, student_preds, teacher_preds, **kwargs):
soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
kd_loss *= self.temp ** 2
return kd_loss
关于知识蒸馏的文章
FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:
https://zhuanlan.zhihu.com/p/603748226?utm_id=0
待更新
原文地址:https://blog.csdn.net/qq_42864343/article/details/134768003
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.7code.cn/show_34380.html
如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱:suwngjj01@126.com进行投诉反馈,一经查实,立即删除!
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。