mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
Add 'local mode' tutorials
This commit is contained in:
parent
8be6c39820
commit
81ff1bea82
37
local_mode/README.md
Normal file
37
local_mode/README.md
Normal file
|
@ -0,0 +1,37 @@
|
|||
![](../resources/logo.jpeg)
|
||||
|
||||
[English](README.md) | [中文](README_zh.md)
|
||||
|
||||
## Local Mode
|
||||
|
||||
The new version of the CodeGeeX plugin **supports offline mode**, allowing the use of offline deployed models to complete automatic
|
||||
completion and simple conversation functions.
|
||||
|
||||
## Usage Tutorial
|
||||
|
||||
### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
cd local_mode
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. Run the Project
|
||||
|
||||
```bash
|
||||
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
|
||||
>>> Running on local URL: http://127.0.0.1:8080
|
||||
```
|
||||
|
||||
### 3. Set API Address and Model Name
|
||||
|
||||
As shown in the figure below, after opening the plugin with the local mode, enter the API address and model name in the settings.
|
||||
![](resources/pic1.png)
|
||||
|
||||
### 4. Start Using
|
||||
|
||||
Click 'Connect' to test, or click 'Ask CodeGeeX' to start using.
|
||||
|
||||
## Demo
|
||||
|
||||
![](resources/demo.gif)
|
37
local_mode/README_zh.md
Normal file
37
local_mode/README_zh.md
Normal file
|
@ -0,0 +1,37 @@
|
|||
![](../resources/logo.jpeg)
|
||||
|
||||
[English](README.md) | [中文](README_zh.md)
|
||||
|
||||
## 本地模式
|
||||
|
||||
CodeGeeX新版插件**支持离线模式**,可使用离线部署的模型完成自动补全以及简单对话功能。
|
||||
|
||||
## 使用教程
|
||||
|
||||
### 1. 安装依赖项
|
||||
|
||||
```bash
|
||||
cd local_mode
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 运行项目
|
||||
|
||||
```bash
|
||||
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
|
||||
|
||||
>>> Running on local URL: http://127.0.0.1:8080
|
||||
```
|
||||
|
||||
### 3. 设置api地址和模型名称
|
||||
|
||||
如下图所示,打开插件后进入本地模式,在设置中输入api地址和模型名称。
|
||||
![](resources/pic1.png)
|
||||
|
||||
### 4. 开始使用
|
||||
|
||||
点击‘连接’进行测试,或点击‘Ask CodeGeeX’即可开始使用。
|
||||
|
||||
## Demo
|
||||
|
||||
![](resources/demo_zh.gif)
|
51
local_mode/main.py
Normal file
51
local_mode/main.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
"""
|
||||
coding : utf-8
|
||||
@Date : 2024/7/10
|
||||
@Author : Shaobo
|
||||
@Describe:
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from protocols.openai_api import ChatCompletionRequest
|
||||
from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--bf16", type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat(request: ChatCompletionRequest):
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(chat_with_codegeex(request))
|
||||
except Exception as e:
|
||||
return JSONResponse(e, status_code=500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
init_model(args)
|
||||
uvicorn.run(app, host="127.0.0.1", port=8080)
|
78
local_mode/models/codegeex.py
Normal file
78
local_mode/models/codegeex.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
coding : utf-8
|
||||
@Date : 2024/7/10
|
||||
@Author : Shaobo
|
||||
@Describe:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
|
||||
from sseclient import Event
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
SYS_PROMPT = "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
|
||||
|
||||
|
||||
class CodegeexChatModel:
|
||||
def __init__(self, args):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
|
||||
if args.bf16:
|
||||
self.model = AutoModel.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(args.device).eval()
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
trust_remote_code=True
|
||||
).to(args.device).eval()
|
||||
print("Model is initialized.")
|
||||
|
||||
def stream_chat(self, request: ChatCompletionRequest):
|
||||
try:
|
||||
length = 0
|
||||
for i, (response, _) in enumerate(self.model.stream_chat(
|
||||
self.tokenizer,
|
||||
query=request.messages[-1].content,
|
||||
history=[msg.model_dump() for msg in request.messages[:-1]],
|
||||
max_new_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.presence_penalty
|
||||
)):
|
||||
resp = ChatCompletionStreamResponse()
|
||||
resp.choices[0].index = i
|
||||
resp.choices[0].delta.content = response[length:]
|
||||
event = Event(id=resp.id, data=resp.json(), event='message')
|
||||
yield event.dump()
|
||||
length = len(response)
|
||||
resp = ChatCompletionStreamResponse()
|
||||
resp.choices[0].finish_reason = 'stop'
|
||||
event = Event(id=resp.id, data=resp.json(), event='message')
|
||||
yield event.dump()
|
||||
except Exception as e:
|
||||
resp = ChatCompletionStreamResponse()
|
||||
resp.choices[0].finish_reason = 'stop'
|
||||
event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
|
||||
yield event.dump()
|
||||
|
||||
def chat(self, request: ChatCompletionRequest):
|
||||
try:
|
||||
response, _ = self.model.chat(
|
||||
self.tokenizer,
|
||||
query=request.messages[0].content,
|
||||
history=[msg.model_dump() for msg in request.messages[:-1]],
|
||||
max_new_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.presence_penalty
|
||||
)
|
||||
resp = ChatCompletionResponse()
|
||||
resp.choices[0].message.content = response
|
||||
resp.choices[0].finish_reason = 'stop'
|
||||
# event = Event(id=resp.id, data=resp.json(), event='message')
|
||||
# return event.dump()
|
||||
return resp.model_dump()
|
||||
except Exception as e:
|
||||
return f"请求报错,错误原因:{e}"
|
61
local_mode/protocols/openai_api.py
Normal file
61
local_mode/protocols/openai_api.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
"""
|
||||
coding : utf-8
|
||||
@Date : 2024/7/11
|
||||
@Author : Shaobo
|
||||
@Describe:
|
||||
"""
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import shortuuid
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str = "codegeex4"
|
||||
messages: list[ChatMessage]
|
||||
temperature: float = 0.2
|
||||
top_p: float = 1.0
|
||||
max_tokens: int = 1024
|
||||
stop: list[str] = ['<|user|>', '<|assistant|>', '<|observation|>', '<|endoftext|>']
|
||||
stream: bool = True
|
||||
presence_penalty: float = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int = 0
|
||||
delta: DeltaMessage = DeltaMessage(role='assistant', content='')
|
||||
finish_reason: Literal["stop", "length"] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str = f"chatcmpl-{shortuuid.random()}"
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = int(time.time())
|
||||
model: str = "codegeex4"
|
||||
choices: list[ChatCompletionResponseStreamChoice] = [ChatCompletionResponseStreamChoice()]
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int = 0
|
||||
message: ChatMessage = ChatMessage(role="assistant", content="")
|
||||
finish_reason: Literal["stop", "length"] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = f"chatcmpl-{shortuuid.random()}"
|
||||
object: str = "chat.completion"
|
||||
created: int = int(time.time())
|
||||
model: str = "codegeex4"
|
||||
choices: list[ChatCompletionResponseChoice] = [ChatCompletionResponseChoice()]
|
||||
# usage: UsageInfo
|
13
local_mode/requirements.txt
Normal file
13
local_mode/requirements.txt
Normal file
|
@ -0,0 +1,13 @@
|
|||
accelerate==0.31.0
|
||||
fastapi==0.111.0
|
||||
openai==1.35.12
|
||||
pydantic==2.8.2
|
||||
regex==2024.5.15
|
||||
requests==2.32.3
|
||||
shortuuid==1.0.13
|
||||
sseclient==0.0.27
|
||||
starlette==0.37.2
|
||||
tiktoken==0.7.0
|
||||
torch==2.3.1
|
||||
transformers==4.39.0
|
||||
uvicorn==0.30.1
|
BIN
local_mode/resources/demo.gif
Normal file
BIN
local_mode/resources/demo.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.2 MiB |
BIN
local_mode/resources/demo_zh.gif
Normal file
BIN
local_mode/resources/demo_zh.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.3 MiB |
BIN
local_mode/resources/pic1.png
Normal file
BIN
local_mode/resources/pic1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
22
local_mode/services/chat.py
Normal file
22
local_mode/services/chat.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""
|
||||
coding : utf-8
|
||||
@Date : 2024/7/10
|
||||
@Author : Shaobo
|
||||
@Describe:
|
||||
"""
|
||||
from models.codegeex import CodegeexChatModel
|
||||
|
||||
model: CodegeexChatModel
|
||||
|
||||
|
||||
def stream_chat_with_codegeex(request):
|
||||
yield from model.stream_chat(request)
|
||||
|
||||
|
||||
def chat_with_codegeex(request):
|
||||
return model.chat(request)
|
||||
|
||||
|
||||
def init_model(args):
|
||||
global model
|
||||
model = CodegeexChatModel(args)
|
Loading…
Reference in New Issue
Block a user