mirror of
https://github.com/JasonYANG170/CodeGeeX4.git
synced 2024-11-23 12:16:33 +00:00
add a function call demo
This commit is contained in:
parent
be3b3bbaf9
commit
b06faf7347
74
function_call_demo/main.py
Normal file
74
function_call_demo/main.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_name_or_path = "THUDM/codegeex4-all-9b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True
|
||||
).to(device).eval()
|
||||
|
||||
tool_content = {
|
||||
"function": [
|
||||
{
|
||||
"name": "weather",
|
||||
"description": "Use for searching weather at a specific location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"description": "the location need to check the weather",
|
||||
"tyoe": "str",
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
response, _ = model.chat(
|
||||
tokenizer,
|
||||
query="Tell me about the weather in Beijing",
|
||||
history=[{"role": "tool", "content": tool_content}],
|
||||
max_new_tokens=1024,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
# support parallel calls, thus the result is a list
|
||||
functions = post_process(response)
|
||||
try:
|
||||
return [json.loads(func) for func in functions if func]
|
||||
# get rid of some possible invalid formats
|
||||
except JSONDecodeError:
|
||||
try:
|
||||
return [json.loads(func.replace('(', '[').replace(')', ']')) for func in functions if func]
|
||||
except JSONDecodeError:
|
||||
try:
|
||||
return [json.loads(func.replace("'", '"')) for func in functions if func]
|
||||
except JSONDecodeError as e:
|
||||
return [{"answer": response, "errors": e}]
|
||||
|
||||
|
||||
def post_process(text: str) -> list[str]:
|
||||
"""
|
||||
Process model's response.
|
||||
In case there are parallel calls, each call is warpped with ```json```.
|
||||
"""
|
||||
pattern = r'```json(.*?)```'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
return matches
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
output = main()
|
||||
print(output) # [{"name": "weather", "arguments": {"location": "Beijing"}}]
|
Loading…
Reference in New Issue
Block a user