Leonurus-free

MLA

关于MLA,我想先简单记录下我了解它的心路历程:

我个人认为,理解MLA的难点之一,是它算法设计颇为“绕”,不管是用数学公式,或者是用流程图,似乎都很难一下找到它设计的核心思想。所以本文第一部分,将会抛开所有复杂的计算细节,根据我自己的理解,抽象出MLA的设计方案。基于此再来谈计算细节和各种优化,全文目录如下:

一、MLA的基本思想 1.1 MLA, MQA 和 GQA 1.2 MLA的整体设计思想

二、MLA的运作细节 2.1 CD(原生MLA) 2.2 CC 2.3 A_CC 2.4 A_CC_ME

三、MLA可以用TP吗?

一、MLA的基本思想

1.1 MLA,MQA与GQA

我们先来快速复习一下decoder架构的MHA的运作流程,如下图:

图片

这里head_num = 4,图中刻画了head0的运算过程,包括 attn_weights = Matmul(q, k) 以及 attn_sv = Matmul(attn_weights, v),图中虚线灰框表示在head0上的结果是由包括其在内的若干前置tokens计算结果汇总而来。为了表达简便,这里省去了诸如softmax,的计算过程。图中被红色虚线框圈起来的部分,就是大家熟知的将被用在推理阶段的KV cache

KV cache的存在,本来是为了避免在推理阶段对前置序列的重复计算的。但是,随着前置序列的长度变长(我们记为kv_len),需要读取的KV cache也将越来越大,数据的传输成本增加,这就使得attn计算逐渐变成memory bound我们采取了一些策略来缓解KV cache过大的问题,其中2种就是大家熟知的MQA和GQA

MQA和GQA的运作方式如下:

图片

但是,不管是MQA还是GQA,对于1个token来说,总是存在heads上k、v信息被压缩的情况。那么是否有一种办法,能在尽量不压缩head上k,v信息的情况下,节省kv cache,提高整体推理速度呢?那么接下来,我们就来大致看一下MLA的设计思想。

1.2 MLA的整体设计思想

在本节中,我们会以K cache为例,抽象出MLA的核心优化思想。V cache的优化思想也是同理,但不在这节赘述,而是合并到后文对于MLA的细节讲解中(参见2.3节A_CC)。

现在先让我们回到MHA上(图1.1),来思考一个问题:为什么对于一个token,我们需要保存它所有heads上的K值作为K cache呢?

图片

主要原因我们在上文解释过:这是因为每个k_head附带有不同的信息,它将用这份独有的信息和对应的q_head进行attn的计算,用公式表示即为,这里的是合并了所有head对应的param weight后的表达。

我们现在的总目标是节省K cache,当你再次端详上面这幅图时,一个idea在你的头脑中出现:

(在这里,你可以抽象地把理解成是4个k_heads共享的信息,但最终K cache的形式还会在这基础上有所变化。我知道此时你脑海中一定有很多疑惑。但我们先不要纠结细节的问题,因为在后文会展示全部细节,这里我们要做的是从宏观上理解MLA设计的核心思想。)

现在我们更具体地画出上面这套“信息转移”方案的具体流程:

图片

⚠️⚠️⚠️:再次说明,在本部分,我们侧重于抽象出MLA的优化思路,大家在阅读上面这幅图时,请不要带入任何具体的细节(例如矩阵尺寸)等去做计算,这部分细节我们会在下文详细介绍。

我们来详细看这幅图:

对v cache的优化也是同理,这里额外提几点:

好,到此为止,我们抽象出了MLA的整体优化思路,从中你可以发现:

二、MLA的运作流程

2.1 CD (CacheDecompressed, dpsk MLA的原生实现)

现在我们可以来看MLA的运作细节了。

图片

好,现在关于这个MLA的原生实现,我们来讨论几个有意思的点:

(1)在MLA中,每个head_dim的尺寸更大了。观察到原始hidden_size = 5120,如果按照num_heads = 128来看的话,正常来说一个head_dim = 40 (5120/128=40)。但是在MLA中,一个head_dim = 128,远大于40。也就说MLA其实是用比一般MHA更大的head_dim(或者也可能是num_heads)来做attn计算的,然后在最终的

