2024-07-11 15:07:02 +00:00
|
|
|
|
"""
|
|
|
|
|
coding : utf-8
|
|
|
|
|
@Date : 2024/7/10
|
|
|
|
|
@Author : Shaobo
|
|
|
|
|
@Describe:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
|
|
|
|
|
from sseclient import Event
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CodegeexChatModel:
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
|
|
|
|
|
if args.bf16:
|
|
|
|
|
self.model = AutoModel.from_pretrained(
|
|
|
|
|
args.model_name_or_path,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
|
|
).to(args.device).eval()
|
|
|
|
|
else:
|
|
|
|
|
self.model = AutoModel.from_pretrained(
|
|
|
|
|
args.model_name_or_path,
|
|
|
|
|
trust_remote_code=True
|
|
|
|
|
).to(args.device).eval()
|
|
|
|
|
print("Model is initialized.")
|
|
|
|
|
|
|
|
|
|
def stream_chat(self, request: ChatCompletionRequest):
|
|
|
|
|
try:
|
2024-07-15 13:15:33 +00:00
|
|
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
|
|
|
conversation=[msg.model_dump() for msg in request.messages],
|
|
|
|
|
add_generation_prompt=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
return_dict=True
|
|
|
|
|
).to(self.model.device)
|
|
|
|
|
gen_configs = {
|
|
|
|
|
"max_new_tokens": request.max_tokens,
|
|
|
|
|
"temperature": request.temperature,
|
|
|
|
|
"top_p": request.top_p,
|
2024-07-15 13:24:38 +00:00
|
|
|
|
"repetition_penalty": request.presence_penalty,
|
|
|
|
|
"do_sample": True if request.temperature else request.temperature,
|
2024-07-15 13:15:33 +00:00
|
|
|
|
}
|
2024-07-11 15:07:02 +00:00
|
|
|
|
length = 0
|
2024-07-16 10:55:27 +00:00
|
|
|
|
for outputs in self.model.stream_generate(**inputs, **gen_configs):
|
2024-07-15 13:15:33 +00:00
|
|
|
|
response = self.tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):-1])
|
2024-07-15 13:24:38 +00:00
|
|
|
|
if not response or response[-1] == "<EFBFBD>":
|
|
|
|
|
continue
|
2024-07-11 15:07:02 +00:00
|
|
|
|
resp = ChatCompletionStreamResponse()
|
|
|
|
|
resp.choices[0].delta.content = response[length:]
|
2024-07-16 10:55:27 +00:00
|
|
|
|
event = Event(data=resp.json(), event='message')
|
2024-07-11 15:07:02 +00:00
|
|
|
|
yield event.dump()
|
|
|
|
|
length = len(response)
|
|
|
|
|
resp = ChatCompletionStreamResponse()
|
|
|
|
|
resp.choices[0].finish_reason = 'stop'
|
2024-07-16 10:55:27 +00:00
|
|
|
|
event = Event(data=resp.json(), event='message')
|
2024-07-11 15:07:02 +00:00
|
|
|
|
yield event.dump()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
resp = ChatCompletionStreamResponse()
|
|
|
|
|
resp.choices[0].finish_reason = 'stop'
|
2024-07-16 10:55:27 +00:00
|
|
|
|
event = Event(data=f"请求报错,错误原因:{e}", event='message')
|
2024-07-11 15:07:02 +00:00
|
|
|
|
yield event.dump()
|
|
|
|
|
|
|
|
|
|
def chat(self, request: ChatCompletionRequest):
|
|
|
|
|
try:
|
|
|
|
|
response, _ = self.model.chat(
|
|
|
|
|
self.tokenizer,
|
2024-07-11 15:27:15 +00:00
|
|
|
|
query=request.messages[-1].content,
|
2024-07-11 15:07:02 +00:00
|
|
|
|
history=[msg.model_dump() for msg in request.messages[:-1]],
|
|
|
|
|
max_new_tokens=request.max_tokens,
|
|
|
|
|
temperature=request.temperature,
|
|
|
|
|
top_p=request.top_p,
|
|
|
|
|
repetition_penalty=request.presence_penalty
|
|
|
|
|
)
|
|
|
|
|
resp = ChatCompletionResponse()
|
|
|
|
|
resp.choices[0].message.content = response
|
|
|
|
|
resp.choices[0].finish_reason = 'stop'
|
|
|
|
|
return resp.model_dump()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return f"请求报错,错误原因:{e}"
|