搬山行者

无志愁压头,有志能搬山

业余程序员的学习笔记~


ragflow 代码解析

目录

version = “0.11.0”

主目录

.
├── Dockerfile
├── Dockerfile.arm
├── Dockerfile.cuda
├── Dockerfile.scratch
├── Dockerfile.scratch.oc9
├── LICENSE
├── README.md
├── README_ja.md
├── README_ko.md
├── README_zh.md
├── SECURITY.md
├── agent
├── api
├── conf
├── deepdoc
├── docker
├── docs
├── download_deps.sh
├── graphrag
├── poetry.lock
├── poetry.toml
├── printEnvironment.sh
├── pyproject.toml
├── rag
├── sdk
├── ubuntu.sources
└── web

rag

.
├── __init__.py
├── app
├── benchmark.py
├── llm
├── nlp
├── raptor.py
├── res
├── settings.py
├── svr
└── utils

svr/task_executor.py

任务执行器

服务入口

if __name__ == "__main__":
    # ....  准备日志
    exe = ThreadPoolExecutor(max_workers=1)
    exe.submit(report_status)
    # 创建最大工作线程数为1的线程池,并提交report_status的任务异步执行。

    while True: # 循环调用main()
    main() 
    if PAYLOAD: # 从redis消息队列中获取的消息数据
        PAYLOAD.ack()
        PAYLOAD = None

主处理函数

def main():
    rows = collect() # 从redis队列中获取任务
    if len(rows) == 0:
        return

    for _, r in rows.iterrows():
        callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
        # 设置进度回调函数
        try:
            embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) # 获取embedding model
            # 创建嵌入模型实例
        except Exception as e:
            callback(-1, msg=str(e))
            cron_logger.error(str(e))
            continue

        if r.get("task_type", "") == "raptor":  # raptor类型任务处理
            try:
                chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
                # 获取对话model
                cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
            except Exception as e:
                callback(-1, msg=str(e))
                cron_logger.error(str(e))
                continue
        else: # 其它类型任务处理
            st = timer()
            cks = build(r)
            cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
            if cks is None:
                continue
            if not cks:
                callback(1., "No chunk! Done!")
                continue
            # TODO: exception handler
            ## set_progress(r["did"], -1, "ERROR: ")
            callback(
                msg="Finished slicing files(%d). Start to embedding the content." %
                    len(cks))
            st = timer()
            try:
                tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
            except Exception as e:
                callback(-1, "Embedding error:{}".format(str(e)))
                cron_logger.error(str(e))
                tk_count = 0
            cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
            callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))

        init_kb(r) # 初始化知识库
        chunk_count = len(set([c["_id"] for c in cks])) 
        # 统计文档片段数量 
        # 统计所有文档片段的唯一 _id 数量,以确保每个片段都是唯一的

        st = timer()
        es_r = ""
        es_bulk_size = 4
        for b in range(0, len(cks), es_bulk_size):
            es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
            if b % 128 == 0:
                callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
        # 使用批量操作将文档片段插入 Elasticsearch,提高插入效率。
        # 每次批量插入的数量为 es_bulk_size(默认为 4),可以根据实际情况调整。
        # 每处理完 128 条记录,更新进度条,使用户可以了解当前处理进度        

        cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
        if es_r: # 如果有错误发生,删除已插入的数据,并记录错误信息
            callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
            ELASTICSEARCH.deleteByQuery(
                Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
            cron_logger.error(str(es_r))
        else:
            if TaskService.do_cancel(r["id"]): # 如果任务被取消,删除已插入的数据
                ELASTICSEARCH.deleteByQuery(
                    Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
                continue
            callback(1., "Done!")
            # 如果成功完成,更新文档片段的数量,并记录相关信息
            DocumentService.increment_chunk_num(
                r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
            cron_logger.info(
                "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
                    r["id"], tk_count, len(cks), timer() - st))

raptor方法处理文档

def run_raptor(row, chat_mdl, embd_mdl, callback=None):
    vts, _ = embd_mdl.encode(["ok"])
    vctr_nm = "q_%d_vec"%len(vts[0]) # 通过编码一个字符串“ok”来获取向量维度大小,并根据维度大小生成向量名称
    chunks = []
    for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
        chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
    # 从给定的文档ID和租户ID中提取文档片段,每个片段包括内容及其权重和对应的向量表示。    

    raptor = Raptor(
        row["parser_config"]["raptor"].get("max_cluster", 64),
        chat_mdl,
        embd_mdl,
        row["parser_config"]["raptor"]["prompt"],
        row["parser_config"]["raptor"]["max_token"],
        row["parser_config"]["raptor"]["threshold"]
    )
    # 根据配置参数初始化Raptor对象,包括最大聚类数量、聊天模型、嵌入模型等
    original_length = len(chunks)
    raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
    # 将文档片段传递给Raptor对象进行处理,并指定随机种子值
    doc = {
        "doc_id": row["doc_id"],
        "kb_id": [str(row["kb_id"])],
        "docnm_kwd": row["name"],
        "title_tks": rag_tokenizer.tokenize(row["name"])
    }
    res = []
    tk_count = 0
    for content, vctr in chunks[original_length:]:
        d = copy.deepcopy(doc)
        md5 = hashlib.md5()
        md5.update((content + str(d["doc_id"])).encode("utf-8"))
        d["_id"] = md5.hexdigest()
        d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
        d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
        d[vctr_nm] = vctr.tolist()
        d["content_with_weight"] = content
        d["content_ltks"] = rag_tokenizer.tokenize(content)
        d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
        res.append(d)
        tk_count += num_tokens_from_string(content)
    # 遍历处理后的文档片段,为每个片段生成新的文档元数据,包括唯一ID、创建时间戳、向量表示、内容等。    
    return res, tk_count # 返回文档列表和总token数量

更新任务进度

def set_progress(task_id, from_page=0, to_page=-1,
                 prog=None, msg="Processing..."):
    global PAYLOAD
    if prog is not None and prog < 0:
    # 根据进度值是否为负数决定是否添加错误标签    
        msg = "[ERROR]" + msg

    cancel = TaskService.do_cancel(task_id)
    # 检查任务是否被取消,并在取消时更新消息和进度值
    if cancel:
        msg += " [Canceled]"
        prog = -1

    if to_page > 0:
        if msg:
            msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
    d = {"progress_msg": msg}
    # 格式化页面范围信息并构建进度消息
    if prog is not None:
        d["progress"] = prog
    try:
        TaskService.update_progress(task_id, d)
        # 更新任务进度信息
    except Exception as e: # 异常处理:记录日志
        cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))

    close_connection() # 断开连接
    if cancel: # 如果任务取消,清除负载并退出
        if PAYLOAD:
            PAYLOAD.ack()
            PAYLOAD = None
        os._exit(0) # ??此处退出是整个服务退出了吗

