BERT多任务学习

BERT多任务学习

栗友们,大家好!

我是Zarc,最近在做一个多任务学习的BERT模型,和大家做一个简单分享

多任务学习在近两年也算是一个比较流行的方向,诸如现在的大模型都或多或少使用多任务训练,如BERT预训练模型就采用的MLM(掩码语言模型)任务和NSP(下一句预测)任务来联合训练。

多任务学习作为迁移学习的一个分支,有着自己独特的优势,在给定几个相关联任务的输入数据和输出数据的情况下,多任务学习能够发挥任务之间的关系,同时学习多个特征。与单任务相比,有以下优势:

  • 在标签数据有限的情况下,单任务学习模型往往不能够学习到足够的信息,表现较差;多任务学习能够克服当前任务样本较少的缺点,从其他任务学习到有用信息,训练出效果更好、更具有鲁棒性的模型。
  • 多任务学习模型具有更好的泛化能力和更多的应用,通过多个任务的联合学习,得到的共享模型能够直接应用到下游的相关任务上,在相关的下游任务上表现往往更好。

最近在做的一个多任务学习模型是基于BERT模型,然后训练任务分为:命名实体识别、领域识别、标点纠正、意图识别。

其中命名实体识别和标点纠正是基于序列标注的思想来做,意图识别和领域识别是基于文本分类的思想来做,设计这四个任务主要是考虑到业务的具体下游任务中包含了上述几个子任务。多任务的设计主要包含以下一个方面:



01


数据处理



数据处理主要是对于训练数据的多标签进行定义以及训练数据的读取,我的数据集构造如下:

你 O O
听 O O
过 O O
大 B_Song O
地 I_Song O
么 O ?
::评价::音乐

if line.startswith("::"):
    intent_label = line.strip("::").split("::")[0]
    domain_label = line.strip("::").split("::")[1]
    continue
splits = line.split(" ")
if len(splits) > 1:
    punc_labels.append(splits[-1].replace("n"""))
    ner_labels.append(splits[1].replace("n"""))

最后一行分别为意图标签和领域标签,第一列为原始数据,第二列为命名实体标签,第三列为标点符号标签。基本上就是标准的序列标注任务和文本分类任务的标签设计。



02


模型设计



模型设计就是最基本的思想,Encoder为中文Bert,输出层接多任务分类器:

self.ner_classifier = nn.Linear(config.hidden_size, config.ner_num_labels)
self.punc_classifier = nn.Linear(config.hidden_size, config.punc_num_labels)
self.intent_classifier = nn.Linear(config.hidden_size, config.intent_num_labels)
self.domain_classifier = nn.Linear(config.hidden_size, config.domain_num_labels)

outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
sequence_output, cls_output = outputs[0], outputs[1]

ner_logits = self.ner_classifier(sequence_output)
punc_logits = self.punc_classifier(sequence_output)
intent_logits = self.intent_classifier(cls_output)
domain_logits = self.domain_classifier(cls_output)

这里分别接了四个全连接层作为多任务输出结果,分类任务输入为[CLS]的输出,序列标注任务输入为BERT最后一层的输出。



03


Loss设计



在loss设计这部分,假设多个任务loss在同一个metric下,那几乎不用统一,除了一些先验干预,各个任务可以得到平等对待。如果同一个metric下不同任务loss差别巨大,那就是任务训练难度本身的问题。同步到一个数量级也不能解决一个任务训练不好的问题。

在实际过程中,目前就是对上述的多个任务的loss进行了叠加,没有进行缩放操作,当然还没有对模型进行仔细调参,或许loss进行精心设计可以达到更好的效果。

ner_loss = loss_fct(ner_logits.view(-1, self.ner_num_labels), ner_labels.view(-1))
punc_loss = loss_fct(punc_logits.view(-1, self.punc_num_labels), punc_labels.view(-1))
intent_loss = loss_fct(intent_logits, intent_label)
domain_loss = loss_fct(domain_logits, domain_label)

loss = ner_loss + punc_loss + intent_loss + domain_loss



04


训练结果



domain_f1 = 0.992931392931393
intent_f1 = 0.9941787941787942
loss = 0.08741999341892791
ner acc = 0.9820608182017064
ner f1 = 0.982705779334501
ner recall = 0.9833515881708653

预训练模型使用的是中文BERT,训练了3个epoch,效果好可能是我的数据集较为简单,接下来将会进行集外测试,看下模型效果。

最后还是有一点要注意,这里利用预训练模型进行多任务学习,主要是考虑到具体的下游任务比较类似,都可以建模成序列标注和文本分类任务。如果多任务之间的差异较大,相关参数和loss策略还是要进行调整设计,欢迎各位栗友进行交流讨论。

最后,关注六只栗子,面试不迷路!

作者    Zarc

编辑   一口栗子  


BERT多任务学习

BERT多任务学习BERT多任务学习BERT多任务学习


原文始发于微信公众号(六只栗子):BERT多任务学习

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/88240.html

(0)
小半的头像小半

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!