Python实现对比学习异常检测方法
时间:2025-08-07 08:24:50 414浏览 收藏
积累知识,胜过积蓄金银!毕竟在文章开发的过程中,会遇到各种各样的问题,往往都是一些细节知识点还没有掌握好而导致的,因此基础知识点的积累是很重要的。下面本文《Python实现对比学习异常检测方法》,就带大家讲解一下知识点,若是你对本文感兴趣,或者是想搞懂其中某个知识点,就请你继续往下看吧~
对比学习在异常表示学习中的核心在于通过无监督或自监督方式,使模型将正常数据紧密聚集,异常数据远离该流形。1. 数据准备与增强:通过正常数据生成正样本对(同一数据不同增强)与负样本对(其他样本)。2. 模型架构选择:使用编码器(如ResNet、Transformer)提取特征,配合投影头映射到对比空间。3. 对比损失函数设计:采用InfoNCE Loss最大化正样本相似度,最小化负样本相似度。4. 训练策略:使用Adam优化器、余弦退火调度器,大批次训练,或结合MoCo解决负样本不足。5. 异常检测:利用编码器提取表示,结合距离、密度估计或One-Class模型计算异常分数。对比学习的优势在于无需异常标签,但挑战在于数据增强策略与负样本选择对性能影响显著。

基于对比学习实现异常表示学习,核心在于通过无监督或自监督的方式,让模型在学习数据内在结构时,能将正常数据点紧密地聚拢在一起,形成一个紧致的“正常流形”,而异常点则自然地远离这个流形。这本质上是利用了数据本身的相似性信息,来训练一个能区分正常与异常的强大特征提取器。在Python中,这通常涉及深度学习框架、精心设计的数据增强策略以及特定的对比损失函数。

解决方案
在Python中实现基于对比学习的异常表示学习,通常遵循以下步骤:
数据准备与增强:

- 核心理念:假设我们只有大量的正常数据(或至少是正常数据占绝大多数的混合数据)。对比学习的关键在于生成“正样本对”和“负样本对”。
- 正样本对:通常通过对同一个正常数据样本应用两种不同的随机数据增强变换(例如,对于图像是随机裁剪、颜色抖动;对于时间序列是随机抖动、缩放;对于文本是随机删除、替换词)来生成。这确保了即使经过变换,它们依然代表了“同一件事物”的不同视角。
- 负样本对:可以是同一批次中其他随机选择的样本,或者从一个动态更新的内存队列(如MoCo)中获取。
- Python实现:使用
torchvision.transforms(图像),tsaug(时间序列),或自定义函数进行数据增强。DataLoader用于批处理。
模型架构选择:
- 编码器(Encoder):一个深度神经网络,其任务是将原始数据映射到一个低维的、信息丰富的表示(嵌入向量)。
- 常见选择:
- 图像:ResNet、Vision Transformer (ViT)。
- 时间序列:CNN、RNN(LSTM/GRU)、Transformer。
- 表格数据:MLP。
- 投影头(Projection Head):在编码器之后,通常会添加一个小的MLP层(如2-3层),将编码器输出的表示进一步映射到一个用于计算对比损失的空间。在推理时,我们通常使用编码器输出的表示(不包括投影头)来进行异常检测。
- Python实现:使用
torch.nn构建模型,或利用timm等库加载预训练模型(并根据需要修改)。
对比损失函数设计:

