""" coding : utf-8 @Date : 2024/7/10 @Author : Shaobo @Describe: """ import torch from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse from sseclient import Event from transformers import AutoTokenizer, AutoModel class CodegeexChatModel: def __init__(self, args): self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) if args.bf16: self.model = AutoModel.from_pretrained( args.model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(args.device).eval() else: self.model = AutoModel.from_pretrained( args.model_name_or_path, trust_remote_code=True ).to(args.device).eval() print("Model is initialized.") def stream_chat(self, request: ChatCompletionRequest): try: inputs = self.tokenizer.apply_chat_template( conversation=[msg.model_dump() for msg in request.messages], add_generation_prompt=True, return_tensors="pt", return_dict=True ).to(self.model.device) gen_configs = { "max_new_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, "repetition_penalty": request.presence_penalty } length = 0 for i, outputs in enumerate(self.model.stream_generate(**inputs, **gen_configs)): response = self.tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):-1]) resp = ChatCompletionStreamResponse() resp.choices[0].index = i resp.choices[0].delta.content = response[length:] event = Event(id=resp.id, data=resp.json(), event='message') yield event.dump() length = len(response) resp = ChatCompletionStreamResponse() resp.choices[0].finish_reason = 'stop' event = Event(id=resp.id, data=resp.json(), event='message') yield event.dump() except Exception as e: resp = ChatCompletionStreamResponse() resp.choices[0].finish_reason = 'stop' event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message') yield event.dump() def chat(self, request: ChatCompletionRequest): try: response, _ = self.model.chat( self.tokenizer, query=request.messages[-1].content, history=[msg.model_dump() for msg in request.messages[:-1]], max_new_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, repetition_penalty=request.presence_penalty ) resp = ChatCompletionResponse() resp.choices[0].message.content = response resp.choices[0].finish_reason = 'stop' return resp.model_dump() except Exception as e: return f"请求报错,错误原因:{e}"