version = “0.11.0”
主目录
.
├── Dockerfile
├── Dockerfile.arm
├── Dockerfile.cuda
├── Dockerfile.scratch
├── Dockerfile.scratch.oc9
├── LICENSE
├── README.md
├── README_ja.md
├── README_ko.md
├── README_zh.md
├── SECURITY.md
├── agent
├── api
├── conf
├── deepdoc
├── docker
├── docs
├── download_deps.sh
├── graphrag
├── poetry.lock
├── poetry.toml
├── printEnvironment.sh
├── pyproject.toml
├── rag
├── sdk
├── ubuntu.sources
└── web
rag
.
├── __init__.py
├── app
├── benchmark.py
├── llm
├── nlp
├── raptor.py
├── res
├── settings.py
├── svr
└── utils
svr/task_executor.py
任务执行器
服务入口
if __name__ == "__main__":
# .... 准备日志
exe = ThreadPoolExecutor(max_workers=1)
exe.submit(report_status)
# 创建最大工作线程数为1的线程池,并提交report_status的任务异步执行。
while True: # 循环调用main()
main()
if PAYLOAD: # 从redis消息队列中获取的消息数据
PAYLOAD.ack()
PAYLOAD = None
主处理函数
def main():
rows = collect() # 从redis队列中获取任务
if len(rows) == 0:
return
for _, r in rows.iterrows():
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
# 设置进度回调函数
try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) # 获取embedding model
# 创建嵌入模型实例
except Exception as e:
callback(-1, msg=str(e))
cron_logger.error(str(e))
continue
if r.get("task_type", "") == "raptor": # raptor类型任务处理
try:
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
# 获取对话model
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
except Exception as e:
callback(-1, msg=str(e))
cron_logger.error(str(e))
continue
else: # 其它类型任务处理
st = timer()
cks = build(r)
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
if cks is None:
continue
if not cks:
callback(1., "No chunk! Done!")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
callback(
msg="Finished slicing files(%d). Start to embedding the content." %
len(cks))
st = timer()
try:
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
tk_count = 0
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
init_kb(r) # 初始化知识库
chunk_count = len(set([c["_id"] for c in cks]))
# 统计文档片段数量
# 统计所有文档片段的唯一 _id 数量,以确保每个片段都是唯一的
st = timer()
es_r = ""
es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size):
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
# 使用批量操作将文档片段插入 Elasticsearch,提高插入效率。
# 每次批量插入的数量为 es_bulk_size(默认为 4),可以根据实际情况调整。
# 每处理完 128 条记录,更新进度条,使用户可以了解当前处理进度
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: # 如果有错误发生,删除已插入的数据,并记录错误信息
callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
cron_logger.error(str(es_r))
else:
if TaskService.do_cancel(r["id"]): # 如果任务被取消,删除已插入的数据
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
continue
callback(1., "Done!")
# 如果成功完成,更新文档片段的数量,并记录相关信息
DocumentService.increment_chunk_num(
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info(
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
r["id"], tk_count, len(cks), timer() - st))
raptor方法处理文档
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"])
vctr_nm = "q_%d_vec"%len(vts[0]) # 通过编码一个字符串“ok”来获取向量维度大小,并根据维度大小生成向量名称
chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
# 从给定的文档ID和租户ID中提取文档片段,每个片段包括内容及其权重和对应的向量表示。
raptor = Raptor(
row["parser_config"]["raptor"].get("max_cluster", 64),
chat_mdl,
embd_mdl,
row["parser_config"]["raptor"]["prompt"],
row["parser_config"]["raptor"]["max_token"],
row["parser_config"]["raptor"]["threshold"]
)
# 根据配置参数初始化Raptor对象,包括最大聚类数量、聊天模型、嵌入模型等
original_length = len(chunks)
raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
# 将文档片段传递给Raptor对象进行处理,并指定随机种子值
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:
d = copy.deepcopy(doc)
md5 = hashlib.md5()
md5.update((content + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
d["content_with_weight"] = content
d["content_ltks"] = rag_tokenizer.tokenize(content)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d)
tk_count += num_tokens_from_string(content)
# 遍历处理后的文档片段,为每个片段生成新的文档元数据,包括唯一ID、创建时间戳、向量表示、内容等。
return res, tk_count # 返回文档列表和总token数量
更新任务进度
def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
# 根据进度值是否为负数决定是否添加错误标签
msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)
# 检查任务是否被取消,并在取消时更新消息和进度值
if cancel:
msg += " [Canceled]"
prog = -1
if to_page > 0:
if msg:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
d = {"progress_msg": msg}
# 格式化页面范围信息并构建进度消息
if prog is not None:
d["progress"] = prog
try:
TaskService.update_progress(task_id, d)
# 更新任务进度信息
except Exception as e: # 异常处理:记录日志
cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
close_connection() # 断开连接
if cancel: # 如果任务取消,清除负载并退出
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
os._exit(0) # ??此处退出是整个服务退出了吗
从Redis队列获取任务事件
def collect():
global CONSUMEER_NAME, PAYLOAD
try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
# 通过REDIS_CONN获取未确认的消息PAYLOAD
if not PAYLOAD: # 如果PAYLOAD为空,则尝试消费队列中的消息
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
if not PAYLOAD: # 若仍无消息,则休眠1秒并返回空DataFrame
time.sleep(1)
return pd.DataFrame()
except Exception as e:
cron_logger.error("Get task event from queue exception:" + str(e))
return pd.DataFrame()
msg = PAYLOAD.get_message()
if not msg: # 获取消息msg,若为空则返回空DataFrame
return pd.DataFrame()
if TaskService.do_cancel(msg["id"]):
# 检查任务是否被取消,若是则记录日志并返回空DataFrame
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
return pd.DataFrame()
tasks = TaskService.get_tasks(msg["id"])
if not tasks: # 获取任务详情,若为空则返回空列表
cron_logger.warn("{} empty task!".format(msg["id"]))
return []
tasks = pd.DataFrame(tasks) # 将任务转换为DataFrame
if msg.get("type", "") == "raptor":
tasks["task_type"] = "raptor" # 根据消息类型设置任务类型
return tasks # 返回处理后的任务数据
构建文档数据
def get_storage_binary(bucket, name):
return STORAGE_IMPL.get(bucket, name)
def build(row):
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
"""
检查文件大小, 输入行数据 row 中的 "size" 字段是否超过了 DOC_MAXIMUM_SIZE.
如果文件大小超过限制,调用set_progress函数更新进度状态为-1并附带错误信息。返回空列表 []
"""
callback = partial(
set_progress,
row["id"],
row["from_page"],
row["to_page"]) # 定义一个偏函数callback,用于更新进度状态
chunker = FACTORY[row["parser_id"].lower()]
# 根据row["parser_id"].lower() 选择对应的分块器chunker
try:
st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"])
# 获取存储桶和文件名
binary = get_storage_binary(bucket, name)
# 获取文件的二进制数据
cron_logger.info(
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
# 以下异常处理
except TimeoutError as e:
callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
cron_logger.error(
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
return
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
else:
callback(-1, f"Get file from minio: %s" %
str(e).replace("'", ""))
traceback.print_exc()
return
try:
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
to_page=row["to_page"], lang=row["language"], callback=callback,
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
# 对文件进行分块处理
cron_logger.info(
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except Exception as e: # 如果分块过程中出现异常,记录内部服务器错误并返回
callback(-1, f"Internal server error while chunking: %s" %
str(e).replace("'", ""))
cron_logger.error(
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
traceback.print_exc()
return
docs = []
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])]
}
el = 0
for ck in cks: # 遍历分块结果,生成文档元数据
d = copy.deepcopy(doc) # 复制基础文档数据 doc
d.update(ck) # 更新分块数据 ck 到文档数据 d 中
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() # 生成文档唯一标识符 _id
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
# 设置文档创建时间和时间戳
if not d.get("image"):
docs.append(d) # 如果文档不包含图片,直接添加到 docs 列表中
continue
try:
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
st = timer()
STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
# 将图片数据保存到存储系统
el += timer() - st
except Exception as e:
cron_logger.error(str(e))
traceback.print_exc()
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
# 记录图片ID
del d["image"] # 删除图片数据
docs.append(d) # 添加处理后的文档数据到 docs 列表中
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
return docs # 返回处理后的文档列表 docs
初始化知识库
def init_kb(row):
"""
在Elasticsearch中为特定租户初始化或创建一个配置好的索引
"""
idxnm = search.index_name(row["tenant_id"])
# 获取索引名称, 根据"tenant_id",调用search.index_name方法来生成一个索引名称
if ELASTICSEARCH.indexExist(idxnm):
return # 如果索引已存在,则函数直接返回
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
# 获取项目的基本目录
# 在基本目录下定位到conf/mapping.json文件
# 使用JSON内容作为参数,调用ELASTICSEARCH.createIdx创建新的索引。
文档嵌入
def embedding(docs, mdl, parser_config={}, callback=None):
"""
docs:一个列表,包含多个文档,每个文档是一个字典,至少包含title_tks和content_with_weight两个键。
mdl:一个模型对象,用于文本编码。
parser_config:解析配置字典,默认为空。
callback:回调函数,默认为None,用于报告进度
"""
batch_size = 32
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get # 预处理("title_tks")], [
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
# 清除标题中的空格
# 使用正则表达式清理HTML标签,替换为单个空格
# rmSpace,去除多余空格(utils/__init__.py)
# 正则匹配以下形式的HTML标签:
# 开始标签:<table>, <td>, <caption>, <tr>, <th>
# 结束标签:</table>, </td>, </caption>, </tr>, </th>
tk_count = 0
if len(tts) == len(cnts): # 分批编码标题
tts_ = np.array([])
for i in range(0, len(tts), batch_size): # 遍历标题列表,每32个标题一组
vts, c = mdl.encode(tts[i: i + batch_size])
# 编码
if len(tts_) == 0:
tts_ = vts
else:
tts_ = np.concatenate((tts_, vts), axis=0)
tk_count += c
callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
# 更新进度(从60%到70%)
tts = tts_
cnts_ = np.array([]) # 分批编码内容
for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i + batch_size])
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
# 更新进度(从70%到90%)
cnts = cnts_
title_w = float(parser_config.get("filename_embd_weight", 0.1))
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
# 根据parser_config中的filename_embd_weight设置标题向量的权重,默认为0.1。
# 如果标题和内容长度一致,则按权重合并向量;否则仅使用内容向量
assert len(vects) == len(docs) # 确保向量数量与输入docs长度一致
for i, d in enumerate(docs):
v = vects[i].tolist()
d["q_%d_vec" % len(v)] = v # 根据每个文档的向量长度动态地命名键
return tk_count
状态报告
def report_status():
global CONSUMEER_NAME
while True:
try:
obj = REDIS_CONN.get("TASKEXE")
if not obj: obj = {}
else: obj = json.loads(obj)
# 从Redis获取数据:
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
obj[CONSUMEER_NAME].append(timer())
# 向当前消费者的列表中追加当前时间戳
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
# 保留最近60个时间戳
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
# 将更新后的字典obj保存回Redis
except Exception as e:
print("[Exception]:", str(e))
time.sleep(30)
svr/jina_server.py
基于Jina框架的文本生成服务 jina 框架是一个用于构建和部署神经网络驱动的搜索应用的开源工具。 它允许开发者快速搭建可扩展的搜索解决方案,适用于多种类型的数据,如文本、图像、视频等。
Jina 的特点包括:
- 灵活性:支持多模态数据,即可以同时处理不同类型的数据(如文本+图像)。
- 易用性:提供简洁的API,使得即使是初学者也能快速上手构建复杂的搜索系统。
- 高效性:利用先进的神经网络技术进行数据索引和检索,以实现高性能的搜索体验。
- 分布式能力:能够无缝地从本地部署扩展到Docker、Kubernetes或者Jina AI Cloud,适合构建大规模分布式系统。
- 兼容性:与不同的深度学习框架兼容,允许开发者根据需求选择最适合的技术栈。
在实际应用中,Jina 可以被用来创建搜索引擎、推荐系统等需要高效检索大量数据的应用程序。它支持Python API,并且社区活跃,拥有丰富的文档和支持资源。
服务入口
if __name__ == "__main__":
parser = argparse.ArgumentParser() # argparse是Python内置库,用于解析命令行参数
parser.add_argument("--model_name", type=str, help="Model name or path")
parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
# 定义了两个命令行参数:
# --model_name: 类型为字符串,表示模型名称或路径。
# --port: 类型为整数,默认值为 12345,表示 Jina 服务监听的端口。
args = parser.parse_args() # 解析命令行参数并将结果存储在 args 对象中
model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# 加载指定模型的分词器
with Deployment(
uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
# 创建一个 Deployment 实例,指定:
# 使用 TokenStreamingExecutor 执行器。
# 监听的端口号,从命令行参数获取。
# 使用 gRPC 协议。
) as dep:
dep.block() # 确保服务一直运行,直到手动停止
执行器对象
class TokenStreamingExecutor(Executor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 加载预训练语言模型
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", # 自动分配模型到合适的设备(如CPU 或 GPU)
torch_dtype="auto" # 根据设备自动选择数据类型(如在GPU上使用FP16 等)
)
@requests(on="/chat") # 指定该函数处理/chat路径的请求
async def generate(self, doc: Prompt, **kwargs) -> Generation:
# 处理一次性生成请求,将输入文本编码并生成输出文本,返回解码后的结果
text = tokenizer.apply_chat_template(
doc.message,
tokenize=False,
) # 将doc.message应用聊天模板格式化,不进行分词处理
inputs = tokenizer([text], return_tensors="pt")
# 对text进行编码,返回PyTorch张量格式
generation_config = GenerationConfig(
**doc.gen_conf,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
) # 根据doc.gen_conf创建生成配置对象,并设置结束标记和填充标记
generated_ids = self.model.generate(
inputs.input_ids, generation_config=generation_config
) # 使用预训练模型生成文本,返回生成的标记序列
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids) # 移除生成结果中的输入部分,保留仅生成的新标记
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 解码生成的标记序列,跳过特殊标记
yield Generation(text=response) # 生成并返回解码后的文本结果
@requests(on="/stream") # 处理/stream路径的请求
async def task(self, doc: Prompt, **kwargs) -> Generation:
# 处理流式生成请求,逐个生成新标记,直到达到最大新标记数或遇到结束标记,每次返回当前新增的部分文本.
text = tokenizer.apply_chat_template(
doc.message,
tokenize=False,
)
input = tokenizer([text], return_tensors="pt")
input_len = input["input_ids"].shape[1]
max_new_tokens = 512 # 默认最大新标记数为 512
if "max_new_tokens" in doc.gen_conf:
max_new_tokens = doc.gen_conf.pop("max_new_tokens")
# 如果 doc.gen_conf 中有 max_new_tokens,则覆盖默认值
generation_config = GenerationConfig(
**doc.gen_conf,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
for _ in range(max_new_tokens): # 循环生成新标记
output = self.model.generate(
**input, max_new_tokens=1, generation_config=generation_config
)
if output[0][-1] == tokenizer.eos_token_id:
break # 如果生成的最后一个标记是结束标记,则跳出循环
yield Generation(
text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
)
input = { # 更新输入状态
"input_ids": output,
# input_ids 是模型生成新标记的基础。
# 每次生成新标记后,将当前的 output 作为新的 input_ids。
# 保证了模型在下一次生成时能够基于最新的生成结果继续生成
"attention_mask": torch.ones(1, len(output[0])),
# attention_mask 用于指示哪些位置的有效性。
# torch.ones(1, len(output[0])) 创建一个全 1 的张量,表示所有位置都是有效的。
# 这样确保了模型在生成过程中能够正确处理所有位置的信息
}
api
.
├── __init__.py
├── apps
├── contants.py
├── db
├── ragflow_server.py
├── settings.py
├── utils
└── versions.py
api/ragflow_server.py
启动api服务
if __name__ == '__main__':
print(r"""
____ ______ __
/ __ \ ____ _ ____ _ / ____// /____ _ __
/ /_/ // __ `// __ `// /_ / // __ \| | /| / /
/ _, _// /_/ // /_/ // __/ / // /_/ /| |/ |/ /
/_/ |_| \__,_/ \__, //_/ /_/ \____/ |__/|__/
/____/
""", flush=True) # flush=True 确保立即输出
stat_logger.info(
f'project base: {utils.file_utils.get_project_base_directory()}'
)
# 记录项目基础目录
# init db
init_web_db() # 初始化数据库表。
init_web_data() # 初始化数据库中的初始数据
# init runtime config
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--version', default=False, help="rag flow version", action='store_true')
parser.add_argument('--debug', default=False, help="debug mode", action='store_true')
# 两个命令行参数:--version 和 --debug
args = parser.parse_args()
if args.version: # --version 参数被设置,打印版本信息并退出程序
print(get_versions())
sys.exit(0)
RuntimeConfig.DEBUG = args.debug
if RuntimeConfig.DEBUG:
stat_logger.info("run on debug mode")
# 置运行配置中的 DEBUG 属性。如果处于调试模式,打印相关信息
RuntimeConfig.init_env() # 初始化运行环境
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
# 初始化运行配置,包括主机和端口号
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
# rag_arch.common.log.ROpenHandler
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
# 获取 peewee 日志记录器。
# 禁止日志传播到根日志记录器。
# 添加数据库日志处理器。
# 设置日志级别。
thr = ThreadPoolExecutor(max_workers=1)
thr.submit(update_progress)
# 创建一个单线程的线程池。
# 提交 update_progress 函数到线程池中执行
# start http server
try:
stat_logger.info("RAG Flow http server start...")
werkzeug_logger = logging.getLogger("werkzeug")
for h in access_logger.handlers:
werkzeug_logger.addHandler(h)
run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG)
# 用 run_simple 方法启动服务器, app为一个flask应用
except Exception:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGKILL)
api/settings.py
应用配置
# Logger
LoggerFactory.set_directory(
os.path.join(
get_project_base_directory(),
"logs",
"api"))
# 设置日志文件存储路径为项目的logs/api目录
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30
stat_logger = getLogger("stat")
access_logger = getLogger("access")
database_logger = getLogger("database")
chat_logger = getLogger("chat")
# 创建了多个日志记录器,分别用于统计(stat)、访问(access)、数据库(database)和聊天(chat)相关日志
# 加载项目基础配置,API版本、服务名称、服务器模块、临时目录路径等
API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py"
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
LIGHTEN = os.environ.get('LIGHTEN')
# 简化模式配置
SUBPROCESS_STD_LOG_NAME = "std.log"
ERROR_REPORT = True
ERROR_REPORT_WITH_PATH = False
MAX_TIMESTAMP_INTERVAL = 60
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
REQUEST_TRY_TIMES = 3
REQUEST_WAIT_SEC = 2
REQUEST_MAX_WAIT_SEC = 300
USE_REGISTRY = get_base_config("use_registry")
LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url")
if not LIGHTEN:
default_llm = {
"Tongyi-Qianwen": {
"chat_model": "qwen-plus",
"embedding_model": "text-embedding-v2",
"image2text_model": "qwen-vl-max",
"asr_model": "paraformer-realtime-8k-v1",
},
"OpenAI": {
"chat_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"Azure-OpenAI": {
"chat_model": "azure-gpt-35-turbo",
"embedding_model": "azure-text-embedding-ada-002",
"image2text_model": "azure-gpt-4-vision-preview",
"asr_model": "azure-whisper-1",
},
"ZHIPU-AI": {
"chat_model": "glm-3-turbo",
"embedding_model": "embedding-2",
"image2text_model": "glm-4v",
"asr_model": "",
},
"Ollama": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-embedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"VolcEngine": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"BAAI": {
"chat_model": "",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
"rerank_model": "BAAI/bge-reranker-v2-m3",
}
}
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
RERANK_MDL = default_llm["BAAI"]["rerank_model"] if not LIGHTEN else ""
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
# 根据配置文件中的LLM_FACTORY选择默认的大语言模型(LLM),并根据所选模型加载相应的子模型(如聊天模型、嵌入模型)
else:
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
# 简化模式下,不加载具体的模型,是延迟加载还是??
API_KEY = LLM.get("api_key", "") # 大模型API密钥
PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
# distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
# 依赖分布标志
RAG_FLOW_UPDATE_CHECK = False
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get(
"secret_key",
"infiniflow") # 秘钥
TOKEN_EXPIRE_IN = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"token_expires_in", 3600) # Token过期时间
NGINX_HOST = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"nginx", {}).get("host") or HOST
NGINX_HTTP_PORT = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"nginx", {}).get("http_port") or HTTP_PORT
# 配置了Nginx代理服务器的相关参数
RANDOM_INSTANCE_ID = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"random_instance_id", False)
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
# 代理设置
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# 根据环境变量确定数据库类型,并解密数据库配置信息
# Switch
# upload
UPLOAD_DATA_FROM_CLIENT = True
# 配置上传数据开关
# authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})
# client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
"client", {}).get(
"switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
# site
SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False)
# permission
PERMISSION_CONF = get_base_config("permission", {})
PERMISSION_SWITCH = PERMISSION_CONF.get("switch")
COMPONENT_PERMISSION = PERMISSION_CONF.get("component")
DATASET_PERMISSION = PERMISSION_CONF.get("dataset")
HOOK_MODULE = get_base_config("hook_module")
HOOK_SERVER_NAME = get_base_config("hook_server_name")
ENABLE_MODEL_STORE = get_base_config('enable_model_store', False)
# authentication
USE_AUTHENTICATION = False
USE_DATA_AUTHENTICATION = False
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
USE_DEFAULT_TIMEOUT = False
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
# 初始化Elasticsearch检索器retrievaler和基于知识图谱的检索器kg_retrievaler
class CustomEnum(Enum):
"""
定义了一个通用枚举类CustomEnum,提供了验证值有效性和获取成员值/名称的方法
"""
@classmethod
def valid(cls, value):
try:
cls(value)
return True
except BaseException:
return False
@classmethod
def values(cls):
return [member.value for member in cls.__members__.values()]
@classmethod
def names(cls):
return [member.name for member in cls.__members__.values()]
class PythonDependenceName(CustomEnum):
Rag_Source_Code = "python"
Python_Env = "miniconda"
class ModelStorage(CustomEnum):
REDIS = "redis"
MYSQL = "mysql"
# 基于CustomEnum定义了PythonDependenceName和ModelStorage两个具体枚举类
class RetCode(IntEnum, CustomEnum):
"""
定义了一个带状态码的枚举类RetCode,用于表示成功、无效、异常等多种操作结果状态
"""
SUCCESS = 0
NOT_EFFECTIVE = 10
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
CONNECTION_ERROR = 105
RUNNING = 106
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
UNAUTHORIZED = 401
SERVER_ERROR = 500
api/apps/init.py
初始化应用
__all__ = ['app']
logger = logging.getLogger('flask.app')
for h in access_logger.handlers:
logger.addHandler(h)
# 设置了名为flask.app的日志记录器,并将其与access_logger的日志处理器关联起来
Request.json = property(lambda self: self.get_json(force=True, silent=True))
# 扩展Request类,使其可以通过属性访问请求中的JSON数据,简化了获取JSON数据的过程
app = Flask(__name__) # 创建了一个Flask应用实例
CORS(app, supports_credentials=True,max_age=2592000)
# 启用了跨域资源共享(CORS),允许跨域请求
app.url_map.strict_slashes = False
# 设置了URL严格斜杠模式为False,即URL末尾是否带斜杠不影响路由匹配。
app.json_encoder = CustomJSONEncoder
# 设置JSON编码器为自定义的CustomJSONEncoder。
app.errorhandler(Exception)(server_error_response)
# 注册全局异常处理器server_error_response,用于处理未捕获的异常。
## convince for dev and debug
#app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False # 会话的持久化选项
app.config["SESSION_TYPE"] = "filesystem"
# 配置了Flask-Session,将会话存储在文件系统中
app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
# 设置最大内容长度限制,默认为128MB
Session(app)
login_manager = LoginManager()
login_manager.init_app(app)
# 使用Flask-Login进行用户认证管理,并初始化了登录管理器
commands.register_commands(app)
# 注册了一些命令到Flask应用中,通常用于开发和调试
def search_pages_path(pages_dir):
"""
搜索指定目录下的应用页面和API文件路径
"""
app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
app_path_list.extend(api_path_list)
return app_path_list
def register_page(page_path):
"""
定义了一个函数register_page,用于动态注册页面到Flask应用中。
将每个页面注册为一个Blueprint,并设置相应的URL前缀。
动态注册所有找到的页面路径
"""
path = f'{page_path}'
# page_path 是一个 Path 对象,表示要注册的页面或API模块的文件路径
# 将 page_path 转换为字符串形式
page_name = page_path.stem.rstrip('_app')
# 获取文件名(不包括后缀),并移除后缀 _app(如果存在)
module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,))
# 构造模块名称,从 page_path 中提取路径部分,并拼接成符合 Python 模块命名的字符串
spec = spec_from_file_location(module_name, page_path)
# 根据模块名称和文件路径创建模块规格(ModuleSpec)
page = module_from_spec(spec)
# 根据模块规格创建模块对象
page.app = app
# 将当前的 Flask 应用实例赋值给模块对象的 app 属性
page.manager = Blueprint(page_name, module_name)
# 创建一个 Blueprint 实例,并将其赋值给模块对象的 manager 属性
sys.modules[module_name] = page
# 将模块对象添加到 sys.modules 中,以便后续可以引用该模块
spec.loader.exec_module(page)
# 执行模块文件中的代码,使模块中的定义生效
page_name = getattr(page, 'page_name', page_name)
# 如果模块中有 page_name 属性,则使用该属性值;否则使用之前计算的 page_name。
url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
# 根据文件路径是否包含 /sdk/ 来决定 URL 前缀
app.register_blueprint(page.manager, url_prefix=url_prefix)
# 将 Blueprint 实例注册到 Flask 应用中,并设置 URL 前缀
return url_prefix # 返回最终的 URL 前缀
pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / 'api' / 'apps',
Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
]
client_urls_prefix = [
register_page(path)
for dir in pages_dir
for path in search_pages_path(dir)
]
# 对每个路径调用 register_page 函数,注册页面或API模块,并收集其 URL 前缀
@login_manager.request_loader
# 装饰器,用于注册一个函数作为请求加载器(request loader)。
# 这个函数会在每次请求到来时被调用,用来从请求中加载用户信息
def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY)
# 根据SECRET_KEY,构造序列化对象
authorization = web_request.headers.get("Authorization")
# web_request.headers 是一个字典,包含请求头中的所有键值对。
# get("Authorization") 方法从请求头中获取 Authorization 字段的值。
if authorization:
try:
access_token = str(jwt.loads(authorization))
# 解析JWT令牌
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
# UserService 是一个服务类,用于查询用户信息。
# query 方法接受 access_token 和 status 参数,查询数据库中符合条件的用户。
# StatusEnum.VALID.value 表示有效的用户状态
if user:
return user[0]
else:
return None
except Exception as e:
stat_logger.exception(e)
return None
else:
return None
@app.teardown_request
def _db_close(exc): # 在请求完成后,关闭数据库连接
close_connection()
api/apps/api_app.py
处理API请求
生成令牌
def generate_confirmation_token(tenent_id):
# 拼写错误:tenent_id=> tenant_id
serializer = URLSafeTimedSerializer(tenent_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
@manager.route('/new_token', methods=['POST'])
@login_required
def new_token():
req = request.json
try:
tenants = UserTenantService.query(user_id=current_user.id)
# 查询租户信息
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")
tenant_id = tenants[0].tenant_id
obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id),
"create_time": current_timestamp(),
"create_date": datetime_format(datetime.now()),
"update_time": None,
"update_date": None
}
# 生成令牌信息
if req.get("canvas_id"):
obj["dialog_id"] = req["canvas_id"]
obj["source"] = "agent"
else:
obj["dialog_id"] = req["dialog_id"]
# 设置对话ID和来源
if not APITokenService.save(**obj): # 保存令牌信息
return get_data_error_result(retmsg="Fail to new a dialog!")
return get_json_result(data=obj) # 返回生成的令牌信息
except Exception as e:
return server_error_response(e)
查询令牌对象列表
@manager.route('/token_list', methods=['GET'])
@login_required
def token_list():
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")
id = request.args["dialog_id"] if "dialog_id" in request.args else request.args["canvas_id"]
objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=id)
# 查询 API 令牌对象列表
return get_json_result(data=[o.to_dict() for o in objs])
except Exception as e:
return server_error_response(e)
删除令牌
@manager.route('/rm', methods=['POST'])
@validate_request("tokens", "tenant_id")
@login_required
def rm():
req = request.json
try:
for token in req["tokens"]:
APITokenService.filter_delete(
[APIToken.tenant_id == req["tenant_id"], APIToken.token == token])
# 遍历令牌列表:对于每个令牌,执行删除操作
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
查询会话统计数据
@manager.route('/stats', methods=['GET'])
@login_required
def stats():
try:
# 查询租户信息
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")
# 获取日期参数
from_date = request.args.get(
"from_date",
(datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d 00:00:00"))
to_date = request.args.get(
"to_date",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# 获取其他参数
agent = "agent" if "canvas_id" in request.args else None
# 查询统计数据
objs = API4ConversationService.stats(
tenants[0].tenant_id,
from_date,
to_date,
agent)
# 构造结果
res = {
"pv": [(o["dt"], o["pv"]) for o in objs],
"uv": [(o["dt"], o["uv"]) for o in objs],
"speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
"tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
"round": [(o["dt"], o["round"]) for o in objs],
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
}
# 返回 JSON 结果
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)
创建会话
@manager.route('/new_conversation', methods=['GET'])
def set_conversation():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token) # 验证Token
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if objs[0].source == "agent": # 处理 Agent 源
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
# 如果 cvs.dsl 不是字符串,则将其转换为字符串
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
# 构造 Canvas 对象
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
# 构造对话对象conv,包含 ID、对话 ID、用户 ID、消息和源。
API4ConversationService.save(**conv)
# 保存对话对对象
return get_json_result(data=conv)
else:
e, dia = DialogService.get_by_id(objs[0].dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found")
conv = {
"id": get_uuid(),
"dialog_id": dia.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
消息的完整性
@manager.route('/completion', methods=['POST'])
@validate_request("conversation_id", "messages")
def completion():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
if "quote" not in req: req["quote"] = False
# 如果请求中没有 quote 参数,则默认设置为 False
msg = []
for m in req["messages"]:
if m["role"] == "system": # 过滤掉 role 为 "system" 的消息
continue
if m["role"] == "assistant" and not msg:
continue
# 如果第一个消息的角色为 "assistant" 且消息列表为空,则忽略这条消息
msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
# 如果最后一个消息没有 ID,则为其生成一个 UUID,并保存该 ID。
def fillin_conv(ans): # 更新对话数据
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
def rename_field(ans):
reference = ans['reference']
if not isinstance(reference, dict):
return
for chunk_i in reference.get('chunks', []):
if 'docnm_kwd' in chunk_i:
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
try:
if conv.source == "agent":
stream = req.get("stream", True)
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
del req["conversation_id"]
del req["messages"]
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "content": ""}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=stream)
assert answer is not None, "Nothing. Is it over?"
if stream:
assert isinstance(answer, partial), "Nothing. Is it over?"
def sse():
nonlocal answer, cvs, conv
try:
for ans in answer():
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result)
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_json_result(data=result)
#******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["conversation_id"]
del req["messages"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
break
rename_field(answer)
return get_json_result(data=answer)
except Exception as e:
return server_error_response(e)
sdk
web
deepdoc
graphrag
参考url
https://ragflow.io/docs/dev/faq https://infiniflow.cn/docs/v0.7.0/ https://infiniflow.cn/docs/dev/ https://github.com/infiniflow/ragflow