2024-02-10-【探索】6G显存畅玩无限长度的LLM角色扮演.md 6.6 KB


title: 【探索】6G显存畅玩无限长度的LLM角色扮演 urlname: Enjoy-unlimited-length-LLM-role-playing-with-6GB-of-VRAM index_img: https://api.limour.top/randomImg?d=2024-02-10 01:02:10 date: 2024-02-10 09:02:10

tags: [llama, 探索]

角色扮演的体验是否舒适主要受角色卡、大模型和生成时间三个因素的影响。

优秀的角色卡往往附带大量的设定,这会极大的拖慢第一次生成的时间,并且随着对话的进行,上下文长度很容易超过kv_cache的上限,这些很破坏沉浸式的体验。

此外,大模型在进行角色扮演时,除了进行必要的对话生成外,还需要生成旁白增加想象空间。

对博主这些相比填空更喜欢选项的玩家,给出提问建议也是非常必要的:在建议的基础上修改比自己从零写一个情景更简单,同时也完整保留了控制剧情走向的权力。

以上这些都让本就稀缺的kv_cache更加雪上加霜。

万幸,StreamingLLM 发现了kv_cache具有良好的平移性,而 llama.cpp 也提供了对kv_cache进行底层操作的api:可以指定范围的 kv_cache_seq_rm 和 kv_cache_seq_shift。基于这两个api,我们将实现对kv_cache的 token 级微操,榨干kv_cache的全部价值。

博主实践表明,在充分利用kv_cache的基础上,哪怕是 huggingface space 免费的2vCPU容器也可以游玩角色扮演,而笔记本端6G显存的1660Ti可以做到畅玩角色扮演。

体验 DEMO

  • Limour/llama-python-streamingllm
  • 同一时间仅支持一个人用,用之前点 Reset 按钮恢复初始的 kv_cache
  • 按 Submit 没反应,说明有人在用,等一段时间后再 Reset
  • 最好是 Duplicate 后,设为私密来使用

代码仓库

  • llama-python-streamingllm
  • 安装conda
  • 学术上网(管理员权限)
  • 使用前需要修改 rp_config.json 里的模型路径和参数,指定为你已经下载了的GGUF格式模型的路径
  • 推荐 causallm_7b.Q5_K_M.gguf
  • 或者自己用 Galgame 解包的对话数据集微调一个合适的模型。

    二选一:GPU版本的环境

    conda create -n llamaCpp libcublas cuda-toolkit git -c nvidia -c conda-forge
    conda activate llamaCpp
    conda install python=3.10 gradio -c conda-forge
    # 然后去 release 下载相应的包 https://github.com/Limour-dev/llama-cpp-python-cuBLAS-wheels/releases
    pip install --force-reinstall llama_cpp_python-0.2.39+cu122-cp310-cp310-win_amd64.whl
    

    二选一:CPU版本的环境

    conda create -n llamaCpp python=3.10 gradio git -c conda-forge
    conda activate llamaCpp
    pip install llama-cpp-python==0.2.39
    

    下载并运行

    conda activate llamaCpp
    git clone --depth=1 https://github.com/Limour-dev/llama-python-streamingllm.git
    cd llama-python-streamingllm
    mkdir cache
    python .\gradio_streamingllm.py
    

    核心内容

  • Submit 会将 msg 发送给模型,然后流式生成回答

  • Retry 会重新生成最近一次的 msg 所对应的回答

  • 旁白 会流式生成一份旁白到 VO

  • 建议 会以 usr 的身份流式生成一份 msg 供修改

  • 上面四个功能的基础就是下面的基于 StreamingLLM 原理的 venv 开头的函数

    class StreamingLLM(Llama):
    pass
    def kv_cache_seq_trim(self):
        self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
    
    def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1):
        if n_past < 0:
            n_past = self.n_tokens
        self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
        self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
        self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
        self.n_tokens = n_past - n_discard
    
    def _venv_init(self):
        self.venv = [0]
        self.venv_idx_map = []
    
    def venv_create(self, name: str):
        self.venv.append(0)
        self.venv_idx_map.append(name)
        return name
    
    def venv_disband(self, name_set):
        if len(self.venv) <= 1:
            return False
        name_set = {x for x in name_set if x in self.venv_idx_map}
        if not name_set:
            return False
        while self.venv_idx_map:
            if self.venv_idx_map[0] in name_set:
                self.venv_idx_map.pop(0)  # 删除
                tmp = self.venv.pop(1)  # 对应的 venv 移入上一层
                self.venv[0] += tmp
            else:
                break
        return True
    
    def venv_revision(self, name: str):
        if len(self.venv) <= 1:
            return False
        if name not in self.venv_idx_map:
            return False
        _s = 0
        while self.venv_idx_map:
            if self.venv_idx_map[-1] == name:
                break
            self.venv_idx_map.pop()  # 删除
            _s += self.venv.pop()
        if _s:
            self.n_tokens -= min(_s, self.n_tokens)
            self.kv_cache_seq_trim()
        return True
    
    def venv_remove(self, name: str):
        if len(self.venv) <= 1:
            return False
        if name not in self.venv_idx_map:
            return False
        venv_idx = self.venv_idx_map.index(name) + 1
        while self.venv_idx_map:
            self.venv_idx_map.pop(venv_idx - 1)  # 删除
            if venv_idx == len(self.venv) - 1:
                # 最后一层
                self.n_tokens -= min(self.venv.pop(), self.n_tokens)
                self.kv_cache_seq_trim()
                break
            else:
                # 非最后一层
                n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
                n_discard = self.venv.pop(venv_idx)
                self.kv_cache_seq_ltrim(n_keep, n_discard)
                try:
                    venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
                except ValueError:  # 没有了
                    break
        return True
    
    def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None):
        if self._n_ctx < self.n_tokens + len(tokens):
            tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx)
            self.kv_cache_seq_ltrim(n_keep, tmp_n_discard)
        for i in range(0, len(tokens), self.n_batch):
            pass
            self.n_tokens += n_tokens
            self.venv[-1] += n_tokens