修复了代码中的错误,优化了流聊天功能,提高了最大长度限制并添加了错误处理。

This commit is contained in:
XingYu-Zhong 2024-07-08 16:00:04 +08:00
parent 724df40103
commit 9ccf7b7581
2 changed files with 9 additions and 13 deletions

View File

@ -2,14 +2,7 @@ from pydantic import Field
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from typing import Iterator from typing import Iterator
import torch import torch
class StreamProcessor:
def __init__(self):
self.previous_str = ""
def get_new_part(self, new_str):
new_part = new_str[len(self.previous_str):]
self.previous_str = new_str
return new_part
class CodegeexChatModel(): class CodegeexChatModel():
device: str = Field(description="device to load the model") device: str = Field(description="device to load the model")
tokenizer = Field(description="model's tokenizer") tokenizer = Field(description="model's tokenizer")
@ -31,7 +24,7 @@ class CodegeexChatModel():
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
query=prompt, query=prompt,
max_length=4012, max_length=120000,
temperature=temperature, temperature=temperature,
top_p=top_p top_p=top_p
) )
@ -42,14 +35,13 @@ class CodegeexChatModel():
def stream_chat(self,prompt,temperature=0.2,top_p=0.95): def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
try: try:
stream_processor = StreamProcessor()
for response, _ in self.model.stream_chat( for response, _ in self.model.stream_chat(
self.tokenizer, self.tokenizer,
query=prompt, query=prompt,
max_length=4012, max_length=120000,
temperature=temperature, temperature=temperature,
top_p=top_p top_p=top_p
): ):
yield stream_processor.get_new_part(response) yield response
except Exception as e: except Exception as e:
yield f'error: {e}' yield f'error: {e}'

View File

@ -5,7 +5,6 @@ from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_j
from utils.tools import unzip_file,get_project_files_with_content from utils.tools import unzip_file,get_project_files_with_content
from utils.bingsearch import bing_search_prompt from utils.bingsearch import bing_search_prompt
from llm.local.codegeex4 import CodegeexChatModel from llm.local.codegeex4 import CodegeexChatModel
local_model_path = '<your_local_model_path>' local_model_path = '<your_local_model_path>'
llm = CodegeexChatModel(local_model_path) llm = CodegeexChatModel(local_model_path)
@ -155,8 +154,13 @@ async def main(message: cl.Message):
if len(prompt_content)/4<120000: if len(prompt_content)/4<120000:
stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p) stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
stream_processor = StreamProcessor()
for part in stream: for part in stream:
if token := (part or " "): if isinstance(part, str):
text = stream_processor.get_new_part(part)
elif isinstance(part, dict):
text = stream_processor.get_new_part(part['name']+part['content'])
if token := (text or " "):
await msg.stream_token(token) await msg.stream_token(token)
else: else:
await msg.stream_token("项目太大了,请换小一点的项目。") await msg.stream_token("项目太大了,请换小一点的项目。")