Site cover image

ベイズ統計と時々バイオインフォ

ベイズ統計と機械学習について学んだことを書き留めるブログです。 2023.4~ wet→dryに転向、ベイズ統計を使ったバイオインフォの研究をしています。

Post title icon EMアルゴリズム(理論編①)

0. はじめに


『しくみががわかる ベイズ統計と機械学習』を読みつつ、より深く理解するために学んだ内容を簡潔にまとめています。今回は「6章 EMアルゴリズム」に対応する内容です。主にEMアルゴリズムの理論について説明し、後日pythonで実装します。

1. EMアルゴリズムの概要


1.1 背景

EMアルゴリズムは、観測値が混合分布から生成されるとき、混合分布の未知パラメータを点推定する手法です。

一般に、観測値 x ~\bm{x}~が従う確率分布 p(xθ) ~p(\bm{x}|\bm{\theta})~の形状を知りたいとき、未知パラメータ θ ~\bm{\theta}~の推定量を点推定することがあります。「尤もらしい」推定量を得る手法、すなわち最尤推定は代表的な点推定の手法で、以下の式で表されます。

θ^ML=arg maxθlogp(xθ)\hat{\bm{\theta}}_{ML}=\argmax_{\bm{\theta}}\log p(\bm{x}|\bm{\theta})

しかし、尤度関数 p(xθ) ~p(\bm{x}|\bm{\theta})~が複雑な場合、対数尤度 logp(xθ) ~\log p(\bm{x}|{\bm{\theta}})~を解析的に最大化することは困難である場合が多いです。

1.2 EMアルゴリズムの場合

EMアルゴリズムがどのように上記の問題点を解決しているのか説明するため、具体例を考えます。

観測値 x=(x1,,xn) ~\bm{x}=(x_1, \cdots,x_n)~ k ~k~個のガウス分布で構成される混合ガウス分布から得られるとします。このとき、「各観測値がどのガウス分布に属するか」が潜在変数 z=(z1,,zn) ~\bm{z}=(\bm{z}_1, \cdots, \bm{z_n})~によって決まるとすると、 p(xθ) ~p(\bm{x}|\bm{\theta})~不完全データ p(x,zθ) ~p(\bm{x},\bm{z}|\bm{\theta})~完全データとなります。EMアルゴリズムでは最適化するのが難しい不完全データ p(xθ) ~p(\bm{x}|\bm{\theta})~ではなく、完全データ p(x,zθ) ~p(\bm{x},\bm{z}|\bm{\theta})~の最適化を行います。

具体的には、以下で定義されるQ関数を最大化することで、間接的に対数尤度の最大化を実現しています( θ^t ~\hat{\bm{\theta}}_t~は時刻 t ~t~のパラメータ推定量)。Q関数は事後分布 p(zx,θ^) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}})~についての完全データの期待値であるため、 z ~\bm{z}~は周辺化され、 z ~\bm{z}~に依存しない値となります。

 Q(θ,θ^t)=Ep(zx,θ^t)[logp(x,zθ)]=zp(zx,θ^t)logp(x,zθ) ~Q(\bm{\theta},\hat{\bm{\theta}}_t)=\mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)}[\log p(\bm{x},\bm{z}|\bm{\theta})]= \sum_{\bm{z}}p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)\log p(\bm{x},\bm{z}|\bm{\theta})~

ひとことで言ってしまえば、EMアルゴリズムでは「パラメータ推定量を用いてQ関数を計算→Q関数を最大化するパラメータを新たなパラメータ推定量とする」ことを繰り返しています。これは以下の式で表されます。

 θ^t+1=arg maxθQ(θ,θ^t) \displaystyle{ ~\bm{\hat{\theta}}_{t+1}=\argmax_{\bm{\theta}}Q(\bm{\theta},\hat{\bm{\theta}}_{t})~ }

2. EMアルゴリズムの学習ステップ


より詳細には、EMアルゴリズムの学習ステップはEステップとMステップに分けられます。


E step(Expectation step;期待値計算ステップ)

  • 事後分布の算出

     Q(θ,θ^t)=Ep(zx,θ^t)[logp(x,zθ)] ~Q(\bm{\theta},\hat{\bm{\theta}}_t)=\mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)}[\log p(\bm{x},\bm{z}|\bm{\theta})]~の算出には事後分布 p(zx,θ^t) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)~が必要。

     t=0 ~t=0~ならランダムな初期値 θ^0 ~\hat{\bm{\theta}}_0~を、 t1 ~t\geqq 1~ならMステップの出力 θ^t ~\hat{\bm{\theta}}_t~を用いて事後分布 p(zx,θ^t) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)~を求める。

     p(zx,θ^t)=p(x,zθ^t)P(xθ^t) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)=\dfrac{p(\bm{x},\bm{z}|\hat{\bm{\theta}}_t)}{P(\bm{x}|\hat{\bm{\theta}}_t)}~

  • 期待値計算(Q関数の算出)

     Q(θ,θ^t)=Ep(zx,θ^t)[logp(x,zθ)] ~Q(\bm{\theta},\hat{\bm{\theta}}_t)=\mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)}[\log p(\bm{x},\bm{z}|\bm{\theta})]~

