Add 'local mode' tutorials

This commit is contained in:
shaobo 2024-07-11 23:07:02 +08:00
parent 8be6c39820
commit 81ff1bea82
10 changed files with 299 additions and 0 deletions

37
local_mode/README.md Normal file
View File

@ -0,0 +1,37 @@
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## Local Mode
The new version of the CodeGeeX plugin **supports offline mode**, allowing the use of offline deployed models to complete automatic
completion and simple conversation functions.
## Usage Tutorial
### 1. Install Dependencies
```bash
cd local_mode
pip install -r requirements.txt
```
### 2. Run the Project
```bash
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
>>> Running on local URL: http://127.0.0.1:8080
```
### 3. Set API Address and Model Name
As shown in the figure below, after opening the plugin with the local mode, enter the API address and model name in the settings.
![](resources/pic1.png)
### 4. Start Using
Click 'Connect' to test, or click 'Ask CodeGeeX' to start using.
## Demo
![](resources/demo.gif)

37
local_mode/README_zh.md Normal file
View File

@ -0,0 +1,37 @@
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## 本地模式
CodeGeeX新版插件**支持离线模式**,可使用离线部署的模型完成自动补全以及简单对话功能。
## 使用教程
### 1. 安装依赖项
```bash
cd local_mode
pip install -r requirements.txt
```
### 2. 运行项目
```bash
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
>>> Running on local URL: http://127.0.0.1:8080
```
### 3. 设置api地址和模型名称
如下图所示打开插件后进入本地模式在设置中输入api地址和模型名称。
![](resources/pic1.png)
### 4. 开始使用
点击连接进行测试或点击Ask CodeGeeX即可开始使用。
## Demo
![](resources/demo_zh.gif)

51
local_mode/main.py Normal file
View File

@ -0,0 +1,51 @@
"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
import argparse
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.responses import StreamingResponse
from protocols.openai_api import ChatCompletionRequest
from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--bf16", type=bool, default=False)
return parser.parse_args()
@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest):
try:
if request.stream:
return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
else:
return JSONResponse(chat_with_codegeex(request))
except Exception as e:
return JSONResponse(e, status_code=500)
if __name__ == "__main__":
args = parse_arguments()
init_model(args)
uvicorn.run(app, host="127.0.0.1", port=8080)

View File

@ -0,0 +1,78 @@
"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
import torch
from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
from sseclient import Event
from transformers import AutoTokenizer, AutoModel
SYS_PROMPT = "你是一位智能编程助手你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题并提供格式规范、可以执行、准确安全的代码并在必要时提供详细的解释。"
class CodegeexChatModel:
def __init__(self, args):
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
if args.bf16:
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(args.device).eval()
else:
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True
).to(args.device).eval()
print("Model is initialized.")
def stream_chat(self, request: ChatCompletionRequest):
try:
length = 0
for i, (response, _) in enumerate(self.model.stream_chat(
self.tokenizer,
query=request.messages[-1].content,
history=[msg.model_dump() for msg in request.messages[:-1]],
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
repetition_penalty=request.presence_penalty
)):
resp = ChatCompletionStreamResponse()
resp.choices[0].index = i
resp.choices[0].delta.content = response[length:]
event = Event(id=resp.id, data=resp.json(), event='message')
yield event.dump()
length = len(response)
resp = ChatCompletionStreamResponse()
resp.choices[0].finish_reason = 'stop'
event = Event(id=resp.id, data=resp.json(), event='message')
yield event.dump()
except Exception as e:
resp = ChatCompletionStreamResponse()
resp.choices[0].finish_reason = 'stop'
event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
yield event.dump()
def chat(self, request: ChatCompletionRequest):
try:
response, _ = self.model.chat(
self.tokenizer,
query=request.messages[0].content,
history=[msg.model_dump() for msg in request.messages[:-1]],
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
repetition_penalty=request.presence_penalty
)
resp = ChatCompletionResponse()
resp.choices[0].message.content = response
resp.choices[0].finish_reason = 'stop'
# event = Event(id=resp.id, data=resp.json(), event='message')
# return event.dump()
return resp.model_dump()
except Exception as e:
return f"请求报错,错误原因:{e}"

View File

@ -0,0 +1,61 @@
"""
coding : utf-8
@Date : 2024/7/11
@Author : Shaobo
@Describe:
"""
import time
from typing import Literal
import shortuuid
from pydantic import BaseModel
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "codegeex4"
messages: list[ChatMessage]
temperature: float = 0.2
top_p: float = 1.0
max_tokens: int = 1024
stop: list[str] = ['<|user|>', '<|assistant|>', '<|observation|>', '<|endoftext|>']
stream: bool = True
presence_penalty: float = None
class DeltaMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseStreamChoice(BaseModel):
index: int = 0
delta: DeltaMessage = DeltaMessage(role='assistant', content='')
finish_reason: Literal["stop", "length"] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = f"chatcmpl-{shortuuid.random()}"
object: str = "chat.completion.chunk"
created: int = int(time.time())
model: str = "codegeex4"
choices: list[ChatCompletionResponseStreamChoice] = [ChatCompletionResponseStreamChoice()]
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage = ChatMessage(role="assistant", content="")
finish_reason: Literal["stop", "length"] = None
class ChatCompletionResponse(BaseModel):
id: str = f"chatcmpl-{shortuuid.random()}"
object: str = "chat.completion"
created: int = int(time.time())
model: str = "codegeex4"
choices: list[ChatCompletionResponseChoice] = [ChatCompletionResponseChoice()]
# usage: UsageInfo

View File

@ -0,0 +1,13 @@
accelerate==0.31.0
fastapi==0.111.0
openai==1.35.12
pydantic==2.8.2
regex==2024.5.15
requests==2.32.3
shortuuid==1.0.13
sseclient==0.0.27
starlette==0.37.2
tiktoken==0.7.0
torch==2.3.1
transformers==4.39.0
uvicorn==0.30.1

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

View File

@ -0,0 +1,22 @@
"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
from models.codegeex import CodegeexChatModel
model: CodegeexChatModel
def stream_chat_with_codegeex(request):
yield from model.stream_chat(request)
def chat_with_codegeex(request):
return model.chat(request)
def init_model(args):
global model
model = CodegeexChatModel(args)