mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 20:26:29 +00:00
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
|
"""
|
|||
|
References: https://python.langchain.com/v0.2/docs/tutorials/rag/
|
|||
|
"""
|
|||
|
import argparse
|
|||
|
|
|||
|
import gradio as gr
|
|||
|
from langchain_core.output_parsers import StrOutputParser
|
|||
|
from langchain_core.runnables import RunnablePassthrough
|
|||
|
|
|||
|
from models.codegeex import CodegeexChatModel
|
|||
|
from utils.prompts import CUSTOM_RAG_PROMPT
|
|||
|
from utils.vector import load_vector_store
|
|||
|
|
|||
|
|
|||
|
def parse_arguments():
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument('--vector_path', type=str, help="path to load the vectors", default='vectors')
|
|||
|
parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
|
|||
|
parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
|
|||
|
parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
|
|||
|
return parser.parse_args()
|
|||
|
|
|||
|
|
|||
|
def format_docs(docs):
|
|||
|
return "\n\n".join(
|
|||
|
[f"[[citation:{i + 1}]]\n```markdown\n{doc.page_content}\n```" for i, doc in enumerate(docs)]
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def chat(query, history):
|
|||
|
retrieve_chain = ({"context": retriever | format_docs, "question": RunnablePassthrough()} | CUSTOM_RAG_PROMPT)
|
|||
|
retrieve_output = retrieve_chain.invoke(query)
|
|||
|
|
|||
|
ans = retrieve_output.text
|
|||
|
yield ans
|
|||
|
|
|||
|
ans += "模型回复".center(150, '-') + '\n'
|
|||
|
yield ans
|
|||
|
|
|||
|
parse_chain = (llm | StrOutputParser())
|
|||
|
ans += parse_chain.invoke(retrieve_output)
|
|||
|
yield ans
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
args = parse_arguments()
|
|||
|
llm = CodegeexChatModel(args)
|
|||
|
try:
|
|||
|
retriever = load_vector_store(args.vector_path).as_retriever()
|
|||
|
except Exception as e:
|
|||
|
print(f"Fail to load vectors,caused by {e}")
|
|||
|
exit()
|
|||
|
demo = gr.ChatInterface(chat).queue()
|
|||
|
demo.launch(server_name="127.0.0.1", server_port=8080)
|