【深度学习笔记】初步理解交叉熵(Cross Entropy)

xiaoxiao2021-02-28  21

交叉熵是信息论中的一个重要概念,用于衡量两个概率分布之间的差异。在机器学习中常用作损失函数。本文主要参考《Visual Information Theory》,结合本科课程学习过的《通信原理》,希望能较为直观地初步理解交叉熵概念

一、何为熵

    熵可以理解为通信中的最短平均码长。作者假想了一个人叫Bob,Bob和作者通信,只使用四个单词:"dog"、“cat”、“fish”、"bird”。这四个单词的第一种编码方式如下:

    这时候平均码长是2bits。但是呢,Bob很喜欢狗,他经常说"dog"这个单词,具体的频率是这样的:

    这个时候如果所有单词都采用2比特编码就显得不合适了,会造成信道的浪费(若要深究要复习通信原理,这里可以简单理解一下,对于经常使用的那些单词,采用更短的编码,而对于那些不常用的单词采用长的编码,那么信道的效率就能提高)。于是就有了第二种编码方式:

    这时候计算平均码长(概率乘以码长求和):0.5*1+0.25*2+0.125*3*2=1.75bits。看,平均码长变小了。事实上,这就是理论能够得到的最优码长(optimal average length),我们把它叫做熵。

二、如何计算熵

    先说一下编码中码长和码元的损失之间权衡的问题,然后再引出熵的公式。

    首先来看一个问题:采用这种码长不同的编码方式该如何解码呢?如果是第一种编码方式,码长都是2,那么2个2个地解码就可以了,这里没有问题。但如果是第二种长度可以改变的编码方式呢?如果我们将0和01同时作为码元,分别指代狗和猫,那么0100111就有很多解读方式了……为了避免这个问题,可以采用前缀编码(prefix code)的方式:保证所有的码元不能有相同的前缀。

    但是这也会产生一个问题,那就是码元的浪费:

    看图,如果采用码元01的话(白色部分),那么011,010,0110....其它以01开头的码元都不能用,所以这里浪费了1/4的码元,我们把1/4叫做损失值,也就是cost(loss),计算公式上图也给出了。

    损失值和码长的关系可以看上图,横坐标是码长,纵坐标是损失值,深色部分的面积也可视作码长为L(x)时的损失值。可以看到,码长越长,损失值越小。

    平均码长和码元码长的关系可以看上图,可联想概率密度函数。可以看到,码长越长,平均码长也会越长,这是我们不想要的。所以平均码长和损失值之间就存在一个冲突关系。来张图直观感受一下:

    所以该如何编码才能取得一个完美的平衡呢?也就是如何编码能得到最优平均码长——熵呢?答案是让每一个码元的损失值等于其概率。也就是p(x)=cost(x)。举个例子,如果"cat”出现的概率是50%,那么就取损失值1/(2^L)=0.5,也就是码长为1。作者在文中还用可视化的方式证明了这一点,但是阅读理解起来依然比较困难(作者自己说的),所以我就愉悦地跳过了……

    现在熵的公式就可以得到了:平均码长的计算是概率乘以码长求和,由于cost(x)=1/(2^L(x))=p(x),把码长用p(x)替换就可以得到熵的公式:

三、何为交叉熵

    Bob结婚了,妻子叫Alice。Alice不喜欢讨论狗而喜欢猫,这是一个频率对照图:

    那么对于Alice来说,采用Bob之前采用的编码方式得到的平均码长是多少呢?0.125*1+0.5*2+0.25*3+0.125*3=2.25bits

这个就是交叉熵!

    我们把Bob对应的概率分布叫做p(x),Alice的是q(x),那么Alice单词频率相对于Bob单词频率的交叉熵就是Hp(q),计算方式就是用Alice的单词概率乘以Bob最优编码模型中对应的单词码长最后求和。

    交叉熵有什么用呢?交叉熵表示了两个概率分布是多么的不一样。p和q的差异性越大,那么Hp(q)(或Hq(p))相对于H(p)(或H(q))也就越大。

   我们把上图中的D叫做Kullback–Leibler divergence,或者叫KL divergence。这个值就像两个概率分布之间的距离,它定量评价了两个概率之间的差异。

四、待学习

    机器学习中的交叉熵作为代价函数是怎么回事?怎么能够收敛到0的?不用减去H(p)吗?

    

    

转载请注明原文地址: https://www.6miu.com/read-1650048.html

最新回复(0)