fix pep8 error

This commit is contained in:
XingYu-Zhong 2024-07-09 11:37:30 +08:00
parent 48fff1449b
commit 25870dc0a1
7 changed files with 289 additions and 262 deletions

View File

@ -1,41 +1,39 @@
import requests import requests
import json import json
URL = "" #the url you deploy codegeex service URL = "" # the url you deploy codegeex service
def codegeex4(prompt, temperature=0.8, top_p=0.8): def codegeex4(prompt, temperature=0.8, top_p=0.8):
url = URL url = URL
headers = { headers = {"Content-Type": "application/json"}
'Content-Type': 'application/json'
}
data = { data = {
'inputs': prompt, "inputs": prompt,
'parameters': { "parameters": {
'best_of':1, "best_of": 1,
'do_sample': True, "do_sample": True,
'max_new_tokens': 4012, "max_new_tokens": 4012,
'temperature': temperature, "temperature": temperature,
'top_p': top_p, "top_p": top_p,
'stop': ["<|endoftext|>", "<|user|>", "<|observation|>", "<|assistant|>"], "stop": ["<|endoftext|>", "<|user|>", "<|observation|>", "<|assistant|>"],
} },
} }
response = requests.post(url, json=data, headers=headers, verify=False, stream=True) response = requests.post(url, json=data, headers=headers, verify=False, stream=True)
if response.status_code == 200: if response.status_code == 200:
for line in response.iter_lines(): for line in response.iter_lines():
if line: if line:
decoded_line = line.decode('utf-8').replace('data:', '').strip() decoded_line = line.decode("utf-8").replace("data:", "").strip()
if decoded_line: if decoded_line:
try: try:
content = json.loads(decoded_line) content = json.loads(decoded_line)
token_text = content.get('token', {}).get('text', '') token_text = content.get("token", {}).get("text", "")
if '<|endoftext|>' in token_text: if "<|endoftext|>" in token_text:
break break
yield token_text yield token_text
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
else: else:
print('请求失败:', response.status_code) print("请求失败:", response.status_code)

View File

@ -3,45 +3,49 @@ from transformers import AutoModel, AutoTokenizer
from typing import Iterator from typing import Iterator
import torch import torch
class CodegeexChatModel():
class CodegeexChatModel:
device: str = Field(description="device to load the model") device: str = Field(description="device to load the model")
tokenizer = Field(description="model's tokenizer") tokenizer = Field(description="model's tokenizer")
model = Field(description="Codegeex model") model = Field(description="Codegeex model")
temperature: float = Field(description="temperature to use for the model.") temperature: float = Field(description="temperature to use for the model.")
def __init__(self,model_name_or_path): def __init__(self, model_name_or_path):
super().__init__() super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(
self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=True
model_name_or_path, )
trust_remote_code=True self.model = (
).to(self.device).eval() AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
.to(self.device)
.eval()
)
print("Model has been initialized.") print("Model has been initialized.")
def chat(self, prompt,temperature=0.2,top_p=0.95): def chat(self, prompt, temperature=0.2, top_p=0.95):
try: try:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
query=prompt, query=prompt,
max_length=120000, max_length=120000,
temperature=temperature, temperature=temperature,
top_p=top_p top_p=top_p,
) )
return response return response
except Exception as e: except Exception as e:
return f"error:{e}" return f"error:{e}"
def stream_chat(self,prompt,temperature=0.2,top_p=0.95): def stream_chat(self, prompt, temperature=0.2, top_p=0.95):
try: try:
for response, _ in self.model.stream_chat( for response, _ in self.model.stream_chat(
self.tokenizer, self.tokenizer,
query=prompt, query=prompt,
max_length=120000, max_length=120000,
temperature=temperature, temperature=temperature,
top_p=top_p top_p=top_p,
): ):
yield response yield response
except Exception as e: except Exception as e:
yield f'error: {e}' yield f"error: {e}"

View File

