1. 集合数据处理的“天生难题”与Set Transformer的登场
如果你玩过乐高积木,应该有过这样的体验:把一堆形状各异的积木块倒在地上,无论你怎么打乱它们的顺序,最终用它们拼出来的城堡,样子总是一样的。这堆积木块,就是一个“集合”。在人工智能的世界里,处理这类“集合数据”的需求无处不在,比如分析一组用户评论(谁先谁后不重要,重要的是整体情感)、处理一个点云模型(三维空间中的点没有固定顺序)、或者判断一张图片中多个物体构成的场景类别。
传统的神经网络,比如我们熟悉的循环神经网络(RNN)或者卷积神经网络(CNN),在处理这类数据时,其实有点“水土不服”。RNN天然是为序列设计的,输入顺序一变,输出可能天差地别。CNN虽然对局部平移有一定鲁棒性,但其卷积核滑动的模式也隐含着对空间顺序的假设。这就引出了处理集合数据的核心要求:置换不变性。简单说,就是无论你把输入集合里的元素怎么排列、怎么打乱,模型都应该给出相同的结果。
几年前,一篇名为《Deep Sets》的论文从理论上证明了,一个满足置换不变性的函数,可以通过“对每个元素独立编码,然后对所有编码结果进行求和(或平均)池化”这样的形式来构造。这为处理集合数据提供了一个坚实的理论基础。然而,这个框架中的编码函数(φ)通常比较简单,难以捕捉集合内部元素之间复杂的、高阶的交互关系。这就好比,你只看了每个乐高积木块的颜色,就把它们简单加在一起,却完全忽略了这些积木块之间该如何拼接才能稳固。
这时候,注意力机制,尤其是Transformer中大放异彩的自注意力机制,进入了研究者的视野。自注意力机制能让集合中的每个元素都“看到”其他所有元素,并动态地决定与谁交互、交互多深,这简直是捕捉集合内部结构的完美工具。但是,问题也随之而来:标准自注意力需要对所有元素两两计算关联度,其计算复杂度是输入元素数量n的平方,即O(n²)。当n很大时(比如一个包含成千上万个点的点云),计算开销将变得无法承受。
Set Transformer 正是在这样的背景下诞生的。它的目标非常明确:既要利用自注意力强大的关系建模能力,来超越简单的Deep Sets框架;又要巧妙地改造注意力机制,把那个恼人的O(n²)复杂度给降下来,让它能真正用于处理大规模集合数据。我最初在点云分类任务中尝试它时,就被这种“鱼与熊掌兼得”的设计思路惊艳到了。接下来,我们就一起拆开这个“黑盒”,看看它是如何做到的。
2. 理论基石:从Deep Sets到注意力驱动的置换不变性
要真正理解Set Transformer的巧妙之处,我们得先回到那个理论起点。Deep Sets的工作告诉我们,任何一个对集合X的置换不变函数f(X),理论上都可以被分解成这样的形式:
f(X) = ρ ( Σ_{x∈X} φ(x) )
这个公式看似简单,却蕴含着强大的思想。我们来打个比方:φ(x) 就像是一个“特征提取器”,它独立地观察集合中的每一个元素x(比如一个乐高积木块),并提取出它的特征。然后,Σ(求和)操作是一个“池化器”,它把这些独立的特征全部加起来,抹除了元素之间的顺序信息,生成了一个集合级别的整体特征。最后,ρ 是一个“决策器”,它基于这个整体特征来做出最终的判断(比如这堆积木能拼出什么)。
Set Transformer完全拥抱了这个框架。它把整个模型结构清晰地分成了编码器(Encoder) 和解码器(Decoder) 两部分。编码器负责实现 φ 的功能,但它的内部不再是简单的全连接层,而是由我们后面要讲的、精心设计的注意力模块堆叠而成,目的是为了学习元素间复杂的相互作用。解码器则负责实现 ρ 的功能,它接收编码器输出的、已经融合了交互信息的集合表示,并产生最终的输出。
这里有一个关键点:即使编码器内部使用了注意力这种会让输出随输入顺序变化(即“置换等变”)的机制,但只要最后的池化操作(比如求和或Set Transformer中独特的注意力池化)是置换不变的,那么整个模型就是置换不变的。这为在模型内部使用强大的自注意力机制扫清了理论障碍。
所以,Set Transformer的设计哲学是:在Deep Sets的理论安全屋内,尽情使用注意力机制这把利器来装修(编码),同时想办法解决注意力带来的计算成本过高的问题。 它的所有创新,都围绕着如何高效、优雅地实现这个目标展开。
3. 核心模块拆解:MAB、SAB与效率杀手ISAB
Set Transformer并不是直接把标准Transformer拿过来用,它构建了几个自己的基础积木块。理解这些积木块,是理解其整个架构的关键。
3.1 基础单元:多头注意力块(MAB)
MAB是Set Transformer中最基础的组件,你可以把它看作一个“加强版”的Transformer层。我们回忆一下标准的多头自注意力(Multi-Head Attention, MHA):输入序列经过线性变换得到Q、K、V,然后计算注意力并输出。
MAB在此基础上做了两处重要的“装修”:
- 加入了残差连接(Residual Connection)和层归一化(LayerNorm):这是现代深度

8万+

被折叠的 条评论
为什么被折叠?



