教程:自定义Stable Diffusion扩展(以ControlNet为例)

Implementation Pipeline of Stable Diffusion with ControlNet

@zilla0717

本文梳理了用ControlNet控制Stable Diffusion输出的实现思路。

分析对象

StableDiffusion WebUI
ControlNet Extension for StableDiffusion WebUI
ControlNet作为StableDiffusion WebUI的扩展,遵照其扩展开发规则。

参考资料

【StableDiffusion WebUI源码分析 — 知乎】
1. Gradio的基本用法
2. txt2img的实现
3. 模型加载的过程
4. 启动流程
5. 多语言的实现方式
6. 脚本的实现方式
7. 扩展的实现方式
8. Lora功能的实现方式
StableDiffusion WebUI的Wiki
gradio UI component

1. 实现扩展的一般流程

插件目录下,各文件、子目录作用如下:

  1. install.py:若有则自动执行,用于完成依赖库的安装。
  2. 子目录scripts放py脚本,插件目录会被追加到sys.path建议脚本中用scripts.basedir()来获取当前插件目录,因为用户可能重命名插件。
  3. style.css和子目录javascript中的js文件会被加载到页面上。
  4. preload.py:若有,则在程序解析命令之前加载。在该文件里的preload函数中追加与该扩展有关的命令行参数。如:
def preload(parser):
    parser.add_argument("--wildcards-dir", type=str, default=None)

下面说明如何编写一个py脚本,以“旋转生成的图片”这一脚本为例(分析见注释)。

  1. import必要的包和函数(这部分不需要改动)
import modules.scripts as scripts
import gradio as gr
import os

from modules import images
from modules.processing import process_images, Processed
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
  1. 定义Script类,后续的title()show()ui()run()都是该类的函数
class Script(scripts.Script)
  1. title():定义脚本名称(显示在该插件的下拉菜单里)
    def title(self):
        return "Rotate Output"
  1. show():其返回值控制该选项何时出现在下拉菜单
    def show(self, is_img2img):
        # 只有在img2img 界面才在下拉菜单显示该功能
        return is_img2img
  1. ui():定义这个脚本在UI上怎么展示,其返回值被用作参数
    多数UI组件返回的是boolean。
    def ui(self, is_img2img):
        angle = gr.Slider(minimum=0.0, maximum=360.0, step=1, value=0,
        label="Angle")
        overwrite = gr.Checkbox(False, label="Overwrite existing files")
        return [angle, overwrite]
  1. run():获取UI传回的参数,做额外的计算过程
    该函数在这个脚本在下拉菜单中被选中时被调用,它必须进行所有处理并返回带有结果的Processed对象(与processing.process_images()返回的结果相同)。
    通常处理过程是调用process_images()完成的。
    • 入参
      1. p(类型为StableDiffusionProcessing的对象实例)
        StableDiffusionProcessing定义参见module/processing.py,定义了它以及子类StableDiffusionProcessingTxt2ImgStableDiffusionProcessingImg2Img
      2. ui()返回的参数
    • run()内部可以自定义函数和引入额外的包。
    • 对图片执行运算的函数以process_images()返回的Processed对象procui()获取的参数 为入参,原始图片在proc.images,返回处理后的proc
    def run(self, p, angle, overwrite):

        def rotate(im, angle):
            from PIL import Image
            raf = im
            if angle != 0:
                raf = raf.rotate(angle, expand=True)
            return raf

        basename = ""
        if(not overwrite):
            if angle != 0:
                basename += "rotated_" + str(angle)
        else:
            p.do_not_save_samples = True

        proc = process_images(p)
        for i in range(len(proc.images)):
            proc.images[i] = rotate(proc.images[i], angle)
            images.save_image(proc.images[i], p.outpath_samples, basename, proc.seed + i, proc.prompt, opts.samples_format, info= proc.info, p=p)
        return proc
  1. process():获取UI传回的参数,做额外的计算过程
    该函数类似run(),区别是它在开始执行总是可见的脚本前被调用,即在图像处理前被调用

before_process_batch()process_batch()postprocess_batch()等函数的作用见modules/scripts.py

2. ControlNet扩展的UI实现和回调方法

controlnet.py的写法类似上面的例子,其ui()实现如下:

    def ui(self, is_img2img):
        self.infotext_fields = []
        self.paste_field_names = []
        controls = ()
        max_models = shared.opts.data.get("control_net_max_models_num", 1)
        elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
        with gr.Group(elem_id=elem_id_tabname):
            with gr.Accordion(f"ControlNet {controlnet_version.version_flag}", open = False, elem_id="controlnet"):
                if max_models > 1:
                    with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"):
                        for i in range(max_models):
                            with gr.Tab(f"ControlNet Unit {i}", 
                                        elem_classes=['cnet-unit-tab']):
                                controls += (self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname),)
                else:
                    with gr.Column():
                        controls += (self.uigroup(f"ControlNet", is_img2img, elem_id_tabname),)

        if shared.opts.data.get("control_net_sync_field_args", False):
            for _, field_name in self.infotext_fields:
                self.paste_field_names.append(field_name)

        return controls

