mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
Add 'local mode' tutorials
This commit is contained in:
parent
8be6c39820
commit
81ff1bea82
37
local_mode/README.md
Normal file
37
local_mode/README.md
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
![](../resources/logo.jpeg)
|
||||||
|
|
||||||
|
[English](README.md) | [中文](README_zh.md)
|
||||||
|
|
||||||
|
## Local Mode
|
||||||
|
|
||||||
|
The new version of the CodeGeeX plugin **supports offline mode**, allowing the use of offline deployed models to complete automatic
|
||||||
|
completion and simple conversation functions.
|
||||||
|
|
||||||
|
## Usage Tutorial
|
||||||
|
|
||||||
|
### 1. Install Dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd local_mode
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Run the Project
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
|
||||||
|
>>> Running on local URL: http://127.0.0.1:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Set API Address and Model Name
|
||||||
|
|
||||||
|
As shown in the figure below, after opening the plugin with the local mode, enter the API address and model name in the settings.
|
||||||
|
![](resources/pic1.png)
|
||||||
|
|
||||||
|
### 4. Start Using
|
||||||
|
|
||||||
|
Click 'Connect' to test, or click 'Ask CodeGeeX' to start using.
|
||||||
|
|
||||||
|
## Demo
|
||||||
|
|
||||||
|
![](resources/demo.gif)
|
37
local_mode/README_zh.md
Normal file
37
local_mode/README_zh.md
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
![](../resources/logo.jpeg)
|
||||||
|
|
||||||
|
[English](README.md) | [中文](README_zh.md)
|
||||||
|
|
||||||
|
## 本地模式
|
||||||
|
|
||||||
|
CodeGeeX新版插件**支持离线模式**,可使用离线部署的模型完成自动补全以及简单对话功能。
|
||||||
|
|
||||||
|
## 使用教程
|
||||||
|
|
||||||
|
### 1. 安装依赖项
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd local_mode
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 运行项目
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
|
||||||
|
|
||||||
|
>>> Running on local URL: http://127.0.0.1:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 设置api地址和模型名称
|
||||||
|
|
||||||
|
如下图所示,打开插件后进入本地模式,在设置中输入api地址和模型名称。
|
||||||
|
![](resources/pic1.png)
|
||||||
|
|
||||||
|
### 4. 开始使用
|
||||||
|
|
||||||
|
点击‘连接’进行测试,或点击‘Ask CodeGeeX’即可开始使用。
|
||||||
|
|
||||||
|
## Demo
|
||||||
|
|
||||||
|
![](resources/demo_zh.gif)
|
51
local_mode/main.py
Normal file
51
local_mode/main.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
"""
|
||||||
|
coding : utf-8
|
||||||
|
@Date : 2024/7/10
|
||||||
|
@Author : Shaobo
|
||||||
|
@Describe:
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from protocols.openai_api import ChatCompletionRequest
|
||||||
|
from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
parser.add_argument("--bf16", type=bool, default=False)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat(request: ChatCompletionRequest):
|
||||||
|
try:
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return JSONResponse(chat_with_codegeex(request))
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(e, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_arguments()
|
||||||
|
init_model(args)
|
||||||
|
uvicorn.run(app, host="127.0.0.1", port=8080)
|
78
local_mode/models/codegeex.py
Normal file
78
local_mode/models/codegeex.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
"""
|
||||||
|
coding : utf-8
|
||||||
|
@Date : 2024/7/10
|
||||||
|
@Author : Shaobo
|
||||||
|
@Describe:
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
|
||||||
|
from sseclient import Event
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
|
SYS_PROMPT = "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
|
||||||
|
|
||||||
|
|
||||||
|
class CodegeexChatModel:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
|
||||||
|
if args.bf16:
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
).to(args.device).eval()
|
||||||
|
else:
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
).to(args.device).eval()
|
||||||
|
print("Model is initialized.")
|
||||||
|
|
||||||
|
def stream_chat(self, request: ChatCompletionRequest):
|
||||||
|
try:
|
||||||
|
length = 0
|
||||||
|
for i, (response, _) in enumerate(self.model.stream_chat(
|
||||||
|
self.tokenizer,
|
||||||
|
query=request.messages[-1].content,
|
||||||
|
history=[msg.model_dump() for msg in request.messages[:-1]],
|
||||||
|
max_new_tokens=request.max_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
repetition_penalty=request.presence_penalty
|
||||||
|
)):
|
||||||
|
resp = ChatCompletionStreamResponse()
|
||||||
|
resp.choices[0].index = i
|
||||||
|
resp.choices[0].delta.content = response[length:]
|
||||||
|
event = Event(id=resp.id, data=resp.json(), event='message')
|
||||||
|
yield event.dump()
|
||||||
|
length = len(response)
|
||||||
|
resp = ChatCompletionStreamResponse()
|
||||||
|
resp.choices[0].finish_reason = 'stop'
|
||||||
|
event = Event(id=resp.id, data=resp.json(), event='message')
|
||||||
|
yield event.dump()
|
||||||
|
except Exception as e:
|
||||||
|
resp = ChatCompletionStreamResponse()
|
||||||
|
resp.choices[0].finish_reason = 'stop'
|
||||||
|
event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
|
||||||
|
yield event.dump()
|
||||||
|
|
||||||
|
def chat(self, request: ChatCompletionRequest):
|
||||||
|
try:
|
||||||
|
response, _ = self.model.chat(
|
||||||
|
self.tokenizer,
|
||||||
|
query=request.messages[0].content,
|
||||||
|
history=[msg.model_dump() for msg in request.messages[:-1]],
|
||||||
|
max_new_tokens=request.max_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
repetition_penalty=request.presence_penalty
|
||||||
|
)
|
||||||
|
resp = ChatCompletionResponse()
|
||||||
|
resp.choices[0].message.content = response
|
||||||
|
resp.choices[0].finish_reason = 'stop'
|
||||||
|
# event = Event(id=resp.id, data=resp.json(), event='message')
|
||||||
|
# return event.dump()
|
||||||
|
return resp.model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
return f"请求报错,错误原因:{e}"
|
61
local_mode/protocols/openai_api.py
Normal file
61
local_mode/protocols/openai_api.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
"""
|
||||||
|
coding : utf-8
|
||||||
|
@Date : 2024/7/11
|
||||||
|
@Author : Shaobo
|
||||||
|
@Describe:
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import shortuuid
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str = "codegeex4"
|
||||||
|
messages: list[ChatMessage]
|
||||||
|
temperature: float = 0.2
|
||||||
|
top_p: float = 1.0
|
||||||
|
max_tokens: int = 1024
|
||||||
|
stop: list[str] = ['<|user|>', '<|assistant|>', '<|observation|>', '<|endoftext|>']
|
||||||
|
stream: bool = True
|
||||||
|
presence_penalty: float = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int = 0
|
||||||
|
delta: DeltaMessage = DeltaMessage(role='assistant', content='')
|
||||||
|
finish_reason: Literal["stop", "length"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str = f"chatcmpl-{shortuuid.random()}"
|
||||||
|
object: str = "chat.completion.chunk"
|
||||||
|
created: int = int(time.time())
|
||||||
|
model: str = "codegeex4"
|
||||||
|
choices: list[ChatCompletionResponseStreamChoice] = [ChatCompletionResponseStreamChoice()]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
index: int = 0
|
||||||
|
message: ChatMessage = ChatMessage(role="assistant", content="")
|
||||||
|
finish_reason: Literal["stop", "length"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str = f"chatcmpl-{shortuuid.random()}"
|
||||||
|
object: str = "chat.completion"
|
||||||
|
created: int = int(time.time())
|
||||||
|
model: str = "codegeex4"
|
||||||
|
choices: list[ChatCompletionResponseChoice] = [ChatCompletionResponseChoice()]
|
||||||
|
# usage: UsageInfo
|
13
local_mode/requirements.txt
Normal file
13
local_mode/requirements.txt
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
accelerate==0.31.0
|
||||||
|
fastapi==0.111.0
|
||||||
|
openai==1.35.12
|
||||||
|
pydantic==2.8.2
|
||||||
|
regex==2024.5.15
|
||||||
|
requests==2.32.3
|
||||||
|
shortuuid==1.0.13
|
||||||
|
sseclient==0.0.27
|
||||||
|
starlette==0.37.2
|
||||||
|
tiktoken==0.7.0
|
||||||
|
torch==2.3.1
|
||||||
|
transformers==4.39.0
|
||||||
|
uvicorn==0.30.1
|
BIN
local_mode/resources/demo.gif
Normal file
BIN
local_mode/resources/demo.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.2 MiB |
BIN
local_mode/resources/demo_zh.gif
Normal file
BIN
local_mode/resources/demo_zh.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.3 MiB |
BIN
local_mode/resources/pic1.png
Normal file
BIN
local_mode/resources/pic1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
22
local_mode/services/chat.py
Normal file
22
local_mode/services/chat.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
"""
|
||||||
|
coding : utf-8
|
||||||
|
@Date : 2024/7/10
|
||||||
|
@Author : Shaobo
|
||||||
|
@Describe:
|
||||||
|
"""
|
||||||
|
from models.codegeex import CodegeexChatModel
|
||||||
|
|
||||||
|
model: CodegeexChatModel
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chat_with_codegeex(request):
|
||||||
|
yield from model.stream_chat(request)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_with_codegeex(request):
|
||||||
|
return model.chat(request)
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(args):
|
||||||
|
global model
|
||||||
|
model = CodegeexChatModel(args)
|
Loading…
Reference in New Issue
Block a user