从Redis队列获取任务事件

def collect():
    global CONSUMEER_NAME, PAYLOAD
    try:
        PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
        # 通过REDIS_CONN获取未确认的消息PAYLOAD
        if not PAYLOAD: # 如果PAYLOAD为空,则尝试消费队列中的消息
            PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
        if not PAYLOAD: # 若仍无消息,则休眠1秒并返回空DataFrame
            time.sleep(1)
            return pd.DataFrame()
    except Exception as e:
        cron_logger.error("Get task event from queue exception:" + str(e))
        return pd.DataFrame()

    msg = PAYLOAD.get_message() 
    if not msg: # 获取消息msg,若为空则返回空DataFrame
        return pd.DataFrame()

    if TaskService.do_cancel(msg["id"]):
    # 检查任务是否被取消,若是则记录日志并返回空DataFrame
        cron_logger.info("Task {} has been canceled.".format(msg["id"]))
        return pd.DataFrame()
    tasks = TaskService.get_tasks(msg["id"])
    if not tasks: # 获取任务详情,若为空则返回空列表
        cron_logger.warn("{} empty task!".format(msg["id"]))
        return []

    tasks = pd.DataFrame(tasks) # 将任务转换为DataFrame
    if msg.get("type", "") == "raptor":
        tasks["task_type"] = "raptor" # 根据消息类型设置任务类型
    return tasks # 返回处理后的任务数据

构建文档数据

def get_storage_binary(bucket, name):
    return STORAGE_IMPL.get(bucket, name)

