🚀 tf-keras
このライブラリは、構造化データの学習タスクに非常に有用なGRNとVSNという2つの重要なアーキテクチャコンポーネントを使用して構築されたモデルを提供します。
🚀 クイックスタート
このモデルは、構造化データの学習タスクに使用できます。具体的には、人物の年収が$500Kを超えるかどうかを判断する二値分類タスクに利用できます。
✨ 主な機能
モデルの概要
このモデルは、Bryan Limらが2019年に論文 Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting で提案した2つの重要なアーキテクチャコンポーネントであるGRNとVSNを使用して構築されています。これらは構造化データの学習タスクに非常に有用です。
- ゲート付き残差ネットワーク (Gated Residual Networks, GRN): スキップ接続とゲート層から構成され、情報の流れを効率的に促進します。必要な場所にのみ非線形処理を適用する柔軟性を持っています。
GRNは ゲート付き線形ユニット (Gated Linear Units, GLU) を使用して、特定のタスクに関連しない入力を抑制します。
GRNの動作は以下の通りです:
- まず、入力に非線形のELU変換を適用します。
- 次に、線形変換を適用し、ドロップアウトを行います。
- その後、GLUを適用し、GLUの出力に元の入力を加えてスキップ(残差)接続を行います。
- 最後に、レイヤー正規化を適用して出力を生成します。
- 変数選択ネットワーク (Variable Selection Networks, VSN): 入力から最も重要な特徴を慎重に選択し、モデルの性能に害を与える可能性のある不要なノイズ入力を取り除くのに役立ちます。
VSNの動作は以下の通りです:
- まず、各特徴に個別にゲート付き残差ネットワーク (GRN) を適用します。
- 次に、すべての特徴を連結し、連結された特徴にGRNを適用し、その後ソフトマックスを適用して特徴の重みを生成します。
- 個々のGRNの出力の加重和を生成します。
注意: このモデルは、上記の論文で説明されている全TFTモデルに基づくものではなく、GRNとVSNのコンポーネントのみを使用しており、GRNとVSNが構造化データの学習タスクに独自で非常に有用であることを示しています。
想定される用途
このモデルは、人物の年収が$500Kを超えるかどうかを判断する二値分類タスクに使用できます。
学習と評価データ
このモデルは、UCI Machine Learning Repositoryが提供する 米国の国勢調査所得データセット を使用して学習されました。
このデータセットは、1994年と1995年に米国国勢調査局が実施した現在の人口調査から抽出された人口統計学的および雇用関連の変数を含む重み付けされた国勢調査データで構成されています。
このデータセットは、41個の入力変数と income_level という1つのターゲット変数を持つ約299Kのサンプルで構成されています。
instance_weight 変数はモデルの入力として使用されないため、最終的にモデルは7つの数値特徴と33のカテゴリカル特徴を含む40個の入力特徴を使用します:
数値特徴 |
カテゴリカル特徴 |
年齢 |
労働者の階級 |
時給 |
産業コード |
キャピタルゲイン |
職種コード |
キャピタル損失 |
調整後総所得 |
株式配当 |
教育レベル |
雇用主に勤務した人数 |
退役軍人福利 |
年間勤務週数 |
先週の教育機関登録状況 |
|
婚姻状況 |
|
主要産業コード |
|
主要職種コード |
|
人種 |
|
ヒスパニック出身 |
|
性別 |
|
労働組合員 |
|
失業理由 |
|
正社員またはパートタイム雇用状況 |
|
連邦所得税債務 |
|
納税者ステータス |
|
以前の居住地の地域 |
|
以前の居住地の州 |
|
詳細な世帯および家族状況 |
|
世帯内の詳細な世帯概要 |
|
大都市圏の移動コード |
|
地域の移動コード |
|
地域内の移動コード |
|
1年前にこの家に住んでいたか |
|
以前の居住地がサンベルト地域か |
|
18歳未満の家族構成員 |
|
個人の総収入 |
|
父親の出生国 |
|
母親の出生国 |
|
自身の出生国 |
|
市民権 |
|
個人の総所得 |
|
自営業または個人事業主 |
|
課税所得額 |
|
退役軍人行政機関の所得調査票を提出したか |
このデータセットは、学習用とテスト用の2つの部分に分かれています。
学習データセットには199523個のサンプルがあり、テストデータセットには99762個のサンプルがあります。
学習手順
-
データの準備: 学習データセットとテストデータセットを読み込み、ターゲット列 income_level を文字列から整数に変換します。学習データセットはさらに学習セットと検証セットに分割されます。
最後に、学習データセットと検証データセットは、モデルの学習と評価に使用するためのtf.data.Datasetに変換されます。
-
入力特徴のエンコードロジックの定義: カテゴリカル特徴と数値特徴を以下のようにエンコードします:
- カテゴリカル特徴: Kerasが提供する Embedding レイヤーを使用してエンコードされます。埋め込みの出力次元は encoding_size と等しくなります。
- 数値特徴: Kerasが提供する Dense レイヤーを使用して線形変換を適用することで、encoding_size 次元のベクトルに射影されます。
したがって、すべてのエンコードされた特徴は、encoding_size の値と等しい次元を持つことになります。
- モデルの作成:
- モデルは、与えられたデータセットの数値特徴とカテゴリカル特徴の両方に対応する入力レイヤーを持ちます。
- 入力レイヤーが受け取った特徴は、ステップ2で定義されたエンコードロジックを使用してエンコードされ、encoding_size は16であり、エンコードされた特徴の出力次元を示します。
- エンコードされた特徴は、変数選択ネットワーク (VSN) を通過します。VSNは内部的にGRNも使用しており、モデルの概要 セクションで説明されている通りです。
- VSNが生成した特徴は、シグモイド活性化関数を持つ最終の Dense レイヤーを通過し、人物の収入が$500Kを超えるかどうかの確率を示すモデルの最終出力を生成します。
- モデルのコンパイル、学習、評価:
- モデルは二値分類を目的としているため、損失関数として二値交差エントロピーが選択されました。
- モデルの性能を評価するための指標として accuracy が選択されました。
- オプティマイザとして、学習率0.001のAdamが選択されました。
- GRNのドロップアウトレイヤーの dropout_rate は0.15でした。
- バッチサイズは265に選択され、モデルは20エポックで学習されました。
- 学習は、EarlyStopping のKerasコールバックを使用して行われました。これは、検証指標が改善しなくなったらすぐに学習が中断されることを意味します。
- 最後に、モデルの性能はテストデータセットでも評価され、約95%の精度に達しました。
学習ハイパーパラメータ
学習中に以下のハイパーパラメータが使用されました:
ハイパーパラメータ |
値 |
名前 |
Adam |
学習率 |
0.0010000000474974513 |
減衰率 |
0.0 |
beta_1 |
0.8999999761581421 |
beta_2 |
0.9990000128746033 |
イプシロン |
1e-07 |
amsgrad |
False |
学習精度 |
float32 |
モデルのプロット
モデルのプロットを表示

クレジット