CodeGeeX4/local_mode/models/codegeex.py
2024-07-11 23:07:02 +08:00

79 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
import torch
from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
from sseclient import Event
from transformers import AutoTokenizer, AutoModel
SYS_PROMPT = "你是一位智能编程助手你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题并提供格式规范、可以执行、准确安全的代码并在必要时提供详细的解释。"
class CodegeexChatModel:
def __init__(self, args):
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
if args.bf16:
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(args.device).eval()
else:
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True
).to(args.device).eval()
print("Model is initialized.")
def stream_chat(self, request: ChatCompletionRequest):
try:
length = 0
for i, (response, _) in enumerate(self.model.stream_chat(
self.tokenizer,
query=request.messages[-1].content,
history=[msg.model_dump() for msg in request.messages[:-1]],
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
repetition_penalty=request.presence_penalty
)):
resp = ChatCompletionStreamResponse()
resp.choices[0].index = i
resp.choices[0].delta.content = response[length:]
event = Event(id=resp.id, data=resp.json(), event='message')
yield event.dump()
length = len(response)
resp = ChatCompletionStreamResponse()
resp.choices[0].finish_reason = 'stop'
event = Event(id=resp.id, data=resp.json(), event='message')
yield event.dump()
except Exception as e:
resp = ChatCompletionStreamResponse()
resp.choices[0].finish_reason = 'stop'
event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
yield event.dump()
def chat(self, request: ChatCompletionRequest):
try:
response, _ = self.model.chat(
self.tokenizer,
query=request.messages[0].content,
history=[msg.model_dump() for msg in request.messages[:-1]],
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
repetition_penalty=request.presence_penalty
)
resp = ChatCompletionResponse()
resp.choices[0].message.content = response
resp.choices[0].finish_reason = 'stop'
# event = Event(id=resp.id, data=resp.json(), event='message')
# return event.dump()
return resp.model_dump()
except Exception as e:
return f"请求报错,错误原因:{e}"