M step(Maximization step;最大化ステップ)

  • Q関数を最大化し、t  t+1 t~\rightarrow~t+1~ とする。

     θ^t+1=arg maxθQ(θ,θ^t) \displaystyle{ ~\hat{\bm{\theta}}_{t+1}=\argmax_{\bm{\theta}}Q(\bm{\theta},\hat{\bm{\theta}}_{t})~ }


3. EMアルゴリズムの原理


EMアルゴリズムはなぜ上手くいくのでしょうか?ここでは、「Q関数を最大化することで対数尤度が最大化される」ことを説明します。

logp(xθ)=logzp(x,zθ)  (1)=logzp(zx,θ^)p(x,zθ)p(zx,θ^)  (2)zp(zx,θ^)logp(x,zθ)p(zx,θ^)  (3)=Ep(zx,θ^)[logp(x,zθ)p(zx,θ^)]=B(θ,θ^)\begin{aligned} \log p(\bm{x}|\bm{\theta}) &=\log \sum_{\bm{z}} p(\bm{x},\bm{z}|\bm{\theta})~~\cdots(1) \\ &=\log\sum_{\bm{z}}p(\bm{z}|\bm{x},\hat{\bm{\theta}})\dfrac{p(\bm{x},\bm{z}|{\bm{\theta}}) }{p(\bm{z}|\bm{x},\hat{\bm{\theta}})}~~\cdots(2) \\ &\geqq \sum_{\bm{z}}p(\bm{z}|\bm{x},\hat{\bm{\theta}})\log \dfrac{p(\bm{x},\bm{z}|{\bm{\theta}}) }{p(\bm{z}|\bm{x},\hat{\bm{\theta}})}~~\cdots(3) \\ &=\mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}})}\left[ \log \dfrac{p(\bm{x},\bm{z}|{\bm{\theta}}) }{p(\bm{z}|\bm{x},\hat{\bm{\theta}})} \right] \\ &=\mathcal{B}(\bm{\theta},\hat{\bm{\theta}}) \end{aligned}
  1. まず、対数尤度を同時分布 p(x,zθ) ~p(\bm{x},\bm{z}|\bm{\theta})~が出現するように変形します (1) ~(1)~
  2. 事後分布 p(zx,θ^t) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)~が出現するように変形します (2) ~(2)~ p(zx,θ^t) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)~である必然性はないですが、後々の都合を考えて p(zx,θ^t) ~p(\bm{z}|\bm{x},\hat{\bm{\theta}}_t)~とします。
  3. (2) (2)~ log() ~\log (\sum)~の形になっています。また、 p ~p~は指数分布族であることが多いので、log(exp) \log(\sum \exp)~という形になってしまいがちです。これが計算を難しくしています。

    そこで、Jensenの不等式を用いて (2) ~(2)~を下から評価します (3) ~(3)~。その結果得られるのが変分下界 B(θ,θ^) ~\mathcal{B}(\bm{\theta}, \hat{\bm{\theta}})~です。変分下界の最大化により、対数尤度の最大化が実現されます。

更に変分下界の変形を行います。

B(θ,θ^)=Ep(zx,θ^)[logp(x,zθ)p(zx,θ^)]=Ep(zx,θ^)[logp(x,zθ)]+Ep(zx,θ^)[logp(zx,θ^)]=Q(θ,θ^)+H(p(zx,θ^))\begin{aligned} \mathcal{B}(\bm{\theta},\hat{\bm{\theta}}) &=\mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}})}\left[ \log \dfrac{p(\bm{x},\bm{z}|{\bm{\theta}}) }{p(\bm{z}|\bm{x},\hat{\bm{\theta}})} \right]\\ &=\mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}})}\left[ \log p(\bm{x},\bm{z}|{\bm{\theta}}) \right]+ \mathbb{E}_{p(\bm{z}|\bm{x},\hat{\bm{\theta}})}\left[ -\log p(\bm{z}|\bm{x},\hat{\bm{\theta}}) \right]\\ &=Q(\bm{\theta},\hat{\bm{\theta}})+H(p(\bm{z}|\bm{x},\hat{\bm{\theta}})) \end{aligned}

変分下界=Q関数+エントロピー」となりました。 θ^ ~\hat{\bm{\theta}}~を定数とみなしてQ関数の最大化を行う際には、エントロピーも定数となります。したがって、Q関数の最大化は変分下界の最大化に相当します。

以上から、「Q関数を最大化することで対数尤度が最大化される」ことがわかりました。

3. おわりに


『しくみががわかる ベイズ統計と機械学習』はたとえ話や具体例が多いため、要約するとだいぶ異なる表現になってしまいました。余裕があればJensenの不等式やエントロピーなどについても追記する予定です。

次回、混合ガウスモデルについてEMアルゴリズムを適用し、pythonでの実装を行います。

4. 参考図書


『しくみががわかる ベイズ統計と機械学習』手塚太郎 著. 2019年. 朝倉書店