前言:本篇是TextCNN系列的第三篇,分享TextCNN的优化经验
前两篇可见:
调优模型的基本方法
大家如果跑过模型的话,不论是demo还是实际项目,通常的情况都是先跑一次看看效果,然后针对某些效果不好的地方做一些调优,俗称「调参狗」,调优有很多方法,基本的方法是:根据模型在测试集合的badcase 来分析有没有共性的问题,譬如做一个文本分类,我们在训练集上效果很好,但是测试集上,某一类的文本总是容易判断错误,那么这就属于一个共性问题。
如何解决这种问题呢?两种方式:
-
从特征的角度优化。我们可以通过做一些特征分析等来找到有区分度的特征来进行优化。特征与特征之间如何结合,如何判断与测试结果是否有相关性,都是我们重点需要分析的。
-
从模型结构上调整。深度学习虽然大家可能认为是一个黑盒,可解释性差,但是不同的网络模型确实能对结果产生一些变化。我们可以从网络模型上优化,适当调整结构,再来训练看测试的效果。不断的迭代再调整。
这是我们调优的基本思路和两个着手点。如果只是按照常规的手段直接去调参,往往达不到我们想要的效果,所以通过分析bascase,可以针对性的解决问题,更加符合我们的应用场景和实际项目。如果单纯的通过使用更fancy的模型来替换老模型来提高准确率,无异盲人摸像,很可能达不到我们最终想要的效果。
具体调优方法介绍
1:明确所有类别预测错误的badcase
假设我们的任务是一个三分类的任务,其中一个类别是"movie",即判断一个query 是不是属于视频。在测试集的预测结果中,有部分预测错误,如图一所示
label | predict | 错误分类原因 |
movie | music | 标注问题 |
movie | book | 数据分布不平衡 |
movie | music | 缺乏先验知识 |
movie | book | 过于依赖A特征 |
movie | book | 数据分布不平衡 |
分析示例一
从表中我们可以看到,一共有5条文本,原label是movie,预测结果与label都不相同,其中3条预测成book、2条预测成music等,这些就属于“badcase”。除了可以统计这一个label的错误预测,也可以把其他类别的badcase全部统一集合,做成list来具体分析。
2: 分析badcse预测错误的共性原因
我们把这些badcase挑出来以后,需要分析每个预测错误的badcase的具体原因是什么,譬如,表中有两条预测成book的原因是“数据分布不一致”,这就属于找到badcase的共性。把原始的特征拿出来再分析,就可以分析出预测错误的共性是什么。
总结出共性原因后,按照错误原因的频次从高到低排列,可以看到我们当前最需要解决的问题有哪些。假设我们一共分析出四类原因,分别是:
-
训练集合和测试集合部分特征抽取方式不一致,如训练集合是用python 来实现抽取特征,但是部署到线上的时候是用java实现来抽取特征,训练和预测抽取特征的时候不是同一套代码,虽然理论上两者应该一致,但是实现的时候由于语言差异以及二次开发的差异很容易造成某一特征在java版本和python版本有不同的结果 (占比约40%)
-
最后的结果过于依赖某一特征,比如我们的模型过于依赖A特征,但是A 特征的准确率只有88%左右,导致A特征的准确率成了我们模型优化的瓶颈 (占比约30%)
-
泛化能力较差,比如"我要看黄渤的电影"可以判别正确,但是"我要看周星驰的电影"就判别不对 (占比约20%)
-
缺乏先验知识,比如"青花瓷" 更容易出现在音乐当中, 而不是出现在baike 当中,但是我们在训练数据当中并没有特征可以体现这一先验知识(占比约10%)
3: 针对原因,专项优化
原因一:训练集合和测试集合部分特征抽取方式不一致
-
优化方法
-
由于测试集合是使用java 代码进行特征抽取以及预测的,训练集合是用python 代码进行特征抽取以及训练的,所以我们统一了预测代码与测试代码,共用一套代码
-
由于之前历史原因,我们训练集合与测试集合抽取特征使用的数据源不同,根据分析来看这种影响非常的大,本次优化使用了同一份数据源
-
-
优化结果:相对于原来的baseline,提升了0.4%
原因二:过于依赖效果不佳的特征
-
优化方法
-
在全连接层 增加dropout层,设置神经元随机失活的比例为0.3,即keep_rate= 0.7
-
在数据预处理的时候,随机去掉10%的A特征
-
-
优化结果
-
针对方法1,效果提升了0.11%,可以认为在这里随机扔掉30%的所有特征(包括A特征),让训练结果与A特征不强相关
-
针对方法2,效果提升了0.29%,因为我们只是过于依赖A特征,并没有过分依赖其他特征,故只在数据预处理的随机去掉10%的A特征,更有针对性的
-
原因三:泛化能力差
-
优化方法:增加槽位抽取:针对部分query, 增加槽位抽取的处理,比如将"黄渤", "周星驰" 等映射为aritist. 这样此类query, 模型见到的都是 我要看artist的电影,不存在缺乏泛化能力的问题. 瓶颈在于槽位抽取的准确率。
-
优化结果:没有单独实验, 增加槽位抽取后共同实验,分析badcase发现此类问题基本解决
原因四:缺乏先验知识
-
优化方法:引入200万外部词表以及 计算”青花瓷“在整体语料中出现在其他类别的频率来部分解决这个问题
-
优化结果:模型提升了0.5%
4.模型调优实操
1) 模型优化术语说明
在我们下面行文中,会有一些术语,为了方便说明我们先提前解释一下
-
baseline: 最开始的实验结果, 后续所有的实验结果,都会和baseline的结果进行对比,来体现每一次实验优化的提升
-
baseline diff:相比baseline的结果, 我们实验的提升有多少
-
测试集acc:即我们的模型在测试集合的准确率
-
epoch :一个epoch 表示将将所有训练样本学习一遍
-
iteration(step)或者叫迭代步数: 表示运行一次iteration或者(step) 更新一次参数。每运行一次iteration 都需要一个batch size 的数据进行学习,学习完一个batch size 的数据,更新一次参数, 下面的所有实验 迭代步数全部为26万步
-
batch size: 迭代一次需要的样本量,下面的实验,batch size 全部设置为64
-
举例说明:exampleNums(样本数) = 100000, epoch = 64, batch size = 64, 那么iteration = 100000
2) 模型优化过程说明
实验BaseLine
无论是论文还是日常的工作,最首要一点是跑出baseline 即我们实验的起点,我们本次的baseline 即不对textCNN的网络结构做任何调整在我们测试集合的准确率
实验0 baseline 测试基本参数在测试集合的表现 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
无任何修改 | 85.14% | 26w | 0 |
优化实验一:修改embedding_dim长度
Embeddinga_dim :128, 即将我们的query 映射成向量的时候,向量本身的维度
参数调整:我们将query 初始化成embedding的时候,修改了embed_dim 从128降低到64,相比于baseline 并没有明显的提升,降低了0.17%,在这里我们希望了解不同的向量维度对最后的实验结果的影响
实验一 修改测试embed_dim 长度对最后结果的影响 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
修改 embed_dim 128 -> 64 | 84.97% | 26w | -0.17% |
优化实验二:在全连接层增加dropout层,keep_rate = 0.7
keep rate 是dropout 的一个参数即:表示本层要保留多少比例神经单元
参数调整:我们在前面的分析当中发现,badcase 与A 特征高度相关,于是在最后的全链接层增加dropout,并将keep_rate 设置为0.7,即随机丢掉30%的数据,相比baseline有0.11%的提升
实验二 在全连接层增加dropout 层,keep_rate = 0.7 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在全连接层增加dropout层, keep_rate = 0.7 | 85.25% | 26w | +0.11% |
优化实验三:在全连接层增加dropout层,keep_rate = 0.5
参数调整:在实验二的基础上,我们希望测试丢掉更多的数据,是否会有更加明显的提升,发现反而降低了0.04%,原因待进一步探究
实验三 在全连接层增加dropout 层,keep_rate = 0.5 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在全连接层增加dropout层, keep_rate = 0.5 | 85.09% | 26w | -0.04% |
优化实验四:随机去掉10%A特征信息
参数调整:在实验三的基础上,进一步思考其实我们只是过于依赖A特征, 并没有过分依赖其他特征,故只在数据预处理的时候,随机去掉10%的A特征,相比baseline 提升0.29%
实验四 随机去掉10% A特征 打分 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在数据增广的时候,设置概率值, 如果 随机数小于0.1就不输出所有A特征信息, 如果随机数高于该概率值就输出打分信息 | 85.43% | 26w | +0.29% |
优化实验五:限制高频query权重
参数调整:我们在实验的时候,会对数据数据重采样,使之符合一定的分布,在重采样的时候,我们限制了部分query的权重,相比baseline 提升0.03%,这个的目的是增加数据的多样性
实验五 限制高频query 权重 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在数据增广的时候,限制高频query | 85.17% | 26w | +0.03% |
优化实验六:随机去掉20%A特征
参数调整:在实验四的基础上,在数据预处理的时候,随机去掉20%的A特征,相比baseline 提升0.28%, 并没有优于随机去掉10%,最终选择随机去掉10%A特征
实验六 随机去掉20% A 特征 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在数据增广的时候,设置概率值, 如果 随机数小于0.2就不输出A特征, 如果随机数高于该概率值就输出A 特征 | 85.42% | 26w | +0.28% |
优化实验七:随机去掉10%A特征,使用一致的特征抽取方式
参数调整:在实验四的基础上,我们分析了测试集合与训练集合特征抽取不一致的来源,统一来源,相比baseline 提升了0.64%
实验七 随机去掉10% A特征 使用一致的特征抽取方式 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在数据增广的时候,设置概率值, 如果 随机数小于0.1就不输出A特征, 如果随机数高于该概率值就输出A特征 | 85.78% | 26w | +0.64% |
优化实验八:随机去掉10% A 特征,使用一致的特征抽取方式,引入B特征
参数调整:在实验七的基础上,我们进一步分析出B特征对badcase 有很好的区分度,将B特征引入训练,相比baseline 提升了1.14%
实验八 随机去掉10% A 特征 使用一致的特征抽取方式,引入B特征 | |||
修改内容 | 测试集acc | 迭代步数 | baseline diff |
在数据增广的时候,设置概率值, 如果随机数小于0.1就不输出A特征, 如果随机数高于该概率值就输出A特征,训练与测试集合使用一致的特征抽取方式,额外引入B特征 | 86.28% | 26w | +1.14% |
总结
这篇文章主要总结了文本分类算法TextCNN调优的方法,先给大家分享了下一般基本做调优我们会采取什么方法——找到badcase,分析共性问题。
分析共性问题后,可以从模型和特征两个方面去优化。在这里我们以文本分类为例,针对具体出现的badcase,我们总结了四个原因:
-
抽取方式差异
-
过于依赖某个特征
-
泛化能力差
-
缺乏先验知识
这些问题可能不仅仅在文本分类这个领域中出现,在其他的应用场景中也可能会出现。那么,具体如何优化呢,我们做了八个实验来解决问题:其中实验五,实验八是从数据层面来优化模型,其余的都是从模型方面来优化
-
实验一:我们通过修改embedding_dim长度希望了解不同的向量维度对最后的实验结果的影响,发现embedding_dim 降低到64会对最后结果有比较明显的影响,下次要实验一下embedding dim 到512 会不会对实验结果有提升
-
实验二:在全连接层增加dropout 层,keep_rate = 0.7,随机丢掉30%,来降低最后结果与A特征相关性
-
实验三:在全连接层增加dropout 层,keep_rate = 0.5,随机丢掉50%,来降低最后结果与A特征相关性
-
实验四:只在数据预处理的时候,随机去掉10%的A特征,相比baseline 提升0.29%
-
实验五:对数据数据重采样,在重采样的时候,限制了部分query的权重,相比baseline 提升0.03%,目的是增加数据的多样性,提升了0.03%
-
实验六:在实验四的基础上,随机去掉20%的A特征,相比baseline 提升0.28%, 并没有优于随机去掉10%,最终选择随机去掉10%A特征
-
实验七:在实验四的基础上,统一来源,相比baseline 提升了0.64%
-
实验八:在实验七的基础上,将B特征引入训练,相比baseline 提升了1.14%
大家不要觉得这些实验很多觉得很麻烦,其实在实际工作中,通常来说,最有效的方法是:分析badcase,引入更有区分度的特征,如果你的项目效果不好,再试试引入dropout等等。具体的参数如何调整,是根据具体的项目而定,但是不管如何优化,都要先跑baseline。在baseline的基础上再进行优化。其次快速迭代实验,做好实验记录也是非常有必要的
希望这一次的优化经验总结能对大家有所帮助,下一篇会给大家讲解训练好的textcnn模型如何部署上线,实际应用于工业场景里。