Latent Factor Analysis via Dynamical Systems (LFADS)をchainerで実装した。

やったこと

Google BrainのDavid Sussilloらが先日Nature Methodsに出版したLFADSをchainerで実装しました。
GitHub - yu-takagi/chainer_lfads: Implementation of LFADS with chainer
公式の実装が既にこちらにあるのでそれを参考にしました。二つのpythonファイルだけと大変わかりやすい公式実装です。が、それぞれ800行と2100行と超重量級、加えてtensorflowにより書かれているため自分でいじるのが難しく、tensorflowとアルゴリズムの勉強も兼ねて再実装しました。

LFADSの概要

LFADSは時系列データ、特に神経科学の実験から取得できる脳信号解析のために開発されたRecurrent Neural Network (RNN)のモデルです。多数のセンサーから取られた大量の神経データを用いることで、単一試行レベルでの神経時系列データ解析をこれまでよりも高い精度で実現します。

ここで、「時系列データ」には気温のような1次元のデータから、市場全体の株価の動きのように多次元かつ複雑にものまで様々あります。我々の脳も時事刻々と異なるデータを生み出している時系列データ生成器であることから、神経科学の実験からも様々な時系列データが得られます。一方で、時系列データには周期性や過去の時点への依存性などモデリングに際して難しい点が多々あります。そのため、活発に研究が行われている領域です。時系列データモデリングの難しさについて詳しくは、たとえば以下の記事によくまとまっています。
tjo.hatenablog.com

LFADSはそのような時系列データ処理のためのディープラーニングモデルですが、大きく見るとVariational Autoencoder (VAE)の亜種です。論文でまず焦点を当てているのは、獲得された潜在表現が生信号やGaussian Process Factor Analysisなどの線形・非動的なモデルに比べ、行動をよく予測するという点です。加えて、LFADSはその構造上、観測データの背後に隠れている低次元の時系列をうまく推定している可能性があります(線形ICAなどのように必ず復元できる保証があるわけではないのに注意)。実際に、論文では意思決定に関わってそうな各種の時系列がモデルの潜在変数として捉えられていることを示唆しています。最適化や使用しているユニット等、手法面ではこれといった新規性はないので、理論家からすると退屈な論文かもしれませんが、個人的には安心して読めます。

LFADSはVAEの亜種であることから、深層生成モデルでもあります。そのため、論文では生成に関して一切触れていないものの、理屈の上ではそれらしい脳活動を生成する機械にもなっているはずです。今回の論文ではおそらく科学的な面白さが特にないため触れていないですが、このような生成モデルはデータ拡張や脳活動シミュレーションに使えるため、それ自体面白いテーマです。

まだやってないこと

論文では様々な設定で実験をしており、それらの中には一部実装できていない機能があります。特に複数のセッションを一つのネットワークで学習し、それにより用いるデータ量を大幅に増やす"Stitching"は色々と役に立つシチュエーションがありそうです。名前は大層ですが要は複数のセッションに合わせてそれぞれネットワークを用意するだけなので、実装は難しくありません。また、自分が今持っているデータが神経スパイクデータではないため、連続値を推定するガウス分布バージョンのみ実装していて、発火率を予測するポアソン分布バージョンは未実装です。今後実装次第更新したいと思っています。