论文笔记 Self-critical Sequence Training for Image Captioning

本文最后更新于:2023年3月16日 下午

论文笔记 Self-critical Sequence Training for Image Captioning

论文链接:CVPR 2017 Open Access Repository (thecvf.com)

代码链接:非官方 ruotianluo/self-critical.pytorch: (github.com)一个用到了SCST训练的代码 AliceMind/mPLUG at main · alibaba/AliceMind (github.com)

本文发表在CVPR2017,这篇文章提出了SCST的训练方式,其使用了一种强化学习的方式来提升Image Captioning模型的性能,作者将评价时使用的不可微分的指标直接作为优化对象,能够简单有效地提分,后面各路模型在做Image Captioning的时候也会带上它。

简介

目前的类似Image Captioning的生成任务的训练方式主要是:根据当前word之前的事实标签来生成下一个word,最小化这个生成的word与当前事实标签的概率差距。这种训练方式叫做“Teacher-Forcing”。然而这会造成exposure bias的问题,因为在测试阶段是使用之前生成的word来预测下一个word,没有事实标签,所以预测的错误容易累积下来。

目前有的方法在解决这个问题,比如一些方法训练的时候也按照一定概率使用之前生成的word来预测,逐渐增加这个概率后训练和测试的误差就会被减小。

但是这类方法在训练时都还是使用的交叉熵作为损失函数,在测试的时候,一般会使用一些不可微分的NLP指标,比如BLEU、ROUGE、METEOR、CIDEr。要直接优化这种指标是极好的,可以通过强化学习的方式,目前有许多研究围绕这个进行,但是他们都面临着不稳定的问题。

本文就提出了**Self-Critical Sequence Training(SCST)**的方法,这种方法是一种REINFORCE算法,它使用自身测试阶段的算法来规范化reward。模型在测试中进行采样时,那些优于baseline的样本会给予正权重,否则给予负权重。文章发现直接对CIDEr这种指标进行优化效果很好,不仅能大幅提升这个指标,还可以连带着提升其他指标。

方法

概念确定

Captioning模型可以被看作是强化学习中的Agent,输入进模型的图像和文本看作是Environment,模型的参数θ\theta就是Policy pθp_\theta,模型预测的下一个单词就是Agent选择的Action。在生成完句子之后,Agent才会获得对于生成的评价,也就是Reward rr,训练的目标就是最小化负的Reward:

L(θ)=Ewspθ[r(ws)]L(\theta) = -\mathbb{E}_{w^s \sim p_\theta}[r(\boldsymbol{w}^s)]

其中ws\boldsymbol{w}^s是模型的采样结果,wisw^s_i是第ii步的采样单词。

⚠️这里ws\boldsymbol{w}^s是采样结果而不是greedy search的结果,因为目前Agent是一个随机模型,输出要是不确定的才能进行优化。

基于Baseline的REINFORCE算法

为了计算损失的梯度θL(θ)\nabla_\theta L(\theta),会通过REINFORCE算法的方式,

θL(θ)=θEwspθ[r(ws)]=θ wsr(ws)pθ(ws)=wsr(ws)pθ(ws)pθ(ws)pθ(ws)=wsr(ws)pθ(ws)logpθ(ws)1Nn=1Nr(ws)logpθ(ws)1Nn=1Nr(ws)t=1Tlogpθ(wtssts)\begin{align} \nabla_\theta L(\theta) &= -\nabla_\theta \mathbb{E}_{w^s \sim p_\theta}[r(\boldsymbol{w}^s) ] \\ &= -\nabla_\theta \ \sum_{w^s} r(\boldsymbol{w}^s)p_\theta(\boldsymbol{w}^s) \\ &= -\sum_{w^s} r(\boldsymbol{w}^s)p_\theta(\boldsymbol{w}^s) \frac{\nabla p_\theta(\boldsymbol{w}^s)}{p_\theta(\boldsymbol{w}^s)} \\ &= -\sum_{w^s} r(\boldsymbol{w}^s)p_\theta(\boldsymbol{w}^s) \nabla \log p_\theta(\boldsymbol{w}^s) \\ &\approx -\frac{1}{N}\sum^N_{n=1} r(\boldsymbol{w}^s) \nabla \log p_\theta(\boldsymbol{w}^s) \\ &\approx -\frac{1}{N}\sum^N_{n=1} r(\boldsymbol{w}^s) \sum^{T}_{t=1} \nabla \log p_\theta(w^s_t|s^s_t) \end{align}

公式第一行就是加上了梯度符号,第二行将reward的期望转换成 wsw^s出现的概率 ×\times wsw^s的reward。

