高版本transformers-4.24中的坑
创始人
2024-03-20 02:32:07
0

最近遇到一个很奇怪的BUG,好早之前写的一个Bert文本分类模型,拿给别人用的时候,发现不灵了,原本90多的acc,什么都没修改,再测一次发现只剩30多了,检查了一番之后,很快我发现他的transformers版本是4.24,而我一直用的是4.9,没有更新。

于是我试着分析问题出在哪里,然后就遇到了这个坑。首先这是我模型的基础结构,很简单,就是一个Encoder模型加一层分类器:

class BertClassifier(torch.nn.Module):def __init__(self, bert_model, num_classes):super(BertClassifier, self).__init__()self.bert = bert_modelself.dropout = torch.nn.Dropout(0.2)self.dense = torch.nn.Linear(768, num_classes)def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None,):bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)# print(list(self.bert.encoder.layer[0].attention.self.query.parameters()))# print(bert_out)sequence_output = bert_out.last_hidden_stateprint(sequence_output)sequence_output = self.dropout(sequence_output)pool_output = torch.mean(sequence_output, axis=1)logits = self.dense(pool_output)# print(logits)loss = Noneloss_fct = torch.nn.CrossEntropyLoss()if labels is not None:# labels = label.long()loss = loss_fct(logits, labels.view(-1))return loss if loss is not None else logits

为了分析问题出在哪里,我把类里的代码全都拿出来,逐行运行,发现最终的logits和正确的logits(在4.9版本的环境里执行的结果)是一致的,这就很奇怪了,但是我实例化模型,再用模型forward出来的结果却是错误的:

# 这个结果计算出来是对的
sequence_output = bert_cls_model.bert(**inputs).last_hidden_state
sequence_output = bert_cls_model.dropout(sequence_output)
pool_output = torch.mean(sequence_output, axis=1)
logits = bert_cls_model.dense(pool_output)
print(logits)# 这样计算出来是错的
logits = bert_cls_model(**inputs)
print(logits)

于是我又在模型类的定义里打印了各个阶段的结果,如上第一段代码中的print,发现从bert_out的打印结果来看全都是错的。

更进一步地,为了确认是不是模型加载权重的时候出现了问题(比如加载权重后的模型被重新初始化了),我又在模型定义代码里打印了模型的参数值,确认参数值也是没有问题的。这就让我感到有些匪夷所思了。

我又按照同样的对比方法,在模型里边打印一次,单独拿出来打印一次,试着找出问题所在,这次是从一开始embedding开始,结果发现在模型内部和外部打印embedding的结果是一致的:

# 这样打印的结果是正确的
bert_cls_model.bert.embeddings(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'])# 在模型的forward方法里打印embedding的结果同样是正确的

更奇怪的是,我将embedding的结果输入给encoder手动计算,出来的sequence_out就变成正确的了:

class BertClassifier(torch.nn.Module):def __init__(self, bert_model, num_classes):super(BertClassifier, self).__init__()self.bert = bert_modelself.dropout = torch.nn.Dropout(0.2)self.dense = torch.nn.Linear(768, num_classes)def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None,):# 直接调用self.bert计算出来结果是错误的# bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)# 手动以此调用embedding和encoder,就算出来的结果就是正确的了embedding_res = self.bert.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)encoder_out = self.bert.encoder(embedding_res)sequence_output = encoder_out[0]sequence_output = self.dropout(sequence_output)pool_output = torch.mean(sequence_output, axis=1)logits = self.dense(pool_output)# print(logits)loss = Noneloss_fct = torch.nn.CrossEntropyLoss()if labels is not None:# labels = label.long()loss = loss_fct(logits, labels.view(-1))return loss if loss is not None else logits

最后我又额外检查了一遍两个版本源码的差别,也没有发现什么端倪,感觉修改的地方都是些写法的差异,不应该有能够造成这个问题的地方。

解决的话,目前就是把transformers的版本降下来,或者像最后这样手动执行计算,还没有发现真正出问题的地方在哪里,如果有哪位也遇到这个问题并且有效解决了的话,还请在评论区指出,谢谢。

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【Ctfer训练计划】——(三... 作者名:Demo不是emo  主页面链接:主页传送门 创作初心ÿ...