0%

变分贝叶斯方法

前言:本文主要是关于变分贝叶斯方法, 这是一类用简单分布去逼近目标分布的方法, 主要参考了Coursera上国立高等经济大学Advanced Machine Learning系列课程Course3: Bayesian Methods for Machine Learning Week3.


本文主要是关于变分推断, 或者说变分贝叶斯方法, 它是一类用来计算难以计算积分的方法. 本文的定位是对EM算法的有益补充, 因此会结合EM算法去讲一讲什么是变分推断及平均场近似, 并得到EM算法的三种变体.

关于EM算法可以参见我的另一篇博文EM算法, 简单的说,EM算法是专门处理隐变量模型的算法. 下文我们会以隐狄利克雷分配模型(Latent Dirichlet Allocation: LDA)为例来说明本文的主角, 当然这是一个隐变量模型.


隐变量模型

隐变量模型, 隐变量模型含有未观测到值的变量, 且该隐变量解释了观测到的变量之间的相关性(局部独立性), 简言之引进的隐变量不能没有用, 举个例子来说明这个性质(来自维基百科):

本科生 读过机器学习 没读过机器学习
读过凸优化 260 140
没读过凸优化 240 360

容易看到读过机器学习的概率是\(\frac{500}{1000}=0.5\), 读过凸优化的概率是\(\frac{400}{1000}=0.4\), 但是两者都读过的概率是\(\frac{260}{1000}=0.26\), 不等于\(\frac{500}{1000}\times\frac{400}{1000}=0.2\), 因此并不独立. 若我们引进变量本科高年级生和本科低年级生:

高年级/低年级 读过机器学习 没读过机器学习
读过凸优化 240/20 60/80
没读过凸优化 160/80 40/320

就会发现, 对高年级生\(\frac{400}{500}\times\frac{300}{500}=\frac{240}{500}\); 对低年级生\(\frac{100}{500}\times\frac{100}{500}=\frac{20}{500}\), 此时相关性已经被隐变量解释掉了.


隐狄利克雷分配模型

下面举个隐变量模型的例子, 隐狄利克雷分配模型. 假如我读过东野圭吾的新参者, 我希望推荐系统能把白夜行也推荐给我. 这里直接给出模型的完整形式, 具体怎么来的请见博文Latent Dirichlet Allocation的第一节. \[P(W, Z, \Theta\mid \varphi, \alpha)=\prod_{d=1}^{D}P(\theta_d)\prod_{n=1}^{N_d}P(z_{dn}\mid\theta_d)P(w_{dn}\mid z_{dn})\]其中:

  • \(\theta_d\sim Dir(\alpha)=\frac{1}{B(\alpha)}\prod_{t=1}^{T}\theta_{dt}^{\alpha_t-1}\), 其中\(d\in\{1,2,...,D\}\), \(T\)给定.
  • \(P(z_{dn}\mid\theta_d)=\theta_{dz_{dn}}\), 其中\(n\in\{1,2,...,N_d\}\).
  • \(P(w_{dn}\mid z_{dn})=\varphi_{w_{dn}z_{dn}}\), 给定词汇表\(w_{set}\), \(w_{dn}\in\{1,2,...,\left|w_{set}\right|\}\).

大小字母表示小写字母的总和, 参数仍简记为小写字母, 总结如下:

  • 未知:\(Z, \Theta\)为隐变量; \(\varphi, \alpha\)未知参数.
  • 已知:\(W\)为数据集, \(D\)为文档数量, \(N_d\)为每篇文档单词数.
  • 超参:主题数\(T\), 词汇表\(w_{set}\)(可以取为这些文档出现过的单词组成的集合).

变分贝叶斯方法

EM算法

下面尝试用4种方法来解决这个模型, 首先从我们最熟悉的EM算法开始(不熟悉的话可以参见我的另一篇博文EM算法): EM算法的目标是极大化对数似然函数\(\log P(W\mid\varphi, \alpha)\), 且每一步迭代又给出隐变量\(Z, \Theta\)的概率分布以及\(\varphi, \alpha\)的点估计. 第k+1次迭代, 当前参数为\(\varphi^k, \alpha^k\):

  • E步求期望:求\(\mathbb{E}_{q^{k+1}}\log P(W, Z, \Theta\mid\varphi, \alpha)\), 其中\(q^{k+1}(Z, \Theta)=P(Z, \Theta\mid W, \varphi^k, \alpha^k)\).
  • M步求max: \((\varphi^{k+1}, \alpha^{k+1})=\arg\max_{(\varphi, \alpha)}\mathbb{E}_{q^{k+1}}\log P(W, Z, \Theta\mid\varphi, \alpha)\).

