diff --git a/selected_homework/openai-translator_v4/README.md b/selected_homework/openai-translator_v4/README.md new file mode 100644 index 00000000..0734be0f --- /dev/null +++ b/selected_homework/openai-translator_v4/README.md @@ -0,0 +1,19 @@ +### 作业需求 +基于 ChatGLM2-6B 实现带图形化界面的 openai-translator + +### 作业总结 ++ [openai_api_demo](openai_api_demo) + +利用的chatGLM中的api_demo进行调整为server项,故没有采用ChatGLM2-6b,而是ChatGLM3-6b + +运行起来需要: +1. git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git +2. 确保机器有足够的资源【俺没有。。。故暂未实现。。。】 + ++ [ai_translator](ai_translator) + +将历史的 +from langchain_openai import ChatOpenAI +替换为 +from langchain.llms import ChatGLM +并针对ChatGLM做参数匹配 diff --git a/selected_homework/openai-translator_v4/ai_translator/__init__.py b/selected_homework/openai-translator_v4/ai_translator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selected_homework/openai-translator_v4/ai_translator/book/__init__.py b/selected_homework/openai-translator_v4/ai_translator/book/__init__.py new file mode 100644 index 00000000..5b688799 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/book/__init__.py @@ -0,0 +1,3 @@ +from .book import Book +from .page import Page +from .content import ContentType, Content, TableContent \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/book/book.py b/selected_homework/openai-translator_v4/ai_translator/book/book.py new file mode 100644 index 00000000..b079357b --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/book/book.py @@ -0,0 +1,9 @@ +from .page import Page + +class Book: + def __init__(self, pdf_file_path): + self.pdf_file_path = pdf_file_path + self.pages = [] + + def add_page(self, page: Page): + self.pages.append(page) \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/book/content.py b/selected_homework/openai-translator_v4/ai_translator/book/content.py new file mode 100644 index 00000000..901c2a07 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/book/content.py @@ -0,0 +1,85 @@ +import pandas as pd + +from enum import Enum, auto +from PIL import Image as PILImage +from utils import LOG +from io import StringIO + +class ContentType(Enum): + TEXT = auto() + TABLE = auto() + IMAGE = auto() + +class Content: + def __init__(self, content_type, original, translation=None): + self.content_type = content_type + self.original = original + self.translation = translation + self.status = False + + def set_translation(self, translation, status): + if not self.check_translation_type(translation): + raise ValueError(f"Invalid translation type. Expected {self.content_type}, but got {type(translation)}") + self.translation = translation + self.status = status + + def check_translation_type(self, translation): + if self.content_type == ContentType.TEXT and isinstance(translation, str): + return True + elif self.content_type == ContentType.TABLE and isinstance(translation, list): + return True + elif self.content_type == ContentType.IMAGE and isinstance(translation, PILImage.Image): + return True + return False + + def __str__(self): + return self.original + + +class TableContent(Content): + def __init__(self, data, translation=None): + df = pd.DataFrame(data) + + # Verify if the number of rows and columns in the data and DataFrame object match + if len(data) != len(df) or len(data[0]) != len(df.columns): + raise ValueError("The number of rows and columns in the extracted table data and DataFrame object do not match.") + + super().__init__(ContentType.TABLE, df) + + def set_translation(self, translation, status): + try: + if not isinstance(translation, str): + raise ValueError(f"Invalid translation type. Expected str, but got {type(translation)}") + + LOG.debug(f"[translation]\n{translation}") + # Extract column names from the first set of brackets + header = translation.split(']')[0][1:].split(', ') + # Extract data rows from the remaining brackets + data_rows = translation.split('] ')[1:] + # Replace Chinese punctuation and split each row into a list of values + data_rows = [row[1:-1].split(', ') for row in data_rows] + # Create a DataFrame using the extracted header and data + translated_df = pd.DataFrame(data_rows, columns=header) + LOG.debug(f"[translated_df]\n{translated_df}") + self.translation = translated_df + self.status = status + except Exception as e: + LOG.error(f"An error occurred during table translation: {e}") + self.translation = None + self.status = False + + def __str__(self): + return self.original.to_string(header=False, index=False) + + def iter_items(self, translated=False): + target_df = self.translation if translated else self.original + for row_idx, row in target_df.iterrows(): + for col_idx, item in enumerate(row): + yield (row_idx, col_idx, item) + + def update_item(self, row_idx, col_idx, new_value, translated=False): + target_df = self.translation if translated else self.original + target_df.at[row_idx, col_idx] = new_value + + def get_original_as_str(self): + return self.original.to_string(header=False, index=False) \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/book/page.py b/selected_homework/openai-translator_v4/ai_translator/book/page.py new file mode 100644 index 00000000..df12e772 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/book/page.py @@ -0,0 +1,8 @@ +from .content import Content + +class Page: + def __init__(self): + self.contents = [] + + def add_content(self, content: Content): + self.contents.append(content) diff --git a/selected_homework/openai-translator_v4/ai_translator/config.yaml b/selected_homework/openai-translator_v4/ai_translator/config.yaml new file mode 100644 index 00000000..ba1e173c --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/config.yaml @@ -0,0 +1,5 @@ +model_name: "chatglm2-6b" +input_file: "tests/test.pdf" +output_file_format: "markdown" +source_language: "English" +target_language: "Chinese" \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/flask_server.py b/selected_homework/openai-translator_v4/ai_translator/flask_server.py new file mode 100644 index 00000000..7b5bed03 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/flask_server.py @@ -0,0 +1,71 @@ +import sys +import os + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from flask import Flask, request, send_file, jsonify +from translator import PDFTranslator, TranslationConfig +from utils import ArgumentParser, LOG + +app = Flask(__name__) + +TEMP_FILE_DIR = "flask_temps/" + +@app.route('/translation', methods=['POST']) +def translation(): + try: + input_file = request.files['input_file'] + source_language = request.form.get('source_language', 'English') + target_language = request.form.get('target_language', 'Chinese') + + LOG.debug(f"[input_file]\n{input_file}") + LOG.debug(f"[input_file.filename]\n{input_file.filename}") + + if input_file and input_file.filename: + # # 创建临时文件 + input_file_path = TEMP_FILE_DIR+input_file.filename + LOG.debug(f"[input_file_path]\n{input_file_path}") + + input_file.save(input_file_path) + + # 调用翻译函数 + output_file_path = Translator.translate_pdf( + input_file=input_file_path, + source_language=source_language, + target_language=target_language) + + # 移除临时文件 + # os.remove(input_file_path) + + # 构造完整的文件路径 + output_file_path = os.getcwd() + "/" + output_file_path + LOG.debug(output_file_path) + + # 返回翻译后的文件 + return send_file(output_file_path, as_attachment=True) + except Exception as e: + response = { + 'status': 'error', + 'message': str(e) + } + return jsonify(response), 400 + + +def initialize_translator(): + # 解析命令行 + argument_parser = ArgumentParser() + args = argument_parser.parse_arguments() + + # 初始化配置单例 + config = TranslationConfig() + config.initialize(args) + # 实例化 PDFTranslator 类,并调用 translate_pdf() 方法 + global Translator + Translator = PDFTranslator(config.model_name) + + +if __name__ == "__main__": + # 初始化 translator + initialize_translator() + # 启动 Flask Web Server + app.run(host="0.0.0.0", port=5000, debug=True) \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/gradio_server.py b/selected_homework/openai-translator_v4/ai_translator/gradio_server.py new file mode 100644 index 00000000..2d250dcd --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/gradio_server.py @@ -0,0 +1,59 @@ +import sys +import os +import gradio as gr + + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from utils import ArgumentParser, LOG +from translator import PDFTranslator, TranslationConfig + + +def translation(input_file, source_language, target_language, translate_style): + LOG.debug( + f"[翻译任务]\n源文件: {input_file.name}\n源语言: {source_language}\n目标语言: {target_language}\n翻译风格: {translate_style}") + + output_file_path = Translator.translate_pdf( + input_file.name, source_language=source_language, + target_language=target_language, translate_style=translate_style + ) + + return output_file_path + +def launch_gradio(): + + iface = gr.Interface( + fn=translation, + title="[Homework]OpenAI-Translator v4(接入GLM3-6b)", + inputs=[ + gr.File(label="上传PDF文件"), + gr.Textbox(label="源语言(默认:英文)", placeholder="English", value="English"), + gr.Textbox(label="目标语言(默认:中文)", placeholder="Chinese", value="Chinese"), + gr.Radio(["Normal people", "Children", "Professor"]), + ], + outputs=[ + gr.File(label="下载翻译文件") + ], + allow_flagging="never" + ) + + iface.launch(share=True, server_name="0.0.0.0") + +def initialize_translator(): + # 解析命令行 + argument_parser = ArgumentParser() + args = argument_parser.parse_arguments() + + # 初始化配置单例 + config = TranslationConfig() + config.initialize(args) + # 实例化 PDFTranslator 类,并调用 translate_pdf() 方法 + global Translator + Translator = PDFTranslator(config.model_name) + + +if __name__ == "__main__": + # 初始化 translator + initialize_translator() + # 启动 Gradio 服务 + launch_gradio() diff --git a/selected_homework/openai-translator_v4/ai_translator/tests/The_Old_Man_of_the_Sea.pdf b/selected_homework/openai-translator_v4/ai_translator/tests/The_Old_Man_of_the_Sea.pdf new file mode 100644 index 00000000..2ad65496 Binary files /dev/null and b/selected_homework/openai-translator_v4/ai_translator/tests/The_Old_Man_of_the_Sea.pdf differ diff --git a/selected_homework/openai-translator_v4/ai_translator/tests/test.pdf b/selected_homework/openai-translator_v4/ai_translator/tests/test.pdf new file mode 100644 index 00000000..dc3e9828 Binary files /dev/null and b/selected_homework/openai-translator_v4/ai_translator/tests/test.pdf differ diff --git a/selected_homework/openai-translator_v4/ai_translator/tests/test_translated.md b/selected_homework/openai-translator_v4/ai_translator/tests/test_translated.md new file mode 100644 index 00000000..6c230b26 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/tests/test_translated.md @@ -0,0 +1,19 @@ +测试数据 +这个数据集包含了由OpenAI的AI语言模型ChatGPT提供的两个测试样本。 +这些样本包括一个Markdown表格和一个英文文本段落,可以用来测试支持文本和表格格式的英译中翻译软件。 +文本测试 +快速的棕色狐狸跳过懒狗。这个句子包含了英语字母表中的每个字母至少一次。句子是经常用来测试字体、键盘和其他与文本相关的工具的。除了英语,其他许多语言也有句子。由于语言的独特特点,有些句子更难构造。 + +| 水果 | 颜色 | 价格(美元) | +| --- | --- | --- | +| 苹果 | 红色 | 1.2 | +| 香蕉 | 黄色 | 0.5 | +| 橙子 | 橙色 | 0.8 | +| 草莓 | 红色 | 2.5 | +| 蓝莓 | 蓝色 | 3.0 | +| 猕猴桃 | 绿色 | 1.0 | +| 芒果 | 橙色 | 1.5 | +| 葡萄 | 紫色 | 2.00 | + +--- + diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/__init__.py b/selected_homework/openai-translator_v4/ai_translator/translator/__init__.py new file mode 100644 index 00000000..0e3fdcca --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/__init__.py @@ -0,0 +1,2 @@ +from .pdf_translator import PDFTranslator +from .translation_config import TranslationConfig \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/exceptions.py b/selected_homework/openai-translator_v4/ai_translator/translator/exceptions.py new file mode 100644 index 00000000..4f4c23c1 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/exceptions.py @@ -0,0 +1,5 @@ +class PageOutOfRangeException(Exception): + def __init__(self, book_pages, requested_pages): + self.book_pages = book_pages + self.requested_pages = requested_pages + super().__init__(f"Page out of range: Book has {book_pages} pages, but {requested_pages} pages were requested.") diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/pdf_parser.py b/selected_homework/openai-translator_v4/ai_translator/translator/pdf_parser.py new file mode 100644 index 00000000..6f2f9bc3 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/pdf_parser.py @@ -0,0 +1,58 @@ +import pdfplumber +from typing import Optional +from book import Book, Page, Content, ContentType, TableContent +from translator.exceptions import PageOutOfRangeException +from utils import LOG + + +class PDFParser: + def __init__(self): + pass + + def parse_pdf(self, pdf_file_path: str, pages: Optional[int] = None) -> Book: + book = Book(pdf_file_path) + + with pdfplumber.open(pdf_file_path) as pdf: + if pages is not None and pages > len(pdf.pages): + raise PageOutOfRangeException(len(pdf.pages), pages) + + if pages is None: + pages_to_parse = pdf.pages + else: + pages_to_parse = pdf.pages[:pages] + + for pdf_page in pages_to_parse: + page = Page() + + # Store the original text content + raw_text = pdf_page.extract_text() + tables = pdf_page.extract_tables() + + # Remove each cell's content from the original text + for table_data in tables: + for row in table_data: + for cell in row: + raw_text = raw_text.replace(cell, "", 1) + + # Handling text + if raw_text: + # Remove empty lines and leading/trailing whitespaces + raw_text_lines = raw_text.splitlines() + cleaned_raw_text_lines = [line.strip() for line in raw_text_lines if line.strip()] + cleaned_raw_text = "\n".join(cleaned_raw_text_lines) + + text_content = Content(content_type=ContentType.TEXT, original=cleaned_raw_text) + page.add_content(text_content) + LOG.debug(f"[raw_text]\n {cleaned_raw_text}") + + + + # Handling tables + if tables: + table = TableContent(tables) + page.add_content(table) + LOG.debug(f"[table]\n{table}") + + book.add_page(page) + + return book diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/pdf_translator.py b/selected_homework/openai-translator_v4/ai_translator/translator/pdf_translator.py new file mode 100644 index 00000000..a3eadb75 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/pdf_translator.py @@ -0,0 +1,30 @@ +from typing import Optional +from translator.pdf_parser import PDFParser +from translator.writer import Writer +from translator.translation_chain import TranslationChain + +class PDFTranslator: + def __init__(self, model_name: str): + self.translate_chain = TranslationChain(model_name) + self.pdf_parser = PDFParser() + self.writer = Writer() + + def translate_pdf(self, + input_file: str, + output_file_format: str = 'markdown', + source_language: str = "English", + target_language: str = 'Chinese', + translate_style: str = "Normal Style", + pages: Optional[int] = None): + + self.book = self.pdf_parser.parse_pdf(input_file, pages) + + for page_idx, page in enumerate(self.book.pages): + for content_idx, content in enumerate(page.contents): + # Translate content.original + translation, status = self.translate_chain.run( + content, source_language, target_language, translate_style) + # Update the content in self.book.pages directly + self.book.pages[page_idx].contents[content_idx].set_translation(translation, status) + + return self.writer.save_translated_book(self.book, output_file_format) diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/translation_chain.py b/selected_homework/openai-translator_v4/ai_translator/translator/translation_chain.py new file mode 100644 index 00000000..18e53735 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/translation_chain.py @@ -0,0 +1,54 @@ +from langchain.llms import ChatGLM +from langchain.chains import LLMChain + +from langchain.prompts.chat import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + +from utils import LOG + +class TranslationChain: + def __init__(self, model_name: str = "chatglm2-6b", verbose: bool = True): + + # 翻译任务指令始终由 System 角色承担 + template = ( + """You are a translation expert, proficient in various languages. \n + Translates {source_language} to {target_language}. \n + Speak like {translate_style}.""" + ) + system_message_prompt = SystemMessagePromptTemplate.from_template(template) + + # 待翻译文本由 Human 角色输入 + human_template = "{text}" + human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) + + # 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate + chat_prompt_template = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) + + chat = ChatGLM( + endpoint_url="http://127.0.0.1:8000", + max_token=8000 + ) + + self.chain = LLMChain(llm=chat, prompt=chat_prompt_template, verbose=verbose) + + def run(self, text: str, + source_language: str, + target_language: str, + translate_style: str) -> (str, bool): + result = "" + try: + result = self.chain.run({ + "text": text, + "source_language": source_language, + "target_language": target_language, + "translate_style": translate_style + }) + except Exception as e: + LOG.error(f"An error occurred during translation: {e}") + return result, False + return result, True diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/translation_config.py b/selected_homework/openai-translator_v4/ai_translator/translator/translation_config.py new file mode 100644 index 00000000..783823ae --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/translation_config.py @@ -0,0 +1,29 @@ +import yaml + +class TranslationConfig: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(TranslationConfig, cls).__new__(cls) + cls._instance._config = None + return cls._instance + + def initialize(self, args): + with open(args.config_file, "r") as f: + config = yaml.safe_load(f) + + # Use the argparse Namespace to update the configuration + overridden_values = { + key: value for key, value in vars(args).items() if key in config and value is not None + } + config.update(overridden_values) + + # Store the original config dictionary + self._instance._config = config + + def __getattr__(self, name): + # Try to get attribute from _config + if self._instance._config and name in self._instance._config: + return self._instance._config[name] + raise AttributeError(f"'TranslationConfig' object has no attribute '{name}'") \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/translator/writer.py b/selected_homework/openai-translator_v4/ai_translator/translator/writer.py new file mode 100644 index 00000000..90b51ed5 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/translator/writer.py @@ -0,0 +1,114 @@ +import os +from reportlab.lib import colors, pagesizes, units +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.ttfonts import TTFont +from reportlab.platypus import ( + SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak +) + +from book import Book, ContentType +from utils import LOG + +class Writer: + def __init__(self): + pass + + def save_translated_book(self, book: Book, ouput_file_format: str): + LOG.debug(ouput_file_format) + + if ouput_file_format.lower() == "pdf": + output_file_path = self._save_translated_book_pdf(book) + elif ouput_file_format.lower() == "markdown": + output_file_path = self._save_translated_book_markdown(book) + else: + LOG.error(f"不支持文件类型: {ouput_file_format}") + return "" + + LOG.info(f"翻译完成,文件保存至: {output_file_path}") + + return output_file_path + + + def _save_translated_book_pdf(self, book: Book, output_file_path: str = None): + + output_file_path = book.pdf_file_path.replace('.pdf', f'_translated.pdf') + + LOG.info(f"开始导出: {output_file_path}") + + # Register Chinese font + font_path = "../fonts/simsun.ttc" # 请将此路径替换为您的字体文件路径 + pdfmetrics.registerFont(TTFont("SimSun", font_path)) + + # Create a new ParagraphStyle with the SimSun font + simsun_style = ParagraphStyle('SimSun', fontName='SimSun', fontSize=12, leading=14) + + # Create a PDF document + doc = SimpleDocTemplate(output_file_path, pagesize=pagesizes.letter) + styles = getSampleStyleSheet() + story = [] + + # Iterate over the pages and contents + for page in book.pages: + for content in page.contents: + if content.status: + if content.content_type == ContentType.TEXT: + # Add translated text to the PDF + text = content.translation + para = Paragraph(text, simsun_style) + story.append(para) + + elif content.content_type == ContentType.TABLE: + # Add table to the PDF + table = content.translation + table_style = TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.grey), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, -1), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), 'SimSun'), # 更改表头字体为 "SimSun" + ('FONTSIZE', (0, 0), (-1, 0), 14), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + ('BACKGROUND', (0, 1), (-1, -1), colors.beige), + ('FONTNAME', (0, 1), (-1, -1), 'SimSun'), # 更改表格中的字体为 "SimSun" + ('GRID', (0, 0), (-1, -1), 1, colors.black) + ]) + pdf_table = Table(table.values.tolist()) + pdf_table.setStyle(table_style) + story.append(pdf_table) + # Add a page break after each page except the last one + if page != book.pages[-1]: + story.append(PageBreak()) + + # Save the translated book as a new PDF file + doc.build(story) + return output_file_path + + + def _save_translated_book_markdown(self, book: Book, output_file_path: str = None): + output_file_path = book.pdf_file_path.replace('.pdf', f'_translated.md') + + LOG.info(f"开始导出: {output_file_path}") + with open(output_file_path, 'w', encoding='utf-8') as output_file: + # Iterate over the pages and contents + for page in book.pages: + for content in page.contents: + if content.status: + if content.content_type == ContentType.TEXT: + # Add translated text to the Markdown file + text = content.translation + output_file.write(text + '\n\n') + + elif content.content_type == ContentType.TABLE: + # Add table to the Markdown file + table = content.translation + header = '| ' + ' | '.join(str(column) for column in table.columns) + ' |' + '\n' + separator = '| ' + ' | '.join(['---'] * len(table.columns)) + ' |' + '\n' + # body = '\n'.join(['| ' + ' | '.join(row) + ' |' for row in table.values.tolist()]) + '\n\n' + body = '\n'.join(['| ' + ' | '.join(str(cell) for cell in row) + ' |' for row in table.values.tolist()]) + '\n\n' + output_file.write(header + separator + body) + + # Add a page break (horizontal rule) after each page except the last one + if page != book.pages[-1]: + output_file.write('---\n\n') + + return output_file_path \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/utils/__init__.py b/selected_homework/openai-translator_v4/ai_translator/utils/__init__.py new file mode 100644 index 00000000..09b16931 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/utils/__init__.py @@ -0,0 +1,2 @@ +from .argument_parser import ArgumentParser +from .logger import LOG \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/ai_translator/utils/argument_parser.py b/selected_homework/openai-translator_v4/ai_translator/utils/argument_parser.py new file mode 100644 index 00000000..57684d86 --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/utils/argument_parser.py @@ -0,0 +1,15 @@ +import argparse + +class ArgumentParser: + def __init__(self): + self.parser = argparse.ArgumentParser(description='A translation tool that supports translations in any language pair.') + self.parser.add_argument('--config_file', type=str, default='config.yaml', help='Configuration file with model and API settings.') + self.parser.add_argument('--model_name', type=str, help='Name of the Large Language Model.') + self.parser.add_argument('--input_file', type=str, help='PDF file to translate.') + self.parser.add_argument('--output_file_format', type=str, help='The file format of translated book. Now supporting PDF and Markdown') + self.parser.add_argument('--source_language', type=str, help='The language of the original book to be translated.') + self.parser.add_argument('--target_language', type=str, help='The target language for translating the original book.') + + def parse_arguments(self): + args = self.parser.parse_args() + return args diff --git a/selected_homework/openai-translator_v4/ai_translator/utils/logger.py b/selected_homework/openai-translator_v4/ai_translator/utils/logger.py new file mode 100644 index 00000000..a252b50e --- /dev/null +++ b/selected_homework/openai-translator_v4/ai_translator/utils/logger.py @@ -0,0 +1,32 @@ +from loguru import logger +import os +import sys + +LOG_FILE = "translation.log" +ROTATION_TIME = "02:00" + +class Logger: + def __init__(self, name="translation", log_dir="logs", debug=False): + if not os.path.exists(log_dir): + os.makedirs(log_dir) + log_file_path = os.path.join(log_dir, LOG_FILE) + + # Remove default loguru handler + logger.remove() + + # Add console handler with a specific log level + level = "DEBUG" if debug else "INFO" + logger.add(sys.stdout, level=level) + # Add file handler with a specific log level and timed rotation + logger.add(log_file_path, rotation=ROTATION_TIME, level="DEBUG") + self.logger = logger + +LOG = Logger(debug=True).logger + +if __name__ == "__main__": + log = Logger().logger + + log.debug("This is a debug message.") + log.info("This is an info message.") + log.warning("This is a warning message.") + log.error("This is an error message.") diff --git a/selected_homework/openai-translator_v4/openai_api_demo/api_server.py b/selected_homework/openai-translator_v4/openai_api_demo/api_server.py new file mode 100644 index 00000000..498ae05c --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/api_server.py @@ -0,0 +1,549 @@ +""" +This script implements an API for the ChatGLM3-6B model, +formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat). +It's designed to be run as a web server using FastAPI and uvicorn, +making the ChatGLM3-6B model accessible through OpenAI Client. + +Key Components and Features: +- Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them. +- FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests. +- API Endpoints: + - "/v1/models": Lists the available models, specifically ChatGLM3-6B. + - "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses. + - "/v1/embeddings": Processes Embedding request of a list of text inputs. +- Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'. +For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output +that many tokens after accounting for the history and prompt tokens. +- Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses. +- Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety. +- Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port. + +Note: + This script doesn't include the setup for special tokens or multi-GPU support by default. + Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions. + Embedding Models only support in One GPU. + + Running this script requires 14-15GB of GPU memory. 2 GB for the embedding model and 12-13 GB for the FP16 ChatGLM3 LLM. + + +""" + +import os +import time +import tiktoken +import torch +import uvicorn +import json +from fastapi import FastAPI, HTTPException, Response +from fastapi.middleware.cors import CORSMiddleware + +from contextlib import asynccontextmanager +from typing import List, Literal, Optional, Union +from loguru import logger +from pydantic import BaseModel, Field +from transformers import AutoTokenizer, AutoModel +from utils import process_response, generate_chatglm3, generate_stream_chatglm3 +from sentence_transformers import SentenceTransformer +from tools.schema import tool_class, tool_def, tool_param_start_with +from sse_starlette.sse import EventSourceResponse + +# Set up limit request time +EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 + +# set LLM path +# MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') +MODEL_PATH = os.environ.get('MODEL_PATH', '/Users/tanshuai/GitHub/chatglm3-6b') +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) + +# set Embedding Model path +EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3') + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class FunctionCallResponse(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system", "function"] + content: str = None + name: Optional[str] = None + function_call: Optional[FunctionCallResponse] = None + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + function_call: Optional[FunctionCallResponse] = None + + +## for Embedding +class EmbeddingRequest(BaseModel): + input: Union[List[str], str] + model: str + + +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class EmbeddingResponse(BaseModel): + data: list + model: str + object: str + usage: CompletionUsage + + +# for ChatCompletionRequest + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = 0.8 + top_p: Optional[float] = 0.8 + max_tokens: Optional[int] = None + stream: Optional[bool] = False + tools: Optional[Union[dict, List[dict]]] = None + repetition_penalty: Optional[float] = 1.1 + agent: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length", "function_call"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length", "function_call"]] + index: int + + +class ChatCompletionResponse(BaseModel): + model: str + id: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + usage: Optional[UsageInfo] = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/v1/embeddings", response_model=EmbeddingResponse) +async def get_embeddings(request: EmbeddingRequest): + if isinstance(request.input, str): + embeddings = [embedding_model.encode(request.input)] + else: + embeddings = [embedding_model.encode(text) for text in request.input] + embeddings = [embedding.tolist() for embedding in embeddings] + + def num_tokens_from_string(string: str) -> int: + """ + Returns the number of tokens in a text string. + use cl100k_base tokenizer + """ + encoding = tiktoken.get_encoding('cl100k_base') + num_tokens = len(encoding.encode(string)) + return num_tokens + + response = { + "data": [ + { + "object": "embedding", + "embedding": embedding, + "index": index + } + for index, embedding in enumerate(embeddings) + ], + "model": request.model, + "object": "list", + "usage": CompletionUsage( + prompt_tokens=sum(len(text.split()) for text in request.input), + completion_tokens=0, + total_tokens=sum(num_tokens_from_string(text) for text in request.input), + ) + } + return response + + +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + model_card = ModelCard( + id="chatglm3-6b" + ) + return ModelList( + data=[model_card] + ) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + global model, tokenizer + + if len(request.messages) < 1 or request.messages[-1].role == "assistant": + raise HTTPException(status_code=400, detail="Invalid request") + + gen_params = dict( + messages=request.messages, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens or 1024, + echo=False, + stream=request.stream, + repetition_penalty=request.repetition_penalty, + agent=request.agent + ) + logger.debug(f"==== request ====\n{gen_params}") + gen_params["tools"] = tool_def if gen_params["agent"] else [] + + if request.stream: + + # Use the stream mode to read the first few characters, if it is not a function call, direct stram output + predict_stream_generator = predict_stream(request.model, gen_params) + output = next(predict_stream_generator) + if not contains_custom_function(output, gen_params["tools"]): + return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") + + # Obtain the result directly at one time and determine whether tools needs to be called. + logger.debug(f"First result output:\n{output}") + + function_call = None + if output and request.tools: + try: + function_call = process_response(output, use_tool=True) + except: + logger.warning("Failed to parse tool call") + + # CallFunction + if isinstance(function_call, dict): + function_call = FunctionCallResponse(**function_call) + + """ + In this demo, we did not register any tools. + You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here. + Similar to the following method: + """ + if tool_param_start_with in output: + tool = tool_class.get(function_call.name) + if tool: + tool_param = json.loads(function_call.arguments).get("symbol") + if tool().parameter_validation(tool_param): + observation = str(tool().run(tool_param)) + tool_response = observation + else: + tool_response = "Tool parameter values error, please tell the user about this situation." + else: + tool_response = "No available tools found, please tell the user about this situation." + else: + tool_response = "Tool parameter content error, please tell the user about this situation." + + if not gen_params.get("messages"): + gen_params["messages"] = [] + + gen_params["messages"].append(ChatMessage( + role="assistant", + content=output, + )) + gen_params["messages"].append(ChatMessage( + role="function", + name=function_call.name, + content=tool_response, + )) + + # Streaming output of results after function calls + generate = predict(request.model, gen_params) + return EventSourceResponse(generate, media_type="text/event-stream") + + else: + # Handled to avoid exceptions in the above parsing function process. + generate = parse_output_text(request.model, output) + return EventSourceResponse(generate, media_type="text/event-stream") + + # Here is the handling of stream = False + response = generate_chatglm3(model, tokenizer, gen_params) + + # Remove the first newline character + if response["text"].startswith("\n"): + response["text"] = response["text"][1:] + response["text"] = response["text"].strip() + + usage = UsageInfo() + function_call, finish_reason = None, "stop" + if request.tools: + try: + function_call = process_response(response["text"], use_tool=True) + except: + logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.") + + if isinstance(function_call, dict): + finish_reason = "function_call" + function_call = FunctionCallResponse(**function_call) + + message = ChatMessage( + role="assistant", + content=response["text"], + function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, + ) + + logger.debug(f"==== message ====\n{message}") + + choice_data = ChatCompletionResponseChoice( + index=0, + message=message, + finish_reason=finish_reason, + ) + task_usage = UsageInfo.model_validate(response["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse( + model=request.model, + id="", # for open_source model, id is empty + choices=[choice_data], + object="chat.completion", + usage=usage + ) + + +async def predict(model_id: str, params: dict): + global model, tokenizer + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + previous_text = "" + for new_response in generate_stream_chatglm3(model, tokenizer, params): + decoded_unicode = new_response["text"] + delta_text = decoded_unicode[len(previous_text):] + previous_text = decoded_unicode + + finish_reason = new_response["finish_reason"] + if len(delta_text) == 0 and finish_reason != "function_call": + continue + + function_call = None + if finish_reason == "function_call": + try: + function_call = process_response(decoded_unicode, use_tool=True) + except: + logger.warning( + "Failed to parse tool call, maybe the response is not a tool call or have been answered.") + + if isinstance(function_call, dict): + function_call = FunctionCallResponse(**function_call) + + delta = DeltaMessage( + content=delta_text, + role="assistant", + function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, + ) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=delta, + finish_reason=finish_reason + ) + chunk = ChatCompletionResponse( + model=model_id, + id="", + choices=[choice_data], + object="chat.completion.chunk" + ) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + chunk = ChatCompletionResponse( + model=model_id, + id="", + choices=[choice_data], + object="chat.completion.chunk" + ) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + yield '[DONE]' + + +def predict_stream(model_id, gen_params): + """ + The function call is compatible with stream mode output. + + The first seven characters are determined. + If not a function call, the stream output is directly generated. + Otherwise, the complete character content of the function call is returned. + + :param model_id: + :param gen_params: + :return: + """ + output = "" + is_function_call = False + has_send_first_chunk = False + for new_response in generate_stream_chatglm3(model, tokenizer, gen_params): + decoded_unicode = new_response["text"] + delta_text = decoded_unicode[len(output):] + output = decoded_unicode + + # When it is not a function call and the character length is> 7, + # try to judge whether it is a function call according to the special function prefix + if not is_function_call and len(output) > 7: + + # Determine whether a function is called + is_function_call = contains_custom_function(output, gen_params["tools"]) + if is_function_call: + continue + + # Non-function call, direct stream output + finish_reason = new_response["finish_reason"] + + # Send an empty string first to avoid truncation by subsequent next() operations. + if not has_send_first_chunk: + message = DeltaMessage( + content="", + role="assistant", + function_call=None, + ) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=message, + finish_reason=finish_reason + ) + chunk = ChatCompletionResponse( + model=model_id, + id="", + choices=[choice_data], + created=int(time.time()), + object="chat.completion.chunk" + ) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + send_msg = delta_text if has_send_first_chunk else output + has_send_first_chunk = True + message = DeltaMessage( + content=send_msg, + role="assistant", + function_call=None, + ) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=message, + finish_reason=finish_reason + ) + chunk = ChatCompletionResponse( + model=model_id, + id="", + choices=[choice_data], + created=int(time.time()), + object="chat.completion.chunk" + ) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + if is_function_call: + yield output + else: + yield '[DONE]' + + +async def parse_output_text(model_id: str, value: str): + """ + Directly output the text content of value + + :param model_id: + :param value: + :return: + """ + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant", content=value), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + yield '[DONE]' + + +def contains_custom_function(value: str, tools: list) -> bool: + """ + Determine whether 'function_call' according to a special function prefix. + [Note] This is not a rigorous judgment method, only for reference. + + :param value: + :param tools: + :return: + """ + for tool in tools: + if value and tool["name"] in value: + return True + + +if __name__ == "__main__": + # Load LLM + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) + model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() + + # load Embedding + embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda") + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/selected_homework/openai-translator_v4/openai_api_demo/docker-compose.yml b/selected_homework/openai-translator_v4/openai_api_demo/docker-compose.yml new file mode 100644 index 00000000..ebb20279 --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/docker-compose.yml @@ -0,0 +1,44 @@ +version: "3.6" + +services: + glm3_api: + image: python:3.10.13-slim + restart: unless-stopped + working_dir: /glm3 + container_name: glm3_api + env_file: ./.env + networks: + - v_glm3 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + environment: + - MODEL_PATH=/models/chatglm3-6b + - EMBEDDING_PATH=/models/bge-large-zh-v1.5 + - TZ=Asia/Shanghai + - PYTHONDONTWRITEBYTECODE=1 + - PYTHONUNBUFFERED=1 + - DOCKER=True + ports: + - 8100:8000 + volumes: + - ./:/glm3 + - ${LOCAL_MODEL_PATH}:/models/chatglm3-6b + - ${LOCAL_EMBEDDING_MODEL_PATH}:/models/bge-large-zh-v1.5 + command: + - sh + - -c + - | + sed -i s/deb.debian.org/mirrors.tencentyun.com/g /etc/apt/sources.list + sed -i s/security.debian.org/mirrors.tencentyun.com/g /etc/apt/sources.list + apt-get update + python -m pip install -i https://mirror.sjtu.edu.cn/pypi/web/simple --upgrade pip + pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + python api_server.py +networks: + v_glm3: + driver: bridge \ No newline at end of file diff --git a/selected_homework/openai-translator_v4/openai_api_demo/langchain_openai_api.py b/selected_homework/openai-translator_v4/openai_api_demo/langchain_openai_api.py new file mode 100644 index 00000000..b3bbb7a9 --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/langchain_openai_api.py @@ -0,0 +1,55 @@ +""" +This script is designed for interacting with a local GLM3 AI model using the `ChatGLM3` class +from the `langchain_community` library. It facilitates continuous dialogue with the GLM3 model. + +1. Start the Local Model Service: Before running this script, you need to execute the `api_server.py` script +to start the GLM3 model's service. +2. Run the Script: The script includes functionality for initializing the LLMChain object and obtaining AI responses, +allowing the user to input questions and receive AI answers. +3. This demo is not support for streaming. + +""" +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain.schema.messages import HumanMessage, SystemMessage, AIMessage +from langchain_community.llms.chatglm3 import ChatGLM3 + + +def initialize_llm_chain(messages: list): + template = "{input}" + prompt = PromptTemplate.from_template(template) + + endpoint_url = "http://127.0.0.1:8000/v1/chat/completions" + llm = ChatGLM3( + endpoint_url=endpoint_url, + max_tokens=4096, + prefix_messages=messages, + top_p=0.9 + ) + return LLMChain(prompt=prompt, llm=llm) + + +def get_ai_response(llm_chain, user_message): + ai_response = llm_chain.invoke({"input": user_message}) + return ai_response + + +def continuous_conversation(): + messages = [ + SystemMessage(content="You are an intelligent AI assistant, named ChatGLM3."), + ] + while True: + user_input = input("Human (or 'exit' to quit): ") + if user_input.lower() == 'exit': + break + llm_chain = initialize_llm_chain(messages=messages) + ai_response = get_ai_response(llm_chain, user_input) + print("ChatGLM3: ", ai_response["text"]) + messages += [ + HumanMessage(content=user_input), + AIMessage(content=ai_response["text"]), + ] + + +if __name__ == "__main__": + continuous_conversation() diff --git a/selected_homework/openai-translator_v4/openai_api_demo/openai_api_request.py b/selected_homework/openai-translator_v4/openai_api_demo/openai_api_request.py new file mode 100644 index 00000000..7270ed03 --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/openai_api_request.py @@ -0,0 +1,99 @@ +""" +This script is an example of using the OpenAI API to create various interactions with a ChatGLM3 model. +It includes functions to: + +1. Conduct a basic chat session, asking about weather conditions in multiple cities. +2. Initiate a simple chat in Chinese, asking the model to tell a short story. +3. Retrieve and print embeddings for a given text input. + +Each function demonstrates a different aspect of the API's capabilities, showcasing how to make requests +and handle responses. +""" + +from openai import OpenAI + +base_url = "http://127.0.0.1:8000/v1/" +client = OpenAI(api_key="EMPTY", base_url=base_url) + + +def function_chat(): + messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + response = client.chat.completions.create( + model="chatglm3-6b", + messages=messages, + tools=tools, + tool_choice="auto", + ) + if response: + content = response.choices[0].message.content + print(content) + else: + print("Error:", response.status_code) + + +def simple_chat(use_stream=True): + messages = [ + { + "role": "system", + "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's " + "instructions carefully. Respond using markdown.", + }, + { + "role": "user", + "content": "你好,请你用生动的话语给我讲一个小故事吧" + } + ] + response = client.chat.completions.create( + model="chatglm3-6b", + messages=messages, + stream=use_stream, + max_tokens=256, + temperature=0.8, + presence_penalty=1.1, + top_p=0.8) + if response: + if use_stream: + for chunk in response: + print(chunk.choices[0].delta.content) + else: + content = response.choices[0].message.content + print(content) + else: + print("Error:", response.status_code) + + +def embedding(): + response = client.embeddings.create( + model="bge-large-zh-1.5", + input=["你好,给我讲一个故事,大概100字"], + ) + embeddings = response.data[0].embedding + print("嵌入完成,维度:", len(embeddings)) + + +if __name__ == "__main__": + simple_chat(use_stream=False) + simple_chat(use_stream=True) + embedding() + function_chat() diff --git a/selected_homework/openai-translator_v4/openai_api_demo/tools/schema.py b/selected_homework/openai-translator_v4/openai_api_demo/tools/schema.py new file mode 100644 index 00000000..6989aea4 --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/tools/schema.py @@ -0,0 +1,34 @@ + +""" +Description: You can customize the developed langchain tool overview information here, +just like the sample code already given in this script. +""" + + +tool_param_start_with = "```python\ntool_call" + + +""" Fill this dictionary with the mapping from tool class names to tool classes that you defined. +Like: + from tools.Calculator import Calculator + tool_class = {"Calculator": Calculator, ...} + +It is required that your customized tool class must define the format for the langchain tool +and implement the parameter verification function in the class: + parameter_validation(self, para: str) -> bool + +Tool class definition reference: ChatGLM3/langchain_demo/tools. +""" +tool_class = {} + + +""" Describe your tool names and parameters in this dictionary. +Like: + tool_def = [ + {"name": "Calculator", + "description": "数学计算器,计算数学问题", + "parameters": {"type": "object", "properties": {"symbol": {"description": "要计算的数学公式"}}, "required": []} + },... + ] +""" +tool_def = [] diff --git a/selected_homework/openai-translator_v4/openai_api_demo/utils.py b/selected_homework/openai-translator_v4/openai_api_demo/utils.py new file mode 100644 index 00000000..314d927b --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/utils.py @@ -0,0 +1,192 @@ +import gc +import json +import torch +from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.generation.logits_process import LogitsProcessor +from typing import Union, Tuple + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: + content = "" + for response in output.split("<|assistant|>"): + metadata, content = response.split("\n", maxsplit=1) + if not metadata.strip(): + content = content.strip() + content = content.replace("[[训练时间]]", "2023年") + else: + if use_tool: + content = "\n".join(content.split("\n")[1:-1]) + + def tool_call(**kwargs): + return kwargs + + parameters = eval(content) + content = { + "name": metadata.strip(), + "arguments": json.dumps(parameters, ensure_ascii=False) + } + else: + content = { + "name": metadata.strip(), + "content": content + } + return content + + +@torch.inference_mode() +def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): + messages = params["messages"] + tools = params["tools"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_tokens", 256)) + echo = params.get("echo", True) + messages = process_chatglm_messages(messages, tools=tools) + query, role = messages[-1]["content"], messages[-1]["role"] + + inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) + inputs = inputs.to(model.device) + input_echo_len = len(inputs["input_ids"][0]) + + if input_echo_len >= model.config.seq_length: + print(f"Input length larger than {model.config.seq_length}") + + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>") + ] + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "logits_processor": [InvalidScoreLogitsProcessor()], + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + total_len = 0 + for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): + total_ids = total_ids.tolist()[0] + total_len = len(total_ids) + if echo: + output_ids = total_ids[:-1] + else: + output_ids = total_ids[input_echo_len:-1] + + response = tokenizer.decode(output_ids) + if response and response[-1] != "�": + response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) + + yield { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": "function_call" if stop_found else None, + } + + if stop_found: + break + + # Only last stream result contains finish_reason, we set finish_reason as stop + ret = { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": "stop", + } + yield ret + + gc.collect() + torch.cuda.empty_cache() + + +def process_chatglm_messages(messages, tools=None): + _messages = messages + messages = [] + msg_has_sys = False + if tools: + messages.append( + { + "role": "system", + "content": "Answer the following questions as best as you can. You have access to the following tools:", + "tools": tools + } + ) + msg_has_sys = True + + for m in _messages: + role, content, func_call = m.role, m.content, m.function_call + if role == "function": + messages.append( + { + "role": "observation", + "content": content + } + ) + + elif role == "assistant" and func_call is not None: + for response in content.split("<|assistant|>"): + metadata, sub_content = response.split("\n", maxsplit=1) + messages.append( + { + "role": role, + "metadata": metadata, + "content": sub_content.strip() + } + ) + else: + if role == "system" and msg_has_sys: + msg_has_sys = False + continue + messages.append({"role": role, "content": content}) + return messages + + +def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): + for response in generate_stream_chatglm3(model, tokenizer, params): + pass + return response + + +def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]: + stop_found = False + for string in stop_strings: + idx = reply.find(string) + if idx != -1: + reply = reply[:idx] + stop_found = True + break + + if not stop_found: + # If something like "\nYo" is generated just before "\nYou: is completed, trim it + for string in stop_strings: + for j in range(len(string) - 1, 0, -1): + if reply[-j:] == string[:j]: + reply = reply[:-j] + break + else: + continue + + break + + return reply, stop_found diff --git a/selected_homework/openai-translator_v4/openai_api_demo/zhipu_api_request.py b/selected_homework/openai-translator_v4/openai_api_demo/zhipu_api_request.py new file mode 100644 index 00000000..537d48f4 --- /dev/null +++ b/selected_homework/openai-translator_v4/openai_api_demo/zhipu_api_request.py @@ -0,0 +1,100 @@ +""" +This script is an example of using the Zhipu API to create various interactions with a ChatGLM3 model. It includes +functions to: + +1. Conduct a basic chat session, asking about weather conditions in multiple cities. +2. Initiate a simple chat in Chinese, asking the model to tell a short story. +3. Retrieve and print embeddings for a given text input. +Each function demonstrates a different aspect of the API's capabilities, +showcasing how to make requests and handle responses. + +Note: Make sure your Zhipu API key is set as an environment +variable formate as xxx.xxx (just for check, not need a real key). +""" + +from zhipuai import ZhipuAI + +base_url = "http://127.0.0.1:8000/v1/" +client = ZhipuAI(api_key="EMP.TY", base_url=base_url) + + +def function_chat(): + messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + response = client.chat.completions.create( + model="chatglm3_6b", + messages=messages, + tools=tools, + tool_choice="auto", + ) + if response: + content = response.choices[0].message.content + print(content) + else: + print("Error:", response.status_code) + + +def simple_chat(use_stream=True): + messages = [ + { + "role": "system", + "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow " + "the user's instructions carefully. Respond using markdown.", + }, + { + "role": "user", + "content": "你好,请你介绍一下chatglm3-6b这个模型" + } + ] + response = client.chat.completions.create( + model="chatglm3_", + messages=messages, + stream=use_stream, + max_tokens=256, + temperature=0.8, + top_p=0.8) + if response: + if use_stream: + for chunk in response: + print(chunk.choices[0].delta.content) + else: + content = response.choices[0].message.content + print(content) + else: + print("Error:", response.status_code) + + +def embedding(): + response = client.embeddings.create( + model="bge-large-zh-1.5", + input=["ChatGLM3-6B 是一个大型的中英双语模型。"], + ) + embeddings = response.data[0].embedding + print("嵌入完成,维度:", len(embeddings)) + + +if __name__ == "__main__": + simple_chat(use_stream=False) + simple_chat(use_stream=True) + embedding() + function_chat()