def build(row):
    if row["size"] > DOC_MAXIMUM_SIZE:
        set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
                                             (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
        return []
    """
    检查文件大小, 输入行数据 row 中的 "size" 字段是否超过了 DOC_MAXIMUM_SIZE.    
    如果文件大小超过限制,调用set_progress函数更新进度状态为-1并附带错误信息。返回空列表 []
    """    

    callback = partial(
        set_progress,
        row["id"],
        row["from_page"],
        row["to_page"]) # 定义一个偏函数callback,用于更新进度状态
    chunker = FACTORY[row["parser_id"].lower()]
    # 根据row["parser_id"].lower() 选择对应的分块器chunker
    try:
        st = timer()
        bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"])
        # 获取存储桶和文件名

        binary = get_storage_binary(bucket, name)
        # 获取文件的二进制数据
        cron_logger.info(
            "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))

    # 以下异常处理        
    except TimeoutError as e:
        callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
        cron_logger.error(
            "Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
        return
    except Exception as e:
        if re.search("(No such file|not found)", str(e)):
            callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
        else:
            callback(-1, f"Get file from minio: %s" %
                     str(e).replace("'", ""))
        traceback.print_exc()
        return

    try:
        cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
                            to_page=row["to_page"], lang=row["language"], callback=callback,
                            kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
        # 对文件进行分块处理                    
        cron_logger.info(
            "Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
    except Exception as e: # 如果分块过程中出现异常,记录内部服务器错误并返回
        callback(-1, f"Internal server error while chunking: %s" %
                     str(e).replace("'", ""))
        cron_logger.error(
            "Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
        traceback.print_exc()
        return

    docs = []
    doc = {
        "doc_id": row["doc_id"],
        "kb_id": [str(row["kb_id"])]
    }
    el = 0
    for ck in cks: # 遍历分块结果,生成文档元数据
        d = copy.deepcopy(doc) # 复制基础文档数据 doc
        d.update(ck) # 更新分块数据 ck 到文档数据 d 中
        md5 = hashlib.md5()
        md5.update((ck["content_with_weight"] +
                    str(d["doc_id"])).encode("utf-8"))
        d["_id"] = md5.hexdigest() # 生成文档唯一标识符 _id
        d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
        d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
        # 设置文档创建时间和时间戳
        if not d.get("image"):
            docs.append(d) # 如果文档不包含图片,直接添加到 docs 列表中
            continue

        try:
            output_buffer = BytesIO()
            if isinstance(d["image"], bytes):
                output_buffer = BytesIO(d["image"])
            else:
                d["image"].save(output_buffer, format='JPEG')
                

            st = timer()
            STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
            # 将图片数据保存到存储系统
            el += timer() - st
        except Exception as e:
            cron_logger.error(str(e))
            traceback.print_exc()

        d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
        # 记录图片ID
        del d["image"] # 删除图片数据
        docs.append(d) # 添加处理后的文档数据到 docs 列表中
    cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))

    return docs # 返回处理后的文档列表 docs

初始化知识库

def init_kb(row):
    """
    在Elasticsearch中为特定租户初始化或创建一个配置好的索引
    """
    idxnm = search.index_name(row["tenant_id"])
    # 获取索引名称, 根据"tenant_id",调用search.index_name方法来生成一个索引名称

    if ELASTICSEARCH.indexExist(idxnm):
        return # 如果索引已存在,则函数直接返回
    return ELASTICSEARCH.createIdx(idxnm, json.load(
        open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
        # 获取项目的基本目录
        # 在基本目录下定位到conf/mapping.json文件
        # 使用JSON内容作为参数,调用ELASTICSEARCH.createIdx创建新的索引。

文档嵌入

def embedding(docs, mdl, parser_config={}, callback=None):
    """
    docs:一个列表,包含多个文档,每个文档是一个字典,至少包含title_tks和content_with_weight两个键。
    mdl:一个模型对象,用于文本编码。
    parser_config:解析配置字典,默认为空。
    callback:回调函数,默认为None,用于报告进度
    """
    batch_size = 32
    tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get # 预处理("title_tks")], [
        re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
    # 清除标题中的空格    
    # 使用正则表达式清理HTML标签,替换为单个空格    
    # rmSpace,去除多余空格(utils/__init__.py)    
    # 正则匹配以下形式的HTML标签:
    # 开始标签:<table>, <td>, <caption>, <tr>, <th>
    # 结束标签:</table>, </td>, </caption>, </tr>, </th>    

    tk_count = 0
    if len(tts) == len(cnts): # 分批编码标题
        tts_ = np.array([])
        for i in range(0, len(tts), batch_size): # 遍历标题列表,每32个标题一组
            vts, c = mdl.encode(tts[i: i + batch_size])
            # 编码
            if len(tts_) == 0:
                tts_ = vts
            else:
                tts_ = np.concatenate((tts_, vts), axis=0)
            tk_count += c
            callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
            # 更新进度(从60%到70%)
        tts = tts_

    cnts_ = np.array([]) # 分批编码内容
    for i in range(0, len(cnts), batch_size):
        vts, c = mdl.encode(cnts[i: i + batch_size])
        if len(cnts_) == 0:
            cnts_ = vts
        else:
            cnts_ = np.concatenate((cnts_, vts), axis=0)
        tk_count += c
        callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
        # 更新进度(从70%到90%)
    cnts = cnts_

    title_w = float(parser_config.get("filename_embd_weight", 0.1))
    vects = (title_w * tts + (1 - title_w) *
             cnts) if len(tts) == len(cnts) else cnts
    # 根据parser_config中的filename_embd_weight设置标题向量的权重,默认为0.1。
    # 如果标题和内容长度一致,则按权重合并向量;否则仅使用内容向量         

    assert len(vects) == len(docs) # 确保向量数量与输入docs长度一致
    for i, d in enumerate(docs): 
        v = vects[i].tolist()
        d["q_%d_vec" % len(v)] = v # 根据每个文档的向量长度动态地命名键
    return tk_count

状态报告

def report_status():
    global CONSUMEER_NAME
    while True:
        try:
            obj = REDIS_CONN.get("TASKEXE")
            if not obj: obj = {}
            else: obj = json.loads(obj)
            # 从Redis获取数据:
            if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
            obj[CONSUMEER_NAME].append(timer())
            # 向当前消费者的列表中追加当前时间戳
            obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
            # 保留最近60个时间戳
            REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
            # 将更新后的字典obj保存回Redis
        except Exception as e:
            print("[Exception]:", str(e))
        time.sleep(30)

svr/jina_server.py

基于Jina框架的文本生成服务 jina 框架是一个用于构建和部署神经网络驱动的搜索应用的开源工具。 它允许开发者快速搭建可扩展的搜索解决方案,适用于多种类型的数据,如文本、图像、视频等。

Jina 的特点包括:

  • 灵活性:支持多模态数据,即可以同时处理不同类型的数据(如文本+图像)。
  • 易用性:提供简洁的API,使得即使是初学者也能快速上手构建复杂的搜索系统。
  • 高效性:利用先进的神经网络技术进行数据索引和检索,以实现高性能的搜索体验。
  • 分布式能力:能够无缝地从本地部署扩展到Docker、Kubernetes或者Jina AI Cloud,适合构建大规模分布式系统。
  • 兼容性:与不同的深度学习框架兼容,允许开发者根据需求选择最适合的技术栈。

在实际应用中,Jina 可以被用来创建搜索引擎、推荐系统等需要高效检索大量数据的应用程序。它支持Python API,并且社区活跃,拥有丰富的文档和支持资源。

服务入口

if __name__ == "__main__":
    parser = argparse.ArgumentParser() # argparse是Python内置库,用于解析命令行参数
    parser.add_argument("--model_name", type=str, help="Model name or path")
    parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
    # 定义了两个命令行参数:
    # --model_name: 类型为字符串,表示模型名称或路径。
    # --port: 类型为整数,默认值为 12345,表示 Jina 服务监听的端口。
    args = parser.parse_args() # 解析命令行参数并将结果存储在 args 对象中
    model_name = args.model_name
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    # 加载指定模型的分词器
    with Deployment(
        uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
        # 创建一个 Deployment 实例,指定:
        # 使用 TokenStreamingExecutor 执行器。
        # 监听的端口号,从命令行参数获取。
        # 使用 gRPC 协议。
    ) as dep:
        dep.block() # 确保服务一直运行,直到手动停止  
   

执行器对象

class TokenStreamingExecutor(Executor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # 加载预训练语言模型
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", # 自动分配模型到合适的设备(如CPU 或 GPU)
            torch_dtype="auto" # 根据设备自动选择数据类型(如在GPU上使用FP16 等)
            
        ) 

    @requests(on="/chat") # 指定该函数处理/chat路径的请求
    async def generate(self, doc: Prompt, **kwargs) -> Generation:
        # 处理一次性生成请求,将输入文本编码并生成输出文本,返回解码后的结果
        text = tokenizer.apply_chat_template(
            doc.message,
            tokenize=False,
        ) # 将doc.message应用聊天模板格式化,不进行分词处理
        inputs = tokenizer([text], return_tensors="pt")
        # 对text进行编码,返回PyTorch张量格式
        generation_config = GenerationConfig(
            **doc.gen_conf,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        ) # 根据doc.gen_conf创建生成配置对象,并设置结束标记和填充标记
        generated_ids = self.model.generate(
            inputs.input_ids, generation_config=generation_config
        ) # 使用预训练模型生成文本,返回生成的标记序列
        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(inputs.input_ids, generated_ids) # 移除生成结果中的输入部分,保留仅生成的新标记
        ]

        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        # 解码生成的标记序列,跳过特殊标记
        yield Generation(text=response) # 生成并返回解码后的文本结果

    @requests(on="/stream") # 处理/stream路径的请求
    async def task(self, doc: Prompt, **kwargs) -> Generation:
        # 处理流式生成请求,逐个生成新标记,直到达到最大新标记数或遇到结束标记,每次返回当前新增的部分文本.
        text = tokenizer.apply_chat_template(
            doc.message,
            tokenize=False,
        )
        input = tokenizer([text], return_tensors="pt")
        input_len = input["input_ids"].shape[1]
        max_new_tokens = 512 # 默认最大新标记数为 512
        if "max_new_tokens" in doc.gen_conf:
            max_new_tokens = doc.gen_conf.pop("max_new_tokens")
        # 如果 doc.gen_conf 中有 max_new_tokens,则覆盖默认值    
        generation_config = GenerationConfig(
            **doc.gen_conf,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )
        for _ in range(max_new_tokens): # 循环生成新标记
            output = self.model.generate(
                **input, max_new_tokens=1, generation_config=generation_config
            )
            if output[0][-1] == tokenizer.eos_token_id:
                break # 如果生成的最后一个标记是结束标记,则跳出循环
            yield Generation(
                text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
            )
            input = { # 更新输入状态
                "input_ids": output,
            # input_ids 是模型生成新标记的基础。
            # 每次生成新标记后,将当前的 output 作为新的 input_ids。
            # 保证了模型在下一次生成时能够基于最新的生成结果继续生成
                "attention_mask": torch.ones(1, len(output[0])),
            # attention_mask 用于指示哪些位置的有效性。
            # torch.ones(1, len(output[0])) 创建一个全 1 的张量,表示所有位置都是有效的。
            # 这样确保了模型在生成过程中能够正确处理所有位置的信息    
            }

api

.
├── __init__.py
├── apps
├── contants.py
├── db
├── ragflow_server.py
├── settings.py
├── utils
└── versions.py

api/ragflow_server.py

启动api服务

if __name__ == '__main__':
    print(r"""
    ____                 ______ __               
   / __ \ ____ _ ____ _ / ____// /____  _      __
  / /_/ // __ `// __ `// /_   / // __ \| | /| / /
 / _, _// /_/ // /_/ // __/  / // /_/ /| |/ |/ / 
/_/ |_| \__,_/ \__, //_/    /_/ \____/ |__/|__/  
              /____/                             

    """, flush=True) # flush=True 确保立即输出
    stat_logger.info(
        f'project base: {utils.file_utils.get_project_base_directory()}'
    )
    # 记录项目基础目录

    # init db
    init_web_db() # 初始化数据库表。
    init_web_data() # 初始化数据库中的初始数据
    # init runtime config
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--version', default=False, help="rag flow version", action='store_true')
    parser.add_argument('--debug', default=False, help="debug mode", action='store_true')
    # 两个命令行参数:--version 和 --debug
    args = parser.parse_args()
    if args.version: #  --version 参数被设置,打印版本信息并退出程序
        print(get_versions())
        sys.exit(0)

    RuntimeConfig.DEBUG = args.debug
    if RuntimeConfig.DEBUG:
        stat_logger.info("run on debug mode")
    # 置运行配置中的 DEBUG 属性。如果处于调试模式,打印相关信息    

    RuntimeConfig.init_env() # 初始化运行环境
    RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
    # 初始化运行配置,包括主机和端口号

    peewee_logger = logging.getLogger('peewee')
    peewee_logger.propagate = False
    # rag_arch.common.log.ROpenHandler
    peewee_logger.addHandler(database_logger.handlers[0])
    peewee_logger.setLevel(database_logger.level)
    # 获取 peewee 日志记录器。
    # 禁止日志传播到根日志记录器。
    # 添加数据库日志处理器。
    # 设置日志级别。

    thr = ThreadPoolExecutor(max_workers=1)
    thr.submit(update_progress)
    # 创建一个单线程的线程池。
    # 提交 update_progress 函数到线程池中执行

    # start http server
    try:
        stat_logger.info("RAG Flow http server start...")
        werkzeug_logger = logging.getLogger("werkzeug")
        for h in access_logger.handlers:
            werkzeug_logger.addHandler(h)
        run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG)
        # 用 run_simple 方法启动服务器, app为一个flask应用
    except Exception:
        traceback.print_exc()
        os.kill(os.getpid(), signal.SIGKILL)

api/settings.py

应用配置

# Logger
LoggerFactory.set_directory(
    os.path.join(
        get_project_base_directory(),
        "logs",
        "api"))
# 设置日志文件存储路径为项目的logs/api目录        
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30

stat_logger = getLogger("stat")
access_logger = getLogger("access")
database_logger = getLogger("database")
chat_logger = getLogger("chat")
# 创建了多个日志记录器,分别用于统计(stat)、访问(access)、数据库(database)和聊天(chat)相关日志

# 加载项目基础配置,API版本、服务名称、服务器模块、临时目录路径等
API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py"
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
LIGHTEN = os.environ.get('LIGHTEN')
# 简化模式配置

SUBPROCESS_STD_LOG_NAME = "std.log"

ERROR_REPORT = True
ERROR_REPORT_WITH_PATH = False

MAX_TIMESTAMP_INTERVAL = 60
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60

REQUEST_TRY_TIMES = 3
REQUEST_WAIT_SEC = 2
REQUEST_MAX_WAIT_SEC = 300

USE_REGISTRY = get_base_config("use_registry")

LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url")

if not LIGHTEN:
    default_llm = {
        "Tongyi-Qianwen": {
            "chat_model": "qwen-plus",
            "embedding_model": "text-embedding-v2",
            "image2text_model": "qwen-vl-max",
            "asr_model": "paraformer-realtime-8k-v1",
        },
        "OpenAI": {
            "chat_model": "gpt-3.5-turbo",
            "embedding_model": "text-embedding-ada-002",
            "image2text_model": "gpt-4-vision-preview",
            "asr_model": "whisper-1",
        },
        "Azure-OpenAI": {
            "chat_model": "azure-gpt-35-turbo",
            "embedding_model": "azure-text-embedding-ada-002",
            "image2text_model": "azure-gpt-4-vision-preview",
            "asr_model": "azure-whisper-1",
        },
        "ZHIPU-AI": {
            "chat_model": "glm-3-turbo",
            "embedding_model": "embedding-2",
            "image2text_model": "glm-4v",
            "asr_model": "",
        },
        "Ollama": {
            "chat_model": "qwen-14B-chat",
            "embedding_model": "flag-embedding",
            "image2text_model": "",
            "asr_model": "",
        },
        "Moonshot": {
            "chat_model": "moonshot-v1-8k",
            "embedding_model": "",
            "image2text_model": "",
            "asr_model": "",
        },
        "DeepSeek": {
            "chat_model": "deepseek-chat",
            "embedding_model": "",
            "image2text_model": "",
            "asr_model": "",
        },
        "VolcEngine": {
            "chat_model": "",
            "embedding_model": "",
            "image2text_model": "",
            "asr_model": "",
        },
        "BAAI": {
            "chat_model": "",
            "embedding_model": "BAAI/bge-large-zh-v1.5",
            "image2text_model": "",
            "asr_model": "",
            "rerank_model": "BAAI/bge-reranker-v2-m3",
        }
    }

    CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
    EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
    RERANK_MDL = default_llm["BAAI"]["rerank_model"] if not LIGHTEN else ""
    ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
    IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
    # 根据配置文件中的LLM_FACTORY选择默认的大语言模型(LLM),并根据所选模型加载相应的子模型(如聊天模型、嵌入模型)
else:
    CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
    # 简化模式下,不加载具体的模型,是延迟加载还是??

API_KEY = LLM.get("api_key", "") # 大模型API密钥
PARSERS = LLM.get(
    "parsers",
    "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")

# distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
# 依赖分布标志
RAG_FLOW_UPDATE_CHECK = False

HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")

SECRET_KEY = get_base_config(
    RAG_FLOW_SERVICE_NAME,
    {}).get(
        "secret_key",
    "infiniflow") # 秘钥
TOKEN_EXPIRE_IN = get_base_config(
    RAG_FLOW_SERVICE_NAME, {}).get(
        "token_expires_in", 3600) # Token过期时间

NGINX_HOST = get_base_config(
    RAG_FLOW_SERVICE_NAME, {}).get(
        "nginx", {}).get("host") or HOST
NGINX_HTTP_PORT = get_base_config(
    RAG_FLOW_SERVICE_NAME, {}).get(
        "nginx", {}).get("http_port") or HTTP_PORT
# 配置了Nginx代理服务器的相关参数

RANDOM_INSTANCE_ID = get_base_config(
    RAG_FLOW_SERVICE_NAME, {}).get(
        "random_instance_id", False)

PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
# 代理设置

DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# 根据环境变量确定数据库类型,并解密数据库配置信息

# Switch
# upload
UPLOAD_DATA_FROM_CLIENT = True
# 配置上传数据开关

# authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})

# client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
    "client", {}).get(
        "switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")

# site
SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False)

# permission
PERMISSION_CONF = get_base_config("permission", {})
PERMISSION_SWITCH = PERMISSION_CONF.get("switch")
COMPONENT_PERMISSION = PERMISSION_CONF.get("component")
DATASET_PERMISSION = PERMISSION_CONF.get("dataset")

HOOK_MODULE = get_base_config("hook_module")
HOOK_SERVER_NAME = get_base_config("hook_server_name")

ENABLE_MODEL_STORE = get_base_config('enable_model_store', False)
# authentication
USE_AUTHENTICATION = False
USE_DATA_AUTHENTICATION = False
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
USE_DEFAULT_TIMEOUT = False
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60  # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False

retrievaler = search.Dealer(ELASTICSEARCH)
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
# 初始化Elasticsearch检索器retrievaler和基于知识图谱的检索器kg_retrievaler

class CustomEnum(Enum):
    """
    定义了一个通用枚举类CustomEnum,提供了验证值有效性和获取成员值/名称的方法
    """
    @classmethod
    def valid(cls, value):
        try:
            cls(value)
            return True
        except BaseException:
            return False

    @classmethod
    def values(cls):
        return [member.value for member in cls.__members__.values()]

    @classmethod
    def names(cls):
        return [member.name for member in cls.__members__.values()]


class PythonDependenceName(CustomEnum):
    Rag_Source_Code = "python"
    Python_Env = "miniconda"


class ModelStorage(CustomEnum):
    REDIS = "redis"
    MYSQL = "mysql"

# 基于CustomEnum定义了PythonDependenceName和ModelStorage两个具体枚举类


class RetCode(IntEnum, CustomEnum):
    """
    定义了一个带状态码的枚举类RetCode,用于表示成功、无效、异常等多种操作结果状态
    """
    SUCCESS = 0
    NOT_EFFECTIVE = 10
    EXCEPTION_ERROR = 100
    ARGUMENT_ERROR = 101
    DATA_ERROR = 102
    OPERATING_ERROR = 103
    CONNECTION_ERROR = 105
    RUNNING = 106
    PERMISSION_ERROR = 108
    AUTHENTICATION_ERROR = 109
    UNAUTHORIZED = 401
    SERVER_ERROR = 500

api/apps/init.py

初始化应用


__all__ = ['app']


logger = logging.getLogger('flask.app')
for h in access_logger.handlers:
    logger.addHandler(h)
# 设置了名为flask.app的日志记录器,并将其与access_logger的日志处理器关联起来

Request.json = property(lambda self: self.get_json(force=True, silent=True))
# 扩展Request类,使其可以通过属性访问请求中的JSON数据,简化了获取JSON数据的过程

app = Flask(__name__) # 创建了一个Flask应用实例
CORS(app, supports_credentials=True,max_age=2592000)
# 启用了跨域资源共享(CORS),允许跨域请求
app.url_map.strict_slashes = False
# 设置了URL严格斜杠模式为False,即URL末尾是否带斜杠不影响路由匹配。

app.json_encoder = CustomJSONEncoder
# 设置JSON编码器为自定义的CustomJSONEncoder。
app.errorhandler(Exception)(server_error_response)
# 注册全局异常处理器server_error_response,用于处理未捕获的异常。

## convince for dev and debug
#app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False # 会话的持久化选项
app.config["SESSION_TYPE"] = "filesystem"
# 配置了Flask-Session,将会话存储在文件系统中
app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
# 设置最大内容长度限制,默认为128MB

Session(app)

login_manager = LoginManager()
login_manager.init_app(app)
# 使用Flask-Login进行用户认证管理,并初始化了登录管理器

commands.register_commands(app)
# 注册了一些命令到Flask应用中,通常用于开发和调试

def search_pages_path(pages_dir):
    """
    搜索指定目录下的应用页面和API文件路径
    """
    app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
    api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
    app_path_list.extend(api_path_list)
    return app_path_list


def register_page(page_path): 
    """
    定义了一个函数register_page,用于动态注册页面到Flask应用中。
    将每个页面注册为一个Blueprint,并设置相应的URL前缀。
    动态注册所有找到的页面路径
    """
    path = f'{page_path}'
    # page_path 是一个 Path 对象,表示要注册的页面或API模块的文件路径
    # 将 page_path 转换为字符串形式

    page_name = page_path.stem.rstrip('_app')
    # 获取文件名(不包括后缀),并移除后缀 _app(如果存在)
    module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,))
    # 构造模块名称,从 page_path 中提取路径部分,并拼接成符合 Python 模块命名的字符串

    spec = spec_from_file_location(module_name, page_path)
    # 根据模块名称和文件路径创建模块规格(ModuleSpec)
    page = module_from_spec(spec)
    # 根据模块规格创建模块对象
    page.app = app
    # 将当前的 Flask 应用实例赋值给模块对象的 app 属性
    page.manager = Blueprint(page_name, module_name)
    # 创建一个 Blueprint 实例,并将其赋值给模块对象的 manager 属性
    sys.modules[module_name] = page
    # 将模块对象添加到 sys.modules 中,以便后续可以引用该模块
    spec.loader.exec_module(page)
    # 执行模块文件中的代码,使模块中的定义生效
    page_name = getattr(page, 'page_name', page_name)
    # 如果模块中有 page_name 属性,则使用该属性值;否则使用之前计算的 page_name。
    url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
    # 根据文件路径是否包含 /sdk/ 来决定 URL 前缀

    app.register_blueprint(page.manager, url_prefix=url_prefix)
    # 将 Blueprint 实例注册到 Flask 应用中,并设置 URL 前缀
    return url_prefix # 返回最终的 URL 前缀


