mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-10-20 01:21:49 +00:00
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
|
from pydantic import Field
|
||
|
from transformers import AutoModel, AutoTokenizer
|
||
|
from typing import Iterator
|
||
|
import torch
|
||
|
class StreamProcessor:
|
||
|
def __init__(self):
|
||
|
self.previous_str = ""
|
||
|
|
||
|
def get_new_part(self, new_str):
|
||
|
new_part = new_str[len(self.previous_str):]
|
||
|
self.previous_str = new_str
|
||
|
return new_part
|
||
|
class CodegeexChatModel():
|
||
|
device: str = Field(description="device to load the model")
|
||
|
tokenizer = Field(description="model's tokenizer")
|
||
|
model = Field(description="Codegeex model")
|
||
|
temperature: float = Field(description="temperature to use for the model.")
|
||
|
|
||
|
def __init__(self,model_name_or_path):
|
||
|
super().__init__()
|
||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||
|
self.model = AutoModel.from_pretrained(
|
||
|
model_name_or_path,
|
||
|
trust_remote_code=True
|
||
|
).to(self.device).eval()
|
||
|
print("Model has been initialized.")
|
||
|
|
||
|
def chat(self, prompt,temperature=0.2,top_p=0.95):
|
||
|
try:
|
||
|
response, _ = self.model.chat(
|
||
|
self.tokenizer,
|
||
|
query=prompt,
|
||
|
max_length=4012,
|
||
|
temperature=temperature,
|
||
|
top_p=top_p
|
||
|
)
|
||
|
return response
|
||
|
except Exception as e:
|
||
|
return f"error:{e}"
|
||
|
|
||
|
def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
|
||
|
|
||
|
try:
|
||
|
stream_processor = StreamProcessor()
|
||
|
for response, _ in self.model.stream_chat(
|
||
|
self.tokenizer,
|
||
|
query=prompt,
|
||
|
max_length=4012,
|
||
|
temperature=temperature,
|
||
|
top_p=top_p
|
||
|
):
|
||
|
yield stream_processor.get_new_part(response)
|
||
|
except Exception as e:
|
||
|
yield f'error: {e}'
|