File size: 5,459 Bytes
359fa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing_extensions import override
import torch

from comfy_api.latest import ComfyExtension, io


class RenormCFG(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="RenormCFG",
            category="advanced/model",
            inputs=[
                io.Model.Input("model"),
                io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
                io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
            ],
            outputs=[
                io.Model.Output(),
            ],
        )

    @classmethod
    def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput:
        def renorm_cfg_func(args):
            cond_denoised = args["cond_denoised"]
            uncond_denoised = args["uncond_denoised"]
            cond_scale = args["cond_scale"]
            timestep = args["timestep"]
            x_orig = args["input"]
            in_channels = model.model.diffusion_model.in_channels

            if timestep[0] < cfg_trunc:
                cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
                cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
                half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
                half_rest = cond_rest

                if float(renorm_cfg) > 0.0:
                    ori_pos_norm = torch.linalg.vector_norm(cond_eps
                            , dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
                    )
                    max_new_norm = ori_pos_norm * float(renorm_cfg)
                    new_pos_norm = torch.linalg.vector_norm(
                            half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
                        )
                    if new_pos_norm >= max_new_norm:
                        half_eps = half_eps * (max_new_norm / new_pos_norm)
            else:
                cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
                cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
                half_eps = cond_eps
                half_rest = cond_rest

            cfg_result = torch.cat([half_eps, half_rest], dim=1)

            # cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale

            return x_orig - cfg_result

        m = model.clone()
        m.set_model_sampler_cfg_function(renorm_cfg_func)
        return io.NodeOutput(m)


class CLIPTextEncodeLumina2(io.ComfyNode):
    SYSTEM_PROMPT = {
        "superior": "You are an assistant designed to generate superior images with the superior "\
            "degree of image-text alignment based on textual prompts or user prompts.",
        "alignment": "You are an assistant designed to generate high-quality images with the "\
            "highest degree of image-text alignment based on textual prompts."
    }
    SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
        "Superior: You are an assistant designed to generate superior images with the superior "\
        "degree of image-text alignment based on textual prompts or user prompts. "\
        "Alignment: You are an assistant designed to generate high-quality images with the highest "\
        "degree of image-text alignment based on textual prompts."
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="CLIPTextEncodeLumina2",
            display_name="CLIP Text Encode for Lumina2",
            category="conditioning",
            description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
                        "that can be used to guide the diffusion model towards generating specific images.",
            inputs=[
                io.Combo.Input(
                    "system_prompt",
                    options=list(cls.SYSTEM_PROMPT.keys()),
                    tooltip=cls.SYSTEM_PROMPT_TIP,
                ),
                io.String.Input(
                    "user_prompt",
                    multiline=True,
                    dynamic_prompts=True,
                    tooltip="The text to be encoded.",
                ),
                io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
            ],
            outputs=[
                io.Conditioning.Output(
                    tooltip="A conditioning containing the embedded text used to guide the diffusion model.",
                ),
            ],
        )

    @classmethod
    def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput:
        if clip is None:
            raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
        system_prompt = cls.SYSTEM_PROMPT[system_prompt]
        prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
        tokens = clip.tokenize(prompt)
        return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))


class Lumina2Extension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            CLIPTextEncodeLumina2,
            RenormCFG,
        ]


async def comfy_entrypoint() -> Lumina2Extension:
    return Lumina2Extension()