CodeGeeX4/repodemo/llm/local/codegeex4.py

51 lines
1.6 KiB
Python
Raw Normal View History

import torch
from pydantic import Field
from transformers import AutoModel, AutoTokenizer
2024-07-09 03:37:30 +00:00
class CodegeexChatModel:
device: str = Field(description="device to load the model")
tokenizer = Field(description="model's tokenizer")
model = Field(description="Codegeex model")
temperature: float = Field(description="temperature to use for the model.")
2024-07-09 03:37:30 +00:00
def __init__(self, model_name_or_path):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
2024-07-09 03:37:30 +00:00
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.model = (
AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
.to(self.device)
.eval()
)
print("Model has been initialized.")
2024-07-09 03:37:30 +00:00
def chat(self, prompt, temperature=0.2, top_p=0.95):
try:
response, _ = self.model.chat(
self.tokenizer,
query=prompt,
max_length=120000,
temperature=temperature,
2024-07-09 03:37:30 +00:00
top_p=top_p,
)
return response
except Exception as e:
return f"error: {e}"
2024-07-09 03:37:30 +00:00
def stream_chat(self, prompt, temperature=0.2, top_p=0.95):
try:
for response, _ in self.model.stream_chat(
self.tokenizer,
query=prompt,
max_length=120000,
temperature=temperature,
top_p=top_p,
):
yield response
except Exception as e:
2024-07-09 03:37:30 +00:00
yield f"error: {e}"