CodeGeeX4/langchain_demo/chat.py
2024-07-05 09:33:53 +08:00

55 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
References: https://python.langchain.com/v0.2/docs/tutorials/rag/
"""
import argparse
import gradio as gr
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from models.codegeex import CodegeexChatModel
from utils.prompts import CUSTOM_RAG_PROMPT
from utils.vector import load_vector_store
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--vector_path', type=str, help="path to load the vectors", default='vectors')
parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
return parser.parse_args()
def format_docs(docs):
return "\n\n".join(
[f"[[citation:{i + 1}]]\n```markdown\n{doc.page_content}\n```" for i, doc in enumerate(docs)]
)
def chat(query, history):
retrieve_chain = ({"context": retriever | format_docs, "question": RunnablePassthrough()} | CUSTOM_RAG_PROMPT)
retrieve_output = retrieve_chain.invoke(query)
ans = retrieve_output.text
yield ans
ans += "模型回复".center(150, '-') + '\n'
yield ans
parse_chain = (llm | StrOutputParser())
ans += parse_chain.invoke(retrieve_output)
yield ans
if __name__ == '__main__':
args = parse_arguments()
llm = CodegeexChatModel(args)
try:
retriever = load_vector_store(args.vector_path).as_retriever()
except Exception as e:
print(f"Fail to load vectorscaused by {e}")
exit()
demo = gr.ChatInterface(chat).queue()
demo.launch(server_name="127.0.0.1", server_port=8080)