""" **CoLLie** 中的异步数据提供器,为模型生成过程提供更加广泛的数据采集渠道
"""
__all__ = [
'BaseProvider',
'GradioProvider',
'DashProvider',
'_GenerationStreamer'
]
import os
import shutil
import time
import pandas as pd
from torch.multiprocessing import Process, Queue
from transformers.generation.streamers import BaseStreamer
from transformers import PreTrainedTokenizer, GenerationConfig
import torch
from collie.driver.io.petrel import no_proxy
[文档]class BaseProvider:
""" BaseProvider 为异步数据提供器的基类,提供了一些基本的接口
"""
def __init__(self,
stream: bool = False,
generation_config: GenerationConfig = GenerationConfig()) -> None:
self.data = Queue()
self.feedback = Queue()
self.stream = stream
self.provider_started = False
self.generation_config = generation_config
[文档] def provider_handler(self):
""" provider_handler 为异步数据提供器的主要逻辑,需要被子类重写,主要功能为异步地收集数据并放入队列 `self.data` 中
"""
while True:
self.data.put('Hello World')
time.sleep(1)
[文档] def start_provider(self):
""" start_provider 为异步数据提供器的启动函数,会在一个新的进程中启动 `provider_handler` 函数
"""
if not self.provider_started:
with no_proxy():
process = Process(target=self.provider_handler)
process.daemon = True
process.start()
self.provider_started = True
[文档] def get_data(self):
""" get_data 为异步数据提供器的数据获取函数,会从队列 `self.data` 中获取数据
"""
if self.data.empty():
return None
else:
return self.data.get()
[文档] def get_feedback(self):
""" get_feedback 为异步数据提供器的反馈获取函数,会从队列 `self.feedback` 中获取反馈,主要指模型生成的结果
"""
if self.feedback.empty():
return None
else:
return self.feedback.get()
[文档] def put_feedback(self, feedback):
""" put_feedback 为异步数据提供器的反馈放入函数,会将反馈放入队列 `self.feedback` 中,该函数由 **CoLLie** 自动调用,将模型生成的结果放入该队列中
"""
self.feedback.put(feedback)
[文档]class GradioProvider(BaseProvider):
""" 基于 Gradio 的异步数据提供器,会在本地启动一个 Gradio 服务,将用户输入的文本作为模型的输入
"""
def __init__(self,
tokenizer: PreTrainedTokenizer,
port: int = 7878,
stream: bool = False,
generation_config: GenerationConfig = GenerationConfig()) -> None:
super().__init__(stream=stream, generation_config=generation_config)
self.tokenizer = tokenizer
self.port = port
def provider_handler(self):
import gradio as gr
output_cache = []
def submit(text):
output_cache.clear()
self.data.put(self.tokenizer(text, return_tensors='pt')["input_ids"])
while True:
feedback = self.get_feedback()
if feedback is not None:
if feedback == 'END_OF_STREAM':
break
output_cache.extend(torch.flatten(feedback).cpu().numpy().tolist())
yield self.tokenizer.decode(output_cache)
if not self.stream:
break
interface = gr.Interface(fn=submit, inputs="textbox", outputs="text")
interface.queue()
interface.launch(server_name="0.0.0.0", server_port=self.port, share=True)
[文档]class DashProvider(BaseProvider):
""" 基于 Dash 的异步数据提供器,会在本地启动一个 Dash 服务,将用户输入的文本作为模型的输入
"""
def __init__(self,
tokenizer: PreTrainedTokenizer,
port: int = 7878,
stream: bool = False,
generation_config: GenerationConfig = GenerationConfig()) -> None:
super().__init__(stream=stream, generation_config=generation_config)
self.tokenizer = tokenizer
self.port = port
def provider_handler(self):
import dash
from dash import Dash, html, Input, Output, State, dcc, dash_table
from dash.long_callback import DiskcacheLongCallbackManager
import dash_bootstrap_components as dbc
## Diskcache
import diskcache
# 文件上传
import dash_uploader as du
import logging
logging.getLogger('werkzeug').setLevel(logging.ERROR)
CACHE_PATH = "./.cache"
if os.path.exists(CACHE_PATH):
shutil.rmtree(CACHE_PATH)
os.mkdir(CACHE_PATH)
disk_cache = diskcache.Cache(os.path.join(CACHE_PATH, "cache"))
long_callback_manager = DiskcacheLongCallbackManager(disk_cache)
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], long_callback_manager=long_callback_manager, suppress_callback_exceptions=True)
# 配置上传文件夹
du.configure_upload(app, folder=os.path.join(CACHE_PATH, "TEMP"))
# 链接的上传
link_upload = dbc.Card(
dbc.CardBody(
[
dbc.Input(placeholder="please enter the absolute file link...", className="mb-3", id="link_uploader"),
# html.P("该文件路径不存在")
dbc.Alert("该文件路径不存在", color="warning", is_open=False, dismissable=True, duration=2000, id="link_uploader_warn"),
]
),
)
# 文件上传的组件
file_upload = dbc.Card(
dbc.CardBody(
[
du.Upload(
id='upload_file',
text='点击或拖动文件到此进行上传!',
text_completed='已完成上传文件:',
cancel_button=True,
pause_button=True,
filetypes=["txt"],
default_style={
'background-color': '#fafafa',
'font-weight': 'bold'
},
upload_id='myupload'
)
]
),
)
# 文件或者链接上传并且生成
upload_app = dbc.Modal(
[
dbc.ModalHeader(dbc.ModalTitle("Upload", style={"width": "100%", "text-align": "center"})),
dbc.ModalBody(
dbc.Tabs([
dbc.Tab(file_upload, label="文件上传", tab_id="tab1"),
dbc.Tab(link_upload, label="链接上传", tab_id="tab2"),
])
),
dbc.ModalFooter(
[
dbc.Button(
"submit", id="submit_file_link", className="ms-4", n_clicks=0, style={"margin-right": "10px"}),
dbc.Button(
"Close", id="close", n_clicks=0)
],
style={"display": "flex", "justify-content": "center"}
),
],
id="modal",
is_open=False,
backdrop="static",
size="lg"
)
# 顶栏组件
navbar = dbc.Navbar(
dbc.Container(
[
html.A(
# Use row and col to control vertical alignment of logo / brand
dbc.Row(
[
dbc.Col(dbc.NavbarBrand("Collie for Generation", className="ms-0"),
),
],
align="center",
),
# href="https://plotly.com",
style={"textDecoration": "none", },
),
],
fluid=True
),
color="dark",
dark=True,
)
# Text Inpt 卡片
top_card = dbc.Card(
[
dbc.CardHeader(html.H4("Text Input", style={"width": "100%", "text-align": "center"})),
dbc.CardBody(
dbc.Textarea(
size="lg", placeholder="This is input text, please edit there", id="text_input", style={"width": "100%", "height": "150px", "resize": "none"})
),
dbc.CardFooter([
dbc.Button(
"Upload", id="upload-button", className="me-3", n_clicks=0, color="dark", outline=True, style={"margin-right": "20px"}, size="lg"
),
dbc.Button(
"Clear-all", id="clear-button", className="me-3", n_clicks=0, color="primary", outline=True, style={"margin-right": "20px"}, size="lg"
),dbc.Button(
"Submit", id="submit-button", className="me-2", n_clicks=0, color="success", outline=True, size="lg"
)
],style={"display": "flex", "justify-content": "center"}),
],
style={"width": "auto", "min-height": "250px","height":"100%"},
color="secondary",
outline=True
)
# Generated Output 卡片
bottom_card = dbc.Card(
[
html.Div(id="hidden-div", style={"display":"none"}),
dbc.Toast(
"正在生成中,请不要重复提交",
id="positioned-toast",
header="Tips",
is_open=False,
dismissable=True,
icon="danger",
duration=4000,
# top: 66 positions the toast below the navbar
style={"position": "fixed", "top": 66, "right": 10, "width": 350},
),
dbc.Toast(
"生成结束",
id="tip-end",
header="Tips",
is_open=False,
dismissable=True,
icon="danger",
duration=4000,
# top: 66 positions the toast below the navbar
style={"position": "fixed", "top": 66, "right": 10, "width": 350},
),
upload_app,
# 刷新时候清除
dcc.Store(id='memory'),
dbc.CardHeader(html.H4("Generated Output", style={"width": "100%", "text-align": "center"})),
dbc.CardBody(dbc.Textarea(
size="lg", placeholder="This is generated text", style={"width": "100%", "height": "150px", "resize": "none"}, id="gen_output")
),
dbc.CardFooter(
dbc.Button(
"Download", id="download-button", className="me-2", n_clicks=0, outline=True, color="danger", size="lg"
),style={"display": "flex", "justify-content": "center"}),
dcc.Download(id="download-dataframe-csv")
],
style={"width": "auto", "min-height": "250px"},
color="success",
outline=True
)
# 展示生成内容的卡片
history_generation = dbc.Card(
[
# 不展示, 用来做中介以便更新表格参数
dbc.Button("test_button", id="clock_assistance", n_clicks=0, style={"display":"none"}),
dbc.CardHeader(html.H4("History Records", style={"width": "100%", "text-align": "center"})),
dbc.CardBody(
dash_table.DataTable(id='live-update-table',
style_header={'backgroundColor':'#305D91','padding':'10px','color':'#FFFFFF'},
style_table={'overflowX':'auto', 'overflowY': 'auto', 'height': "150px"},
style_cell_conditional=[{'if': {'column_id': 'input_text'}, 'width': '35%'}, {'if': {'column_id': 'gen_text'}, 'width': '65%'}],
style_cell={'overflow':'hidden', 'textOverflow': 'ellipsis', 'whiteSpace': 'normal', 'textAlign':'center'},
data=[{}],
columns=[{"name": i, "id": i} for i in ['input_text', 'gen_text']])
),
dbc.CardFooter(
dbc.Button(
"FreeRec", id="freerec-button", className="me-2", n_clicks=0, outline=True, color="warning", size="lg"
),style={"display": "flex", "justify-content": "center"}),
],
style={"width": "auto", "min-height": "250px"},
color="success",
outline=True
)
# 展示进度条的卡片
progress_card = dbc.Card(
[
# 刷新时候清除
dcc.Store(id='memory_file_data'),
dbc.CardHeader(html.H4("Prcocess Time", style={"width": "100%", "text-align": "center"})),
dbc.CardBody(
[dbc.Progress(value=0, id="animated-progress", animated=True, striped=True, style={"margin-bottom": "20px", "margin-top": "10px"}),
dcc.Interval(id='timer_progress', interval=1000),
dash_table.DataTable(id='process-table',
style_header={'backgroundColor':'#305D91','padding':'10px','color':'#FFFFFF'},
style_table={'overflowX':'auto', 'overflowY': 'auto', 'height': "100px"},
style_cell_conditional=[{'if': {'column_id': 'CurProcess'}, 'width': '25%'},
{'if': {'column_id': 'CurTime'}, 'width': '25%'},
{'if': {'column_id': 'TotalProcess'}, 'width': '25%'},
{'if': {'column_id': 'TotalTime'}, 'width': '25%'}],
style_cell={'overflow':'hidden', 'textOverflow': 'ellipsis', 'whiteSpace': 'normal', 'textAlign':'center'},
data=[{'CurProcess': 0, 'CurTime': 0, 'TotalProcess':0 ,'TotalTime':0}],
columns=[{"name": i, "id": i} for i in ['CurProcess', 'CurTime', 'TotalProcess' ,'TotalTime']])
],
style={"width": "100%", "height": "180px", "resize": "none"}
),
dbc.CardFooter(
dbc.Button(
"Reset", id="reset-button", className="me-2", n_clicks=0, outline=True, color="info", size="lg"
),style={"display": "flex", "justify-content": "center"}),
],
style={"width": "auto", "min-height": "250px"},
color="success",
outline=True
)
app.layout = dbc.Container([
dbc.Row([
navbar
]),
dbc.Row([
dbc.Col(top_card, width=6),
dbc.Col(bottom_card, width=6),
], style={"margin-top": "20px", }),
dbc.Row([
dbc.Col(history_generation, width=6),
dbc.Col(progress_card, width=6)
], style={"margin-top": "20px", })
], fluid=True)
@app.long_callback(
output=[Output("memory", 'data', allow_duplicate=True), Output("gen_output", "value", allow_duplicate=True),
Output("submit-button", "n_clicks", allow_duplicate=True), Output("clock_assistance", "n_clicks", allow_duplicate=True),
Output("tip-end", "is_open", allow_duplicate=True), Output("process-table", "data", allow_duplicate=True)],
inputs=[State("memory", 'data'), State("text_input", "value"), State("process-table", "data"),
Input("submit-button", "n_clicks")],
running=[
(Output("submit-button", "disabled"), True, False),
(Output("positioned-toast", "is_open"), True, False)
],
prevent_initial_call=True
)
def submit_click(session_data, text_input, process_data,n):
session_data = session_data or []
if n == 1 and text_input is not None:
# 进度条文件的写入
progress_file = os.path.join(CACHE_PATH, 'progress.txt')
progress_pt = open(progress_file, 'w')
precent = round(0)
progress_pt.write(f"{precent}%\n")
progress_pt.close()
self.data.put(self.tokenizer(text_input, return_tensors='pt')["input_ids"])
n = 0
output_cache = []
res = []
start = time.time()
while True:
feedback = self.get_feedback()
if feedback is not None:
if feedback == 'END_OF_STREAM':
break
output_cache.extend(torch.flatten(feedback).cpu().numpy().tolist())
res.append(self.tokenizer.decode(output_cache))
if not self.stream:
break
progress_pt = open(progress_file, 'w')
precent = precent + round((100-precent)/2)
progress_pt.write(f"{precent}%\n")
progress_pt.close()
end = time.time()
process_data[0]['CurTime'] = round((end - start) / 60, 2)
process_data[0]['TotalTime'] += round((end - start) / 60, 2)
process_data[0]['TotalProcess'] += 1
process_data[0]['CurProcess'] = 1
progress_pt = open(progress_file, 'w')
precent = 100
progress_pt.write(f"{precent}%\n")
progress_pt.close()
n = 0
session_data.append({'input_text': text_input, 'gen_text': res[-1]})
return session_data, res[-1], n, 1, True, process_data
else:
n = 0
return session_data, "", n, 1, False, process_data
@app.callback(
[Output("text_input", "value"), Output("gen_output", "value"), Output("clear-button", "n_clicks")],
[State("text_input", "value"), State("gen_output", "value"), Input("clear-button", "n_clicks")],
prevent_initial_call=True
)
def input_text(value, value1, n_clicks):
if n_clicks > 0:
value = ''
value1 = ''
n_clicks = 0
return value, value1, n_clicks
# 打开 upload 的回调函数
@app.callback(
Output("modal", "is_open", allow_duplicate=True),
[Input("upload-button", "n_clicks"), Input("close", "n_clicks")],
[State("modal", "is_open")],
prevent_initial_call=True
)
def upload_open(n1, n2, is_open):
if n1 or n2:
return not is_open
return is_open
# 文件上传的回调函数 upload_file
@app.callback(
Output('hidden-div', 'children'),
Input('upload_file', 'isCompleted'),
State('upload_file', 'fileNames')
)
def upload_file_fn(isCompleted, fileNames):
if isCompleted:
fileTemp = os.path.join(CACHE_PATH, "TEMP", "myupload")
all_files = os.listdir(fileTemp)
print(fileNames, all_files)
all_files.remove(fileNames[0])
for file_name in all_files:
os.remove(os.path.join(fileTemp, file_name))
return dash.no_update
# 链接上传的回调函数 link_uploader
@app.callback(
[Output('link_uploader_warn', 'is_open'), Output('link_uploader', 'valid')],
Input('link_uploader', 'value'),
prevent_initial_call=True
)
def link_uploader_fn(link):
if len(link) == 0:
return False, False
if os.path.isfile(link):
return False, True
else:
return True, False
# 提交上传文件或者链接后关闭页面然后上传
@app.callback(
[Output("modal", "is_open"), Output("memory_file_data", "data")],
[Input("submit_file_link", "n_clicks"), State("link_uploader", "value")],
prevent_initial_call=True
)
def submit_file_link(n1, value):
if n1 > 0:
filenames = os.listdir(os.path.join(CACHE_PATH, "TEMP", "myupload"))
if len(filenames) == 1:
filepath = os.path.join(CACHE_PATH, "TEMP", "myupload", filenames[0])
else:
# 先检测链接中是绝对路径,还是为相对路径
curPath = os.getcwd()
if not os.path.isabs(value):
value = os.path.join(curPath, value)
if not os.path.isfile(value):
return True
filepath = value
data_list = []
with open(filepath, "r") as fp:
for line in fp:
line = line.strip("\n").strip()
if len(line) > 0:
data_list.append(line)
return False, data_list
return True, []
# 更新表格参数
@app.callback([Output("live-update-table", "data"), Output("live-update-table", "columns")],
[Input("clock_assistance", "n_clicks"), State("memory", "data")],
)
def update_table(n, memery_data):
df = pd.DataFrame(memery_data)
return df.to_dict('records'), [{"name": i, "id": i} for i in ['input_text', 'gen_text']]
# 清除记录数据
@app.callback(
[Output("live-update-table", "data", allow_duplicate=True), Output("live-update-table", "columns", allow_duplicate=True),
Output("memory", "data", allow_duplicate=True), Output("freerec-button", "n_clicks")],
Input("freerec-button", "n_clicks"),
prevent_initial_call=True
)
def freeRecord(n):
return [{}], [{"name": i, "id": i} for i in ['input_text', 'gen_text']], [], 0
# 文件批量生成
@app.long_callback(
output=[Output("tip-end", "is_open"), Output("memory", 'data'), Output("clock_assistance", "n_clicks"),
Output("process-table", "data")],
inputs=[Input("memory_file_data", "data"), State("memory", 'data'), State("process-table", "data")],
running=[
(Output("submit-button", "disabled"), True, False),
(Output("positioned-toast", "is_open"), True, False)
],
prevent_initial_call=True
)
def file_generate(data_list, mem_data, process_data):
all_pair_gen = []
progress_file = os.path.join(CACHE_PATH, 'progress.txt')
mem_data = mem_data or []
for idx, text_input in enumerate(data_list):
self.data.put(self.tokenizer(text_input, return_tensors='pt')["input_ids"])
output_cache = []
res = []
start = time.time()
while True:
feedback = self.get_feedback()
if feedback is not None:
if feedback == 'END_OF_STREAM':
break
output_cache.extend(torch.flatten(feedback).cpu().numpy().tolist())
res.append(self.tokenizer.decode(output_cache))
if not self.stream:
break
end = time.time()
progress_pt = open(progress_file, 'w')
precent = round((idx+1)/len(data_list) * 100)
progress_pt.write(f"{precent}%\n")
progress_pt.close()
all_pair_gen.append({'input_text': text_input, 'gen_text': res[-1]})
mem_data.append({'input_text': text_input, 'gen_text': res[-1]})
process_data[0]['CurTime'] += round((end - start) / 60, 2)
process_data[0]['TotalTime'] += round((end - start) / 60, 2)
process_data[0]['TotalProcess'] += 1
process_data[0]['CurProcess'] = len(data_list)
return True, mem_data, 1, process_data
# 进度条更新展示
@app.callback(
output=[Output('animated-progress', 'value'), Output('animated-progress', 'label')],
inputs=Input('timer_progress', "n_intervals"),
# progress_default=0,
prevent_initial_call=True,
)
def progress_callback(n_intervals):
try:
with open(os.path.join(CACHE_PATH, 'progress.txt'), 'r') as file:
str_raw = file.read()
last_line = list(filter(None, str_raw.split('\n')))[-1]
percent = float(last_line.split('%')[0])
except:
percent = 0
finally:
text = f'{percent:.0f}%'
return percent, text
# 清除 process的数据
@app.callback(
output=[Output('animated-progress', 'value', allow_duplicate=True), Output('animated-progress', 'label', allow_duplicate=True),
Output("process-table", "data", allow_duplicate=True), Output("reset-button", "n_clicks")],
inputs=[Input("reset-button", "n_clicks")],
prevent_initial_call=True,
)
def clear_process_data(n_clicks):
if n_clicks > 0:
progress_file = os.path.join(CACHE_PATH, 'progress.txt')
if os.path.exists(progress_file):
os.remove(progress_file)
return 0, "0%", [{'CurProcess': 0, 'CurTime': 0, 'TotalProcess':0 ,'TotalTime':0}], 0
# 下载数据
@app.callback(
Output("download-dataframe-csv", "data"),
Input("download-button", "n_clicks"),
State("memory", "data"),
prevent_initial_call=True
)
def download_json_file(n, data):
data = data or []
return dcc.send_data_frame(pd.DataFrame(data).to_csv, "data.csv")
app.run_server(port=self.port, host="0.0.0.0")
class _GenerationStreamer(BaseStreamer):
""" 重写 `transformers` 的 `BaseStreamer` 类以兼容 **CoLLie** 的异步数据提供器
"""
def __init__(self, server: BaseProvider) -> None:
self.server = server
self.stop_signal = 'END_OF_STREAM'
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("_GenerationStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
self.server.put_feedback(value)
def end(self):
self.server.put_feedback(self.stop_signal)