p.s. EM算法的迭代过程没有涉及参数\(\varphi, \alpha\)分布, 且没有引进先验分布, 因而仍然算是频率学派的方法, 下面来看一种明显是属于贝叶斯学派的方法:变分推断

变分推断

根据贝叶斯思想, 我们将参数\(\varphi, \alpha\)视为随机变量, 因此必须给出参数\(\varphi, \alpha\)的先验假设\(P(\varphi), P(\alpha)\), 如果能直接算出隐变量及未知参数\(Z, \Theta, \varphi, \alpha\)的概率分布\(P(Z, \Theta, \varphi, \alpha\mid W)\), 那么就能很方便的得到想要的参数值(比如说通过极大后验分布MAP). 事实上, 我们已经有了\(P(W, Z, \Theta\mid \varphi, \alpha)\), 根据贝叶斯公式:\[P(Z, \Theta, \varphi, \alpha\mid W)=\frac{P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha)}{P(W)}=\frac{P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha)}{\int_{Z, \Theta, \varphi, \alpha} P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha)}\]通过分析不难发现, 如果给出的参数\(\varphi, \alpha\)的先验恰为共轭先验, 那么上式就非常好求, 否则就不可避免的要计算\(P(W)\), 而这事实上相当难处理.

变分推断则是试图绕过求\(P(W)\), 去找一个对\(P(Z, \Theta, \varphi, \alpha\mid W)\)的良好逼近\(q(Z, \Theta, \varphi, \alpha)\), 其中\(q(Z, \Theta, \varphi, \alpha)\)会来自一个简单得多的分布族, 比如\(Q=N(\mu, \Sigma)\), 其中\(\Sigma\)为对角阵; 而逼近的方法一般则是最小化\(\mathcal{KL}\)散度:\[q=\arg\min_{q\in Q}\mathcal{KL}(q(Z, \Theta, \varphi, \alpha)\Vert P(Z, \Theta, \varphi, \alpha\mid W))\]关于\(\mathcal{KL}\)散度的定义见维基百科, 这里简单理解为衡量两个分布的距离就好. 下式告诉我们, 变分推断确实可以绕开\(P(W)\), 下面把\(q(Z, \Theta, \varphi, \alpha)\)简记为\(q\): \[\begin{aligned} \mathcal{KL}(q\Vert P(Z, \Theta, \varphi, \alpha\mid W)) & = \mathcal{KL}(q\Vert \frac{P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha)}{P(W)}) \\\ & =\int q\log \frac{q}{P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha)/P(W)} dZd \Theta d\varphi d\alpha\\\ & = \int \left( q\log \frac{q}{P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha)}dZd \Theta + q\log P(W) \right)dZd \Theta d\varphi d\alpha\\\ & = \mathcal{KL}(q\Vert P(W, Z, \Theta\mid \varphi, \alpha)P(\varphi, \alpha))+\log P(W) \end{aligned}\]其中第一个等号为简单带入, 第二个等号由\(\mathcal{KL}\)散度的定义, 第三个等号为简单化简, 第四个等号由\(\mathcal{KL}\)散度定义. 容易看到, 最后一个等式的后一项\(\log P(W)\)\(q\)无关, 因而成功绕开了\(P(W)\), 这便是变分推断.

p.s. 变分推断需要给出参数的先验分布和分布族\(Q\). 如果\(Q\)足够大, 则有\[q(Z, \Theta, \varphi, \alpha)=P(Z, \Theta, \varphi, \alpha\mid W)\]并没有简化计算. 所以这个分布族\(Q\)的选取也是一件麻烦的事情, 既要考虑\(Z, \Theta\)实际上是有关的, 又要考虑参数\(\varphi, \alpha\)与隐变量\(Z, \Theta\)是无关的. 一个合理的想法是预先假定\(\varphi, \alpha\)\(Z, \Theta\)是无关的, 在变分推断的基础上于是有了平均场近似的变分推断:

平均场近变分推断