矩阵中映射回原来的hidden_size。对此我个人给出一些简单猜测:如果推理阶段KV cache造成的memory bound的问题已经得到解决的话,那么训练时我就能少一点后顾之忧,然后通过提升模型的复杂度来取得与MHA比肩或更好的效果(训练阶段还有别的优化方式)。这样当我回到推理阶段时,我的整体计算强度就上去了(每读1次,算的次数更多了)只要没有达到compute bound的界限,这样的提升就是有好处的。

(2)原生MLA的计算最终展开成了MHA的计算。这一点可以参见图中q(蓝色),k(绿色),v(黄色),它们最终都变成了标准MHA的计算。从理论上来说,这一点也不奇怪,因为我们在第一部分说过MLA就是MHA的变种,只是它在MHA的基础上做了信息从k/v_head向q_head的转移。嗯?!!但是等等,从上图这个原生MLA上来看,虽然产出了compress_kv,但是好像并没有做什么信息转移呀,也就是粗糙来看目前的计算流程还是而不是转移后的 呀:

基于这些,我们管原生MLA的实现方式为CD(CacheDecompressed),即存储的KV cache是没有经过任何压缩的。为什么dpsk放出来的原生MLA会这样呢?这一点我一直没有想通,这也是为什么我在MLA刚出来那阵,看完它的实践就决定先暂停探索的原因。当时没有实际的业务需求,自己也没动力去细想,以及考虑到MLA算法的复杂性,我还以为是我理解错了。但是随着时间推移,后续开源社区有一系列对MLA的优化实现,直到近期再次捡起来后,才使我对MLA有了更多的了解。目前来看,这个原生MLA似乎以提供“MLA的概念”为主,而具体的优化实践方式还是要看个人。我们马上就来看后一些做过“信息转移/吸收”的优化方法,不过在此之前,我们先对原生MLA的计算量和KV cache做一个分析。

(公众号编辑表格太难了,这里我直接从我笔记截图了,大家可以点开放大看)图片

我们对这张表格做一些说明:

2.2 CC (CacheCompressed)

好,在进入大家从第一部分开始就心心念念的“k/v_head信息向q转移(或者理解成被q吸收)”这个优化介绍前,我们先介绍基于原生实践和这个优化的一个中间态:CC (CacheCompressed)在这个中间态中,我们终于是以compress_kv为kv cache了,但是我们没做任何吸收。之所以要介绍这个中间态,是方便大家更好感受“吸收”的好处。

我们直接对着2.1的图,列出CC表格:

图片

不难发现,在这个中间态CC优化的MLA下:

2.3 A_CC(AbsorbCacheCompressed)

现在,终于来到我们心心念念的涉及吸收的优化了:

图片

这里解释下为什么A_CC相比于CC,总计算量降低了很多,但单token计算量却没有变化:

2.4 A_CC_ME

最后,这个优化其实就是在A_CC的基础上,在计算attn_weights的时候,把nope和rope的部分拆开算,然后再求和。这样做是为了避开无用的数据拷贝和广播(可以看代码,你会发现A_CC为了做数据拼接,是先初始化一个一个拼接好的空张量,再往里塞数据,这样就是2倍的显存开销。而避开拼接各自算各自的,可以直接复用已有的数据),实际测起来这种方法性能是最好的。

图片

三、MLA可以用TP吗

现在,回来看一个经常被讨论的问题:MLA可以做TP吗?因为看样子,对于每一个token来说,它所有的num_heads上的kv信息已经被压缩成compress_kv了,好像是不能再切分了?

这里先说结论:MLA可以做TP,但是它可能需要一些定制化的TP方式,而不是直接套用惯常decoder模型的TP方式。

为了解答这个问题,我们这里再贴出2.1中的流程图:

图片

我们着重关注流程图中红色部分(也就是param_weights),大家回想一下之前的介绍:尽管compress_kv已经被抽取成类似单头的形式了(1个token只有1个,且不区分heads),但是它能这样做的原因是因为kv_heads上的信息转移去q_heads了,对了!q还是有heads的呀!!

我们首先来看一下,dpsk官方是如何在上面这张流程图中做TP切分的,详细代码可以参见这里:https://github.com/deepseek-ai/DeepSeek-V3/blob/ee4c4ea32bfd89e197616b80e713466954c51c75/inference/model.py#L409,从图里来说: