mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
修复了代码中的错误,优化了流聊天功能,提高了最大长度限制并添加了错误处理。
This commit is contained in:
parent
724df40103
commit
9ccf7b7581
|
@ -2,14 +2,7 @@ from pydantic import Field
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
import torch
|
import torch
|
||||||
class StreamProcessor:
|
|
||||||
def __init__(self):
|
|
||||||
self.previous_str = ""
|
|
||||||
|
|
||||||
def get_new_part(self, new_str):
|
|
||||||
new_part = new_str[len(self.previous_str):]
|
|
||||||
self.previous_str = new_str
|
|
||||||
return new_part
|
|
||||||
class CodegeexChatModel():
|
class CodegeexChatModel():
|
||||||
device: str = Field(description="device to load the model")
|
device: str = Field(description="device to load the model")
|
||||||
tokenizer = Field(description="model's tokenizer")
|
tokenizer = Field(description="model's tokenizer")
|
||||||
|
@ -31,7 +24,7 @@ class CodegeexChatModel():
|
||||||
response, _ = self.model.chat(
|
response, _ = self.model.chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
query=prompt,
|
query=prompt,
|
||||||
max_length=4012,
|
max_length=120000,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p
|
top_p=top_p
|
||||||
)
|
)
|
||||||
|
@ -42,14 +35,13 @@ class CodegeexChatModel():
|
||||||
def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
|
def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream_processor = StreamProcessor()
|
|
||||||
for response, _ in self.model.stream_chat(
|
for response, _ in self.model.stream_chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
query=prompt,
|
query=prompt,
|
||||||
max_length=4012,
|
max_length=120000,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p
|
top_p=top_p
|
||||||
):
|
):
|
||||||
yield stream_processor.get_new_part(response)
|
yield response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f'error: {e}'
|
yield f'error: {e}'
|
|
@ -5,7 +5,6 @@ from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_j
|
||||||
from utils.tools import unzip_file,get_project_files_with_content
|
from utils.tools import unzip_file,get_project_files_with_content
|
||||||
from utils.bingsearch import bing_search_prompt
|
from utils.bingsearch import bing_search_prompt
|
||||||
from llm.local.codegeex4 import CodegeexChatModel
|
from llm.local.codegeex4 import CodegeexChatModel
|
||||||
|
|
||||||
local_model_path = '<your_local_model_path>'
|
local_model_path = '<your_local_model_path>'
|
||||||
llm = CodegeexChatModel(local_model_path)
|
llm = CodegeexChatModel(local_model_path)
|
||||||
|
|
||||||
|
@ -155,8 +154,13 @@ async def main(message: cl.Message):
|
||||||
|
|
||||||
if len(prompt_content)/4<120000:
|
if len(prompt_content)/4<120000:
|
||||||
stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
|
stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
|
||||||
|
stream_processor = StreamProcessor()
|
||||||
for part in stream:
|
for part in stream:
|
||||||
if token := (part or " "):
|
if isinstance(part, str):
|
||||||
|
text = stream_processor.get_new_part(part)
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
text = stream_processor.get_new_part(part['name']+part['content'])
|
||||||
|
if token := (text or " "):
|
||||||
await msg.stream_token(token)
|
await msg.stream_token(token)
|
||||||
else:
|
else:
|
||||||
await msg.stream_token("项目太大了,请换小一点的项目。")
|
await msg.stream_token("项目太大了,请换小一点的项目。")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user