- 核心:InfoNCE Loss(也称为NT-Xent Loss,Normalized Temperature-scaled Cross Entropy Loss)是目前最流行的选择。
- 原理:它旨在最大化正样本对之间的相似度,同时最小化正样本与负样本之间的相似度。
- 数学形式:对于一个批次中的每个锚点
x_i,其正样本x_j,以及2N-2个负样本,损失函数会计算x_i与x_j相似度相对于x_i与所有其他样本相似度的对数比。 - 温度参数(Temperature Parameter,
tau):一个关键的超参数,它控制了相似度分布的平滑程度。较小的tau会使模型更关注区分最相似的负样本。 - Python实现:手动实现InfoNCE损失,或者使用
Pytorch Metric Learning等库中提供的现成实现。
训练策略:
- 优化器:Adam、SGD等。
- 学习率调度器:余弦退火(Cosine Annealing)等。
- 批次大小:对比学习通常需要较大的批次大小来提供足够多的负样本。如果硬件受限,可以考虑MoCo(Momentum Contrast)等策略,它使用一个动态更新的队列来存储负样本。
- 训练循环:标准深度学习训练循环,迭代数据批次,计算损失,反向传播,更新模型参数。
- Python实现:标准的PyTorch或TensorFlow训练脚本。
异常分数计算与检测:
- 推理阶段:训练完成后,我们只使用编码器(不包括投影头)来获取数据点的表示。
- 异常分数:
- 距离到质心:计算所有正常训练样本表示的质心,然后计算新样本表示到该质心的欧氏距离或余弦距离。距离越大,异常分数越高。
- K近邻距离:计算新样本表示到其K个最近的正常训练样本表示的平均距离。
- 密度估计:在嵌入空间中对正常样本进行密度估计(如使用高斯混合模型GMM或核密度估计KDE),异常点将位于低密度区域。
- One-Class SVM/Isolation Forest:在学到的嵌入空间上训练一个One-Class SVM或Isolation Forest模型。
- 阈值:根据异常分数的分布(通常是正态分布或偏态分布),设置一个阈值来区分正常和异常。
- Python实现:
scikit-learn库中的NearestNeighbors、OneClassSVM、IsolationForest等。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import IsolationForest
# 假设我们处理的是图像数据,所以用torchvision
# 1. 模拟数据(实际应用中会加载真实数据集)
class SimpleImageDataset(Dataset):
def __init__(self, num_samples=1000, img_size=32, is_anomaly=False):
self.num_samples = num_samples
self.img_size = img_size
# 模拟正常数据:中心是0.5的噪声
self.data = torch.randn(num_samples, 3, img_size, img_size) * 0.1 + 0.5
if is_anomaly:
# 模拟异常数据:偏离中心,例如非常亮或非常暗
self.data = torch.randn(num_samples, 3, img_size, img_size) * 0.2 + (0 if np.random.rand() > 0.5 else 1.0)
# 简单的数据增强,用于对比学习
self.transform = transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(), # 已经转换为tensor了,这里只是为了兼容Compose
])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 对于对比学习,每个样本需要生成两个增强视图
img = self.data[idx]
return self.transform(img), self.transform(img)
# 2. 模型架构:简单的CNN编码器 + 投影头
class Encoder(nn.Module):
def __init__(self, in_channels=3, hidden_dim=128):
super(Encoder, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, hidden_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)) # 将特征图池化到1x1
)
self.flatten = nn.Flatten()
def forward(self, x):
x = self.features(x)
x = self.flatten(x)
return x
class ProjectionHead(nn.Module):
def __init__(self, input_dim, output_dim=128):
super(ProjectionHead, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.ReLU(),
nn.Linear(input_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# 3. InfoNCE Loss实现
class InfoNCELoss(nn.Module):
def __init__(self, temperature=0.07):
super(InfoNCELoss, self).__init__()
self.temperature = temperature
self.criterion = nn.CrossEntropyLoss()
def forward(self, features):
# features: [2*batch_size, feature_dim]
# 前半部分是view1,后半部分是view2
batch_size = features.shape[0] // 2
# 归一化特征
features = nn.functional.normalize(features, dim=1)
# 计算余弦相似度矩阵
# similarities: [2*batch_size, 2*batch_size]
similarities = torch.matmul(features, features.T) / self.temperature
# 构造标签:对角线是正样本对
# 比如 batch_size=2:
# view1_0, view1_1, view2_0, view2_1
# target for view1_0 is view2_0 (index 2)
# target for view2_0 is view1_0 (index 0)
# target for view1_1 is view2_1 (index 3)
# target for view2_1 is view1_1 (index 1)
labels = torch.arange(2 * batch_size).roll(shifts=batch_size, dims=0)
# 移除自相似性 (将对角线设为负无穷,避免自己和自己比较)
# 实际操作中,InfoNCE的labels通常是0, 1, ..., N-1,对应于正样本在相似度矩阵中的位置
# 这里为了简化,我们直接用交叉熵,把所有非正样本都看作负样本
# 构造正样本对的索引
# (0, batch_size), (1, batch_size+1), ..., (batch_size-1, 2*batch_size-1)
# 以及反向的
# (batch_size, 0), (batch_size+1, 1), ..., (2*batch_size-1, batch_size-1)
# 确保正样本对的索引是正确的
# 假设 f_i 是 view1 的第 i 个样本,f_j 是 view2 的第 j 个样本
# 正样本对是 (f_i, f_i')
# 我们的 features 结构是 [v1_0, v1_1, ..., v1_N-1, v2_0, v2_1, ..., v2_N-1]
# 那么 (v1_i, v2_i) 是正样本对
# 它们的索引是 (i, i+batch_size)
# 创建一个掩码,将正样本对的相似度设为0,避免被softmax影响
mask = torch.eye(2 * batch_size, dtype=torch.bool).to(features.device)
similarities = similarities.masked_fill(mask, float('-inf'))
# 计算交叉熵损失
# 目标是让每个样本的增强视图与自身对应的增强视图相似度最高
# 假设 batch_size=2
# features = [v1_0, v1_1, v2_0, v2_1]
# similarities[0] (v1_0与其他) -> 期望 v2_0 (idx 2) 相似度最高
# similarities[1] (v1_1与其他) -> 期望 v2_1 (idx 3) 相似度最高
# similarities[2] (v2_0与其他) -> 期望 v1_0 (idx 0) 相似度最高
# similarities[3] (v2_1与其他) -> 期望 v1_1 (idx 1) 相似度最高
# 目标索引
labels = torch.cat([torch.arange(batch_size, 2 * batch_size),
torch.arange(0, batch_size)], dim=0).to(features.device)
loss = self.criterion(similarities, labels)
return loss
# 4. 训练过程
def train_contrastive_model(encoder, projection_head, dataloader, epochs=50, lr=1e-3):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
projection_head.to(device)
optimizer = optim.Adam(list(encoder.parameters()) + list(projection_head.parameters()), lr=lr)
criterion = InfoNCELoss().to(device)
print("开始训练对比学习模型...")
for epoch in range(epochs):
total_loss = 0
for (img1, img2) in dataloader:
img1, img2 = img1.to(device), img2.to(device)
optimizer.zero_grad()
# 获取特征
feat1 = encoder(img1)
feat2 = encoder(img2)
# 通过投影头
proj1 = projection_head(feat1)
proj2 = projection_head(feat2)
# 合并特征用于计算损失
features = torch.cat([proj1, proj2], dim=0)
loss = criterion(features)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
print("训练完成。")
# 5. 异常分数计算与评估
def evaluate_anomaly_detection(encoder, normal_dataloader, anomaly_dataloader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.eval() # 评估模式
normal_embeddings = []
with torch.no_grad():
for (img1, _) in normal_dataloader: # 只取一个增强视图
img1 = img1.to(device)
embedding = encoder(img1).cpu().numpy()
normal_embeddings.append(embedding)
normal_embeddings = np.concatenate(normal_embeddings, axis=0)
anomaly_embeddings = []
with torch.no_grad():
for (img1, _) in anomaly_dataloader:
img1 = img1.to(device)
embedding = encoder(img1).cpu().numpy()
anomaly_embeddings.append(embedding)
anomaly_embeddings = np.concatenate(anomaly_embeddings, axis=0)
# 简单地使用Isolation Forest在学习到的嵌入空间上进行异常检测
# 这是一个常见的后处理步骤,用于从表示中提取异常
print("在学习到的嵌入空间上训练Isolation Forest...")
clf = IsolationForest(random_state=42, contamination=0.01) # contamination是一个估计值
clf.fit(normal_embeddings)
normal_scores = clf.decision_function(normal_embeddings)
anomaly_scores = clf.decision_function(anomaly_embeddings)
# 标签:正常为1,异常为-1 (Isolation Forest的输出) 或 0/1
# 为了计算AUC,我们通常将正常标记为0,异常标记为1
y_true = np.concatenate([np.zeros(len(normal_scores)), np.ones(len(anomaly_scores))])
# Isolation Forest的decision_function输出越大越正常,所以我们需要取负数或者1-score
y_scores = np.concatenate([-normal_scores, -anomaly_scores])
auc_roc = roc_auc_score(y_true, y_scores)
print(f"AUC-ROC Score: {auc_roc:.4f}")
return auc_roc
# 主运行逻辑
if __name__ == "__main__":
BATCH_SIZE = 64
IMAGE_SIZE = 32
EMBEDDING_DIM = 128 # 编码器输出维度
PROJECTION_DIM = 128 # 投影头输出维度
# 准备数据集
normal_dataset = SimpleImageDataset(num_samples=2000, img_size=IMAGE_SIZE)
normal_dataloader = DataLoader(normal_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# 用于评估的测试集(包含正常和异常)
test_normal_dataset = SimpleImageDataset(num_samples=500, img_size=IMAGE_SIZE)
test_anomaly_dataset = SimpleImageDataset(num_samples=100, img_size=IMAGE_SIZE, is_anomaly=True)
test_normal_dataloader = DataLoader(test_normal_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_anomaly_dataloader = DataLoader(test_anomaly_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 初始化模型
encoder = Encoder(in_channels=3, hidden_dim=EMBEDDING_DIM)
projection_head = ProjectionHead(input_dim=EMBEDDING_DIM, output_dim=PROJECTION_DIM)
# 训练模型
train_contrastive_model(encoder, projection_head, normal_dataloader, epochs=50)
# 评估模型
evaluate_anomaly_detection(encoder, test_normal_dataloader, test_anomaly_dataloader)
对比学习在异常检测中的独特优势与挑战
说实话,我个人觉得对比学习在处理异常检测问题时,简直是找到了一个非常巧妙的突破口。传统方法经常苦恼于异常样本的稀缺性,或者干脆没有标签,这让监督学习无从下手。但对比学习不一样,它把“正常”这个概念掰开了揉碎了去学,通过让模型理解什么是“相似”,什么是“不相似”,从而间接定义了什么是“正常”。
它的优势非常明显:
- **无需异常标签
理论要掌握,实操不能落!以上关于《Python实现对比学习异常检测方法》的详细介绍,大家都掌握了吧!如果想要继续提升自己的能力,那么就来关注golang学习网公众号吧!
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
165 收藏
-
449 收藏
-
216 收藏
-
325 收藏
-
300 收藏
-
337 收藏
-
385 收藏
-
165 收藏
-
254 收藏
-
427 收藏
-
149 收藏
-
190 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 立即学习 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 立即学习 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 立即学习 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 立即学习 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 立即学习 485次学习