Shortcuts

collie.models.moss.model 源代码

import gc
import json
import math
import os
from collections import OrderedDict
from typing import Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.checkpoint
from deepspeed.accelerator import get_accelerator
from deepspeed.pipe import LayerSpec, TiedLayerSpec
from einops import rearrange
from megatron.core import parallel_state, tensor_parallel
from torch import nn
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.modeling_utils import dtype_byte_size

from collie.config import CollieConfig
from collie.driver.io import IODriver
from collie.log.logger import logger
from collie.models.base import CollieModelForCausalLM
from collie.models.utils import (
    flash_attention, 
    kv_cache_to_inputs_for_layer, inputs_to_kv_cache_for_layer, 
    kv_cache_to_inputs_for_model, inputs_to_kv_cache_for_model, 
)
from collie.module import (
    ColumnParallelLinearWithoutBias,
    ColumnParallelLMHead,
    RowParallelLinearWithoutBias,
)
from collie.utils import concat_tensor, dict_as_params, env, progress


class RotaryPositionEmbedding(nn.Module):
    def __init__(self, head_dim: int) -> None:
        super().__init__()
        inv_freq = 1.0 / (
            10000.0
            ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)
        )
        self.register_buffer("inv_freq", inv_freq)

    def forward(
        self, query: torch.Tensor, key: torch.Tensor, seq_len: int, start_pos: int = 0
    ):
        t = query.dtype
        query = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
        key = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
        freqs = torch.outer(
            torch.arange((2**16) * 2, device=self.inv_freq.device), self.inv_freq
        ).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)[
            start_pos : start_pos + seq_len
        ]
        shape = [
            d if i == 1 or i == query.ndim - 1 else 1 for i, d in enumerate(query.shape)
        ]
        freqs_cis = freqs_cis.view(*shape)
        query = torch.view_as_real(query * freqs_cis).flatten(3)
        key = torch.view_as_real(key * freqs_cis).flatten(3)
        return query.type(t), key.type(t)


class RMSNormalize(nn.Module):
    def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
        super(RMSNormalize, self).__init__()
        if weight is not None:
            self.weight = weight
        else:
            self.weight = nn.Parameter(
                torch.ones(
                    dim, dtype=dtype, device=get_accelerator().current_device_name()
                )
            )
        self.eps = eps

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)
        return hidden_states * self.weight