其想法很简单, 在变分推断\[q=\arg\min_{q\in Q}\mathcal{KL}(q(Z, \Theta, \varphi, \alpha)\Vert P(Z, \Theta, \varphi, \alpha\mid W))\]的基础上, 预先假定参数\(\varphi, \alpha\)与隐变量\(Z, \Theta\)的独立性, 得到 \[\begin{aligned}q(Z, \Theta)&=\arg\min_{q(Z, \Theta)\in Q_1}\mathcal{KL}(q(Z, \Theta) q(\varphi, \alpha)\Vert P(Z, \Theta, \varphi, \alpha\mid W)) \\\ q(\varphi, \alpha)&=\arg\min_{q(\varphi, \alpha)\in Q_2}\mathcal{KL}(q(Z, \Theta)q(\varphi, \alpha)\Vert P(Z, \Theta, \varphi, \alpha\mid W))\end{aligned}\]这样计算上有了一定程度的减小, 当然准确度也有可能有一定的损失, 但其实际上不失为一个好的假设. 下面来看平均长近似后的形式, 为了表述方便, 记\(q(Z, \Theta)=q_1\), \(q(\varphi, \alpha)=q_2\), \(P(Z, \Theta, \varphi, \alpha\mid W)=\hat{P}\), \(dZd \Theta d\varphi d\alpha=d\hat{Z}\), 目标则简化为\[q_i=\arg\min_{q_i\in Q_i}\mathcal{KL}(\prod_{i=1}^2 q_i\Vert \hat{P})\]下面来看, 目标是对\(q_k\in Q_k\)求最小值点, \(k=1, 2\): \[\begin{aligned} \mathcal{KL}(\prod_{i=1}^2 q_i\Vert \hat{P}) & = \int \prod_{j=1}^2 q_j\log \frac{\prod_{i=1}^2 q_i}{\hat{P}} d\hat{Z} \\\ & =\sum_{i=1}^2\int \prod_{j=1}^2q_j\log q_i d\hat{Z}-\int \prod_{j=1}^2q_j\log \hat{P}d\hat{Z}\\\ & = \int \prod_{j=1}^2q_j\log q_k d\hat{Z}+\sum_{i\neq k}^2\int \prod_{j=1}^2q_j\log q_i d\hat{Z}-\int \prod_{j=1}^2q_j\log \hat{P} d\hat{Z} \\\ &=\int q_k\log q_k d\hat{Z}_k+\sum_{i\neq k}^2\int \prod_{j\neq k}^2q_j\log q_i d\hat{Z}_{-k}-\int \prod_{j=1}^2q_j\log \hat{P} d\hat{Z} \\\ &= \int q_k\log q_k d\hat{Z}_k-\int q_k\left[\int \prod_{j\neq k}^2q_j\log \hat{P} d\hat{Z}_{-k}\right]d\hat{Z}_k + C\\\ &=\int q_k\left[\log q_k -\int \prod_{j\neq k}^2q_j\log \hat{P} d\hat{Z}_{-k}\right]d\hat{Z}_k + C \end{aligned}\]这里的\(d\hat{Z}\)\(dZd \Theta d\varphi d\alpha\), \(d\hat{Z}_k\)为对\(q_k\)的变量积分, 比方说\(k=1\), 即为\(dZ d\Theta\), \(d\hat{Z}_{-k}\)为对除\(q_k\)的之外的变量积分.

稍作解释:第一个等号由\(\mathcal{KL}\)散度的定义; 第二个等号为简单化简; 第三个等号为将第一项拆成\(i=k\)\(i\neq k\)两项; 第四个等号为对第一第二项中能积分的先积掉; 第五个等号为将第三项对\(q_k\)的积分拎到最前面, 且此时第二项与\(q_k\)无关, 视为常数\(C\); 第六个等号即为简单合并.

下面继续, 把我们把\(\int \prod_{j\neq k}^2q_j\log \hat{P} d\hat{Z}_{-k}=\mathbb{E}_{q_{-k}}\log \hat{P}\)记作\(h(\hat{Z}_k)\), 再令\(t(\hat{Z}_k)=\frac{exp(h(\hat{Z}_k))}{\int exp(h(\hat{Z}_k))d\hat{Z}_k}\), 除以正则项确保是个概率分布, 上式可继续化为: \[\begin{aligned} \mathcal{KL}(\prod_{i=1}^2 q_i\Vert \hat{P}) & = \int q_k\left[\log q_k -\int \prod_{j\neq k}^2q_j\log \hat{P} d\hat{Z}_{-k}\right]d\hat{Z}_k + C\\\ & = \int q_k\left[\log q_k d\hat{Z}_k-h(\hat{Z}_k)\right]d\hat{Z}_k + C\\\ &=\int q_k\log \frac{q_k}{t} d\hat{Z}_k + C \end{aligned}\]注意到第一项即为KL散度, 那么此时关于\(q_k\)取极小只需令\(q_k=t=\frac{exp(h(\hat{Z}_k))}{\int exp(h(\hat{Z}_k))d\hat{Z}_k}\), 其中\(h(\hat{Z}_k)=\mathbb{E}_{q_{-k}}\log \hat{P}\)再取对数得到\[\log q_k=\mathbb{E}_{q_{-k}}\log \hat{P}+C\]其中\(\hat{P}=P(Z, \Theta, \varphi, \alpha\mid W)\)在LDA模型这个例子中, 上述结果化为: \[\begin{aligned}\log q(Z, \Theta)&=\mathbb{E}_{q(\varphi, \alpha)}\log P(Z, \Theta, \varphi, \alpha\mid W)+C_1 \\\ \log q(\varphi, \alpha)&=\mathbb{E}_{q(Z, \Theta)}\log P(Z, \Theta, \varphi, \alpha\mid W)+C_2\end{aligned}\]

