MoE训练中的Top-K运算不会导致不可导(不连续)吗?
之前在了解混合专家的语言模型,比如Switch Transformers。它们的原理是用混合专家层替换掉全连接层。 混合专家层包含多个专家网络,和一个门控网络。门控网络会根据输入内容给每个专家打分,随后输入被传送给前k个高分的专家,得到的结果会被加权平均。 然而,“选择前k个专家”的操作是不连续的,这种情况下为什么能正常使用梯度下降?对于同一对输入输出(x,y),如果模型参数发生了微小的更新变化,就可能导致某一层选择了不同的专家,使得输出和损失值发生很大的不连续的变化。 这会使得梯度下降时,算出来的梯度不能指示附近最好的改进方向。 按照这种常见的梯度下降的可视化图,如果把它的函数换成一个不连续的函数,比如某种分段函数,那么梯度下降可能就不会收敛,或者会在函数的两段之间反复横跳 彭罗斯楼梯。这种情况下模型参数不会收敛到某个值,只会在某个轨道里循环。def f(x, y): return torch.complex(x, y).angle() * 0.3 z=max(x,y)虽然在x=y处不可导,但是作为分段函数,它的两段都是可导的(是z=x和z=y的两个平面)。不可导的地方也可以用次梯度解决。 max(x,y)虽然可导性有些问题,但连续性没有问题。max(x,y)对在R^2的所有位置都是连续的。输入值(x,y)略微变化的时候,输出值z不会发生突变,只有梯度会突变。函数图像只有“折痕”,没有“裂缝”或者说“垂直的悬崖”。 这一点和MoE不同,MoE的输出值真的会发生突变,并且出现上面几张图中的问题。 从图中也能看到梯度下降能正确运作。虽然模型参数(x,y)可能在函数的两段(折痕)之间反复切换,但是实际前进的方向相当于两种梯度的一种平均。损失仍然可以变小。 深度学习中的MaxPool和ReLU本质上也只是max函数,所以它们也都是连续的,不会出现MoE的问题。 全连接神经网络的例子。在折痕两边切换不会导致问题。