@ -4,7 +4,7 @@ repo_system_prompt = """<|system|>\n你是一位智能编程助手你叫CodeG
judge_task_prompt = """<|system|>\n你是一位任务分类专家,请你对用户的输入进行分类(问答/修改/正常),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字,如果用户输入的是一个与项目无关的问题则只需要输出正常两个字。<|user|>\n{user_input}<|assistant|>\n""" judge_task_prompt = """<|system|>\n你是一位任务分类专家,请你对用户的输入进行分类(问答/修改/正常),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字,如果用户输入的是一个与项目无关的问题则只需要输出正常两个字。<|user|>\n{user_input}<|assistant|>\n"""
web_judge_task_prompt ="""<|system|>\n你是一位智能编程助手你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题并提供格式规范、可以执行、准确安全的代码并在必要时提供详细的解释。<|user|>\n{user_input}\n这个问题需要进行联网来回答吗?仅回答“是”或者“否”。<|assistant|>\n""" web_judge_task_prompt = """<|system|>\n你是一位智能编程助手你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题并提供格式规范、可以执行、准确安全的代码并在必要时提供详细的解释。<|user|>\n{user_input}\n这个问题需要进行联网来回答吗?仅回答“是”或者“否”。<|assistant|>\n"""
# judge_task_prompt = """<|system|>\n你是一位任务分类专家请你对用户的输入进行分类问答/修改),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字。<|user|>\n{user_input}<|assistant|>\n""" # judge_task_prompt = """<|system|>\n你是一位任务分类专家请你对用户的输入进行分类问答/修改),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字。<|user|>\n{user_input}<|assistant|>\n"""
web_search_prompy = """ web_search_prompy = """
@ -19,24 +19,27 @@ web_search_prompy = """
- 除了代码和特定的名称和引用外您的答案必须使用与问题相同的语言来撰写 - 除了代码和特定的名称和引用外您的答案必须使用与问题相同的语言来撰写
""".lstrip() """.lstrip()
def get_cur_base_user_prompt(message_history,index_prompt = None,judge_context = ""):
def get_cur_base_user_prompt(message_history, index_prompt=None, judge_context=""):
user_prompt_tmp = """<|user|>\n{user_input}""" user_prompt_tmp = """<|user|>\n{user_input}"""
assistant_prompt_tmp = """<|assistant|>\n{assistant_input}""" assistant_prompt_tmp = """<|assistant|>\n{assistant_input}"""
history_prompt = "" history_prompt = ""
for i,message in enumerate(message_history): for i, message in enumerate(message_history):
if message['role'] == 'user': if message["role"] == "user":
if i==0 and index_prompt is not None: if i == 0 and index_prompt is not None:
history_prompt += "<|user|>\n"+index_prompt+message['content'] history_prompt += "<|user|>\n" + index_prompt + message["content"]
else: else:
history_prompt += user_prompt_tmp.format(user_input=message['content']) history_prompt += user_prompt_tmp.format(user_input=message["content"])
elif message['role'] == 'assistant': elif message["role"] == "assistant":
history_prompt += assistant_prompt_tmp.format(assistant_input=message['content']) history_prompt += assistant_prompt_tmp.format(
assistant_input=message["content"]
)
# print("修改" not in judge_context) # print("修改" not in judge_context)
# print(judge_context) # print(judge_context)
if "修改" not in judge_context: if "修改" not in judge_context:
result = base_system_prompt+history_prompt+"""<|assistant|>\n""" result = base_system_prompt + history_prompt + """<|assistant|>\n"""
else: else:
result = repo_system_prompt+history_prompt+"""<|assistant|>\n""" result = repo_system_prompt + history_prompt + """<|assistant|>\n"""
print(result) print(result)
return result return result

View File

