您现在的位置是:首页 >技术交流 >大模型推理——MLA实现方案网站首页技术交流

大模型推理——MLA实现方案

凯尔哥 2025-08-17 00:01:03
简介大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # RMSNorm的参数g
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 防止分母为0
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states*torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()


def rotate_half(x):
    x1,  x2 = x.chunk(2, dim=-1)
    return torch.cat((x1, x2), dim=-1)


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q*cos) + (rotate_half(q)*sin)
    k_embed = (k*cos) + (rotate_half(k)*cos)
    return q_embed, k_embed


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float()/dim))
        t= torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = t @ inv_freq.unsqueeze(0)
        freqs = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_rotate_pos_emb(q, k, cos, sin)


class MLA(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 q_lora_rank,
                 kv_lora_rank,
                 qk_nope_head_dim,
                 qk_rope_head_dim,
                 v_head_dim,
                 max_seq_len,
                 max_batch_size):
        super().__init__()
        # 隐藏层维度
        self.dim = dim
        # attention head数
        self.n_heads = n_heads
        # q低秩压缩到的维度
        self.q_lora_rank = q_lora_rank
        # k/v低秩压缩到的维度
        self.kv_lora_rank = kv_lora_rank
        # q/k不带旋转位置编码的维度
        self.qk_nope_head_dim = qk_nope_head_dim
        # q/k带旋转位置编码的维度
        self.qk_rope_head_dim = qk_rope_head_dim
        # q/k的总维度
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        # v的维度
        self.v_head_dim = v_head_dim
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

        self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
        self.q_norm = RMSNorm(self.q_lora_rank)
        self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads*self.qk_head_dim)
        self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads*(self.qk_nope_head_dim + self.v_head_dim))
        self.wo = nn.Linear(self.n_heads*self.v_head_dim, self.dim)

        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)

        self.register_buffer("kv_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank))
        self.register_buffer("pe_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim))

    def forward(self, x, mask=None):
        bs, seq_len, _ = x.shape
        # [bs, seq_len, q_lora_rank]
        q = self.wq_a(x)
        # [bs, seq_len, q_lora_rank]
        q = self.q_norm(q)
        # [bs, seq_len, n_heads*(qk_nope_head_dim+qk_rope_head_dim)]
        q = self.wq_b(q)
        # [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)]
        q = q.view(bs, seq_len, self.n_heads,  self.qk_head_dim)
        # 按照最后一个维度进行切分
        #                                                                 --> [bs, seq_len, n_heads, qk_nope_head_dim]
        #                                                               --
        # [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)] --
        #                                                               --
        #                                                                 --> [bs, seq_len, n_heads, qk_rope_head_dim]
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
        kv = self.wkv_a(x)
        # 按照最后一个维度进行切分
        #                                                    --> [bs, seq_len, kv_lora_rank]
        #                                                  --
        # [bs, seq_len, kv_lora_rank + qk_rope_head_dim] --
        #                                                  --
        #                                                    --> [bs, seq_len, qk_rope_head_dim]
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        # 和q的维度保持一致,[bs, seq_len, 1, qk_rope_head_dim]
        k_pe = k_pe.unsqueeze(2)
        # 旋转位置编码
        q_pe, k_pe = self.rotary_emb(q_pe, k_pe)

        # 重新压缩为原来的维度 [bs, seq_len, qk_rope_head_dim]
        k_pe = k_pe.squeeze(2)
        kv = self.kv_norm(kv)
        # 缓存共同作用于k和v的矩阵,该矩阵用于对k和v升维
        self.kv_cache[:bs, :seq_len, :] = kv
        # 缓存用于计算旋转位置编码部分的k矩阵
        self.pe_cache[:bs, :seq_len, :] = k_pe
        # [n_heads*(qk_nope_head_dim + v_head_dim), kv_lora_rank]
        wkv_b = self.wkv_b.weight
        # [n_heads, (qk_nope_head_dim + v_head_dim), kv_lora_rank]
        wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
        # #################################MLA的核心#################################
        # q_nope可简单理解成x*w_q,然后再乘以w_k,即x*w_q*w_k,计算结果的shape为[bs, seq_len, n_heads, qk_nope_head_dim)
        q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
        # 再乘以k,这里的k是降维之后的x,即对x作用了一个降维矩阵wkv_a,计算结果的shape为[bs, seq_len, n_heads, seq_len]
        # 得到非旋转位置编码部分q和k的相似度
        scores_nope = torch.einsum("bshc, btc->bsht", q_nope, self.kv_cache[:bs, :seq_len, :])
        # 得到旋转位置编码部分q和k的相似度,计算结果的shape为[bs, seq_len, n_heads, seq_len]
        scores_pe = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bs, :seq_len, :])
        # #################################MLA的核心#################################
        # 将两个部分的得分值加起来,然后再进行scale
        scores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
        if mask is not None:
            scores += mask.unseqeeze(2)

        scores = scores.softmax(dim=-1)
        # k和v的相似度计算好了之后就要和v计算了,那v是由kv矩阵和wkv_b矩阵中的一部分计算得到的
        # 先同kv矩阵计算,shape为[bs, seq_len, n_heads, kv_lora_rank]
        x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bs, :seq_len,:])
        # 再同wkv_b[:, -self.v_head_dim:]计算,shape为[bs, seq_len, n_heads, v_head_dim]
        x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])

        x = x.contiguous().view(bs, seq_len, -1)
        x = self.wo(x)

        return x


if __name__ == '__main__':
    torch.manual_seed(0)
    torch.set_printoptions(precision=3, sci_mode=False)

    x = torch.randn(1, 4, 16)

    dim = 16
    n_heads = 2
    q_lora_rank = 10
    kv_lora_rank = 6
    qk_nope_head_dim = 8
    qk_rope_head_dim = 4
    v_head_dim = 8
    max_seq_len = 10
    max_batch_size = 4
    mode = 'none'

    mla = MLA(dim=dim,
              n_heads=n_heads,
              q_lora_rank=q_lora_rank,
              kv_lora_rank=kv_lora_rank,
              qk_nope_head_dim=qk_nope_head_dim,
              qk_rope_head_dim=qk_rope_head_dim,
              v_head_dim=v_head_dim,
              max_seq_len=max_seq_len,
              max_batch_size=max_batch_size)

    print(mla(x))
    print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

llm_related/deepseek_learn at main · wyf3/llm_related · GitHub

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。