pages_dir = [
    Path(__file__).parent,
    Path(__file__).parent.parent / 'api' / 'apps',
    Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
]

client_urls_prefix = [
    register_page(path)
    for dir in pages_dir
    for path in search_pages_path(dir)
]
# 对每个路径调用 register_page 函数,注册页面或API模块,并收集其 URL 前缀

@login_manager.request_loader
# 装饰器,用于注册一个函数作为请求加载器(request loader)。
# 这个函数会在每次请求到来时被调用,用来从请求中加载用户信息
def load_user(web_request):
    jwt = Serializer(secret_key=SECRET_KEY)
    # 根据SECRET_KEY,构造序列化对象
    authorization = web_request.headers.get("Authorization")
    # web_request.headers 是一个字典,包含请求头中的所有键值对。
    # get("Authorization") 方法从请求头中获取 Authorization 字段的值。

    if authorization:
        try:
            access_token = str(jwt.loads(authorization))
            # 解析JWT令牌
            user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
            # UserService 是一个服务类,用于查询用户信息。
            # query 方法接受 access_token 和 status 参数,查询数据库中符合条件的用户。
            # StatusEnum.VALID.value 表示有效的用户状态
            if user:
                return user[0]
            else:
                return None
        except Exception as e:
            stat_logger.exception(e)
            return None
    else:
        return None