@ -1,8 +1,12 @@
import chainlit as cl import chainlit as cl
from chainlit.input_widget import Slider from chainlit.input_widget import Slider
from llm.api.codegeex4 import codegeex4 from llm.api.codegeex4 import codegeex4
from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt from prompts.base_prompt import (
from utils.tools import unzip_file,get_project_files_with_content judge_task_prompt,
get_cur_base_user_prompt,
web_judge_task_prompt,
)
from utils.tools import unzip_file, get_project_files_with_content
from utils.bingsearch import bing_search_prompt from utils.bingsearch import bing_search_prompt
@ -12,41 +16,33 @@ async def chat_profile():
cl.ChatProfile( cl.ChatProfile(
name="chat聊天", name="chat聊天",
markdown_description="聊天demo支持多轮对话。", markdown_description="聊天demo支持多轮对话。",
starters = [ starters=[
cl.Starter( cl.Starter(
label="请你用python写一个快速排序。", label="请你用python写一个快速排序。",
message="请你用python写一个快速排序。", message="请你用python写一个快速排序。",
), ),
cl.Starter(
cl.Starter( label="请你介绍一下自己。",
label="请你介绍一下自己。", message="请你介绍一下自己。",
message="请你介绍一下自己。",
), ),
cl.Starter( cl.Starter(
label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
), ),
cl.Starter( cl.Starter(
label="我是一个python初学者请你告诉我怎么才能学好python。", label="我是一个python初学者请你告诉我怎么才能学好python。",
message="我是一个python初学者请你告诉我怎么才能学好python。", message="我是一个python初学者请你告诉我怎么才能学好python。",
),
) ],
]
), ),
cl.ChatProfile( cl.ChatProfile(
name="联网问答", name="联网问答",
markdown_description="联网能力dome支持联网回答用户问题。", markdown_description="联网能力demo支持联网回答用户问题。",
), ),
cl.ChatProfile( cl.ChatProfile(
name="上传本地项目", name="上传本地项目",
markdown_description="项目级能力dome支持上传本地zip压缩包项目可以进行项目问答和对项目进行修改。", markdown_description="项目级能力demo支持上传本地zip压缩包项目可以进行项目问答和对项目进行修改。",
),
)
] ]
@ -74,34 +70,32 @@ async def start():
).send() ).send()
temperature = settings["temperature"] temperature = settings["temperature"]
top_p = settings["top_p"] top_p = settings["top_p"]
cl.user_session.set('temperature',temperature) cl.user_session.set("temperature", temperature)
cl.user_session.set('top_p',top_p) cl.user_session.set("top_p", top_p)
cl.user_session.set( cl.user_session.set("message_history", [])
"message_history",
[]
)
chat_profile = cl.user_session.get("chat_profile") chat_profile = cl.user_session.get("chat_profile")
extract_dir = 'repodata' extract_dir = "repodata"
if chat_profile == "chat聊天": if chat_profile == "chat聊天":
pass pass
elif chat_profile =="上传本地项目": elif chat_profile == "上传本地项目":
files = None files = None
while files == None: while files == None:
files = await cl.AskFileMessage( files = await cl.AskFileMessage(
content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50 content="请上传项目zip压缩文件!",
accept={"application/zip": [".zip"]},
max_size_mb=50,
).send() ).send()
text_file = files[0] text_file = files[0]
extracted_path = unzip_file(text_file.path,extract_dir) extracted_path = unzip_file(text_file.path, extract_dir)
files_list = get_project_files_with_content(extracted_path) files_list = get_project_files_with_content(extracted_path)
cl.user_session.set("project_index",files_list) cl.user_session.set("project_index", files_list)
if len(files_list)>0: if len(files_list) > 0:
await cl.Message( await cl.Message(
content=f"已成功上传,您可以开始对项目进行提问!", content=f"已成功上传,您可以开始对项目进行提问!",
).send() ).send()
@cl.on_message @cl.on_message
async def main(message: cl.Message): async def main(message: cl.Message):
chat_profile = cl.user_session.get("chat_profile") chat_profile = cl.user_session.get("chat_profile")
@ -110,42 +104,56 @@ async def main(message: cl.Message):
if chat_profile == "chat聊天": if chat_profile == "chat聊天":
prompt_content = get_cur_base_user_prompt(message_history=message_history) prompt_content = get_cur_base_user_prompt(message_history=message_history)
elif chat_profile=="联网问答": elif chat_profile == "联网问答":
judge_tmp = codegeex4(web_judge_task_prompt.format(user_input=message.content),temperature=0.2,top_p = 0.95) judge_tmp = codegeex4(
judge_context = '\n'.join(judge_tmp) web_judge_task_prompt.format(user_input=message.content),
temperature=0.2,
top_p=0.95,
)
judge_context = "\n".join(judge_tmp)
print(judge_context) print(judge_context)
message_history.pop() message_history.pop()
if '' in judge_context: if "" in judge_context:
prompt_tmp = bing_search_prompt(message.content) prompt_tmp = bing_search_prompt(message.content)
message_history.append({"role": "user", "content": prompt_tmp}) message_history.append({"role": "user", "content": prompt_tmp})
else: else:
message_history.append({"role": "user", "content": message.content}) message_history.append({"role": "user", "content": message.content})
prompt_content = get_cur_base_user_prompt(message_history=message_history) prompt_content = get_cur_base_user_prompt(message_history=message_history)
elif chat_profile =="上传本地项目" : elif chat_profile == "上传本地项目":
judge_tmp = codegeex4(judge_task_prompt.format(user_input=message.content),temperature=0.2,top_p = 0.95) judge_tmp = codegeex4(
judge_context = '' judge_task_prompt.format(user_input=message.content),
temperature=0.2,
top_p=0.95,
)
judge_context = ""
for part in judge_tmp: for part in judge_tmp:
judge_context+=part judge_context += part
project_index = cl.user_session.get("project_index") project_index = cl.user_session.get("project_index")
index_prompt = "" index_prompt = ""
index_tmp = """###PATH:{path}\n{code}\n""" index_tmp = """###PATH:{path}\n{code}\n"""
for index in project_index: for index in project_index:
index_prompt+=index_tmp.format(path=index['path'],code=index['content']) index_prompt += index_tmp.format(path=index["path"], code=index["content"])
print(judge_context) print(judge_context)
prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history) prompt_content = (
get_cur_base_user_prompt(
message_history=message_history,
index_prompt=index_prompt,
judge_context=judge_context,
)
if "正常" not in judge_context
else get_cur_base_user_prompt(message_history=message_history)
)
msg = cl.Message(content="") msg = cl.Message(content="")
await msg.send() await msg.send()
temperature = cl.user_session.get("temperature") temperature = cl.user_session.get("temperature")
top_p = cl.user_session.get('top_p') top_p = cl.user_session.get("top_p")
if len(prompt_content)/4<120000: if len(prompt_content) / 4 < 120000:
stream = codegeex4(prompt_content,temperature=temperature,top_p = top_p) stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
for part in stream: for part in stream:
if token := (part or " "): if token := (part or " "):