class MossBlock(nn.Module):
    def __init__(self, config: CollieConfig, layer_idx) -> None:
        super().__init__()
        self.idx = layer_idx
        self.config = config
        self.self_attn = nn.ModuleDict(
            {
                "q_proj": ColumnParallelLinearWithoutBias(
                    config.hidden_size,
                    config.hidden_size,
                    bias=False,
                    gather_output=False,
                    init_method=lambda x: x,
                ),
                "k_proj": ColumnParallelLinearWithoutBias(
                    config.hidden_size,
                    config.hidden_size,
                    bias=False,
                    gather_output=False,
                    init_method=lambda x: x,
                ),
                "v_proj": ColumnParallelLinearWithoutBias(
                    config.hidden_size,
                    config.hidden_size,
                    bias=False,
                    gather_output=False,
                    init_method=lambda x: x,
                ),
                "o_proj": RowParallelLinearWithoutBias(
                    config.hidden_size,
                    config.hidden_size,
                    bias=False,
                    input_is_parallel=True,
                    init_method=lambda x: x,
                ),
                "rotary_emb": RotaryPositionEmbedding(
                    self.config.hidden_size // self.config.num_attention_heads
                ),
            }
        )
        self.input_layernorm = RMSNormalize(
            dim=config.hidden_size, eps=config.rms_norm_eps
        )
        self.mlp = nn.ModuleDict(
            {
                "gate_proj": ColumnParallelLinearWithoutBias(
                    config.hidden_size,
                    config.intermediate_size,
                    bias=False,
                    gather_output=False,
                    init_method=lambda x: x,
                ),
                "up_proj": ColumnParallelLinearWithoutBias(
                    config.hidden_size,
                    config.intermediate_size,
                    bias=False,
                    gather_output=False,
                    init_method=lambda x: x,
                ),
                "down_proj": RowParallelLinearWithoutBias(
                    config.intermediate_size,
                    config.hidden_size,
                    bias=False,
                    input_is_parallel=True,
                    init_method=lambda x: x,
                ),
            }
        )
        self.post_attention_layernorm = RMSNormalize(
            dim=config.hidden_size, eps=config.rms_norm_eps
        )
        # 务必保持变量名一致
        self.use_cache = self.config.model_config.use_cache
        self.hidden_states = None

    def _forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        if not self.training:
            self.hidden_states = hidden_states
        else:
            self.hidden_states = None
        layer_past = None
        if past_key_values is not None:
            layer_past = past_key_values
        assert (
            hidden_states.ndim == 3
        ), f"hidden_states.shape must be (B, N, H), but got {hidden_states.shape}"
        batch_size, seq_len, _ = hidden_states.shape
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        _hidden_states = self.input_layernorm(hidden_states)
        query, key, value = (
            self.self_attn["q_proj"](_hidden_states),
            self.self_attn["k_proj"](_hidden_states),
            self.self_attn["v_proj"](_hidden_states),
        )
        query, key, value = (
            rearrange(query, "b n (h d) -> b n h d", d=head_dim),
            rearrange(key, "b n (h d) -> b n h d", d=head_dim),
            rearrange(value, "b n (h d) -> b n h d", d=head_dim),
        )
        if layer_past is not None:
            if self.config.peft_config and self.config.peft_config.peft_type == "PREFIX_TUNING":
                start_pos = layer_past[0].shape[2]
            else:
                start_pos = layer_past[0].shape[1]
        else:
            start_pos = 0
        query, key = self.self_attn["rotary_emb"](query, key, seq_len, start_pos)
        if layer_past is not None:
            # past_key: batch_size, num_heads, seq_len, head_dim
            if self.config.peft_config and self.config.peft_config.peft_type == "PREFIX_TUNING":
                past_key =  layer_past[0].reshape(*layer_past[0].shape[:-1], 2, -1)\
                                        .permute(0, 2, 1, 4, 3)\
                                        .reshape(batch_size, start_pos, self.num_heads, -1)
                past_value = layer_past[1].permute([0, 2, 1, 3])
            else:
                past_key, past_value = layer_past
            query = torch.cat([past_key, query], dim=1)
            key = torch.cat([past_key, key], dim=1)
            value = torch.cat([past_value, value], dim=1)
        new_layer_past = None
        if self.use_cache and not self.training:
            # 调整成和 hf 兼容的格式,方便 prefix tuning
            if self.config.peft_config and self.config.peft_config.peft_type == "PREFIX_TUNING":
                present_key = key.reshape(*key.shape[:-1], -1, 2)\
                                .permute(0, 2, 1, 4, 3)\
                                .reshape(batch_size, self.num_heads // self.config.tp_size, seq_len + start_pos, -1)
                present_value = value.permute([0, 2, 1, 3])
                new_layer_past = (present_key, present_value)
            else:
                new_layer_past = (key, value)
        attention_mask = (
            attention_mask
            if attention_mask is not None
            else torch.ones((query.shape[0], query.shape[1])).to(hidden_states.device)
        )
        if self.config.use_flash:
            output = flash_attention(query, key, value, attention_mask)
        else:
            query, key, value = (
                query.permute(0, 2, 1, 3),
                key.permute(0, 2, 1, 3),
                value.permute(0, 2, 1, 3),
            )
            attention_score = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(
                head_dim
            )
            if seq_len + start_pos > 1:
                mask = torch.full(
                    (1, 1, seq_len + start_pos, seq_len + start_pos), float("-inf")
                )
                mask = torch.triu(mask, diagonal=1).to(attention_score.device)
                attention_score = attention_score + mask
            key_padding_mask = (
                1.0 - attention_mask.unsqueeze(1).unsqueeze(2)
            ) * torch.finfo(attention_score.dtype).min
            attention_score = F.softmax(
                attention_score + key_padding_mask, dim=-1
            ).type_as(value)
            output = torch.matmul(attention_score, value)
            output = (
                output.transpose(1, 2)
                .contiguous()
                .view(batch_size, seq_len + start_pos, -1)
            )
        output = F.dropout(output, p=self.config.dropout, training=self.training)
        output = output[:, start_pos:, :]
        hidden_states = hidden_states + self.self_attn["o_proj"](output)
        _hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = hidden_states + F.dropout(
            self.mlp["down_proj"](
                F.silu(self.mlp["gate_proj"](_hidden_states))
                * self.mlp["up_proj"](_hidden_states)
            ),
            p=self.config.dropout,
            training=self.training,
        )
        return hidden_states, new_layer_past

    def forward(self, inputs: dict):
        
        layer_past = inputs_to_kv_cache_for_layer(idx=self.idx, inputs=inputs)
            
        if self.config.checkpointing and self.training:
            hidden_states, new_layer_past = torch.utils.checkpoint.checkpoint(
                self._forward,
                inputs["hidden_states"],
                inputs.get("attention_mask", None),
                layer_past,  # inputs.get("past_key_values", None),
            )
        else:
            hidden_states, new_layer_past = self._forward(inputs["hidden_states"],
                                                          inputs.get("attention_mask", None), 
                                                          layer_past)  # **inputs
        inputs["hidden_states"] = hidden_states

        inputs.update(kv_cache_to_inputs_for_layer(idx=self.idx, new_layer_past=new_layer_past))

        return inputs


class MossModel(nn.Module):
    def __init__(self, config: CollieConfig):
        super().__init__()
        self.config = config
        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
        self.layers = nn.ModuleList(
            [MossBlock(self.config, i) for i in range(self.config.num_hidden_layers)]
        )
        self.norm = RMSNormalize(
            dim=self.config.hidden_size, eps=self.config.rms_norm_eps
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[torch.Tensor]] = None,
        **kwargs,
    ):
        inputs = {"input_ids": input_ids}
        if attention_mask is not None:
            inputs["attention_mask"] = attention_mask
        if input_ids == None:
            inputs["hidden_states"] = kwargs["inputs_embeds"]
        else:
            inputs["hidden_states"] = self.embed_tokens(inputs["input_ids"])

        inputs.update(kv_cache_to_inputs_for_model(past_key_values))
            
        all_hidden_states = ()
        for layer in self.layers:
            all_hidden_states += (inputs["hidden_states"],)
            inputs.update(layer(inputs))
        inputs["hidden_states"] = self.norm(inputs["hidden_states"])
        all_hidden_states += (inputs["hidden_states"],)
        
        past_key_values = inputs_to_kv_cache_for_model(self.config.num_hidden_layers, inputs)

        return BaseModelOutputWithPast(
            last_hidden_state=inputs["hidden_states"],
            hidden_states=all_hidden_states,
            past_key_values=past_key_values,
        )

    @classmethod
    def pipeline_layers(cls, config: CollieConfig):
        """
        Get layers of pipeline.

        :return: list
        """
        if isinstance(config, str):
            config = CollieConfig.from_pretrained(config)

        if config.model_config.tie_word_embeddings:
            embed_tokens = TiedLayerSpec(
                "embed_tokens",
                dict_as_params(input_keys="input_ids", output_keys="hidden_states"),
                tensor_parallel.VocabParallelEmbedding,
                config.vocab_size,
                config.hidden_size,
            )
        else:
            embed_tokens = LayerSpec(
                dict_as_params(input_keys="input_ids", output_keys="hidden_states"),
                tensor_parallel.VocabParallelEmbedding,
                config.vocab_size,
                config.hidden_size,
            )

        layers = [
            LayerSpec(MossBlock, config, i) for i in range(config.num_hidden_layers)
        ]
        norm = LayerSpec(
            dict_as_params(input_keys="hidden_states", output_keys="hidden_states"),
            RMSNormalize,
            dim=config.hidden_size,
            eps=config.rms_norm_eps,
        )

        return [
            ("embed_tokens", embed_tokens),
            ("layers", layers),
            ("norm", norm),
        ]


