2024-07-08 07:17:28 +00:00
|
|
|
|
import chainlit as cl
|
|
|
|
|
from chainlit.input_widget import Slider
|
2024-07-09 04:05:40 +00:00
|
|
|
|
|
|
|
|
|
from llm.local.codegeex4 import CodegeexChatModel
|
2024-07-09 03:37:30 +00:00
|
|
|
|
from prompts.base_prompt import (
|
|
|
|
|
judge_task_prompt,
|
|
|
|
|
get_cur_base_user_prompt,
|
|
|
|
|
web_judge_task_prompt,
|
|
|
|
|
)
|
2024-07-08 07:17:28 +00:00
|
|
|
|
from utils.bingsearch import bing_search_prompt
|
2024-07-09 04:05:40 +00:00
|
|
|
|
from utils.tools import unzip_file, get_project_files_with_content
|
2024-07-09 03:37:30 +00:00
|
|
|
|
|
|
|
|
|
local_model_path = "<your_local_model_path>"
|
2024-07-08 07:17:28 +00:00
|
|
|
|
llm = CodegeexChatModel(local_model_path)
|
|
|
|
|
|
2024-07-09 03:37:30 +00:00
|
|
|
|
|
2024-07-08 07:17:28 +00:00
|
|
|
|
class StreamProcessor:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.previous_str = ""
|
|
|
|
|
|
|
|
|
|
def get_new_part(self, new_str):
|
2024-07-09 04:05:40 +00:00
|
|
|
|
new_part = new_str[len(self.previous_str):]
|
2024-07-08 07:17:28 +00:00
|
|
|
|
self.previous_str = new_str
|
|
|
|
|
return new_part
|
|
|
|
|
|
2024-07-09 03:37:30 +00:00
|
|
|
|
|
2024-07-08 07:17:28 +00:00
|
|
|
|
@cl.set_chat_profiles
|
|
|
|
|
async def chat_profile():
|
|
|
|
|
return [
|
|
|
|
|
cl.ChatProfile(
|
|
|
|
|
name="chat聊天",
|
|
|
|
|
markdown_description="聊天demo:支持多轮对话。",
|
2024-07-09 03:37:30 +00:00
|
|
|
|
starters=[
|
2024-07-08 07:17:28 +00:00
|
|
|
|
cl.Starter(
|
2024-07-09 03:37:30 +00:00
|
|
|
|
label="请你用python写一个快速排序。",
|
|
|
|
|
message="请你用python写一个快速排序。",
|
2024-07-08 07:17:28 +00:00
|
|
|
|
),
|
2024-07-09 03:37:30 +00:00
|
|
|
|
cl.Starter(
|
|
|
|
|
label="请你介绍一下自己。",
|
|
|
|
|
message="请你介绍一下自己。",
|
2024-07-08 07:17:28 +00:00
|
|
|
|
),
|
2024-07-09 03:37:30 +00:00
|
|
|
|
cl.Starter(
|
|
|
|
|
label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
|
|
|
|
|
message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
|
|
|
|
|
),
|
|
|
|
|
cl.Starter(
|
|
|
|
|
label="我是一个python初学者,请你告诉我怎么才能学好python。",
|
|
|
|
|
message="我是一个python初学者,请你告诉我怎么才能学好python。",
|
2024-07-08 07:17:28 +00:00
|
|
|
|
),
|
2024-07-09 03:37:30 +00:00
|
|
|
|
],
|
2024-07-08 07:17:28 +00:00
|
|
|
|
),
|
|
|
|
|
cl.ChatProfile(
|
|
|
|
|
name="联网问答",
|
2024-07-09 03:37:30 +00:00
|
|
|
|
markdown_description="联网能力demo:支持联网回答用户问题。",
|
2024-07-08 07:17:28 +00:00
|
|
|
|
),
|
|
|
|
|
cl.ChatProfile(
|
|
|
|
|
name="上传本地项目",
|
2024-07-09 03:37:30 +00:00
|
|
|
|
markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
|
|
|
|
|
),
|
2024-07-08 07:17:28 +00:00
|
|
|
|
]
|
|
|
|
|
|
2024-07-09 03:37:30 +00:00
|
|
|
|
|
2024-07-08 07:17:28 +00:00
|
|
|
|
@cl.on_chat_start
|
|
|
|
|
async def start():
|
|
|
|
|
settings = await cl.ChatSettings(
|
|
|
|
|
[
|
|
|
|
|
Slider(
|
|
|
|
|
id="temperature",
|
|
|
|
|
label="CodeGeeX4 - Temperature",
|
|
|
|
|
initial=0.2,
|
|
|
|
|
min=0,
|
|
|
|
|
max=1,
|
|
|
|
|
step=0.1,
|
|
|
|
|
),
|
|
|
|
|
Slider(
|
|
|
|
|
id="top_p",
|
|
|
|
|
label="CodeGeeX4 - top_p",
|
|
|
|
|
initial=0.95,
|
|
|
|
|
min=0,
|
|
|
|
|
max=1,
|
|
|
|
|
step=0.1,
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
).send()
|
|
|
|
|
temperature = settings["temperature"]
|
|
|
|
|
top_p = settings["top_p"]
|
2024-07-09 03:37:30 +00:00
|
|
|
|
cl.user_session.set("temperature", temperature)
|
|
|
|
|
cl.user_session.set("top_p", top_p)
|
|
|
|
|
cl.user_session.set("message_history", [])
|
2024-07-08 07:17:28 +00:00
|
|
|
|
chat_profile = cl.user_session.get("chat_profile")
|
2024-07-09 03:37:30 +00:00
|
|
|
|
extract_dir = "repodata"
|
2024-07-08 07:17:28 +00:00
|
|
|
|
if chat_profile == "chat聊天":
|
|
|
|
|
pass
|
2024-07-09 03:37:30 +00:00
|
|
|
|
elif chat_profile == "上传本地项目":
|
2024-07-08 07:17:28 +00:00
|
|
|
|
files = None
|
|
|
|
|
while files == None:
|
|
|
|
|
files = await cl.AskFileMessage(
|
2024-07-09 03:37:30 +00:00
|
|
|
|
content="请上传项目zip压缩文件!",
|
|
|
|
|
accept={"application/zip": [".zip"]},
|
|
|
|
|
max_size_mb=50,
|
2024-07-08 07:17:28 +00:00
|
|
|
|
).send()
|
|
|
|
|
|
|
|
|
|
text_file = files[0]
|
2024-07-09 03:37:30 +00:00
|
|
|
|
extracted_path = unzip_file(text_file.path, extract_dir)
|
2024-07-08 07:17:28 +00:00
|
|
|
|
files_list = get_project_files_with_content(extracted_path)
|
2024-07-09 03:37:30 +00:00
|
|
|
|
cl.user_session.set("project_index", files_list)
|
|
|
|
|
if len(files_list) > 0:
|
2024-07-08 07:17:28 +00:00
|
|
|
|
await cl.Message(
|
|
|
|
|
content=f"已成功上传,您可以开始对项目进行提问!",
|
|
|
|
|
).send()
|
2024-07-09 03:37:30 +00:00
|
|
|
|
|
2024-07-08 07:17:28 +00:00
|
|
|
|
|
|
|
|
|
@cl.on_message
|
|
|
|
|
async def main(message: cl.Message):
|
|
|
|
|
chat_profile = cl.user_session.get("chat_profile")
|
|
|
|
|
message_history = cl.user_session.get("message_history")
|
|
|
|
|
message_history.append({"role": "user", "content": message.content})
|
|
|
|
|
if chat_profile == "chat聊天":
|
|
|
|
|
prompt_content = get_cur_base_user_prompt(message_history=message_history)
|
2024-07-09 03:37:30 +00:00
|
|
|
|
|
|
|
|
|
elif chat_profile == "联网问答":
|
|
|
|
|
judge_context = llm.chat(
|
|
|
|
|
web_judge_task_prompt.format(user_input=message.content), temperature=0.2
|
|
|
|
|
)
|
2024-07-08 07:17:28 +00:00
|
|
|
|
print(judge_context)
|
|
|
|
|
message_history.pop()
|
|
|
|
|
|
2024-07-09 03:37:30 +00:00
|
|
|
|
if "是" in judge_context:
|
2024-07-08 07:17:28 +00:00
|
|
|
|
prompt_tmp = bing_search_prompt(message.content)
|
|
|
|
|
message_history.append({"role": "user", "content": prompt_tmp})
|
|
|
|
|
else:
|
|
|
|
|
message_history.append({"role": "user", "content": message.content})
|
|
|
|
|
prompt_content = get_cur_base_user_prompt(message_history=message_history)
|
|
|
|
|
|
2024-07-09 03:37:30 +00:00
|
|
|
|
elif chat_profile == "上传本地项目":
|
|
|
|
|
judge_context = llm.chat(
|
|
|
|
|
judge_task_prompt.format(user_input=message.content), temperature=0.2
|
|
|
|
|
)
|
|
|
|
|
|
2024-07-08 07:17:28 +00:00
|
|
|
|
project_index = cl.user_session.get("project_index")
|
|
|
|
|
index_prompt = ""
|
|
|
|
|
index_tmp = """###PATH:{path}\n{code}\n"""
|
|
|
|
|
for index in project_index:
|
2024-07-09 03:37:30 +00:00
|
|
|
|
index_prompt += index_tmp.format(path=index["path"], code=index["content"])
|
2024-07-08 07:17:28 +00:00
|
|
|
|
print(judge_context)
|
2024-07-09 03:37:30 +00:00
|
|
|
|
prompt_content = (
|
|
|
|
|
get_cur_base_user_prompt(
|
|
|
|
|
message_history=message_history,
|
|
|
|
|
index_prompt=index_prompt,
|
|
|
|
|
judge_context=judge_context,
|
|
|
|
|
)
|
|
|
|
|
if "正常" not in judge_context
|
|
|
|
|
else get_cur_base_user_prompt(message_history=message_history)
|
|
|
|
|
)
|
2024-07-08 07:17:28 +00:00
|
|
|
|
|
|
|
|
|
msg = cl.Message(content="")
|
|
|
|
|
await msg.send()
|
|
|
|
|
temperature = cl.user_session.get("temperature")
|
2024-07-09 03:37:30 +00:00
|
|
|
|
top_p = cl.user_session.get("top_p")
|
|
|
|
|
|
|
|
|
|
if len(prompt_content) / 4 < 120000:
|
|
|
|
|
stream = llm.stream_chat(prompt_content, temperature=temperature, top_p=top_p)
|
2024-07-08 08:00:04 +00:00
|
|
|
|
stream_processor = StreamProcessor()
|
2024-07-08 07:17:28 +00:00
|
|
|
|
for part in stream:
|
2024-07-08 08:00:04 +00:00
|
|
|
|
if isinstance(part, str):
|
|
|
|
|
text = stream_processor.get_new_part(part)
|
|
|
|
|
elif isinstance(part, dict):
|
2024-07-09 03:37:30 +00:00
|
|
|
|
text = stream_processor.get_new_part(part["name"] + part["content"])
|
2024-07-08 08:00:04 +00:00
|
|
|
|
if token := (text or " "):
|
2024-07-08 07:17:28 +00:00
|
|
|
|
await msg.stream_token(token)
|
|
|
|
|
else:
|
|
|
|
|
await msg.stream_token("项目太大了,请换小一点的项目。")
|
|
|
|
|
|
|
|
|
|
message_history.append({"role": "assistant", "content": msg.content})
|
2024-07-09 03:37:30 +00:00
|
|
|
|
await msg.update()
|