mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
|
"""
|
||
|
References: https://docs.llamaindex.ai/en/stable/use_cases/q_and_a/
|
||
|
"""
|
||
|
import argparse
|
||
|
|
||
|
import gradio as gr
|
||
|
from llama_index.core import Settings
|
||
|
|
||
|
from models.embedding import GLMEmbeddings
|
||
|
from models.synthesizer import CodegeexSynthesizer
|
||
|
from utils.vector import load_vectors
|
||
|
|
||
|
|
||
|
def parse_arguments():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--vector_path', type=str, help="path to store the vectors", default='vectors')
|
||
|
parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
|
||
|
parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
|
||
|
parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
def chat(query, history):
|
||
|
resp = query_engine.query(query)
|
||
|
|
||
|
ans = "相关文档".center(150, '-') + '\n'
|
||
|
yield ans
|
||
|
for i, node in enumerate(resp.source_nodes):
|
||
|
file_name = node.metadata['filename']
|
||
|
ext = node.metadata['extension']
|
||
|
text = node.text
|
||
|
ans += f"File{i + 1}: {file_name}\n```{ext}\n{text}\n```\n"
|
||
|
yield ans
|
||
|
|
||
|
ans += "模型回复".center(150, '-') + '\n'
|
||
|
ans += resp.response
|
||
|
yield ans
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = parse_arguments()
|
||
|
Settings.embed_model = GLMEmbeddings()
|
||
|
try:
|
||
|
query_engine = load_vectors(args.vector_path).as_query_engine(
|
||
|
response_synthesizer=CodegeexSynthesizer(args)
|
||
|
)
|
||
|
except Exception as e:
|
||
|
print(f"Fail to load vectors, caused by {e}")
|
||
|
exit()
|
||
|
demo = gr.ChatInterface(chat).queue()
|
||
|
demo.launch(server_name="127.0.0.1", server_port=8080)
|