View File

@ -1,63 +1,62 @@
import chainlit as cl import chainlit as cl
from chainlit.input_widget import Slider from chainlit.input_widget import Slider
from llm.api.codegeex4 import codegeex4 from llm.api.codegeex4 import codegeex4
from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt from prompts.base_prompt import (
from utils.tools import unzip_file,get_project_files_with_content judge_task_prompt,
get_cur_base_user_prompt,
web_judge_task_prompt,
)
from utils.tools import unzip_file, get_project_files_with_content
from utils.bingsearch import bing_search_prompt from utils.bingsearch import bing_search_prompt
from llm.local.codegeex4 import CodegeexChatModel from llm.local.codegeex4 import CodegeexChatModel
local_model_path = '<your_local_model_path>'
local_model_path = "<your_local_model_path>"
llm = CodegeexChatModel(local_model_path) llm = CodegeexChatModel(local_model_path)
class StreamProcessor: class StreamProcessor:
def __init__(self): def __init__(self):
self.previous_str = "" self.previous_str = ""
def get_new_part(self, new_str): def get_new_part(self, new_str):
new_part = new_str[len(self.previous_str):] new_part = new_str[len(self.previous_str) :]
self.previous_str = new_str self.previous_str = new_str
return new_part return new_part
@cl.set_chat_profiles @cl.set_chat_profiles
async def chat_profile(): async def chat_profile():
return [ return [
cl.ChatProfile( cl.ChatProfile(
name="chat聊天", name="chat聊天",
markdown_description="聊天demo支持多轮对话。", markdown_description="聊天demo支持多轮对话。",
starters = [ starters=[
cl.Starter( cl.Starter(
label="请你用python写一个快速排序。", label="请你用python写一个快速排序。",
message="请你用python写一个快速排序。", message="请你用python写一个快速排序。",
), ),
cl.Starter(
cl.Starter( label="请你介绍一下自己。",
label="请你介绍一下自己。", message="请你介绍一下自己。",
message="请你介绍一下自己。",
), ),
cl.Starter( cl.Starter(
label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
), ),
cl.Starter( cl.Starter(
label="我是一个python初学者请你告诉我怎么才能学好python。", label="我是一个python初学者请你告诉我怎么才能学好python。",
message="我是一个python初学者请你告诉我怎么才能学好python。", message="我是一个python初学者请你告诉我怎么才能学好python。",
),
) ],
]
), ),
cl.ChatProfile( cl.ChatProfile(
name="联网问答", name="联网问答",
markdown_description="联网能力dome支持联网回答用户问题。", markdown_description="联网能力demo支持联网回答用户问题。",
), ),
cl.ChatProfile( cl.ChatProfile(
name="上传本地项目", name="上传本地项目",
markdown_description="项目级能力dome支持上传本地zip压缩包项目可以进行项目问答和对项目进行修改。", markdown_description="项目级能力demo支持上传本地zip压缩包项目可以进行项目问答和对项目进行修改。",
),
)
] ]
@ -85,34 +84,32 @@ async def start():
).send() ).send()
temperature = settings["temperature"] temperature = settings["temperature"]
top_p = settings["top_p"] top_p = settings["top_p"]
cl.user_session.set('temperature',temperature) cl.user_session.set("temperature", temperature)
cl.user_session.set('top_p',top_p) cl.user_session.set("top_p", top_p)
cl.user_session.set( cl.user_session.set("message_history", [])
"message_history",
[]
)
chat_profile = cl.user_session.get("chat_profile") chat_profile = cl.user_session.get("chat_profile")
extract_dir = 'repodata' extract_dir = "repodata"
if chat_profile == "chat聊天": if chat_profile == "chat聊天":
pass pass
elif chat_profile =="上传本地项目": elif chat_profile == "上传本地项目":
files = None files = None
while files == None: while files == None:
files = await cl.AskFileMessage( files = await cl.AskFileMessage(
content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50 content="请上传项目zip压缩文件!",
accept={"application/zip": [".zip"]},
max_size_mb=50,
).send() ).send()
text_file = files[0] text_file = files[0]
extracted_path = unzip_file(text_file.path,extract_dir) extracted_path = unzip_file(text_file.path, extract_dir)
files_list = get_project_files_with_content(extracted_path) files_list = get_project_files_with_content(extracted_path)
cl.user_session.set("project_index",files_list) cl.user_session.set("project_index", files_list)
if len(files_list)>0: if len(files_list) > 0:
await cl.Message( await cl.Message(
content=f"已成功上传,您可以开始对项目进行提问!", content=f"已成功上传,您可以开始对项目进行提问!",
).send() ).send()
@cl.on_message @cl.on_message
async def main(message: cl.Message): async def main(message: cl.Message):
chat_profile = cl.user_session.get("chat_profile") chat_profile = cl.user_session.get("chat_profile")
@ -121,45 +118,54 @@ async def main(message: cl.Message):
if chat_profile == "chat聊天": if chat_profile == "chat聊天":
prompt_content = get_cur_base_user_prompt(message_history=message_history) prompt_content = get_cur_base_user_prompt(message_history=message_history)
elif chat_profile=="联网问答": elif chat_profile == "联网问答":
judge_context = llm.chat(web_judge_task_prompt.format(user_input=message.content),temperature=0.2) judge_context = llm.chat(
web_judge_task_prompt.format(user_input=message.content), temperature=0.2
)
print(judge_context) print(judge_context)
message_history.pop() message_history.pop()
if '' in judge_context: if "" in judge_context:
prompt_tmp = bing_search_prompt(message.content) prompt_tmp = bing_search_prompt(message.content)
message_history.append({"role": "user", "content": prompt_tmp}) message_history.append({"role": "user", "content": prompt_tmp})
else: else:
message_history.append({"role": "user", "content": message.content}) message_history.append({"role": "user", "content": message.content})
prompt_content = get_cur_base_user_prompt(message_history=message_history) prompt_content = get_cur_base_user_prompt(message_history=message_history)
elif chat_profile =="上传本地项目" : elif chat_profile == "上传本地项目":
judge_context = llm.chat(judge_task_prompt.format(user_input=message.content),temperature=0.2) judge_context = llm.chat(
judge_task_prompt.format(user_input=message.content), temperature=0.2
)
project_index = cl.user_session.get("project_index") project_index = cl.user_session.get("project_index")
index_prompt = "" index_prompt = ""
index_tmp = """###PATH:{path}\n{code}\n""" index_tmp = """###PATH:{path}\n{code}\n"""
for index in project_index: for index in project_index:
index_prompt+=index_tmp.format(path=index['path'],code=index['content']) index_prompt += index_tmp.format(path=index["path"], code=index["content"])
print(judge_context) print(judge_context)
prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history) prompt_content = (
get_cur_base_user_prompt(
message_history=message_history,
index_prompt=index_prompt,
judge_context=judge_context,
)
if "正常" not in judge_context
else get_cur_base_user_prompt(message_history=message_history)
)
msg = cl.Message(content="") msg = cl.Message(content="")
await msg.send() await msg.send()
temperature = cl.user_session.get("temperature") temperature = cl.user_session.get("temperature")
top_p = cl.user_session.get('top_p') top_p = cl.user_session.get("top_p")
if len(prompt_content)/4<120000: if len(prompt_content) / 4 < 120000:
stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p) stream = llm.stream_chat(prompt_content, temperature=temperature, top_p=top_p)
stream_processor = StreamProcessor() stream_processor = StreamProcessor()
for part in stream: for part in stream:
if isinstance(part, str): if isinstance(part, str):
text = stream_processor.get_new_part(part) text = stream_processor.get_new_part(part)
elif isinstance(part, dict): elif isinstance(part, dict):
text = stream_processor.get_new_part(part['name']+part['content']) text = stream_processor.get_new_part(part["name"] + part["content"])
if token := (text or " "): if token := (text or " "):
await msg.stream_token(token) await msg.stream_token(token)
else: else:

