TabTransformer转换器提升多层感知机性能深度解析
来源:51CTO.COM
时间:2023-04-26 13:19:58 499浏览 收藏
本篇文章向大家介绍《TabTransformer转换器提升多层感知机性能深度解析》,主要包括,具有一定的参考价值,需要的朋友可以参考一下。
如今,转换器(Transformers)成为大多数先进的自然语言处理(NLP)和计算机视觉(CV)体系结构中的关键模块。然而,表格式数据领域仍然主要以梯度提升决策树(GBDT)算法为主导。于是,有人试图弥合这一差距。其中,第一篇基于转换器的表格数据建模论文是由Huang等人于2020年发表的论文《TabTransformer:使用上下文嵌入的表格数据建模》。
本文旨在提供该论文内容的基本展示,同时将深入探讨TabTransformer模型的实现细节,并向您展示如何针对我们自己的数据来具体使用TabTransformer。
一、论文概述
上述论文的主要思想是,如果使用转换器将常规的分类嵌入转换为上下文嵌入,那么,常规的多层感知器(MLP)的性能将会得到显著提高。接下来,让我们更为深入地理解这一描述。
1.分类嵌入(Categorical Embeddings)
在深度学习模型中,使用分类特征的经典方法是训练其嵌入性。这意味着,每个类别值都有一个唯一的密集型向量表示,并且可以传递给下一层。例如,由下图您可以看到,每个分类特征都使用一个四维数组表示。然后,这些嵌入与数字特征串联,并用作MLP的输入。
带有分类嵌入的MLP
2.上下文嵌入(Contextual Embeddings)
论文作者认为,分类嵌入缺乏上下文含义,即它们并没有对分类变量之间的任何交互和关系信息进行编码。为了将嵌入内容更加具体化,有人建议使用NLP领域当前所使用的转换器来实现这一目的。
TabTransformer转换器中的上下文嵌入
为了以可视化方式形象地展示上述想法,我们不妨考虑下面这个训练后得到的上下文嵌入图像。其中,突出显示了两个分类特征:关系(黑色)和婚姻状况(蓝色)。这些特征是相关的;所以,“已婚(Married)”、“丈夫(Husband)”和“妻子(Wife)”的值应该在向量空间中彼此接近,即使它们来自不同的变量。
经训练后的TabTransformer转换器嵌入结果示例
通过上图中经过训练的上下文嵌入结果,我们可以看到,“已婚(Married)”的婚姻状况更接近“丈夫(Husband)”和“妻子(Wife)”的关系水平,而“未结婚(non-married)”的分类值则来自右侧的单独数据簇。这种类型的上下文使这样的嵌入更加有用,而使用简单形式的类别嵌入技术是不可能实现这种效果的。
3.TabTransformer架构
为了达到上述目的,论文作者提出了以下架构:
TabTransformer转换器架构示意图
(摘取自Huang等人2020年发表的论文)
我们可以将此体系结构分解为5个步骤:
- 标准化数字特征并向前传递
- 嵌入分类特征
- 嵌入经过N次转换器块处理,以便获得上下文嵌入
- 把上下文分类嵌入与数字特征进行串联
- 通过MLP进行串联获得所需的预测
虽然模型架构非常简单,但论文作者表示,添加转换器层可以显著提高计算性能。当然,所有的“魔术”发生在这些转换器块内部;所以,接下来让我们更加详细地研究一下其中的实现过程。
4.转换器
转换器(Transformer)架构示意
(选自Vaswani等人于2017年发表的论文)
您可能以前见过转换器架构,但为了快速介绍起见,请记住该转换器是由编码器和解码器两部分组成(见上图)。对于TabTransformer,我们只关心将输入的嵌入内容上下文化的编码器部分(解码器部分将这些嵌入内容转换为最终输出结果)。但它到底是如何做到的呢?答案是——多头注意力机制。
5.多头注意力机制(Multi-head-attention)
引用我最喜欢的关于注意力机制的文章的描述,是这样的:
“自我关注(self attention)背后的关键概念是,这种机制允许神经网络学习如何在输入序列的各个片段之间以最好的路由方案进行信息调度。”
换句话说,自我关注(self-attention)有助于模型找出在表示某个单词/类别时,输入的哪些部分更重要,哪些部分相对不重要。为此,我强烈建议您阅读一下上面引用的这篇文章,以便对自我关注为什么如此有效有一个更为直观的理解。
多头注意力机制
(选自Vaswani等人于2017年发表的论文)
注意力是通过3个学习过的矩阵来计算的——Q、K和V,它们代表查询(Query)、键(Key)和值(Value)。首先,我们将矩阵Q和K相乘得到注意力矩阵。该矩阵被缩放并通过softmax层传递。然后,我们将其乘以V矩阵,得出最终值。为了更直观地理解起见,请考虑下面的示意图,它显示了我们如何使用矩阵Q、K和V实现从输入嵌入转换到上下文嵌入。
自我关注流程可视化
通过重复该过程h次(使用不同的Q、K、V矩阵),我们就能够得到多个上下文嵌入,它们形成我们最终的多头注意力。
6.简短回顾
让我们总结一下上面所介绍的内容:
- 简单的分类嵌入不包含上下文信息
- 通过转换器编码器传递分类嵌入,我们就能够将嵌入上下文化
- 转换器部分能够将嵌入上下文化,因为它使用了多头注意力机制
- 多头注意力机制在编码变量时使用矩阵Q、K和V来寻找有用的相互作用和相关性信息
- 在TabTransformer中,被上下文化的嵌入与数字输入相串联,并通过一个简单的MLP输出预测
虽然TabTransformer背后的想法很简单,但您可能需要一些时间才能掌握注意力机制。因此,我强烈建议您重新阅读以上解释。如果您感到有些迷茫,请认真阅读本文中所有建议的链接相关内容。我保证,做到这些后,您就不难搞明白注意力机制的原理了。
7.试验结果展示
结果数据(选自Huang等人2020年发表的论文)
根据报告的结果,TabTransformer转换器优于所有其他深度学习表格模型,此外,它接近GBDT的性能水平,这非常令人鼓舞。该模型对缺失数据和噪声数据也相对稳健,并且在半监督环境下优于其他模型。然而,这些数据集显然不是详尽无遗的,正如以后发表的一些相关论文所证实的那样,仍有很大的改进空间。
二、构建我们自己的示例程序
现在,让我们最终来确定一下如何将模型应用于我们自己的数据。接下来的示例数据取自著名的Tabular Playground Kaggle比赛。为了方便使用TabTransformer转换器,我创建了一个tabtransformertf包。它可以使用如下pip命令进行安装:
pip install tabtransformertf
并允许我们使用该模型,而无需进行大量预处理。
1.数据预处理
第一步是设置适当的数据类型,并将我们的训练和验证数据转换为TF数据集。其中,前面安装的软件包中就提供了一个很好的实用程序可以做到这一点。
from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep # 设置数据类型 train_data[CATEGORICAL_FEATURES] = train_data[CATEGORICAL_FEATURES].astype(str) val_data[CATEGORICAL_FEATURES] = val_data[CATEGORICAL_FEATURES].astype(str) train_data[NUMERIC_FEATURES] = train_data[NUMERIC_FEATURES].astype(float) val_data[NUMERIC_FEATURES] = val_data[NUMERIC_FEATURES].astype(float) # 转换成TF数据集 train_dataset = df_to_dataset(train_data[FEATURES + [LABEL]], LABEL, batch_size=1024) val_dataset = df_to_dataset(val_data[FEATURES + [LABEL]], LABEL, shuffle=False, batch_size=1024)
下一步是为分类数据准备预处理层。该分类数据稍后将被传递给我们的主模型。
from tabtransformertf.utils.preprocessing import build_categorical_prep category_prep_layers = build_categorical_prep(train_data, CATEGORICAL_FEATURES) # 输出结果是一个字典结构,其中键部分是特征名称,值部分是StringLookup层 # category_prep_layers -> # {'product_code':, #'attribute_0': , #'attribute_1': }
这就是预处理!现在,我们可以开始构建模型了。
2.构建TabTransformer模型
初始化模型很容易。其中,有几个参数需要指定,但最重要的几个参数是:embeding_dim、depth和heads。所有参数都是在超参数调整后选择的。
from tabtransformertf.models.tabtransformer import TabTransformer tabtransformer = TabTransformer( numerical_features = NUMERIC_FEATURES,# 带有数字特征名称的列表 categorical_features = CATEGORICAL_FEATURES, # 带有分类特征名称的列表 categorical_lookup=category_prep_layers, # 带StringLookup层的Dict numerical_discretisers=None,# None代表我们只是简单地传递数字特征 embedding_dim=32,# 嵌入维数 out_dim=1,# Dimensionality of output (binary task) out_activatinotallow='sigmoid',# 输出层激活 depth=4,# 转换器块层的个数 heads=8,# 转换器块中注意力头的个数 attn_dropout=0.1,# 在转换器块中的丢弃率 ff_dropout=0.1,# 在最后MLP中的丢弃率 mlp_hidden_factors=[2, 4],# 我们为每一层划分最终嵌入的因子 use_column_embedding=True,#如果我们想使用列嵌入,设置此项为真 ) # 模型运行中摘要输出: # 总参数个数: 1,778,884 # 可训练的参数个数: 1,774,064 # 不可训练的参数个数: 4,820
模型初始化后,我们可以像任何其他Keras模型一样安装它。训练参数也可以调整,所以可以随意调整学习速度和提前停止。
LEARNING_RATE = 0.0001 WEIGHT_DECAY = 0.0001 NUM_EPOCHS = 1000 optimizer = tfa.optimizers.AdamW( learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) tabtransformer.compile( optimizer = optimizer, loss = tf.keras.losses.BinaryCrossentropy(), metrics= [tf.keras.metrics.AUC(name="PR AUC", curve='PR')], ) out_file = './tabTransformerBasic' checkpoint = ModelCheckpoint( out_file, mnotallow="val_loss", verbose=1, save_best_notallow=True, mode="min" ) early = EarlyStopping(mnotallow="val_loss", mode="min", patience=10, restore_best_weights=True) callback_list = [checkpoint, early] history = tabtransformer.fit( train_dataset, epochs=NUM_EPOCHS, validation_data=val_dataset, callbacks=callback_list )
3.评价
竞赛中最关键的指标是ROC AUC。因此,让我们将其与PR AUC指标一起输出来评估一下模型的性能。
val_preds = tabtransformer.predict(val_dataset) print(f"PR AUC: {average_precision_score(val_data['isFraud'], val_preds.ravel())}") print(f"ROC AUC: {roc_auc_score(val_data['isFraud'], val_preds.ravel())}") # PR AUC: 0.26 # ROC AUC: 0.58
您也可以自己给测试集评分,然后将结果值提交给Kaggle官方。我现在选择的这个解决方案使我跻身前35%,这并不坏,但也不太好。那么,为什么TabTransfromer在上述方案中表现不佳呢?可能有以下几个原因:
- 数据集太小,而深度学习模型以需要大量数据著称
- TabTransformer很容易在表格式数据示例领域出现过拟合
- 没有足够的分类特征使模型有用
三、结论
本文探讨了TabTransformer背后的主要思想,并展示了如何使用Tabtransformertf包来具体应用此转换器。
归纳起来看,TabTransformer的确是一种有趣的体系结构,它在当时的表现明显优于大多数深度表格模型。它的主要优点是将分类嵌入语境化,从而增强其表达能力。它使用在分类特征上的多头注意力机制来实现这一点,而这是在表格数据领域使用转换器的第一个应用实例。
TabTransformer体系结构的一个明显缺点是,数字特征被简单地传递到最终的MLP层。因此,它们没有语境化,它们的价值也没有在分类嵌入中得到解释。在下一篇文章中,我将探讨如何修复此缺陷并进一步提高性能。
译者介绍
朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。
原文链接:https://towardsdatascience.com/transformers-for-tabular-data-tabtransformer-deep-dive-5fb2438da820?source=collection_home---------4----------------------------
好了,本文到此结束,带大家了解了《TabTransformer转换器提升多层感知机性能深度解析》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多科技周边知识!
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
177 收藏
-
367 收藏
-
325 收藏
-
151 收藏
-
205 收藏
-
268 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 立即学习 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 立即学习 507次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 立即学习 497次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 立即学习 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 立即学习 484次学习