第三行把求导符号放了进去,由于r(w)r(w)项和参数θ\theta无关,所以不用对它求导,然后凑了一项pθ(ws)p_\theta(\boldsymbol{w}^s)。凑的这一项在第四行发挥作用,把最后一项变成了log的梯度。

由于无法穷举所有的wsw^s,所以第五行实现用一个Batch内的reward来近似。

第六行将多个单词组成的一句话拆开,整句话的概率pθ(ws)p_\theta(\boldsymbol{w}^s)是每个状态下生成下一个单词的概率的乘积:pθ(w1ss1s)pθ(w2ss2s)pθ(wTssTs)p_\theta(w^s_1|s^s_1)p_\theta(w^s_2|s^s_2)\dots p_\theta(w^s_T|s^s_T),然后根据log的性质变成了相加。

对于第六行这个式子,我们直观地看,假如最后得到的reward是正的,那么这个式子就会提高在每一步中在stss_t^s状态下生成wtsw^s_t的概率,反之则降低。因为采样过程是随机的,假如这一次采样到了cat这个单词,然后reward是正的,提升了cat的概率,下一次可能会采样到dog这个单词,然后发现reward是负的,就会降低dog的概率。

现在这个公式其实还有一个问题,就是其实现在reward不会是负数,因为我们用的评价指标CIDEr(或者其他)都是衡量相似度,相似度0%就到底了。所以我们要让reward减去一个baseline,这就到了SCST的核心思想——使用模型在推理算法下的reward作为baseline。

SCST

紧接上节,我们要将reward减去一个Baseline,而SCST则是使用模型在推理算法下生成caption的reward,论文中是将Greedy Search的结果作为了Baseline,实际上也可以使用Beam Search的结果。使用这种baseline可以减小方差,并且实施起来也很方便,不需要像一些Actor-Critic的方法需要第二个critic网络,在训练时,也只是增加了一次不用梯度的前向传播(算r(w^))r(\hat{\boldsymbol{w}})))。

θL(θ)1Nn=1Nt=1T(r(ws)b)logpθ(wtssts)1Nn=1Nt=1T(r(ws)r(w^))logpθ(wtssts)\begin{align} \nabla_\theta L(\theta) &\approx -\frac{1}{N}\sum^N_{n=1} \sum^{T}_{t=1} (r(\boldsymbol{w}^s)-b) \nabla \log p_\theta(w^s_t|s^s_t) \\ &\approx -\frac{1}{N}\sum^N_{n=1} \sum^{T}_{t=1} (r(\boldsymbol{w}^s)-r(\hat{\boldsymbol{w}})) \nabla \log p_\theta(w^s_t|s^s_t) \end{align}

论文给出了一个整体的流程图:

从强化学习外考虑SCST

作者是套用了强化学习的框架来解释SCST这个方法的,这里我给出一个从强化学习领域外看待这个方法的角度:

上面是我手绘的图,模型已经生成了“一个人在喝”这些词,下一步的概率算出来如图所示,同时对应的Ground Truth有4个。一般训练的情况下,假设是使用了第一个Ground Truth,那么下一个字应该是“水”,所以计算交叉熵的时候就是yilogpi,yi=1-y_i\log p_i,y_i=1

而SCST的情况下,模型在这一步会根据概率来进行随机采样。假如采样到了水,然后后面结束之后计算整句的reward,发现GT里面提到“水”不多,然后可能比baseline稍微低那么一点(假设),结果reward是-0.02,loss就成了负数。假如采样到了“奶”,之后就也很可能采样到“茶”,最终计算reward的时候发现GT中奶茶频率比较高,所以reward也比较高,然后loss就是正数,并且是对“奶”的概率进行调整。

这里就能看出两种方法的区别,一般的Teacher-Forcing训练方法比较死板,像上面这个例子,它就是强迫模型应该输出“一个人在喝水”,然而这个情况下输出“一个人在喝奶茶”的效果应该是更好的,即使是采用了label-smoothing的方法,也只是把log0.63-\log 0.63变成了0.9log0.63-0.9 \log 0.63,同样是死板地认为下一个词应该90%是“水”。

但是对于SCST,它在生成句子的时候会根据概率的采样冒出一些新的句子,有点像“突发奇想”,它认为“虽然下一个词很可能是xxx,但是也有可能是yyy,用yyy造的句子会不会更好?”于是它就尝试下一个词用“奶”而不是“水”,然后最终计算reward的时候评价它这个选择,调节的也是0.23这个概率。