カテゴリ変数の分散表現を学習するEntity Embeddingの実装

機械学習においてカテゴリ変数を扱うとき、何らかの変換を施して任意の数値で表現しなければなりません。

今回はWord2Vecのように任意のカテゴリ変数の分散表現を学習する、Entity Embeddingの紹介とそのPythonの実装をライブラリとして公開したので紹介します。

実装はこちらです。

github.com

実はEmbeddingレイヤというものを知ったときに、Entity Embeddingと同じ方法を思いついてCategory2Vecなどという名前で自分では呼んでいたのですが、普通に既に提案されていて、まあそりゃ誰でも思いつくよなと思った次第です。

Entity Embedding

Entity EmbeddingはkaggleのRossmann Store Salesという店舗の売上を予測するコンペで3位になったチームが提案して使用した方法で、論文にもなっています。

arxiv.org

カテゴリ変数を扱う際色々な方法はあるものの、基本はOne-hot encodingを行い1つのカテゴリ変数を取りうる値の次元数のベクトルで、その値に対応するインデックスのみ1となりその他は0となるように数値に変換します。

ただこの場合、各特徴量はその値であるか・その値でないかの情報しか持たず、次元数も取りうる値に従って大きくなるので計算量の問題も出てきます。

この問題をWord2Vecなどでも使われているEmbeddingレイヤを用いて解決したのが、Entity Embeddingです。(と私は解釈しています)

そもそもEmbeddingレイヤとは何でしょうか?

Embeddingレイヤ

ゼロから作るDeep Learning ❷ ―自然言語処理編』の説明が分かりやすいです。

ニューラルネットワークの入力をOne-hotベクトルだとしてその次元数を10万とすると、10万次元のベクトルと(10万×中間層のユニット数)の行列の積を通常は計算します。

ただ実際やってることは、値が1となっているインデックスに対応する行を抜き出しているだけです。(以下の図の赤枠で囲った部分)

f:id:pompom168:20190324234328p:plain

よって特定のインデックスに対応する行(ベクトルを)抜き出すレイヤをEmbeddingレイヤと呼び、このレイヤを使って任意のネットワークを学習します。

Word2Vecの慣習に従えば、このEmbeddingレイヤの各行が各カテゴリ変数の値の分散表現です。

以下の図が、提案されている論文に載っているEntity Embeddingを表す図です。

上で説明したことをそのまま図にしただけで、各カテゴリ変数のOne-hotベクトルに対してEmbeddingレイヤを接続し、それを全カテゴリ変数分結合したものを入力として普通に全結合ニューラルネットワークを構成します。

そして、任意の目的変数に応じて学習を行えば、Embeddingレイヤの重みが分散表現となるわけです。

f:id:pompom168:20190324235255p:plain

このネットワーク自体は分散表現を入力としたニューラルネットワークなわけなので、これ自体を予測器として用いてもいいですし、学習した分散表現を別の任意の予測器の入力として再学習を行っても良いわけです。

(ただし、Entity Embeddingと別の予測器の学習データを完全に同一のものにすると超絶過学習を起こすので、そこは学習データを分けるなどの工夫が必要です。)

またEntity Embeddingによって得た分散表現を用いて、カテゴリ変数の各値同士の近さを可視化できます。

以下はドイツの州の各値の分散表現同士の近さを可視化したもので、各州の相対的な距離が実際の地図と似たものになったそうです。

f:id:pompom168:20190324235441p:plain

ライブラリ

このカテゴリ変数を扱う画期的な方法であるEntity Embeddingを簡単に扱うために、ライブラリaltenaとして公開したので紹介します。

PyPIにアップロードしてあるので、pipでインストール可能です。

$ pip install altena

まだドキュメントの整備が済んでいないので、使い方は上記リポジトリのexampes/以下を参照してください。

現段階では、任意のネットワーク構成におけるEntity Embeddingの学習と、学習したモデルを使った分散表現への変換に対応しています。

また、その他のカテゴリ変数から特徴抽出する方法についても実装予定です。

実際にEntity Embeddingを使った結果なども追記していと思ってます。

(今まで社内のデータで使っていたため、公開されてるデータセットでやったことがなかったです)