mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
修复了代码中的错误,优化了流聊天功能,提高了最大长度限制并添加了错误处理。
This commit is contained in:
parent
724df40103
commit
9ccf7b7581
|
@ -2,14 +2,7 @@ from pydantic import Field
|
|||
from transformers import AutoModel, AutoTokenizer
|
||||
from typing import Iterator
|
||||
import torch
|
||||
class StreamProcessor:
|
||||
def __init__(self):
|
||||
self.previous_str = ""
|
||||
|
||||
def get_new_part(self, new_str):
|
||||
new_part = new_str[len(self.previous_str):]
|
||||
self.previous_str = new_str
|
||||
return new_part
|
||||
class CodegeexChatModel():
|
||||
device: str = Field(description="device to load the model")
|
||||
tokenizer = Field(description="model's tokenizer")
|
||||
|
@ -31,7 +24,7 @@ class CodegeexChatModel():
|
|||
response, _ = self.model.chat(
|
||||
self.tokenizer,
|
||||
query=prompt,
|
||||
max_length=4012,
|
||||
max_length=120000,
|
||||
temperature=temperature,
|
||||
top_p=top_p
|
||||
)
|
||||
|
@ -42,14 +35,13 @@ class CodegeexChatModel():
|
|||
def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
|
||||
|
||||
try:
|
||||
stream_processor = StreamProcessor()
|
||||
for response, _ in self.model.stream_chat(
|
||||
self.tokenizer,
|
||||
query=prompt,
|
||||
max_length=4012,
|
||||
max_length=120000,
|
||||
temperature=temperature,
|
||||
top_p=top_p
|
||||
):
|
||||
yield stream_processor.get_new_part(response)
|
||||
yield response
|
||||
except Exception as e:
|
||||
yield f'error: {e}'
|
|
@ -5,7 +5,6 @@ from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_j
|
|||
from utils.tools import unzip_file,get_project_files_with_content
|
||||
from utils.bingsearch import bing_search_prompt
|
||||
from llm.local.codegeex4 import CodegeexChatModel
|
||||
|
||||
local_model_path = '<your_local_model_path>'
|
||||
llm = CodegeexChatModel(local_model_path)
|
||||
|
||||
|
@ -155,8 +154,13 @@ async def main(message: cl.Message):
|
|||
|
||||
if len(prompt_content)/4<120000:
|
||||
stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
|
||||
stream_processor = StreamProcessor()
|
||||
for part in stream:
|
||||
if token := (part or " "):
|
||||
if isinstance(part, str):
|
||||
text = stream_processor.get_new_part(part)
|
||||
elif isinstance(part, dict):
|
||||
text = stream_processor.get_new_part(part['name']+part['content'])
|
||||
if token := (text or " "):
|
||||
await msg.stream_token(token)
|
||||
else:
|
||||
await msg.stream_token("项目太大了,请换小一点的项目。")
|
||||
|
|
Loading…
Reference in New Issue
Block a user