From 9ccf7b7581d29ab97e862ef9351b9f17ed28c2b6 Mon Sep 17 00:00:00 2001 From: XingYu-Zhong <1736101137@qq.com> Date: Mon, 8 Jul 2024 16:00:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E9=94=99=E8=AF=AF=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BA=86=E6=B5=81=E8=81=8A=E5=A4=A9=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E6=8F=90=E9=AB=98=E4=BA=86=E6=9C=80=E5=A4=A7=E9=95=BF=E5=BA=A6?= =?UTF-8?q?=E9=99=90=E5=88=B6=E5=B9=B6=E6=B7=BB=E5=8A=A0=E4=BA=86=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E5=A4=84=E7=90=86=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- repodemo/llm/local/codegeex4.py | 14 +++----------- repodemo/run_local.py | 8 ++++++-- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/repodemo/llm/local/codegeex4.py b/repodemo/llm/local/codegeex4.py index 6a67433..53df2f4 100644 --- a/repodemo/llm/local/codegeex4.py +++ b/repodemo/llm/local/codegeex4.py @@ -2,14 +2,7 @@ from pydantic import Field from transformers import AutoModel, AutoTokenizer from typing import Iterator import torch -class StreamProcessor: - def __init__(self): - self.previous_str = "" - def get_new_part(self, new_str): - new_part = new_str[len(self.previous_str):] - self.previous_str = new_str - return new_part class CodegeexChatModel(): device: str = Field(description="device to load the model") tokenizer = Field(description="model's tokenizer") @@ -31,7 +24,7 @@ class CodegeexChatModel(): response, _ = self.model.chat( self.tokenizer, query=prompt, - max_length=4012, + max_length=120000, temperature=temperature, top_p=top_p ) @@ -42,14 +35,13 @@ class CodegeexChatModel(): def stream_chat(self,prompt,temperature=0.2,top_p=0.95): try: - stream_processor = StreamProcessor() for response, _ in self.model.stream_chat( self.tokenizer, query=prompt, - max_length=4012, + max_length=120000, temperature=temperature, top_p=top_p ): - yield stream_processor.get_new_part(response) + yield response except Exception as e: yield f'error: {e}' \ No newline at end of file diff --git a/repodemo/run_local.py b/repodemo/run_local.py index c95ae42..f82a190 100644 --- a/repodemo/run_local.py +++ b/repodemo/run_local.py @@ -5,7 +5,6 @@ from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_j from utils.tools import unzip_file,get_project_files_with_content from utils.bingsearch import bing_search_prompt from llm.local.codegeex4 import CodegeexChatModel - local_model_path = '' llm = CodegeexChatModel(local_model_path) @@ -155,8 +154,13 @@ async def main(message: cl.Message): if len(prompt_content)/4<120000: stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p) + stream_processor = StreamProcessor() for part in stream: - if token := (part or " "): + if isinstance(part, str): + text = stream_processor.get_new_part(part) + elif isinstance(part, dict): + text = stream_processor.get_new_part(part['name']+part['content']) + if token := (text or " "): await msg.stream_token(token) else: await msg.stream_token("项目太大了,请换小一点的项目。")