🚀 tf-keras
本項目基於GRN和VSN構建模型,用於結構化數據學習任務,尤其適用於二元分類,可判斷個人年收入是否超過50萬美元。該模型在相關數據集上訓練後,達到了約95%的準確率。
🚀 快速開始
本模型可用於二元分類任務,以確定一個人每年的收入是否超過50萬美元。
✨ 主要特性
本模型使用了Bryan Lim等人在 Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting 中提出的兩個重要架構組件GRN和VSN,它們對結構化數據學習任務非常有用。
- 門控殘差網絡(Gated Residual Networks,GRN):由跳躍連接和門控層組成,可有效促進信息流動。它們能夠靈活地僅在需要的地方應用非線性處理。
GRN利用 門控線性單元(Gated Linear Units,GLU)來抑制與給定任務無關的輸入。
GRN的工作原理如下:
- 首先對其輸入應用非線性ELU變換。
- 然後應用線性變換,接著進行丟棄(dropout)操作。
- 接下來應用GLU,並將原始輸入添加到GLU的輸出中,以執行跳躍(殘差)連接。
- 最後,應用層歸一化併產生輸出。
- 變量選擇網絡(Variable Selection Networks,VSN):有助於從輸入中仔細選擇最重要的特徵,並去除可能損害模型性能的任何不必要的噪聲輸入。
VSN的工作原理如下:
- 首先,對每個特徵單獨應用門控殘差網絡(GRN)。
- 然後將所有特徵連接起來,並對連接後的特徵應用GRN,接著應用softmax以產生特徵權重。
- 最後,產生各個GRN輸出的加權和。
注意:本模型並非基於上述論文中描述的整個TFT模型,僅使用了其GRN和VSN組件,這表明GRN和VSN本身對於結構化數據學習任務也非常有用。
📦 安裝指南
文檔未提及安裝步驟,故跳過此章節。
💻 使用示例
文檔未提供代碼示例,故跳過此章節。
📚 詳細文檔
訓練和評估數據
本模型使用由UCI機器學習庫提供的 美國人口普查收入數據集 進行訓練。
該數據集由加權人口普查數據組成,包含從1994年和1995年美國人口普查局進行的當前人口調查中提取的與人口統計和就業相關的變量。
數據集包含約29.9萬個樣本,有41個輸入變量和1個名為 income_level 的目標變量。變量 instance_weight 不用作模型的輸入,因此最終模型使用40個輸入特徵,其中包含7個數值特徵和33個分類特徵:
數值特徵 |
分類特徵 |
年齡 |
工人類別 |
每小時工資 |
行業代碼 |
資本收益 |
職業代碼 |
資本損失 |
調整後總收入 |
股票股息 |
教育程度 |
為僱主工作的人數 |
退伍軍人福利 |
一年中工作的週數 |
上週是否參加教育機構 |
|
婚姻狀況 |
|
主要行業代碼 |
|
主要職業代碼 |
|
種族 |
|
西班牙裔血統 |
|
性別 |
|
工會成員 |
|
失業原因 |
|
全職或兼職就業狀況 |
|
聯邦所得稅負債 |
|
報稅人身份 |
|
先前居住地區 |
|
先前居住州 |
|
詳細的家庭和家庭狀況 |
|
家庭中的詳細家庭摘要 |
|
遷移代碼 - MSA變更 |
|
遷移代碼 - 地區變更 |
|
遷移代碼 - 地區內移動 |
|
一年前是否居住在此房屋 |
|
先前居住地是否在陽光地帶 |
|
18歲以下家庭成員 |
|
個人總收入 |
|
父親的出生國家 |
|
母親的出生國家 |
|
自己的出生國家 |
|
公民身份 |
|
個人總收入 |
|
擁有自己的企業或自營職業 |
|
應納稅收入金額 |
|
是否為退伍軍人管理局填寫收入問卷 |
該數據集已經分為兩部分,分別用於訓練和測試。訓練數據集有199523個樣本,而測試數據集有99762個樣本。
訓練過程
- 準備數據:加載訓練和測試數據集,並將目標列 income_level 從字符串轉換為整數。訓練數據集進一步拆分為訓練集和驗證集。最後,將訓練和驗證數據集轉換為用於模型訓練和評估的tf.data.Dataset。
- 定義輸入特徵編碼邏輯:我們對分類和數值特徵進行如下編碼:
- 分類特徵:使用Keras提供的 Embedding 層進行編碼。嵌入的輸出維度等於 encoding_size。
- 數值特徵:通過使用Keras提供的 Dense 層應用線性變換,將其投影到 encoding_size 維向量中。
因此,所有編碼後的特徵將具有相同的維度,等於 encoding_size 的值。
- 創建模型:
- 模型將具有與給定數據集的數值和分類特徵相對應的輸入層。
- 輸入層接收到的特徵然後使用步驟2中定義的編碼邏輯進行編碼,encoding_size 為16,表示編碼後特徵的輸出維度。
- 編碼後的特徵通過變量選擇網絡(VSN)。如 模型描述 部分所述,VSN內部也使用了GRN。
- VSN產生的特徵通過具有sigmoid激活函數的最終 Dense 層,以產生模型的最終輸出,表示一個人的收入是否超過50萬美元的概率。
- 編譯、訓練和評估模型:
- 由於該模型用於二元分類,選擇的損失函數是二元交叉熵。
- 用於評估模型性能的指標是 準確率。
- 選擇的優化器是Adam,學習率為0.001。
- GRN的丟棄層的丟棄率為0.15。
- 選擇的批量大小為265,模型訓練了20個週期。
- 訓練過程中使用了Keras的 EarlyStopping 回調,這意味著一旦驗證指標停止改善,訓練將中斷。
- 最後,在測試數據集上評估了模型的性能,準確率達到了約95%。
訓練超參數
訓練期間使用了以下超參數:
超參數 |
值 |
名稱 |
Adam |
學習率 |
0.0010000000474974513 |
衰減 |
0.0 |
beta_1 |
0.8999999761581421 |
beta_2 |
0.9990000128746033 |
epsilon |
1e-07 |
amsgrad |
False |
訓練精度 |
float32 |
模型圖
查看模型圖

🔧 技術細節
本模型使用了GRN和VSN兩個重要組件,GRN通過跳躍連接和門控層有效促進信息流動,利用GLU抑制無關輸入;VSN則幫助選擇重要特徵,去除噪聲輸入。在訓練過程中,對數據進行了處理和特徵編碼,使用Adam優化器和二元交叉熵損失函數進行訓練,最終在測試集上達到了約95%的準確率。
📄 許可證
文檔未提及許可證信息,故跳過此章節。
🔗 相關鏈接