File size: 1,899 Bytes
359fa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing_extensions import override
import torch
import torch.nn.functional as F

from comfy_api.latest import ComfyExtension, io


class Mahiro(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="Mahiro",
            display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
            category="_for_testing",
            description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
            inputs=[
                io.Model.Input("model"),
            ],
            outputs=[
                io.Model.Output(display_name="patched_model"),
            ],
            is_experimental=True,
        )

    @classmethod
    def execute(cls, model) -> io.NodeOutput:
        m = model.clone()
        def mahiro_normd(args):
            scale: float = args['cond_scale']
            cond_p: torch.Tensor = args['cond_denoised']
            uncond_p: torch.Tensor = args['uncond_denoised']
            #naive leap
            leap = cond_p * scale
            #sim with uncond leap
            u_leap = uncond_p * scale
            cfg = args["denoised"]
            merge = (leap + cfg) / 2
            normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
            normm = torch.sqrt(merge.abs()) * merge.sign()
            sim = F.cosine_similarity(normu, normm).mean()
            simsc = 2 * (sim+1)
            wm = (simsc*cfg + (4-simsc)*leap) / 4
            return wm
        m.set_model_sampler_post_cfg_function(mahiro_normd)
        return io.NodeOutput(m)


class MahiroExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            Mahiro,
        ]


async def comfy_entrypoint() -> MahiroExtension:
    return MahiroExtension()