CodeGeeX4/local_mode/main.py
2024-07-11 23:07:02 +08:00

52 lines
1.4 KiB
Python

"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
import argparse
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.responses import StreamingResponse
from protocols.openai_api import ChatCompletionRequest
from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--bf16", type=bool, default=False)
return parser.parse_args()
@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest):
try:
if request.stream:
return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
else:
return JSONResponse(chat_with_codegeex(request))
except Exception as e:
return JSONResponse(e, status_code=500)
if __name__ == "__main__":
args = parse_arguments()
init_model(args)
uvicorn.run(app, host="127.0.0.1", port=8080)