p.s. 相比变分推断, 这个平均场近似的变分推断实际是将参数\(\varphi, \alpha\)与隐变量\(Z, \Theta\)进行了分开处理, 降低了模型复杂度, 同时明显也降低了计算的维度, 但由于对参数\(\varphi, \alpha\)与隐变量\(Z, \Theta\)的独立性假设非常合理, 所以也不会有太多精度上的损失, 甚至会使得模型更稳健. 另外顺便得到了一个非常有用的式子, 这个式子能在变分贝叶斯方法 的维基百科的Mean field approximation一节找到, 只是没有上述的推导过程.

EM算法也是将参数\(\varphi, \alpha\)与隐变量\(Z, \Theta\)进行了分开处理, 并且对隐变量\(Z, \Theta\)同样是通过对\(\mathcal{KL}\)散度最小化求得分布, 而EM算法对未知参数则是直接采用了点估计的方法而不是变分推断. 因此相比平均场近似的变分推断, EM算法进一步降低了计算复杂度, , 我想这大概就是EM算法流行的理由之一.

平均场EM算法

平均场EM算法实际上是在EM算法基础上进一步对隐变量\(Z, \Theta\)进行了独立性假设. 下面分两步在EM算法的基础上推导出平均场EM算法:

  1. EM算法: 第k+1次迭代, 当前参数为\(\varphi^k, \alpha^k\):
    • E步: \(q^{k+1}(Z, \Theta)=\arg\min_{q(Z, \Theta)\in Q_1}\mathcal{KL}(q(Z, \Theta)\Vert P(Z, \Theta\mid W, \varphi, \alpha))\).
    • M步: \((\varphi^{k+1}, \alpha^{k+1})=\arg\max_{(\varphi, \alpha)}\mathbb{E}_{q^{k+1}}\log P(W, Z, \Theta\mid\varphi, \alpha)\).
  2. 假定\(q(Z, \Theta)=q(Z)q(\Theta)\), 由平均场近似得到的公式, 得到平均场EM算法: 第k+1次迭代, 当前参数为\(\varphi^k, \alpha^k\):
    • E步:\(\begin{aligned}\log q^{k+1}(Z)&=\mathbb{E}_{q(\Theta)}\log P(Z, \Theta\mid \varphi^k, \alpha^k, W)+C_1 \\\ \log q^{k+1}(\Theta)&=\mathbb{E}_{q(Z)}\log P(Z, \Theta\mid \varphi^k, \alpha^k, W)+C_2\end{aligned}\)
    • M步: \((\varphi^{k+1}, \alpha^{k+1})=\arg\max_{(\varphi, \alpha)}\mathbb{E}_{q^{k+1}(Z, \Theta)}\log P(W, Z, \Theta\mid\varphi, \alpha)\)

p.s. 平均场EM算法比EM算法多了一个隐变量之间独立的假设, 但一般隐变量之间应该不独立, 比方说这个LDA模型中\(Z\)代表每篇文档每个单词的主题分布, \(\Theta\)代表每篇文档的主题分布, 两者应该不独立, 所以平均场EM算法不适合这个例子. 另外相比EM算法, 平均场EM算法因为多了一个独立性假设, 模型应该会变得简单一些, 但精度有可能会有些损失.

总结

最后用一张表格总结上述4中算法:

算法 模型复杂度 需要的计算资源
变分推断 4 4
平均场变分推断 3 3
EM算法 2 2
平均场EM算法 1 1

EM算法作为其中比较折中的选择, 兼顾了模型复杂度与计算资源.