import json import re from json import JSONDecodeError import torch from transformers import AutoModelForCausalLM, AutoTokenizer def main(): device = "cuda" if torch.cuda.is_available() else "cpu" model_name_or_path = "THUDM/codegeex4-all-9b" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, trust_remote_code=True ).to(device).eval() tool_content = { "function": [ { "name": "weather", "description": "Use for searching weather at a specific location", "parameters": { "type": "object", "properties": { "location": { "description": "the location need to check the weather", "tyoe": "str", } }, "required": [ "location" ] } } ] } response, _ = tokenizer, query="Tell me about the weather in Beijing", history=[{"role": "tool", "content": tool_content}], max_new_tokens=1024, temperature=0.1 ) # support parallel calls, thus the result is a list functions = post_process(response) try: return [json.loads(func) for func in functions if func] # get rid of some possible invalid formats except JSONDecodeError: try: return [json.loads(func.replace('(', '[').replace(')', ']')) for func in functions if func] except JSONDecodeError: try: return [json.loads(func.replace("'", '"')) for func in functions if func] except JSONDecodeError as e: return [{"answer": response, "errors": e}] def post_process(text: str) -> list[str]: """ Process model's response. In case there are parallel calls, each call is warpped with ```json```. """ pattern = r'```json(.*?)```' matches = re.findall(pattern, text, re.DOTALL) return matches if __name__ == '__main__': output = main() print(output) # [{"name": "weather", "arguments": {"location": "Beijing"}}]