美文网首页
教程:自定义Stable Diffusion扩展(以Contro

教程:自定义Stable Diffusion扩展(以Contro

作者: zilla | 来源:发表于2023-06-18 18:22 被阅读0次

    Implementation Pipeline of Stable Diffusion with ControlNet

    @zilla0717

    本文梳理了用ControlNet控制Stable Diffusion输出的实现思路。

    分析对象

    StableDiffusion WebUI
    ControlNet Extension for StableDiffusion WebUI
    ControlNet作为StableDiffusion WebUI的扩展,遵照其扩展开发规则。

    参考资料

    【StableDiffusion WebUI源码分析 — 知乎】
    1. Gradio的基本用法
    2. txt2img的实现
    3. 模型加载的过程
    4. 启动流程
    5. 多语言的实现方式
    6. 脚本的实现方式
    7. 扩展的实现方式
    8. Lora功能的实现方式
    StableDiffusion WebUI的Wiki
    gradio UI component

    1. 实现扩展的一般流程

    插件目录下,各文件、子目录作用如下:

    1. install.py:若有则自动执行,用于完成依赖库的安装。
    2. 子目录scripts放py脚本,插件目录会被追加到sys.path建议脚本中用scripts.basedir()来获取当前插件目录,因为用户可能重命名插件。
    3. style.css和子目录javascript中的js文件会被加载到页面上。
    4. preload.py:若有,则在程序解析命令之前加载。在该文件里的preload函数中追加与该扩展有关的命令行参数。如:
    def preload(parser):
        parser.add_argument("--wildcards-dir", type=str, default=None)
    

    下面说明如何编写一个py脚本,以“旋转生成的图片”这一脚本为例(分析见注释)。

    1. import必要的包和函数(这部分不需要改动)
    import modules.scripts as scripts
    import gradio as gr
    import os
    
    from modules import images
    from modules.processing import process_images, Processed
    from modules.processing import Processed
    from modules.shared import opts, cmd_opts, state
    
    1. 定义Script类,后续的title()show()ui()run()都是该类的函数
    class Script(scripts.Script)
    
    1. title():定义脚本名称(显示在该插件的下拉菜单里)
        def title(self):
            return "Rotate Output"
    
    1. show():其返回值控制该选项何时出现在下拉菜单
        def show(self, is_img2img):
            # 只有在img2img 界面才在下拉菜单显示该功能
            return is_img2img
    
    1. ui():定义这个脚本在UI上怎么展示,其返回值被用作参数
      多数UI组件返回的是boolean。
        def ui(self, is_img2img):
            angle = gr.Slider(minimum=0.0, maximum=360.0, step=1, value=0,
            label="Angle")
            overwrite = gr.Checkbox(False, label="Overwrite existing files")
            return [angle, overwrite]
    
    1. run():获取UI传回的参数,做额外的计算过程
      该函数在这个脚本在下拉菜单中被选中时被调用,它必须进行所有处理并返回带有结果的Processed对象(与processing.process_images()返回的结果相同)。
      通常处理过程是调用process_images()完成的。
      • 入参
        1. p(类型为StableDiffusionProcessing的对象实例)
          StableDiffusionProcessing定义参见module/processing.py,定义了它以及子类StableDiffusionProcessingTxt2ImgStableDiffusionProcessingImg2Img
        2. ui()返回的参数
      • run()内部可以自定义函数和引入额外的包。
      • 对图片执行运算的函数以process_images()返回的Processed对象procui()获取的参数 为入参,原始图片在proc.images,返回处理后的proc
        def run(self, p, angle, overwrite):
    
            def rotate(im, angle):
                from PIL import Image
                raf = im
                if angle != 0:
                    raf = raf.rotate(angle, expand=True)
                return raf
    
            basename = ""
            if(not overwrite):
                if angle != 0:
                    basename += "rotated_" + str(angle)
            else:
                p.do_not_save_samples = True
    
            proc = process_images(p)
            for i in range(len(proc.images)):
                proc.images[i] = rotate(proc.images[i], angle)
                images.save_image(proc.images[i], p.outpath_samples, basename, proc.seed + i, proc.prompt, opts.samples_format, info= proc.info, p=p)
            return proc
    
    1. process():获取UI传回的参数,做额外的计算过程
      该函数类似run(),区别是它在开始执行总是可见的脚本前被调用,即在图像处理前被调用

    before_process_batch()process_batch()postprocess_batch()等函数的作用见modules/scripts.py

    2. ControlNet扩展的UI实现和回调方法

    controlnet.py的写法类似上面的例子,其ui()实现如下:

        def ui(self, is_img2img):
            self.infotext_fields = []
            self.paste_field_names = []
            controls = ()
            max_models = shared.opts.data.get("control_net_max_models_num", 1)
            elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
            with gr.Group(elem_id=elem_id_tabname):
                with gr.Accordion(f"ControlNet {controlnet_version.version_flag}", open = False, elem_id="controlnet"):
                    if max_models > 1:
                        with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"):
                            for i in range(max_models):
                                with gr.Tab(f"ControlNet Unit {i}", 
                                            elem_classes=['cnet-unit-tab']):
                                    controls += (self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname),)
                    else:
                        with gr.Column():
                            controls += (self.uigroup(f"ControlNet", is_img2img, elem_id_tabname),)
    
            if shared.opts.data.get("control_net_sync_field_args", False):
                for _, field_name in self.infotext_fields:
                    self.paste_field_names.append(field_name)
    
            return controls
    

    api.py中,可以看到 在web app启动(on_app_started)时就会调用controlnet_api()方法。

    try:
        import modules.script_callbacks as script_callbacks
    
        script_callbacks.on_app_started(controlnet_api)
    except:
        pass
    

    controlnet_api()中定义了一些异步的方法(其中获取插件模型列表、版本、设置等信息的方法由GET请求调用,detect()由POST请求调用),实现如下:

    def controlnet_api(_: gr.Blocks, app: FastAPI):
        @app.get("/controlnet/version")
        async def version():
            return {"version": external_code.get_api_version()}
    
        @app.get("/controlnet/model_list")
        async def model_list():
            up_to_date_model_list = external_code.get_models(update=True)
            logger.debug(up_to_date_model_list)
            return {"model_list": up_to_date_model_list}
    
        @app.get("/controlnet/module_list")
        async def module_list(alias_names: bool = False):
            _module_list = external_code.get_modules(alias_names)
            logger.debug(_module_list)
            
            return {
                "module_list": _module_list,
                "module_detail": external_code.get_modules_detail(alias_names)
            }
        
        @app.get("/controlnet/settings")
        async def settings():
            max_models_num = external_code.get_max_models_num()
            return {"control_net_max_models_num":max_models_num}
    
        cached_cn_preprocessors = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
        @app.post("/controlnet/detect")
        async def detect(
            controlnet_module: str = Body("none", title='Controlnet Module'),
            controlnet_input_images: List[str] = Body([], title='Controlnet Input Images'),
            controlnet_processor_res: int = Body(512, title='Controlnet Processor Resolution'),
            controlnet_threshold_a: float = Body(64, title='Controlnet Threshold a'),
            controlnet_threshold_b: float = Body(64, title='Controlnet Threshold b')
        ):
            controlnet_module = global_state.reverse_preprocessor_aliases.get(controlnet_module, controlnet_module)
    
            if controlnet_module not in cached_cn_preprocessors:
                raise HTTPException(
                    status_code=422, detail="Module not available")
    
            if len(controlnet_input_images) == 0:
                raise HTTPException(
                    status_code=422, detail="No image selected")
    
            logger.info(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")
    
            results = []
    
            processor_module = cached_cn_preprocessors[controlnet_module]
    
            for input_image in controlnet_input_images:
                img = external_code.to_base64_nparray(input_image)
                results.append(processor_module(img, res=controlnet_processor_res, thr_a=controlnet_threshold_a, thr_b=controlnet_threshold_b)[0])
    
            global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
            results64 = list(map(encode_to_base64, results))
            return {"images": results64, "info": "Success"}
    

    3. ControlNet扩展的功能实现

    原始的Stable Diffusion 由三个模型构成:text encoder模型(CLIPTextModel)、UNet模型和VAE 模型。ControlNet是在UNet网络上新增的旁路,用于增加额外的条件控制Stable Diffusion的输出。


    controlnet.pyScript类的process()中,实现了网络结构的注入。process()在图像处理前被调用,此处unet为原先网络的结构,UnetHook为新定义的结构,通过UnetHook.hook()改变原始的UNet。
            sd_ldm = p.sd_model
            unet = sd_ldm.model.diffusion_model
            ......
            self.latest_network = UnetHook(lowvram=hook_lowvram)
            self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
            self.detected_map = detected_maps
            self.post_processors = post_processors
    

    UnetHook.hook方法,model即是原先的网络,hook方法先将原先的模型的forward方法保存起来(model._original_forward = model.forward),然后给它重新赋值,赋值为自行实现的forward2。

    1. 文本生成图片
      text2img流程
      text_embedding = text_encoder(prompt)
      for i in steps:
      predict_noise = unet(text_embedding, timestamp,latent)
      latent_new = DDPM(latent, timestamp) # 求解器
      img = vae_decoder(latent)
    1. img2img的流程
      原始的img2img
      如图片卡通风格转换
      img_info = vae_encoder(img)
      latent_init = handle(img_info)
      其他类似text2img

    unet 我们可以拆开为 uencoder和udecoder。
    controlnet_information = contorlnet(controlnet_img, timestamp, latent,text_embedding )
    encoder_info = uencoder(timestamp, latent,text_embedding)
    信息融合:
    decoder_input = controlnet_information * rate + encoder_info
    predict_noise = decoder(decoder_input, timestamp, latent,text_embedding )
    其他流程与text2img相同

    img2paint(with mask)

    要梳理什么:

    1. controlnet的pipeline具体实现,参考:onnxweb(一个repo)的diffusion 和 diffusers 的 controlnet
      需要考虑的是?
    2. controlnet的根据参数功能和实现(我有一版本,晚点发)

    相关文章

      网友评论

          本文标题:教程:自定义Stable Diffusion扩展(以Contro

          本文链接:https://www.haomeiwen.com/subject/lsnnedtx.html