diff --git a/local_mode/models/codegeex.py b/local_mode/models/codegeex.py index 30063cb..e9ddf66 100644 --- a/local_mode/models/codegeex.py +++ b/local_mode/models/codegeex.py @@ -39,11 +39,14 @@ class CodegeexChatModel: "max_new_tokens": request.max_tokens, "temperature": request.temperature, "top_p": request.top_p, - "repetition_penalty": request.presence_penalty + "repetition_penalty": request.presence_penalty, + "do_sample": True if request.temperature else request.temperature, } length = 0 for i, outputs in enumerate(self.model.stream_generate(**inputs, **gen_configs)): response = self.tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):-1]) + if not response or response[-1] == "�": + continue resp = ChatCompletionStreamResponse() resp.choices[0].index = i resp.choices[0].delta.content = response[length:]