from concurrent.futures import ThreadPoolExecutor
from tornado.concurrent import run_on_executor
import asyncio
class ChatHandler(tornado_factory("rest")):
'''
流式问答
'''
executor = ThreadPoolExecutor(max_workers=4) # 创建线程池
async def post(self):
data = self.get_json()
self.set_header("Content-Type", "text/event-stream")
self.set_header("Cache-Control", "no-cache")
queue = asyncio.Queue() # 创建一个队列用于数据传递
asyncio.create_task(self.run_prediction(data, queue))
# 在主线程中不断从队列中取数据并发送
while not self.request.connection.stream.closed():
result = await queue.get()
if result is None:
break
self.write(result + '\n')
await self.flush()
async def run_prediction(self, data, queue):
# 将同步生成器放入线程池执行,逐步将生成结果放入队列
loop = asyncio.get_running_loop()
def blocking_predict():
# predict 是同步生成器,逐步生成预测结果
for result in predict(data):
# 使用 run_coroutine_threadsafe 将结果安全地放入队列
asyncio.run_coroutine_threadsafe(queue.put(result), loop)
await loop.run_in_executor(self.executor, blocking_predict)
# 放入 None 表示生成器结束
await queue.put(None)
class ChatNostreamHandler(tornado_factory("rest")):
'''
非流式问答
'''
executor = ThreadPoolExecutor(max_workers=4) # 创建线程池
async def post(self):
data = self.get_json()
await self.run_prediction(data)
@run_on_executor
def run_prediction(self, data):
try:
result = ''
for char in predict(data):
result += char
return self.json_response(status="OK", result=result)
except:
return self.json_response(status="Failed", code=500)