CodeGeeX4/repodemo/run.py

166 lines
5.9 KiB
Python
Raw Normal View History

2024-07-05 01:33:53 +00:00
import chainlit as cl
from chainlit.input_widget import Slider
from llm.api.codegeex4 import codegeex4
2024-07-09 03:37:30 +00:00
from prompts.base_prompt import (
judge_task_prompt,
get_cur_base_user_prompt,
web_judge_task_prompt,
)
from utils.tools import unzip_file, get_project_files_with_content
2024-07-05 01:33:53 +00:00
from utils.bingsearch import bing_search_prompt
@cl.set_chat_profiles
async def chat_profile():
return [
cl.ChatProfile(
name="chat聊天",
markdown_description="聊天demo支持多轮对话。",
2024-07-09 03:37:30 +00:00
starters=[
2024-07-05 01:33:53 +00:00
cl.Starter(
2024-07-09 03:37:30 +00:00
label="请你用python写一个快速排序。",
message="请你用python写一个快速排序。",
2024-07-05 01:33:53 +00:00
),
2024-07-09 03:37:30 +00:00
cl.Starter(
label="请你介绍一下自己。",
message="请你介绍一下自己。",
),
cl.Starter(
label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
2024-07-05 01:33:53 +00:00
),
2024-07-09 03:37:30 +00:00
cl.Starter(
label="我是一个python初学者请你告诉我怎么才能学好python。",
message="我是一个python初学者请你告诉我怎么才能学好python。",
2024-07-05 01:33:53 +00:00
),
2024-07-09 03:37:30 +00:00
],
2024-07-05 01:33:53 +00:00
),
cl.ChatProfile(
name="联网问答",
2024-07-09 03:37:30 +00:00
markdown_description="联网能力demo支持联网回答用户问题。",
2024-07-05 01:33:53 +00:00
),
cl.ChatProfile(
name="上传本地项目",
2024-07-09 03:37:30 +00:00
markdown_description="项目级能力demo支持上传本地zip压缩包项目可以进行项目问答和对项目进行修改。",
),
2024-07-05 01:33:53 +00:00
]
2024-07-09 03:37:30 +00:00
2024-07-05 01:33:53 +00:00
@cl.on_chat_start
async def start():
settings = await cl.ChatSettings(
[
Slider(
id="temperature",
label="CodeGeeX4 - Temperature",
initial=0.2,
min=0,
max=1,
step=0.1,
),
Slider(
id="top_p",
label="CodeGeeX4 - top_p",
initial=0.95,
min=0,
max=1,
step=0.1,
),
]
).send()
temperature = settings["temperature"]
top_p = settings["top_p"]
2024-07-09 03:37:30 +00:00
cl.user_session.set("temperature", temperature)
cl.user_session.set("top_p", top_p)
cl.user_session.set("message_history", [])
2024-07-05 01:33:53 +00:00
chat_profile = cl.user_session.get("chat_profile")
2024-07-09 03:37:30 +00:00
extract_dir = "repodata"
2024-07-05 01:33:53 +00:00
if chat_profile == "chat聊天":
pass
2024-07-09 03:37:30 +00:00
elif chat_profile == "上传本地项目":
2024-07-05 01:33:53 +00:00
files = None
while files == None:
files = await cl.AskFileMessage(
2024-07-09 03:37:30 +00:00
content="请上传项目zip压缩文件!",
accept={"application/zip": [".zip"]},
max_size_mb=50,
2024-07-05 01:33:53 +00:00
).send()
text_file = files[0]
2024-07-09 03:37:30 +00:00
extracted_path = unzip_file(text_file.path, extract_dir)
2024-07-05 01:33:53 +00:00
files_list = get_project_files_with_content(extracted_path)
2024-07-09 03:37:30 +00:00
cl.user_session.set("project_index", files_list)
if len(files_list) > 0:
2024-07-05 01:33:53 +00:00
await cl.Message(
content=f"已成功上传,您可以开始对项目进行提问!",
).send()
2024-07-09 03:37:30 +00:00
2024-07-05 01:33:53 +00:00
@cl.on_message
async def main(message: cl.Message):
chat_profile = cl.user_session.get("chat_profile")
message_history = cl.user_session.get("message_history")
message_history.append({"role": "user", "content": message.content})
if chat_profile == "chat聊天":
prompt_content = get_cur_base_user_prompt(message_history=message_history)
2024-07-09 03:37:30 +00:00
elif chat_profile == "联网问答":
judge_tmp = codegeex4(
web_judge_task_prompt.format(user_input=message.content),
temperature=0.2,
top_p=0.95,
)
judge_context = "\n".join(judge_tmp)
2024-07-05 01:33:53 +00:00
print(judge_context)
message_history.pop()
2024-07-09 03:37:30 +00:00
if "" in judge_context:
2024-07-05 01:33:53 +00:00
prompt_tmp = bing_search_prompt(message.content)
message_history.append({"role": "user", "content": prompt_tmp})
else:
message_history.append({"role": "user", "content": message.content})
prompt_content = get_cur_base_user_prompt(message_history=message_history)
2024-07-09 03:37:30 +00:00
elif chat_profile == "上传本地项目":
judge_tmp = codegeex4(
judge_task_prompt.format(user_input=message.content),
temperature=0.2,
top_p=0.95,
)
judge_context = ""
2024-07-05 01:33:53 +00:00
for part in judge_tmp:
2024-07-09 03:37:30 +00:00
judge_context += part
2024-07-05 01:33:53 +00:00
project_index = cl.user_session.get("project_index")
index_prompt = ""
index_tmp = """###PATH:{path}\n{code}\n"""
for index in project_index:
2024-07-09 03:37:30 +00:00
index_prompt += index_tmp.format(path=index["path"], code=index["content"])
2024-07-05 01:33:53 +00:00
print(judge_context)
2024-07-09 03:37:30 +00:00
prompt_content = (
get_cur_base_user_prompt(
message_history=message_history,
index_prompt=index_prompt,
judge_context=judge_context,
)
if "正常" not in judge_context
else get_cur_base_user_prompt(message_history=message_history)
)
2024-07-05 01:33:53 +00:00
msg = cl.Message(content="")
await msg.send()
temperature = cl.user_session.get("temperature")
2024-07-09 03:37:30 +00:00
top_p = cl.user_session.get("top_p")
if len(prompt_content) / 4 < 120000:
stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
2024-07-05 01:33:53 +00:00
for part in stream:
if token := (part or " "):
await msg.stream_token(token)
else:
await msg.stream_token("项目太大了,请换小一点的项目。")
message_history.append({"role": "assistant", "content": msg.content})
2024-07-09 03:37:30 +00:00
await msg.update()