Spaces:
Paused
Paused
File size: 4,612 Bytes
0084610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
CLIPTextModel,
CLIPTokenizer,
)
from .utils import freeze
class ClipTextEmbedder:
def __init__(self, conf, device):
self.model = CLIPTextModel.from_pretrained(conf.checkpoint_path).to(device)
self.model = freeze(self.model)
self.tokenizer = CLIPTokenizer.from_pretrained(conf.checkpoint_path)
self.max_length = conf.max_length
def __call__(self, texts):
inputs = self.tokenizer(
texts,
max_length=self.max_length,
truncation=True,
add_special_tokens=True,
padding="max_length",
return_tensors="pt",
).to(self.model.device)
with torch.no_grad():
pooled_embed = self.model(**inputs)["pooler_output"]
return pooled_embed
class Qwen2_5_VLTextEmbedder:
PROMPT_TEMPLATE = {
"template": {
"video": (
"<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
"Describe the location of the video, main characters or objects and their action.",
"Describe the dynamism of the video and presented actions.",
"Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
"Describe the visual effects, postprocessing and transitions if they are presented in the video.",
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
"<|im_start|>user\n{}<|im_end|>",
),
"image": (
"<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>",
"<|im_start|>user\n{}<|im_end|>",
),
},
"crop_start": {"video": 129, "image": 41},
}
def __init__(self, conf, device):
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
conf.checkpoint_path,
dtype=torch.bfloat16,
device_map=device,
)
self.model = freeze(self.model)
self.model = torch.compile(self.model, dynamic=True)
self.processor = AutoProcessor.from_pretrained(conf.checkpoint_path, use_fast=True)
self.max_length = conf.max_length
def __call__(self, texts, type_of_content="video"):
prompt_template = "\n".join(self.PROMPT_TEMPLATE["template"][type_of_content])
crop_start = self.PROMPT_TEMPLATE["crop_start"][type_of_content]
full_texts = list(map(lambda x: prompt_template.format(x), texts))
max_length = self.max_length + crop_start
inputs = self.processor(
text=full_texts,
images=None,
videos=None,
max_length=max_length,
truncation=True,
return_tensors="pt",
padding=True,
).to(self.model.device)
with torch.no_grad():
embeds = self.model(
input_ids=inputs["input_ids"],
return_dict=True,
output_hidden_states=True,
)["hidden_states"][-1][:, crop_start:]
attention_mask = inputs["attention_mask"][:, crop_start:]
embeds = embeds[attention_mask.bool()]
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(
dtype=torch.int32
)
return embeds, cu_seqlens
class Kandinsky5TextEmbedder:
def __init__(self, conf, device="cpu"):
self.embedder = Qwen2_5_VLTextEmbedder(conf.qwen, device)
self.clip_embedder = ClipTextEmbedder(conf.clip, device)
self.conf = conf
def encode(self, texts, type_of_content="image"):
text_embeds, cu_seqlens = self.embedder(texts, type_of_content=type_of_content)
pooled_embed = self.clip_embedder(texts)
return {"text_embeds": text_embeds, "pooled_embed": pooled_embed}, cu_seqlens
def to(self, device):
self.embedder.model = self.embedder.model.to(device)
self.clip_embedder.model = self.clip_embedder.model.to(device)
return self
def get_text_embedder(conf, device="cpu"):
return Kandinsky5TextEmbedder(conf, device)
|