File size: 4,612 Bytes
0084610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    CLIPTextModel,
    CLIPTokenizer,
)

from .utils import freeze


class ClipTextEmbedder:
    def __init__(self, conf, device):
        self.model = CLIPTextModel.from_pretrained(conf.checkpoint_path).to(device)
        self.model = freeze(self.model)
        self.tokenizer = CLIPTokenizer.from_pretrained(conf.checkpoint_path)
        self.max_length = conf.max_length

    def __call__(self, texts):
        inputs = self.tokenizer(
            texts,
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=True,
            padding="max_length",
            return_tensors="pt",
        ).to(self.model.device)

        with torch.no_grad():
            pooled_embed = self.model(**inputs)["pooler_output"]
        return pooled_embed


class Qwen2_5_VLTextEmbedder:
    PROMPT_TEMPLATE = {
        "template": {
            "video": (
                "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
                "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
                "Describe the location of the video, main characters or objects and their action.",
                "Describe the dynamism of the video and presented actions.",
                "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
                "Describe the visual effects, postprocessing and transitions if they are presented in the video.",
                "Pay attention to the order of key actions shown in the scene.<|im_end|>",
                "<|im_start|>user\n{}<|im_end|>",
            ),
            "image": (
                "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>",
                "<|im_start|>user\n{}<|im_end|>",
            ),
        },
        "crop_start": {"video": 129, "image": 41},
    }

    def __init__(self, conf, device):
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            conf.checkpoint_path,
            dtype=torch.bfloat16,
            device_map=device,
        )
        self.model = freeze(self.model)
        self.model = torch.compile(self.model, dynamic=True)
        self.processor = AutoProcessor.from_pretrained(conf.checkpoint_path, use_fast=True)
        self.max_length = conf.max_length

    def __call__(self, texts, type_of_content="video"):
        prompt_template = "\n".join(self.PROMPT_TEMPLATE["template"][type_of_content])
        crop_start = self.PROMPT_TEMPLATE["crop_start"][type_of_content]
        full_texts = list(map(lambda x: prompt_template.format(x), texts))

        max_length = self.max_length + crop_start
        inputs = self.processor(
            text=full_texts,
            images=None,
            videos=None,
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
            padding=True,
        ).to(self.model.device)

        with torch.no_grad():
            embeds = self.model(
                input_ids=inputs["input_ids"],
                return_dict=True,
                output_hidden_states=True,
            )["hidden_states"][-1][:, crop_start:]
        attention_mask = inputs["attention_mask"][:, crop_start:]
        embeds = embeds[attention_mask.bool()]
        cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
        cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(
            dtype=torch.int32
        )
        return embeds, cu_seqlens


class Kandinsky5TextEmbedder:
    def __init__(self, conf, device="cpu"):
        self.embedder = Qwen2_5_VLTextEmbedder(conf.qwen, device)
        self.clip_embedder = ClipTextEmbedder(conf.clip, device)
        self.conf = conf

    def encode(self, texts, type_of_content="image"):
        text_embeds, cu_seqlens = self.embedder(texts, type_of_content=type_of_content)
        pooled_embed = self.clip_embedder(texts)
        return {"text_embeds": text_embeds, "pooled_embed": pooled_embed}, cu_seqlens

    def to(self, device):
        self.embedder.model = self.embedder.model.to(device)
        self.clip_embedder.model = self.clip_embedder.model.to(device)
        return self


def get_text_embedder(conf, device="cpu"):
    return Kandinsky5TextEmbedder(conf, device)