SHAP特征顺序自定义教程详解
时间:2025-09-30 16:00:34 145浏览 收藏
本文详细介绍了如何自定义SHAP summary_plot中的特征顺序,提升模型可解释性和图表可控性。SHAP摘要图默认按特征重要性排序,但有时我们需要根据业务逻辑或其他需求自定义特征显示顺序。文章阐述了通过设置`sort=False`参数,并结合Pandas DataFrame对特征数据和SHAP值进行手动重排的核心策略。通过将特征数据和SHAP值转换为DataFrame,利用DataFrame的列操作功能进行重排,再转换回NumPy数组,最后调用`shap.summary_plot`函数并传入自定义顺序的特征名称,即可实现自定义排序的SHAP摘要图。文章提供了一个完整的CNN示例代码,展示了如何准备数据、定义目标特征顺序、重排数据以及绘制自定义排序的摘要图,并强调了`feature_names`参数和数据维度匹配的重要性。通过这种方法,用户可以根据特定需求更有效地解读模型解释结果,增强模型的可解释性和沟通效率。

1. 理解SHAP summary_plot 及其默认行为
SHAP (SHapley Additive exPlanations) 是一种流行的模型可解释性框架,能够解释单个预测以及模型整体的行为。shap.summary_plot 是其核心可视化工具之一,它能够以多种形式(如条形图、点图)展示每个特征对模型输出的平均影响。默认情况下,summary_plot 会根据特征的平均绝对SHAP值(即特征重要性)从高到低进行排序,将最重要的特征显示在顶部。
然而,在某些场景下,用户可能希望按照特定的业务逻辑、预设顺序或为了与其他图表保持一致性来排列特征,而非单纯依赖模型计算出的重要性。例如,你可能希望将一组相关的特征放在一起,或者按照数据输入的原始顺序进行展示。
2. 自定义特征排序的核心策略
要实现自定义特征顺序,主要依赖于 shap.summary_plot 函数的一个关键参数:sort。
- sort=False 参数: 当此参数设置为 False 时,summary_plot 将不再对特征进行自动排序,而是按照你传入的特征数据和SHAP值的列顺序进行绘制。
- 手动重排数据: 由于 sort=False 只是禁用了自动排序,因此你需要确保传入 summary_plot 的 shap_values 和特征数据(通常是 X 或 features)已经按照你期望的顺序进行了排列。这通常通过重新组织这些数据的列来实现。
3. 实践指南:通过Pandas DataFrame实现特征重排
以下是一个详细的步骤,演示如何使用Pandas DataFrame来方便地重排特征数据和SHAP值,从而控制 summary_plot 的显示顺序。
3.1 准备数据与模型解释器
首先,我们需要一个训练好的模型和相应的SHAP解释器及SHAP值。我们将使用一个简单的卷积神经网络(CNN)示例来生成SHAP值。
import matplotlib.pyplot as plt
import numpy as np
import shap
import pandas as pd # 导入pandas用于数据操作
from tensorflow import keras
from tensorflow.keras import layers
# 示例数据
X = np.array([[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)],
[(4,5,6,4,4),(5,6,4,3,2),(5,5,6,1,3),(3,3,3,2,2),(2,3,3,2,1)],
[(7,8,9,4,7),(7,7,6,7,8),(5,8,7,8,8),(6,7,6,7,8),(5,7,6,6,6)],
[(7,8,9,8,6),(6,6,7,8,6),(8,7,8,8,8),(8,6,7,8,7),(8,6,7,8,8)],
[(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
[(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)]])
y = np.array([0, 1, 2, 2, 1, 1, 0])
# 构建并训练一个简单的CNN模型
model = keras.Sequential([
layers.Conv1D(128, kernel_size=3, activation='relu', input_shape=(5,5)),
layers.MaxPooling1D(pool_size=2),
layers.LSTM(128, return_sequences=True),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(3, activation='softmax') # 假设有3个类别
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X, y, epochs=10, verbose=0) # verbose=0 减少训练输出
# 解释器和SHAP值计算
explainer = shap.GradientExplainer(model, X)
shap_values = explainer.shap_values(X)
# 原始问题中指定了用于绘图的数据切片
cls = 0 # 针对第一个类别
idx = 0 # 针对X的第一个"时间步"或"特征组"
X_for_plot = X[:, idx, :] # 形状为 (num_samples, num_features)
shap_values_for_plot = shap_values[cls][:, idx, :] # 形状为 (num_samples, num_features)
# 定义原始特征名称
original_feature_names = ["Feature1", "Feature2", "Feature3", "Feature4", "Feature5"]
# 绘制默认排序的摘要图(可选,用于对比)
print("--- 默认排序的SHAP摘要图 ---")
shap.summary_plot(shap_values_for_plot, X_for_plot, plot_type="bar", feature_names=original_feature_names)
plt.title("Default SHAP Summary Plot (Sorted by Importance)")
plt.show()3.2 定义目标特征顺序
现在,我们来定义一个自定义的特征顺序。这个顺序将决定特征在图表中的排列方式。
# 定义你期望的特征顺序
# 假设我们想将Feature3放在最前面,然后是Feature5,接着是Feature1,以此类推
custom_feature_order = ["Feature3", "Feature5", "Feature1", "Feature4", "Feature2"]
# 确保自定义顺序中的所有特征名称都存在于原始特征名称中
if not all(f in original_feature_names for f in custom_feature_order):
raise ValueError("自定义特征顺序中包含不在原始特征列表中的名称!")3.3 重排特征数据与SHAP值
这是实现自定义排序的核心步骤。我们将 X_for_plot 和 shap_values_for_plot 转换为Pandas DataFrame,利用DataFrame的列操作功能进行重排,然后再转换回NumPy数组以供 shap.summary_plot 使用。
# 将特征数据转换为DataFrame features_df = pd.DataFrame(X_for_plot, columns=original_feature_names) # 将SHAP值转换为DataFrame shap_df = pd.DataFrame(shap_values_for_plot, columns=original_feature_names) # 根据自定义顺序重排DataFrame的列 features_df_ordered = features_df[custom_feature_order] shap_df_ordered = shap_df[custom_feature_order] # 将重排后的DataFrame转换回NumPy数组 X_ordered_for_plot = features_df_ordered.to_numpy() shap_values_ordered_for_plot = shap_df_ordered.to_numpy()
3.4 绘制自定义顺序的SHAP摘要图
最后,使用重排后的数据和 sort=False 参数来生成图表。
# 绘制自定义排序的摘要图
print("\n--- 自定义排序的SHAP摘要图 ---")
shap.summary_plot(
shap_values_ordered_for_plot,
X_ordered_for_plot,
plot_type="bar",
feature_names=custom_feature_order, # 注意这里传入的是自定义顺序的特征名称
sort=False # 禁用自动排序
)
plt.title("Custom Ordered SHAP Summary Plot")
plt.show()4. 完整示例代码
将上述所有步骤整合到一个可运行的脚本中:
import matplotlib.pyplot as plt
import numpy as np
import shap
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers
# 示例数据
X = np.array([[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)],
[(4,5,6,4,4),(5,6,4,3,2),(5,5,6,1,3),(3,3,3,2,2),(2,3,3,2,1)],
[(7,8,9,4,7),(7,7,6,7,8),(5,8,7,8,8),(6,7,6,7,8),(5,7,6,6,6)],
[(7,8,9,8,6),(6,6,7,8,6),(8,7,8,8,8),(8,6,7,8,7),(8,6,7,8,8)],
[(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
[(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)]])
y = np.array([0, 1, 2, 2, 1, 1, 0])
# 构建并训练一个简单的CNN模型
model = keras.Sequential([
layers.Conv1D(128, kernel_size=3, activation='relu', input_shape=(5,5)),
layers.MaxPooling1D(pool_size=2),
layers.LSTM(128, return_sequences=True),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(3, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X, y, epochs=10, verbose=0)
# 解释器和SHAP值计算
explainer = shap.GradientExplainer(model, X)
shap_values = explainer.shap_values(X)
# 原始问题中指定了用于绘图的数据切片
cls = 0
idx = 0
X_for_plot = X[:, idx, :]
shap_values_for_plot = shap_values[cls][:, idx, :]
# 定义原始特征名称
original_feature_names = ["Feature1", "Feature2", "Feature3", "Feature4", "Feature5"]
# --- 默认排序的SHAP摘要图(用于对比)---
print("--- 默认排序的SHAP摘要图 ---")
shap.summary_plot(shap_values_for_plot, X_for_plot, plot_type="bar", feature_names=original_feature_names)
plt.title("Default SHAP Summary Plot (Sorted by Importance)")
plt.show()
# --- 自定义特征排序 ---
# 1. 定义你期望的特征顺序
custom_feature_order = ["Feature3", "Feature5", "Feature1", "Feature4", "Feature2"]
# 确保自定义顺序中的所有特征名称都存在于原始特征名称中
if not all(f in original_feature_names for f in custom_feature_order):
raise ValueError("自定义特征顺序中包含不在原始特征列表中的名称!")
# 2. 将特征数据和SHAP值转换为DataFrame
features_df = pd.DataFrame(X_for_plot, columns=original_feature_names)
shap_df = pd.DataFrame(shap_values_for_plot, columns=original_feature_names)
# 3. 根据自定义顺序重排DataFrame的列
features_df_ordered = features_df[custom_feature_order]
shap_df_ordered = shap_df[custom_feature_order]
# 4. 将重排后的DataFrame转换回NumPy数组
X_ordered_for_plot = features_df_ordered.to_numpy()
shap_values_ordered_for_plot = shap_df_ordered.to_numpy()
# 5. 绘制自定义排序的摘要图
print("\n--- 自定义排序的SHAP摘要图 ---")
shap.summary_plot(
shap_values_ordered_for_plot,
X_ordered_for_plot,
plot_type="bar",
feature_names=custom_feature_order, # 传入自定义顺序的特征名称
sort=False # 禁用自动排序
)
plt.title("Custom Ordered SHAP Summary Plot")
plt.show()5. 注意事项
- feature_names 参数: 确保在调用 shap.summary_plot 时,feature_names 参数传入的列表与你重排后的数据列顺序严格一致。这是图表正确显示特征名称的关键。
- 数据维度匹配: 传入 shap.summary_plot 的 shap_values 和特征数据 (X) 必须具有相同的样本数和特征数。在进行重排操作时,务必保持这种对应关系。
- Pandas的便利性: 使用Pandas DataFrame进行列重排非常方便直观。如果你的数据已经是DataFrame格式,则可以省去 to_numpy() 的转换步骤(尽管 shap.summary_plot 也能接受DataFrame作为输入)。
- plot_type 的选择: summary_plot 支持多种 plot_type,如 "bar" (条形图) 和 "dot" (点图)。自定义排序的方法适用于所有这些类型。
6. 总结
通过灵活运用 shap.summary_plot 的 sort=False 参数,并结合Pandas DataFrame强大的数据操作能力,我们可以轻松地实现SHAP摘要图中特征的自定义排序。这不仅提高了图表的可控性,也使得我们能够根据特定的分析需求或业务背景,更有效地解读模型解释结果,从而增强模型的可解释性和沟通效率。
以上就是《SHAP特征顺序自定义教程详解》的详细内容,更多关于的资料请关注golang学习网公众号!
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
296 收藏
-
351 收藏
-
157 收藏
-
485 收藏
-
283 收藏
-
349 收藏
-
291 收藏
-
204 收藏
-
401 收藏
-
227 收藏
-
400 收藏
-
327 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 立即学习 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 立即学习 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 立即学习 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 立即学习 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 立即学习 485次学习