[文档]class MossForCausalLM(CollieModelForCausalLM): base_model_prefix = "model" def __init__(self, config: CollieConfig) -> None: super().__init__(config) self.model = MossModel(config) self.lm_head = ColumnParallelLinearWithoutBias( self.collie_config.hidden_size, self.collie_config.vocab_size, bias=False ) # GenerationMixin 需要的额外参数 self.config.is_decoder = True if config.model_config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.main_input_name = "input_ids" def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs, ): output = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs, ) logits = self.lm_head(output.last_hidden_state) return CausalLMOutputWithPast( loss=None, logits=logits, past_key_values=output.past_key_values, hidden_states=output.hidden_states, attentions=None, ) def clean_cache(self): self._clean_hidden_states([*self.model.layers, self.lm_head]) self._set_use_cache(self.model.layers, False) def set_cache(self, use_cache): self._set_use_cache(self.model.layers, use_cache)
[文档] @classmethod def pipeline_layers(cls, config: CollieConfig): """ Get layers of pipeline. :return: list """ if isinstance(config, str): config = CollieConfig.from_pretrained(config) lm_head = TiedLayerSpec( "embed_tokens", dict_as_params(input_keys="hidden_states", output_keys="logits"), ColumnParallelLMHead, config.hidden_size, config.vocab_size, bias=False, ) return [("model", MossModel.pipeline_layers(config)), ("lm_head", lm_head)]
@staticmethod def load_parallel_state_dict( path: str, config: Union[CollieConfig, str], process_exclusion: bool = False, **kwargs, ): ...
[文档] @staticmethod def load_parallel_state_dict( path: str, config: Union[CollieConfig, str], process_exclusion: bool = False, protocol: str = "file", **kwargs, ): """ Load state_dict from ``path``. The format of pretrained model should be the same as that of `huggingface`. :return: state_dict. Note that the state_dict should be processed properly to match the current rank. """ if isinstance(config, str): config = CollieConfig.from_pretrained(config) io_driver = IODriver.from_protocol(protocol) if not io_driver.exists(path): raise FileNotFoundError(f"folder {path} not found.") state_dict = OrderedDict() weights = [] parts = None # 如果开启了进程互斥,那么每个进程都会显示进度条,否则只显示 RANK0 的 hide_progress = not process_exclusion and int(os.environ.get("RANK", "0")) != 0 if dist.is_initialized() and process_exclusion: # 如果启动了进程互斥,则要进行 dist.get_world_size() 次循环 rank_order = range(dist.get_world_size()) else: # 不开启只进行一次循环 rank_order = range(1) for rank in rank_order: # 如果开启了进程互斥,那么只有对应 RANK 的能进入循环;不开启进程互斥的话就都可以进 if int(os.environ.get("RANK", "0")) == rank or not process_exclusion: # PP 分层的方法保存在了 os.environ["COLLIE_PP_PARTS"], 格式类似于 [0, 17, 35], 左闭右开 if env.is_pipeline: # 保存的是 json 格式 parts = env.pipeline_parts # 如果存在 pytorch_model.bin.index.json 文件的话,此时不同的 pp 进程可以按需加载自己需要的权重 if ( io_driver.exists(os.path.join(path, "pytorch_model.bin.index.json")) and "COLLIE_PP_PARTS" in os.environ.keys() ): weight_map = json.loads( io_driver.load( os.path.join(path, "pytorch_model.bin.index.json"), mode="r" ) )["weight_map"] # layers 表示自己需要的层 layers = env.pipeline_layers_idx # 筛选出形似 model.layers.0 这样的层。包含两个条件:1. 有数字的层;2. 数字加一要在 layers 里面(因为最开始还有个 embedding 占一层) weights.extend( [ value for key, value in weight_map.items() if len(key.split(".")) > 2 and key.split(".")[2].isdigit() and (int(key.split(".")[2]) + 1) in layers ] ) # 去重 weights = list(set(weights)) # 继续筛选,如果有 0 层,那么就要加载 embedding;如果有最后一层,那么就要加载 lm_head;如果有倒数第二层,那么就要加载 norm if 0 in layers: weights.append(weight_map["model.embed_tokens.weight"]) if max(parts) - 1 in layers: weights.append(weight_map["lm_head.weight"]) if max(parts) - 2 in layers: weights.append(weight_map["model.norm.weight"]) else: # 如果没有 pytorch_model.bin.index.json 文件的话,那么就加载所有的权重 weights = [ weight for weight in io_driver.list(path) if weight.endswith(".bin") ] with progress( weights, desc="Loading state dict", total=len(weights), disable=hide_progress, ) as pbar: for weight in pbar: part_state_dict = io_driver.load( os.path.join(path, weight), mode="rb" ) for key in list(part_state_dict.keys()): # 对 q_proj.weight 和 k_proj.weight 进行 reshape if key.endswith("q_proj.weight") or key.endswith( "k_proj.weight" ): part_state_dict[key] = ( rearrange( part_state_dict[key], "(h two t) d -> h two t d", h=config.num_attention_heads, two=2, ) .transpose(1, 2) .reshape(config.hidden_size, config.hidden_size) ) state_dict.update(part_state_dict) del part_state_dict if parts is not None: # 这一步是 pp 的复筛 layers = env.pipeline_layers_idx for key in list(state_dict.keys()): if key.startswith("layers"): layer = int(key.split(".")[1]) if layer + 1 not in layers: state_dict.pop(key) if key.endswith("embed_tokens.weight"): if 0 not in layers: state_dict.pop(key) if key == "norm.weight": if max(parts) - 2 not in layers: state_dict.pop(key) if key.endswith("lm_head.weight"): if max(parts) - 1 not in layers: state_dict.pop(key) # 根据用户配置的新的 tp size 进行分割 for key in list(state_dict.keys()): col_filter = [ "q_proj.weight", "k_proj.weight", "v_proj.weight", "gate_proj.weight", "up_proj.weight", "embed_tokens.weight", "lm_head.weight", ] col_split = any([key.endswith(filter) for filter in col_filter]) if col_split: tensor = ( list(torch.chunk(state_dict[key], config.tp_size, dim=0))[ env.tp_rank ] .detach() .clone() ) del state_dict[key] if process_exclusion: # CPU 内存回收(速度很慢) gc.collect() state_dict[key] = tensor elif key.endswith("o_proj.weight") or key.endswith( "down_proj.weight" ): tensor = ( list(torch.chunk(state_dict[key], config.tp_size, dim=1))[ env.tp_rank ] .detach() .clone() ) del state_dict[key] if process_exclusion: # CPU 内存回收(速度很慢) gc.collect() state_dict[key] = tensor if dist.is_initialized() and process_exclusion: # 如果选择了进程互斥,那么本次循环中不需要加载权重的进程需等待 dist.barrier() return state_dict
@staticmethod def save_parallel_state_dict( state_dict: dict, path: str, config: CollieConfig, process_exclusion: bool = False, **kwargs, ): ...
[文档] @staticmethod def save_parallel_state_dict( state_dict: dict, path: str, config: CollieConfig, process_exclusion: bool = False, protocol: str = "file", ): """ Save state_dict to ``path``. The format of saved state dict should be the same as that of `huggingface`. """ io_driver = IODriver.from_protocol(protocol) def reshape_wq_wk(w: torch.Tensor): return ( w.view( config.num_attention_heads, config.hidden_size // config.num_attention_heads // 2, 2, config.hidden_size, ) .transpose(1, 2) .reshape(config.hidden_size, config.hidden_size) ) # gather to tp rank 0 if dist.is_initialized() and process_exclusion: # 如果启动了进程互斥,则要进行 pp_size 次循环 rank_order = range(config.pp_size) else: # 不开启只进行一次循环 rank_order = range(1) dst = parallel_state.get_tensor_model_parallel_src_rank() with progress( rank_order, desc="Saving model", disable=int(os.environ.get("RANK", "0")) != 0, ) as pbar: for rank in pbar: if env.dp_rank == 0 and (env.pp_rank == rank or not process_exclusion): for key in sorted(list(state_dict.keys())): tensor_list = None if env.tp_rank == 0: tensor_list = [ torch.zeros_like(state_dict[key]) .to(state_dict[key].dtype) .cuda() for _ in range(config.tp_size) ] dist.gather( state_dict[key].cuda(), dst=dst, gather_list=tensor_list, group=env.tp_group, ) if env.tp_rank == 0: col_filter = [ "q_proj.weight", "k_proj.weight", "v_proj.weight", "gate_proj.weight", "up_proj.weight", "embed_tokens.weight", "lm_head.weight", ] col_split = any( [key.endswith(filter) for filter in col_filter] ) if col_split: state_dict[key] = concat_tensor(tensor_list, dim=0) if process_exclusion: # CPU 内存回收(速度很慢) gc.collect() elif key.endswith("o_proj.weight") or key.endswith( "down_proj.weight" ): state_dict[key] = concat_tensor(tensor_list, dim=1) if process_exclusion: # CPU 内存回收(速度很慢) gc.collect() if key.endswith("q_proj.weight") or key.endswith( "k_proj.weight" ): state_dict[key] = reshape_wq_wk(state_dict[key]) if env.tp_rank == 0: # Save gathered weights if env.is_pipeline: ckpt_name = f"pytorch_model-{env.pp_rank+1:05d}-of-{config.pp_size:05d}.bin" total_size = 0 weight_map = {} for name, weight in state_dict.items(): weight_size = weight.numel() * dtype_byte_size( weight.dtype ) weight_map[name] = ckpt_name total_size += weight_size index_dict = dict( total_size=total_size, weight_map=weight_map ) index_dicts = [None for _ in range(env.pp_size)] dist.gather_object( index_dict, index_dicts if env.pp_rank == 0 else None, group=env.pp_group ) if env.pp_rank == 0: total_size = 0 weight_map = {} for _index_dict in index_dicts: total_size += _index_dict["total_size"] weight_map.update(_index_dict["weight_map"]) merged_dict = { "metadata": {"total_size": total_size}, "weight_map": weight_map, } io_driver.save( json.dumps(merged_dict, indent=2, sort_keys=True) + "\n", os.path.join(path, "pytorch_model.bin.index.json"), ) else: ckpt_name = f"pytorch_model.bin" ckpt_path = os.path.join(path, ckpt_name) io_driver.save(state_dict, ckpt_path) if dist.is_initialized() and process_exclusion: dist.barrier() if env.rank == 0: config.save_pretrained(path, protocol=protocol) dist.barrier()