🚀 tf-keras
本项目基于GRN和VSN构建模型,用于结构化数据学习任务,尤其适用于二元分类,可判断个人年收入是否超过50万美元。该模型在相关数据集上训练后,达到了约95%的准确率。
🚀 快速开始
本模型可用于二元分类任务,以确定一个人每年的收入是否超过50万美元。
✨ 主要特性
本模型使用了Bryan Lim等人在 Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting 中提出的两个重要架构组件GRN和VSN,它们对结构化数据学习任务非常有用。
- 门控残差网络(Gated Residual Networks,GRN):由跳跃连接和门控层组成,可有效促进信息流动。它们能够灵活地仅在需要的地方应用非线性处理。
GRN利用 门控线性单元(Gated Linear Units,GLU)来抑制与给定任务无关的输入。
GRN的工作原理如下:
- 首先对其输入应用非线性ELU变换。
- 然后应用线性变换,接着进行丢弃(dropout)操作。
- 接下来应用GLU,并将原始输入添加到GLU的输出中,以执行跳跃(残差)连接。
- 最后,应用层归一化并产生输出。
- 变量选择网络(Variable Selection Networks,VSN):有助于从输入中仔细选择最重要的特征,并去除可能损害模型性能的任何不必要的噪声输入。
VSN的工作原理如下:
- 首先,对每个特征单独应用门控残差网络(GRN)。
- 然后将所有特征连接起来,并对连接后的特征应用GRN,接着应用softmax以产生特征权重。
- 最后,产生各个GRN输出的加权和。
注意:本模型并非基于上述论文中描述的整个TFT模型,仅使用了其GRN和VSN组件,这表明GRN和VSN本身对于结构化数据学习任务也非常有用。
📦 安装指南
文档未提及安装步骤,故跳过此章节。
💻 使用示例
文档未提供代码示例,故跳过此章节。
📚 详细文档
训练和评估数据
本模型使用由UCI机器学习库提供的 美国人口普查收入数据集 进行训练。
该数据集由加权人口普查数据组成,包含从1994年和1995年美国人口普查局进行的当前人口调查中提取的与人口统计和就业相关的变量。
数据集包含约29.9万个样本,有41个输入变量和1个名为 income_level 的目标变量。变量 instance_weight 不用作模型的输入,因此最终模型使用40个输入特征,其中包含7个数值特征和33个分类特征:
数值特征 |
分类特征 |
年龄 |
工人类别 |
每小时工资 |
行业代码 |
资本收益 |
职业代码 |
资本损失 |
调整后总收入 |
股票股息 |
教育程度 |
为雇主工作的人数 |
退伍军人福利 |
一年中工作的周数 |
上周是否参加教育机构 |
|
婚姻状况 |
|
主要行业代码 |
|
主要职业代码 |
|
种族 |
|
西班牙裔血统 |
|
性别 |
|
工会成员 |
|
失业原因 |
|
全职或兼职就业状况 |
|
联邦所得税负债 |
|
报税人身份 |
|
先前居住地区 |
|
先前居住州 |
|
详细的家庭和家庭状况 |
|
家庭中的详细家庭摘要 |
|
迁移代码 - MSA变更 |
|
迁移代码 - 地区变更 |
|
迁移代码 - 地区内移动 |
|
一年前是否居住在此房屋 |
|
先前居住地是否在阳光地带 |
|
18岁以下家庭成员 |
|
个人总收入 |
|
父亲的出生国家 |
|
母亲的出生国家 |
|
自己的出生国家 |
|
公民身份 |
|
个人总收入 |
|
拥有自己的企业或自营职业 |
|
应纳税收入金额 |
|
是否为退伍军人管理局填写收入问卷 |
该数据集已经分为两部分,分别用于训练和测试。训练数据集有199523个样本,而测试数据集有99762个样本。
训练过程
- 准备数据:加载训练和测试数据集,并将目标列 income_level 从字符串转换为整数。训练数据集进一步拆分为训练集和验证集。最后,将训练和验证数据集转换为用于模型训练和评估的tf.data.Dataset。
- 定义输入特征编码逻辑:我们对分类和数值特征进行如下编码:
- 分类特征:使用Keras提供的 Embedding 层进行编码。嵌入的输出维度等于 encoding_size。
- 数值特征:通过使用Keras提供的 Dense 层应用线性变换,将其投影到 encoding_size 维向量中。
因此,所有编码后的特征将具有相同的维度,等于 encoding_size 的值。
- 创建模型:
- 模型将具有与给定数据集的数值和分类特征相对应的输入层。
- 输入层接收到的特征然后使用步骤2中定义的编码逻辑进行编码,encoding_size 为16,表示编码后特征的输出维度。
- 编码后的特征通过变量选择网络(VSN)。如 模型描述 部分所述,VSN内部也使用了GRN。
- VSN产生的特征通过具有sigmoid激活函数的最终 Dense 层,以产生模型的最终输出,表示一个人的收入是否超过50万美元的概率。
- 编译、训练和评估模型:
- 由于该模型用于二元分类,选择的损失函数是二元交叉熵。
- 用于评估模型性能的指标是 准确率。
- 选择的优化器是Adam,学习率为0.001。
- GRN的丢弃层的丢弃率为0.15。
- 选择的批量大小为265,模型训练了20个周期。
- 训练过程中使用了Keras的 EarlyStopping 回调,这意味着一旦验证指标停止改善,训练将中断。
- 最后,在测试数据集上评估了模型的性能,准确率达到了约95%。
训练超参数
训练期间使用了以下超参数:
超参数 |
值 |
名称 |
Adam |
学习率 |
0.0010000000474974513 |
衰减 |
0.0 |
beta_1 |
0.8999999761581421 |
beta_2 |
0.9990000128746033 |
epsilon |
1e-07 |
amsgrad |
False |
训练精度 |
float32 |
模型图
查看模型图

🔧 技术细节
本模型使用了GRN和VSN两个重要组件,GRN通过跳跃连接和门控层有效促进信息流动,利用GLU抑制无关输入;VSN则帮助选择重要特征,去除噪声输入。在训练过程中,对数据进行了处理和特征编码,使用Adam优化器和二元交叉熵损失函数进行训练,最终在测试集上达到了约95%的准确率。
📄 许可证
文档未提及许可证信息,故跳过此章节。
🔗 相关链接