@app.teardown_request
def _db_close(exc): # 在请求完成后,关闭数据库连接
    close_connection()

api/apps/api_app.py

处理API请求

生成令牌

def generate_confirmation_token(tenent_id): 
    # 拼写错误:tenent_id=> tenant_id
    serializer = URLSafeTimedSerializer(tenent_id)
    return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]

@manager.route('/new_token', methods=['POST'])
@login_required
def new_token():
    req = request.json
    try:
        tenants = UserTenantService.query(user_id=current_user.id)
        # 查询租户信息
        if not tenants:
            return get_data_error_result(retmsg="Tenant not found!")

        tenant_id = tenants[0].tenant_id
        obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id),
               "create_time": current_timestamp(),
               "create_date": datetime_format(datetime.now()),
               "update_time": None,
               "update_date": None
               }
        # 生成令牌信息

        if req.get("canvas_id"):
            obj["dialog_id"] = req["canvas_id"]
            obj["source"] = "agent"
        else:
            obj["dialog_id"] = req["dialog_id"]
        # 设置对话ID和来源    

        if not APITokenService.save(**obj): # 保存令牌信息
            return get_data_error_result(retmsg="Fail to new a dialog!")

        return get_json_result(data=obj) # 返回生成的令牌信息
    except Exception as e:
        return server_error_response(e)