api.py中,可以看到 在web app启动(on_app_started)时就会调用controlnet_api()方法。

try:
    import modules.script_callbacks as script_callbacks

    script_callbacks.on_app_started(controlnet_api)
except:
    pass

controlnet_api()中定义了一些异步的方法(其中获取插件模型列表、版本、设置等信息的方法由GET请求调用,detect()由POST请求调用),实现如下:

def controlnet_api(_: gr.Blocks, app: FastAPI):
    @app.get("/controlnet/version")
    async def version():
        return {"version": external_code.get_api_version()}

    @app.get("/controlnet/model_list")
    async def model_list():
        up_to_date_model_list = external_code.get_models(update=True)
        logger.debug(up_to_date_model_list)
        return {"model_list": up_to_date_model_list}

    @app.get("/controlnet/module_list")
    async def module_list(alias_names: bool = False):
        _module_list = external_code.get_modules(alias_names)
        logger.debug(_module_list)
        
        return {
            "module_list": _module_list,
            "module_detail": external_code.get_modules_detail(alias_names)
        }
    
    @app.get("/controlnet/settings")
    async def settings():
        max_models_num = external_code.get_max_models_num()
        return {"control_net_max_models_num":max_models_num}

    cached_cn_preprocessors = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
    @app.post("/controlnet/detect")
    async def detect(
        controlnet_module: str = Body("none", title='Controlnet Module'),
        controlnet_input_images: List[str] = Body([], title='Controlnet Input Images'),
        controlnet_processor_res: int = Body(512, title='Controlnet Processor Resolution'),
        controlnet_threshold_a: float = Body(64, title='Controlnet Threshold a'),
        controlnet_threshold_b: float = Body(64, title='Controlnet Threshold b')
    ):
        controlnet_module = global_state.reverse_preprocessor_aliases.get(controlnet_module, controlnet_module)

        if controlnet_module not in cached_cn_preprocessors:
            raise HTTPException(
                status_code=422, detail="Module not available")

        if len(controlnet_input_images) == 0:
            raise HTTPException(
                status_code=422, detail="No image selected")

        logger.info(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")

        results = []

        processor_module = cached_cn_preprocessors[controlnet_module]

        for input_image in controlnet_input_images:
            img = external_code.to_base64_nparray(input_image)
            results.append(processor_module(img, res=controlnet_processor_res, thr_a=controlnet_threshold_a, thr_b=controlnet_threshold_b)[0])

        global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
        results64 = list(map(encode_to_base64, results))
        return {"images": results64, "info": "Success"}

3. ControlNet扩展的功能实现

原始的Stable Diffusion 由三个模型构成:text encoder模型(CLIPTextModel)、UNet模型和VAE 模型。ControlNet是在UNet网络上新增的旁路,用于增加额外的条件控制Stable Diffusion的输出。


controlnet.pyScript类的process()中,实现了网络结构的注入。process()在图像处理前被调用,此处unet为原先网络的结构,UnetHook为新定义的结构,通过UnetHook.hook()改变原始的UNet。

        sd_ldm = p.sd_model
        unet = sd_ldm.model.diffusion_model
        ......
        self.latest_network = UnetHook(lowvram=hook_lowvram)
        self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
        self.detected_map = detected_maps
        self.post_processors = post_processors

UnetHook.hook方法,model即是原先的网络,hook方法先将原先的模型的forward方法保存起来(model._original_forward = model.forward),然后给它重新赋值,赋值为自行实现的forward2。

  1. 文本生成图片
    text2img流程
    text_embedding = text_encoder(prompt)
    for i in steps:
    predict_noise = unet(text_embedding, timestamp,latent)
    latent_new = DDPM(latent, timestamp) # 求解器
    img = vae_decoder(latent)
  1. img2img的流程
    原始的img2img
    如图片卡通风格转换
    img_info = vae_encoder(img)
    latent_init = handle(img_info)
    其他类似text2img

unet 我们可以拆开为 uencoder和udecoder。
controlnet_information = contorlnet(controlnet_img, timestamp, latent,text_embedding )
encoder_info = uencoder(timestamp, latent,text_embedding)
信息融合:
decoder_input = controlnet_information * rate + encoder_info
predict_noise = decoder(decoder_input, timestamp, latent,text_embedding )
其他流程与text2img相同

img2paint(with mask)

要梳理什么:

  1. controlnet的pipeline具体实现,参考:onnxweb(一个repo)的diffusion 和 diffusers 的 controlnet
    需要考虑的是?
  2. controlnet的根据参数功能和实现(我有一版本,晚点发)
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,378评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,356评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,702评论 0 342
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,259评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,263评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,036评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,349评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,979评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,469评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,938评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,059评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,703评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,257评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,262评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,485评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,501评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,792评论 2 345

推荐阅读更多精彩内容