diff --git a/repodemo/llm/api/codegeex4.py b/repodemo/llm/api/codegeex4.py index 2ca9f7c..a92e9bf 100644 --- a/repodemo/llm/api/codegeex4.py +++ b/repodemo/llm/api/codegeex4.py @@ -1,41 +1,39 @@ import requests import json -URL = "" #the url you deploy codegeex service +URL = "" # the url you deploy codegeex service + + def codegeex4(prompt, temperature=0.8, top_p=0.8): url = URL - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} data = { - 'inputs': prompt, - 'parameters': { - 'best_of':1, - 'do_sample': True, - 'max_new_tokens': 4012, - 'temperature': temperature, - 'top_p': top_p, - 'stop': ["<|endoftext|>", "<|user|>", "<|observation|>", "<|assistant|>"], - } + "inputs": prompt, + "parameters": { + "best_of": 1, + "do_sample": True, + "max_new_tokens": 4012, + "temperature": temperature, + "top_p": top_p, + "stop": ["<|endoftext|>", "<|user|>", "<|observation|>", "<|assistant|>"], + }, } response = requests.post(url, json=data, headers=headers, verify=False, stream=True) if response.status_code == 200: for line in response.iter_lines(): if line: - decoded_line = line.decode('utf-8').replace('data:', '').strip() + decoded_line = line.decode("utf-8").replace("data:", "").strip() if decoded_line: try: content = json.loads(decoded_line) - - token_text = content.get('token', {}).get('text', '') - if '<|endoftext|>' in token_text: - break + + token_text = content.get("token", {}).get("text", "") + if "<|endoftext|>" in token_text: + break yield token_text except json.JSONDecodeError: continue else: - print('请求失败:', response.status_code) - - + print("请求失败:", response.status_code) diff --git a/repodemo/llm/local/codegeex4.py b/repodemo/llm/local/codegeex4.py index 53df2f4..cd922a7 100644 --- a/repodemo/llm/local/codegeex4.py +++ b/repodemo/llm/local/codegeex4.py @@ -3,45 +3,49 @@ from transformers import AutoModel, AutoTokenizer from typing import Iterator import torch -class CodegeexChatModel(): + +class CodegeexChatModel: device: str = Field(description="device to load the model") tokenizer = Field(description="model's tokenizer") model = Field(description="Codegeex model") temperature: float = Field(description="temperature to use for the model.") - def __init__(self,model_name_or_path): + def __init__(self, model_name_or_path): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) - self.model = AutoModel.from_pretrained( - model_name_or_path, - trust_remote_code=True - ).to(self.device).eval() + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) + .to(self.device) + .eval() + ) print("Model has been initialized.") - def chat(self, prompt,temperature=0.2,top_p=0.95): + def chat(self, prompt, temperature=0.2, top_p=0.95): try: response, _ = self.model.chat( self.tokenizer, query=prompt, max_length=120000, temperature=temperature, - top_p=top_p + top_p=top_p, ) return response except Exception as e: return f"error:{e}" - def stream_chat(self,prompt,temperature=0.2,top_p=0.95): + def stream_chat(self, prompt, temperature=0.2, top_p=0.95): try: for response, _ in self.model.stream_chat( - self.tokenizer, - query=prompt, - max_length=120000, - temperature=temperature, - top_p=top_p + self.tokenizer, + query=prompt, + max_length=120000, + temperature=temperature, + top_p=top_p, ): yield response except Exception as e: - yield f'error: {e}' \ No newline at end of file + yield f"error: {e}" diff --git a/repodemo/prompts/base_prompt.py b/repodemo/prompts/base_prompt.py index 023391a..eb969b7 100644 --- a/repodemo/prompts/base_prompt.py +++ b/repodemo/prompts/base_prompt.py @@ -4,7 +4,7 @@ repo_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeG judge_task_prompt = """<|system|>\n你是一位任务分类专家,请你对用户的输入进行分类(问答/修改/正常),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字,如果用户输入的是一个与项目无关的问题则只需要输出正常两个字。<|user|>\n{user_input}<|assistant|>\n""" -web_judge_task_prompt ="""<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。<|user|>\n{user_input}\n这个问题需要进行联网来回答吗?仅回答“是”或者“否”。<|assistant|>\n""" +web_judge_task_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。<|user|>\n{user_input}\n这个问题需要进行联网来回答吗?仅回答“是”或者“否”。<|assistant|>\n""" # judge_task_prompt = """<|system|>\n你是一位任务分类专家,请你对用户的输入进行分类(问答/修改),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字。<|user|>\n{user_input}<|assistant|>\n""" web_search_prompy = """ @@ -19,24 +19,27 @@ web_search_prompy = """ - 除了代码和特定的名称和引用外,您的答案必须使用与问题相同的语言来撰写。 """.lstrip() -def get_cur_base_user_prompt(message_history,index_prompt = None,judge_context = ""): + +def get_cur_base_user_prompt(message_history, index_prompt=None, judge_context=""): user_prompt_tmp = """<|user|>\n{user_input}""" assistant_prompt_tmp = """<|assistant|>\n{assistant_input}""" history_prompt = "" - for i,message in enumerate(message_history): - if message['role'] == 'user': - if i==0 and index_prompt is not None: - history_prompt += "<|user|>\n"+index_prompt+message['content'] + for i, message in enumerate(message_history): + if message["role"] == "user": + if i == 0 and index_prompt is not None: + history_prompt += "<|user|>\n" + index_prompt + message["content"] else: - history_prompt += user_prompt_tmp.format(user_input=message['content']) - elif message['role'] == 'assistant': - history_prompt += assistant_prompt_tmp.format(assistant_input=message['content']) - + history_prompt += user_prompt_tmp.format(user_input=message["content"]) + elif message["role"] == "assistant": + history_prompt += assistant_prompt_tmp.format( + assistant_input=message["content"] + ) + # print("修改" not in judge_context) # print(judge_context) if "修改" not in judge_context: - result = base_system_prompt+history_prompt+"""<|assistant|>\n""" + result = base_system_prompt + history_prompt + """<|assistant|>\n""" else: - result = repo_system_prompt+history_prompt+"""<|assistant|>\n""" + result = repo_system_prompt + history_prompt + """<|assistant|>\n""" print(result) - return result \ No newline at end of file + return result diff --git a/repodemo/run.py b/repodemo/run.py index 10065ec..eda8a2f 100644 --- a/repodemo/run.py +++ b/repodemo/run.py @@ -1,8 +1,12 @@ import chainlit as cl from chainlit.input_widget import Slider from llm.api.codegeex4 import codegeex4 -from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt -from utils.tools import unzip_file,get_project_files_with_content +from prompts.base_prompt import ( + judge_task_prompt, + get_cur_base_user_prompt, + web_judge_task_prompt, +) +from utils.tools import unzip_file, get_project_files_with_content from utils.bingsearch import bing_search_prompt @@ -12,44 +16,36 @@ async def chat_profile(): cl.ChatProfile( name="chat聊天", markdown_description="聊天demo:支持多轮对话。", - starters = [ + starters=[ cl.Starter( - label="请你用python写一个快速排序。", - message="请你用python写一个快速排序。", - + label="请你用python写一个快速排序。", + message="请你用python写一个快速排序。", ), - - cl.Starter( - label="请你介绍一下自己。", - message="请你介绍一下自己。", - + cl.Starter( + label="请你介绍一下自己。", + message="请你介绍一下自己。", ), - cl.Starter( - label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", - message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", - + cl.Starter( + label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", + message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", ), - cl.Starter( - label="我是一个python初学者,请你告诉我怎么才能学好python。", - message="我是一个python初学者,请你告诉我怎么才能学好python。", - - ) - ] - + cl.Starter( + label="我是一个python初学者,请你告诉我怎么才能学好python。", + message="我是一个python初学者,请你告诉我怎么才能学好python。", + ), + ], ), cl.ChatProfile( name="联网问答", - markdown_description="联网能力dome:支持联网回答用户问题。", - + markdown_description="联网能力demo:支持联网回答用户问题。", ), cl.ChatProfile( name="上传本地项目", - markdown_description="项目级能力dome:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。", - - ) + markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。", + ), ] - + @cl.on_chat_start async def start(): settings = await cl.ChatSettings( @@ -74,33 +70,31 @@ async def start(): ).send() temperature = settings["temperature"] top_p = settings["top_p"] - cl.user_session.set('temperature',temperature) - cl.user_session.set('top_p',top_p) - cl.user_session.set( - "message_history", - [] - ) + cl.user_session.set("temperature", temperature) + cl.user_session.set("top_p", top_p) + cl.user_session.set("message_history", []) chat_profile = cl.user_session.get("chat_profile") - extract_dir = 'repodata' + extract_dir = "repodata" if chat_profile == "chat聊天": pass - elif chat_profile =="上传本地项目": + elif chat_profile == "上传本地项目": files = None while files == None: files = await cl.AskFileMessage( - content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50 + content="请上传项目zip压缩文件!", + accept={"application/zip": [".zip"]}, + max_size_mb=50, ).send() text_file = files[0] - extracted_path = unzip_file(text_file.path,extract_dir) + extracted_path = unzip_file(text_file.path, extract_dir) files_list = get_project_files_with_content(extracted_path) - cl.user_session.set("project_index",files_list) - if len(files_list)>0: + cl.user_session.set("project_index", files_list) + if len(files_list) > 0: await cl.Message( content=f"已成功上传,您可以开始对项目进行提问!", ).send() - - + @cl.on_message async def main(message: cl.Message): @@ -109,43 +103,57 @@ async def main(message: cl.Message): message_history.append({"role": "user", "content": message.content}) if chat_profile == "chat聊天": prompt_content = get_cur_base_user_prompt(message_history=message_history) - - elif chat_profile=="联网问答": - judge_tmp = codegeex4(web_judge_task_prompt.format(user_input=message.content),temperature=0.2,top_p = 0.95) - judge_context = '\n'.join(judge_tmp) + + elif chat_profile == "联网问答": + judge_tmp = codegeex4( + web_judge_task_prompt.format(user_input=message.content), + temperature=0.2, + top_p=0.95, + ) + judge_context = "\n".join(judge_tmp) print(judge_context) message_history.pop() - if '是' in judge_context: + if "是" in judge_context: prompt_tmp = bing_search_prompt(message.content) message_history.append({"role": "user", "content": prompt_tmp}) else: message_history.append({"role": "user", "content": message.content}) prompt_content = get_cur_base_user_prompt(message_history=message_history) - elif chat_profile =="上传本地项目" : - judge_tmp = codegeex4(judge_task_prompt.format(user_input=message.content),temperature=0.2,top_p = 0.95) - judge_context = '' + elif chat_profile == "上传本地项目": + judge_tmp = codegeex4( + judge_task_prompt.format(user_input=message.content), + temperature=0.2, + top_p=0.95, + ) + judge_context = "" for part in judge_tmp: - judge_context+=part - + judge_context += part + project_index = cl.user_session.get("project_index") index_prompt = "" index_tmp = """###PATH:{path}\n{code}\n""" for index in project_index: - index_prompt+=index_tmp.format(path=index['path'],code=index['content']) + index_prompt += index_tmp.format(path=index["path"], code=index["content"]) print(judge_context) - prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history) - - + prompt_content = ( + get_cur_base_user_prompt( + message_history=message_history, + index_prompt=index_prompt, + judge_context=judge_context, + ) + if "正常" not in judge_context + else get_cur_base_user_prompt(message_history=message_history) + ) msg = cl.Message(content="") await msg.send() temperature = cl.user_session.get("temperature") - top_p = cl.user_session.get('top_p') - - if len(prompt_content)/4<120000: - stream = codegeex4(prompt_content,temperature=temperature,top_p = top_p) + top_p = cl.user_session.get("top_p") + + if len(prompt_content) / 4 < 120000: + stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p) for part in stream: if token := (part or " "): @@ -154,4 +162,4 @@ async def main(message: cl.Message): await msg.stream_token("项目太大了,请换小一点的项目。") message_history.append({"role": "assistant", "content": msg.content}) - await msg.update() \ No newline at end of file + await msg.update() diff --git a/repodemo/run_local.py b/repodemo/run_local.py index f82a190..c4c4ad9 100644 --- a/repodemo/run_local.py +++ b/repodemo/run_local.py @@ -1,66 +1,65 @@ import chainlit as cl from chainlit.input_widget import Slider from llm.api.codegeex4 import codegeex4 -from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt -from utils.tools import unzip_file,get_project_files_with_content +from prompts.base_prompt import ( + judge_task_prompt, + get_cur_base_user_prompt, + web_judge_task_prompt, +) +from utils.tools import unzip_file, get_project_files_with_content from utils.bingsearch import bing_search_prompt from llm.local.codegeex4 import CodegeexChatModel -local_model_path = '' + +local_model_path = "" llm = CodegeexChatModel(local_model_path) + class StreamProcessor: def __init__(self): self.previous_str = "" def get_new_part(self, new_str): - new_part = new_str[len(self.previous_str):] + new_part = new_str[len(self.previous_str) :] self.previous_str = new_str return new_part + @cl.set_chat_profiles async def chat_profile(): return [ cl.ChatProfile( name="chat聊天", markdown_description="聊天demo:支持多轮对话。", - starters = [ + starters=[ cl.Starter( - label="请你用python写一个快速排序。", - message="请你用python写一个快速排序。", - + label="请你用python写一个快速排序。", + message="请你用python写一个快速排序。", ), - - cl.Starter( - label="请你介绍一下自己。", - message="请你介绍一下自己。", - + cl.Starter( + label="请你介绍一下自己。", + message="请你介绍一下自己。", ), - cl.Starter( - label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", - message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", - + cl.Starter( + label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", + message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。", ), - cl.Starter( - label="我是一个python初学者,请你告诉我怎么才能学好python。", - message="我是一个python初学者,请你告诉我怎么才能学好python。", - - ) - ] - + cl.Starter( + label="我是一个python初学者,请你告诉我怎么才能学好python。", + message="我是一个python初学者,请你告诉我怎么才能学好python。", + ), + ], ), cl.ChatProfile( name="联网问答", - markdown_description="联网能力dome:支持联网回答用户问题。", - + markdown_description="联网能力demo:支持联网回答用户问题。", ), cl.ChatProfile( name="上传本地项目", - markdown_description="项目级能力dome:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。", - - ) + markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。", + ), ] - + @cl.on_chat_start async def start(): settings = await cl.ChatSettings( @@ -85,33 +84,31 @@ async def start(): ).send() temperature = settings["temperature"] top_p = settings["top_p"] - cl.user_session.set('temperature',temperature) - cl.user_session.set('top_p',top_p) - cl.user_session.set( - "message_history", - [] - ) + cl.user_session.set("temperature", temperature) + cl.user_session.set("top_p", top_p) + cl.user_session.set("message_history", []) chat_profile = cl.user_session.get("chat_profile") - extract_dir = 'repodata' + extract_dir = "repodata" if chat_profile == "chat聊天": pass - elif chat_profile =="上传本地项目": + elif chat_profile == "上传本地项目": files = None while files == None: files = await cl.AskFileMessage( - content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50 + content="请上传项目zip压缩文件!", + accept={"application/zip": [".zip"]}, + max_size_mb=50, ).send() text_file = files[0] - extracted_path = unzip_file(text_file.path,extract_dir) + extracted_path = unzip_file(text_file.path, extract_dir) files_list = get_project_files_with_content(extracted_path) - cl.user_session.set("project_index",files_list) - if len(files_list)>0: + cl.user_session.set("project_index", files_list) + if len(files_list) > 0: await cl.Message( content=f"已成功上传,您可以开始对项目进行提问!", ).send() - - + @cl.on_message async def main(message: cl.Message): @@ -120,50 +117,59 @@ async def main(message: cl.Message): message_history.append({"role": "user", "content": message.content}) if chat_profile == "chat聊天": prompt_content = get_cur_base_user_prompt(message_history=message_history) - - elif chat_profile=="联网问答": - judge_context = llm.chat(web_judge_task_prompt.format(user_input=message.content),temperature=0.2) + + elif chat_profile == "联网问答": + judge_context = llm.chat( + web_judge_task_prompt.format(user_input=message.content), temperature=0.2 + ) print(judge_context) message_history.pop() - if '是' in judge_context: + if "是" in judge_context: prompt_tmp = bing_search_prompt(message.content) message_history.append({"role": "user", "content": prompt_tmp}) else: message_history.append({"role": "user", "content": message.content}) prompt_content = get_cur_base_user_prompt(message_history=message_history) - elif chat_profile =="上传本地项目" : - judge_context = llm.chat(judge_task_prompt.format(user_input=message.content),temperature=0.2) - - + elif chat_profile == "上传本地项目": + judge_context = llm.chat( + judge_task_prompt.format(user_input=message.content), temperature=0.2 + ) + project_index = cl.user_session.get("project_index") index_prompt = "" index_tmp = """###PATH:{path}\n{code}\n""" for index in project_index: - index_prompt+=index_tmp.format(path=index['path'],code=index['content']) + index_prompt += index_tmp.format(path=index["path"], code=index["content"]) print(judge_context) - prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history) - - + prompt_content = ( + get_cur_base_user_prompt( + message_history=message_history, + index_prompt=index_prompt, + judge_context=judge_context, + ) + if "正常" not in judge_context + else get_cur_base_user_prompt(message_history=message_history) + ) msg = cl.Message(content="") await msg.send() temperature = cl.user_session.get("temperature") - top_p = cl.user_session.get('top_p') - - if len(prompt_content)/4<120000: - stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p) + top_p = cl.user_session.get("top_p") + + if len(prompt_content) / 4 < 120000: + stream = llm.stream_chat(prompt_content, temperature=temperature, top_p=top_p) stream_processor = StreamProcessor() for part in stream: if isinstance(part, str): text = stream_processor.get_new_part(part) elif isinstance(part, dict): - text = stream_processor.get_new_part(part['name']+part['content']) + text = stream_processor.get_new_part(part["name"] + part["content"]) if token := (text or " "): await msg.stream_token(token) else: await msg.stream_token("项目太大了,请换小一点的项目。") message_history.append({"role": "assistant", "content": msg.content}) - await msg.update() \ No newline at end of file + await msg.update() diff --git a/repodemo/utils/bingsearch.py b/repodemo/utils/bingsearch.py index 999b13a..6c22c89 100644 --- a/repodemo/utils/bingsearch.py +++ b/repodemo/utils/bingsearch.py @@ -2,7 +2,9 @@ import requests from bs4 import BeautifulSoup as BS4 import requests -BING_API_KEY = '' +BING_API_KEY = "" + + def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]: """ Search with bing and return the contexts. @@ -13,9 +15,9 @@ def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]: headers={"Ocp-Apim-Subscription-Key": BING_API_KEY}, params={ "q": query, - "responseFilter": ['webpages'], - "freshness": 'month', - "mkt": 'zh-CN' + "responseFilter": ["webpages"], + "freshness": "month", + "mkt": "zh-CN", }, timeout=search_timeout, ) @@ -23,25 +25,29 @@ def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]: json_content = response.json() # print(json_content) contexts = json_content["webPages"]["value"][:top_k] - #logger.info("Web搜索完成") + # logger.info("Web搜索完成") return contexts except Exception as e: - #logger.error(f"搜索失败,错误原因: {e}") + # logger.error(f"搜索失败,错误原因: {e}") print(f"搜索失败,错误原因: {e}") return [] + def fetch_url(url): response = requests.get(url) - #use beautifulsoup4 to parse html - soup = BS4(response.text, 'html.parser') + # use beautifulsoup4 to parse html + soup = BS4(response.text, "html.parser") plain_text = soup.get_text() return plain_text + def bing_search_prompt(input): contents = search_with_bing(input, search_timeout=5, top_k=6) citations = "\n\n".join( - [f"[[citation:{i + 1}]]\n```markdown\n{item['snippet']}\n```" for i, item in enumerate(contents)] + [ + f"[[citation:{i + 1}]]\n```markdown\n{item['snippet']}\n```" + for i, item in enumerate(contents) + ] ) prompt = f"[引用]\n{citations}\n问:{input}\n" return prompt - diff --git a/repodemo/utils/tools.py b/repodemo/utils/tools.py index 695313f..8f85dbf 100644 --- a/repodemo/utils/tools.py +++ b/repodemo/utils/tools.py @@ -2,14 +2,15 @@ import zipfile import os import json + def unzip_file(zip_path, extract_dir): """ 解压zip文件到指定目录,并在指定目录下创建一个新的目录存放解压后的文件 - + 参数: zip_path (str): zip压缩包的地址 extract_dir (str): 指定解压的目录 - + 返回: str: 解压后的路径 """ @@ -19,11 +20,11 @@ def unzip_file(zip_path, extract_dir): base_name = os.path.basename(zip_path) dir_name = os.path.splitext(base_name)[0] new_extract_dir = os.path.join(extract_dir, dir_name) - + if not os.path.exists(new_extract_dir): os.makedirs(new_extract_dir) - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(new_extract_dir) return new_extract_dir @@ -32,15 +33,15 @@ def unzip_file(zip_path, extract_dir): def get_project_files_with_content(project_dir): """ 获取项目目录下所有文件的相对路径和内容 - + 参数: project_dir (str): 项目目录地址 - + 返回: list: 包含字典的列表,每个字典包含文件的相对路径和内容 """ files_list = [] - + for root, dirs, files in os.walk(project_dir): for file in files: if filter_data(file): @@ -48,83 +49,85 @@ def get_project_files_with_content(project_dir): relative_path = os.path.relpath(file_path, project_dir) if "__MACOSX" in relative_path: continue - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: content = f.read() - files_list.append({'path': relative_path, 'content': content}) + files_list.append({"path": relative_path, "content": content}) else: continue - + return files_list + def filter_data(obj): LANGUAGE_TAG = { - "c++" : "// C++", - "cpp" : "// C++", - "c" : "// C", - "c#" : "// C#", - "c-sharp" : "// C#", - "css" : "/* CSS */", - "cuda" : "// Cuda", - "fortran" : "! Fortran", - "go" : "// Go", - "html" : "", - "java" : "// Java", - "js" : "// JavaScript", - "javascript" : "// JavaScript", - "kotlin" : "// Kotlin", - "lean" : "-- Lean", - "lua" : "-- Lua", - "objectivec" : "// Objective-C", - "objective-c" : "// Objective-C", - "objective-c++": "// Objective-C++", - "pascal" : "// Pascal", - "php" : "// PHP", - "python" : "# Python", - "r" : "# R", - "rust" : "// Rust", - "ruby" : "# Ruby", - "scala" : "// Scala", - "shell" : "# Shell", - "sql" : "-- SQL", - "tex" : f"% TeX", - "typescript" : "// TypeScript", - "vue" : "", - - "assembly" : "; Assembly", - "dart" : "// Dart", - "perl" : "# Perl", - "prolog" : f"% Prolog", - "swift" : "// swift", - "lisp" : "; Lisp", - "vb" : "' Visual Basic", - "visual basic" : "' Visual Basic", - "matlab" : f"% Matlab", - "delphi" : "{ Delphi }", - "scheme" : "; Scheme", - "basic" : "' Basic", - "assembly" : "; Assembly", - "groovy" : "// Groovy", - "abap" : "* Abap", - "gdscript" : "# GDScript", - "haskell" : "-- Haskell", - "julia" : "# Julia", - "elixir" : "# Elixir", - "excel" : "' Excel", - "clojure" : "; Clojure", - "actionscript" : "// ActionScript", - "solidity" : "// Solidity", - "powershell" : "# PowerShell", - "erlang" : f"% Erlang", - "cobol" : "// Cobol", - "batchfile" : ":: Batch file", - "makefile" : "# Makefile", - "dockerfile" : "# Dockerfile", - "markdown" : "", - "cmake" : "# CMake", - "dockerfile" : "# Dockerfile", + "c++": "// C++", + "cpp": "// C++", + "c": "// C", + "c#": "// C#", + "c-sharp": "// C#", + "css": "/* CSS */", + "cuda": "// Cuda", + "fortran": "! Fortran", + "go": "// Go", + "html": "", + "java": "// Java", + "js": "// JavaScript", + "javascript": "// JavaScript", + "kotlin": "// Kotlin", + "lean": "-- Lean", + "lua": "-- Lua", + "objectivec": "// Objective-C", + "objective-c": "// Objective-C", + "objective-c++": "// Objective-C++", + "pascal": "// Pascal", + "php": "// PHP", + "python": "# Python", + "r": "# R", + "rust": "// Rust", + "ruby": "# Ruby", + "scala": "// Scala", + "shell": "# Shell", + "sql": "-- SQL", + "tex": f"% TeX", + "typescript": "// TypeScript", + "vue": "", + "assembly": "; Assembly", + "dart": "// Dart", + "perl": "# Perl", + "prolog": f"% Prolog", + "swift": "// swift", + "lisp": "; Lisp", + "vb": "' Visual Basic", + "visual basic": "' Visual Basic", + "matlab": f"% Matlab", + "delphi": "{ Delphi }", + "scheme": "; Scheme", + "basic": "' Basic", + "assembly": "; Assembly", + "groovy": "// Groovy", + "abap": "* Abap", + "gdscript": "# GDScript", + "haskell": "-- Haskell", + "julia": "# Julia", + "elixir": "# Elixir", + "excel": "' Excel", + "clojure": "; Clojure", + "actionscript": "// ActionScript", + "solidity": "// Solidity", + "powershell": "# PowerShell", + "erlang": f"% Erlang", + "cobol": "// Cobol", + "batchfile": ":: Batch file", + "makefile": "# Makefile", + "dockerfile": "# Dockerfile", + "markdown": "", + "cmake": "# CMake", + "dockerfile": "# Dockerfile", } - programming_languages_to_file_extensions = json.load(open('utils/programming-languages-to-file-extensions.json')) + programming_languages_to_file_extensions = json.load( + open("utils/programming-languages-to-file-extensions.json") + ) need2del = [] for key in programming_languages_to_file_extensions.keys(): if key.lower() not in LANGUAGE_TAG: @@ -140,15 +143,14 @@ def filter_data(obj): ext_to_programming_languages[item] = key want_languages.append(item) - ext = '.'+obj.split('.')[-1] - with open('utils/keep.txt', 'r') as f: + ext = "." + obj.split(".")[-1] + with open("utils/keep.txt", "r") as f: keep_files = f.readlines() keep_files = [l.strip() for l in keep_files] - #print(ext) + # print(ext) if ext not in want_languages: if obj in keep_files: return True return False else: return True -