查询令牌对象列表

@manager.route('/token_list', methods=['GET'])
@login_required
def token_list():
    try:
        tenants = UserTenantService.query(user_id=current_user.id)
        if not tenants:
            return get_data_error_result(retmsg="Tenant not found!")

        id = request.args["dialog_id"] if "dialog_id" in request.args else request.args["canvas_id"]
        objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=id)
        # 查询 API 令牌对象列表
        return get_json_result(data=[o.to_dict() for o in objs])
    except Exception as e:
        return server_error_response(e)

删除令牌

@manager.route('/rm', methods=['POST'])
@validate_request("tokens", "tenant_id")
@login_required
def rm():
    req = request.json
    try:
        for token in req["tokens"]:
            APITokenService.filter_delete(
                [APIToken.tenant_id == req["tenant_id"], APIToken.token == token])
        # 遍历令牌列表:对于每个令牌,执行删除操作        
        return get_json_result(data=True)
    except Exception as e:
        return server_error_response(e)

查询会话统计数据

@manager.route('/stats', methods=['GET'])
@login_required
def stats():
    try:
        # 查询租户信息
        tenants = UserTenantService.query(user_id=current_user.id)
        if not tenants:
            return get_data_error_result(retmsg="Tenant not found!")

        # 获取日期参数
        from_date = request.args.get(
            "from_date",
            (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d 00:00:00"))
        to_date = request.args.get(
            "to_date",
            datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        # 获取其他参数
        agent = "agent" if "canvas_id" in request.args else None

        # 查询统计数据
        objs = API4ConversationService.stats(
            tenants[0].tenant_id,
            from_date,
            to_date,
            agent)

        # 构造结果
        res = {
            "pv": [(o["dt"], o["pv"]) for o in objs],
            "uv": [(o["dt"], o["uv"]) for o in objs],
            "speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
            "tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
            "round": [(o["dt"], o["round"]) for o in objs],
            "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
        }

        # 返回 JSON 结果
        return get_json_result(data=res)
    except Exception as e:
        return server_error_response(e)

创建会话

@manager.route('/new_conversation', methods=['GET'])
def set_conversation():
    token = request.headers.get('Authorization').split()[1]
    objs = APIToken.query(token=token) # 验证Token
    if not objs:
        return get_json_result(
            data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
    req = request.json
    try:
        if objs[0].source == "agent": # 处理 Agent 源
            e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
            if not e:
                return server_error_response("canvas not found.")
            if not isinstance(cvs.dsl, str):
                cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
                # 如果 cvs.dsl 不是字符串,则将其转换为字符串
            canvas = Canvas(cvs.dsl, objs[0].tenant_id)
            # 构造 Canvas 对象
            conv = {
                "id": get_uuid(),
                "dialog_id": cvs.id,
                "user_id": request.args.get("user_id", ""),
                "message": [{"role": "assistant", "content": canvas.get_prologue()}],
                "source": "agent"
            }
            # 构造对话对象conv,包含 ID、对话 ID、用户 ID、消息和源。
            API4ConversationService.save(**conv)
            # 保存对话对对象
            return get_json_result(data=conv)
        else:
            e, dia = DialogService.get_by_id(objs[0].dialog_id)
            if not e:
                return get_data_error_result(retmsg="Dialog not found")
            conv = {
                "id": get_uuid(),
                "dialog_id": dia.id,
                "user_id": request.args.get("user_id", ""),
                "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
            }
            API4ConversationService.save(**conv)
            return get_json_result(data=conv)
    except Exception as e:
        return server_error_response(e)

消息的完整性

@manager.route('/completion', methods=['POST'])
@validate_request("conversation_id", "messages")
def completion():
    token = request.headers.get('Authorization').split()[1]
    objs = APIToken.query(token=token)
    if not objs:
        return get_json_result(
            data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
    req = request.json
    e, conv = API4ConversationService.get_by_id(req["conversation_id"])
    if not e:
        return get_data_error_result(retmsg="Conversation not found!")
    if "quote" not in req: req["quote"] = False
    # 如果请求中没有 quote 参数,则默认设置为 False

    msg = []
    for m in req["messages"]:
        if m["role"] == "system": # 过滤掉 role 为 "system" 的消息
            continue
        if m["role"] == "assistant" and not msg:
            continue 
        # 如果第一个消息的角色为 "assistant" 且消息列表为空,则忽略这条消息
        msg.append(m)
    if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
    message_id = msg[-1]["id"]
    # 如果最后一个消息没有 ID,则为其生成一个 UUID,并保存该 ID。

    def fillin_conv(ans): # 更新对话数据
        nonlocal conv, message_id
        if not conv.reference:
            conv.reference.append(ans["reference"])
        else:
            conv.reference[-1] = ans["reference"]
        conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
        ans["id"] = message_id

    def rename_field(ans):
        reference = ans['reference']
        if not isinstance(reference, dict):
            return
        for chunk_i in reference.get('chunks', []):
            if 'docnm_kwd' in chunk_i:
                chunk_i['doc_name'] = chunk_i['docnm_kwd']
                chunk_i.pop('docnm_kwd')

    try:
        if conv.source == "agent":
            stream = req.get("stream", True)
            conv.message.append(msg[-1])
            e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
            if not e:
                return server_error_response("canvas not found.")
            del req["conversation_id"]
            del req["messages"]

            if not isinstance(cvs.dsl, str):
                cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)

            if not conv.reference:
                conv.reference = []
            conv.message.append({"role": "assistant", "content": "", "id": message_id})
            conv.reference.append({"chunks": [], "doc_aggs": []})

            final_ans = {"reference": [], "content": ""}
            canvas = Canvas(cvs.dsl, objs[0].tenant_id)

            canvas.messages.append(msg[-1])
            canvas.add_user_input(msg[-1]["content"])
            answer = canvas.run(stream=stream)

            assert answer is not None, "Nothing. Is it over?"

            if stream:
                assert isinstance(answer, partial), "Nothing. Is it over?"

                def sse():
                    nonlocal answer, cvs, conv
                    try:
                        for ans in answer():
                            for k in ans.keys():
                                final_ans[k] = ans[k]
                            ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
                            fillin_conv(ans)
                            rename_field(ans)
                            yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
                                                       ensure_ascii=False) + "\n\n"

                        canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
                        if final_ans.get("reference"):
                            canvas.reference.append(final_ans["reference"])
                        cvs.dsl = json.loads(str(canvas))
                        API4ConversationService.append_message(conv.id, conv.to_dict())
                    except Exception as e:
                        yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
                                                    "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
                                                   ensure_ascii=False) + "\n\n"
                    yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"

                resp = Response(sse(), mimetype="text/event-stream")
                resp.headers.add_header("Cache-control", "no-cache")
                resp.headers.add_header("Connection", "keep-alive")
                resp.headers.add_header("X-Accel-Buffering", "no")
                resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
                return resp

            final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
            canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
            if final_ans.get("reference"):
                canvas.reference.append(final_ans["reference"])
            cvs.dsl = json.loads(str(canvas))

            result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
            fillin_conv(result)
            API4ConversationService.append_message(conv.id, conv.to_dict())
            rename_field(result)
            return get_json_result(data=result)
        
        #******************For dialog******************
        conv.message.append(msg[-1])
        e, dia = DialogService.get_by_id(conv.dialog_id)
        if not e:
            return get_data_error_result(retmsg="Dialog not found!")
        del req["conversation_id"]
        del req["messages"]

        if not conv.reference:
            conv.reference = []
        conv.message.append({"role": "assistant", "content": "", "id": message_id})
        conv.reference.append({"chunks": [], "doc_aggs": []})

        def stream():
            nonlocal dia, msg, req, conv
            try:
                for ans in chat(dia, msg, True, **req):
                    fillin_conv(ans)
                    rename_field(ans)
                    yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
                                               ensure_ascii=False) + "\n\n"
                API4ConversationService.append_message(conv.id, conv.to_dict())
            except Exception as e:
                yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
                                            "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
                                           ensure_ascii=False) + "\n\n"
            yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"

        if req.get("stream", True):
            resp = Response(stream(), mimetype="text/event-stream")
            resp.headers.add_header("Cache-control", "no-cache")
            resp.headers.add_header("Connection", "keep-alive")
            resp.headers.add_header("X-Accel-Buffering", "no")
            resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
            return resp
            
        answer = None
        for ans in chat(dia, msg, **req):
            answer = ans
            fillin_conv(ans)
            API4ConversationService.append_message(conv.id, conv.to_dict())
            break
        rename_field(answer)
        return get_json_result(data=answer)

    except Exception as e:
        return server_error_response(e)

sdk

web

deepdoc

graphrag

参考url

https://ragflow.io/docs/dev/faq https://infiniflow.cn/docs/v0.7.0/ https://infiniflow.cn/docs/dev/ https://github.com/infiniflow/ragflow