View File

@ -2,7 +2,9 @@ import requests
from bs4 import BeautifulSoup as BS4 from bs4 import BeautifulSoup as BS4
import requests import requests
BING_API_KEY = '<your_bing_api_key>' BING_API_KEY = "<your_bing_api_key>"
def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]: def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]:
""" """
Search with bing and return the contexts. Search with bing and return the contexts.
@ -13,9 +15,9 @@ def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]:
headers={"Ocp-Apim-Subscription-Key": BING_API_KEY}, headers={"Ocp-Apim-Subscription-Key": BING_API_KEY},
params={ params={
"q": query, "q": query,
"responseFilter": ['webpages'], "responseFilter": ["webpages"],
"freshness": 'month', "freshness": "month",
"mkt": 'zh-CN' "mkt": "zh-CN",
}, },
timeout=search_timeout, timeout=search_timeout,
) )
@ -23,25 +25,29 @@ def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]:
json_content = response.json() json_content = response.json()
# print(json_content) # print(json_content)
contexts = json_content["webPages"]["value"][:top_k] contexts = json_content["webPages"]["value"][:top_k]
#logger.info("Web搜索完成") # logger.info("Web搜索完成")
return contexts return contexts
except Exception as e: except Exception as e:
#logger.error(f"搜索失败,错误原因: {e}") # logger.error(f"搜索失败,错误原因: {e}")
print(f"搜索失败,错误原因: {e}") print(f"搜索失败,错误原因: {e}")
return [] return []
def fetch_url(url): def fetch_url(url):
response = requests.get(url) response = requests.get(url)
#use beautifulsoup4 to parse html # use beautifulsoup4 to parse html
soup = BS4(response.text, 'html.parser') soup = BS4(response.text, "html.parser")
plain_text = soup.get_text() plain_text = soup.get_text()
return plain_text return plain_text
def bing_search_prompt(input): def bing_search_prompt(input):
contents = search_with_bing(input, search_timeout=5, top_k=6) contents = search_with_bing(input, search_timeout=5, top_k=6)
citations = "\n\n".join( citations = "\n\n".join(
[f"[[citation:{i + 1}]]\n```markdown\n{item['snippet']}\n```" for i, item in enumerate(contents)] [
f"[[citation:{i + 1}]]\n```markdown\n{item['snippet']}\n```"
for i, item in enumerate(contents)
]
) )
prompt = f"[引用]\n{citations}\n问:{input}\n" prompt = f"[引用]\n{citations}\n问:{input}\n"
return prompt return prompt

View File

@ -2,6 +2,7 @@ import zipfile
import os import os
import json import json
def unzip_file(zip_path, extract_dir): def unzip_file(zip_path, extract_dir):
""" """
解压zip文件到指定目录并在指定目录下创建一个新的目录存放解压后的文件 解压zip文件到指定目录并在指定目录下创建一个新的目录存放解压后的文件
@ -23,7 +24,7 @@ def unzip_file(zip_path, extract_dir):
if not os.path.exists(new_extract_dir): if not os.path.exists(new_extract_dir):
os.makedirs(new_extract_dir) os.makedirs(new_extract_dir)
with zipfile.ZipFile(zip_path, 'r') as zip_ref: with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(new_extract_dir) zip_ref.extractall(new_extract_dir)
return new_extract_dir return new_extract_dir
@ -48,83 +49,85 @@ def get_project_files_with_content(project_dir):
relative_path = os.path.relpath(file_path, project_dir) relative_path = os.path.relpath(file_path, project_dir)
if "__MACOSX" in relative_path: if "__MACOSX" in relative_path:
continue continue
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
files_list.append({'path': relative_path, 'content': content}) files_list.append({"path": relative_path, "content": content})
else: else:
continue continue
return files_list return files_list
def filter_data(obj): def filter_data(obj):
LANGUAGE_TAG = { LANGUAGE_TAG = {
"c++" : "// C++", "c++": "// C++",
"cpp" : "// C++", "cpp": "// C++",
"c" : "// C", "c": "// C",
"c#" : "// C#", "c#": "// C#",
"c-sharp" : "// C#", "c-sharp": "// C#",
"css" : "/* CSS */", "css": "/* CSS */",
"cuda" : "// Cuda", "cuda": "// Cuda",
"fortran" : "! Fortran", "fortran": "! Fortran",
"go" : "// Go", "go": "// Go",
"html" : "<!-- HTML -->", "html": "<!-- HTML -->",
"java" : "// Java", "java": "// Java",
"js" : "// JavaScript", "js": "// JavaScript",
"javascript" : "// JavaScript", "javascript": "// JavaScript",
"kotlin" : "// Kotlin", "kotlin": "// Kotlin",
"lean" : "-- Lean", "lean": "-- Lean",
"lua" : "-- Lua", "lua": "-- Lua",
"objectivec" : "// Objective-C", "objectivec": "// Objective-C",
"objective-c" : "// Objective-C", "objective-c": "// Objective-C",
"objective-c++": "// Objective-C++", "objective-c++": "// Objective-C++",
"pascal" : "// Pascal", "pascal": "// Pascal",
"php" : "// PHP", "php": "// PHP",
"python" : "# Python", "python": "# Python",
"r" : "# R", "r": "# R",
"rust" : "// Rust", "rust": "// Rust",
"ruby" : "# Ruby", "ruby": "# Ruby",
"scala" : "// Scala", "scala": "// Scala",
"shell" : "# Shell", "shell": "# Shell",
"sql" : "-- SQL", "sql": "-- SQL",
"tex" : f"% TeX", "tex": f"% TeX",
"typescript" : "// TypeScript", "typescript": "// TypeScript",
"vue" : "<!-- Vue -->", "vue": "<!-- Vue -->",
"assembly": "; Assembly",
"assembly" : "; Assembly", "dart": "// Dart",
"dart" : "// Dart", "perl": "# Perl",
"perl" : "# Perl", "prolog": f"% Prolog",
"prolog" : f"% Prolog", "swift": "// swift",
"swift" : "// swift", "lisp": "; Lisp",
"lisp" : "; Lisp", "vb": "' Visual Basic",
"vb" : "' Visual Basic", "visual basic": "' Visual Basic",
"visual basic" : "' Visual Basic", "matlab": f"% Matlab",
"matlab" : f"% Matlab", "delphi": "{ Delphi }",
"delphi" : "{ Delphi }", "scheme": "; Scheme",
"scheme" : "; Scheme", "basic": "' Basic",
"basic" : "' Basic", "assembly": "; Assembly",
"assembly" : "; Assembly", "groovy": "// Groovy",
"groovy" : "// Groovy", "abap": "* Abap",
"abap" : "* Abap", "gdscript": "# GDScript",
"gdscript" : "# GDScript", "haskell": "-- Haskell",
"haskell" : "-- Haskell", "julia": "# Julia",
"julia" : "# Julia", "elixir": "# Elixir",
"elixir" : "# Elixir", "excel": "' Excel",
"excel" : "' Excel", "clojure": "; Clojure",
"clojure" : "; Clojure", "actionscript": "// ActionScript",
"actionscript" : "// ActionScript", "solidity": "// Solidity",
"solidity" : "// Solidity", "powershell": "# PowerShell",
"powershell" : "# PowerShell", "erlang": f"% Erlang",
"erlang" : f"% Erlang", "cobol": "// Cobol",
"cobol" : "// Cobol", "batchfile": ":: Batch file",
"batchfile" : ":: Batch file", "makefile": "# Makefile",
"makefile" : "# Makefile", "dockerfile": "# Dockerfile",
"dockerfile" : "# Dockerfile", "markdown": "<!-- Markdown -->",
"markdown" : "<!-- Markdown -->", "cmake": "# CMake",
"cmake" : "# CMake", "dockerfile": "# Dockerfile",
"dockerfile" : "# Dockerfile",
} }
programming_languages_to_file_extensions = json.load(open('utils/programming-languages-to-file-extensions.json')) programming_languages_to_file_extensions = json.load(
open("utils/programming-languages-to-file-extensions.json")
)
need2del = [] need2del = []
for key in programming_languages_to_file_extensions.keys(): for key in programming_languages_to_file_extensions.keys():
if key.lower() not in LANGUAGE_TAG: if key.lower() not in LANGUAGE_TAG:
@ -140,15 +143,14 @@ def filter_data(obj):
ext_to_programming_languages[item] = key ext_to_programming_languages[item] = key
want_languages.append(item) want_languages.append(item)
ext = '.'+obj.split('.')[-1] ext = "." + obj.split(".")[-1]
with open('utils/keep.txt', 'r') as f: with open("utils/keep.txt", "r") as f:
keep_files = f.readlines() keep_files = f.readlines()
keep_files = [l.strip() for l in keep_files] keep_files = [l.strip() for l in keep_files]
#print(ext) # print(ext)
if ext not in want_languages: if ext not in want_languages:
if obj in keep_files: if obj in keep_files:
return True return True
return False return False
else: else:
return True return True