toshi-456 commited on
Commit
61ac967
·
verified ·
1 Parent(s): 967c4dd

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 SB Intuitions
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,158 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ja
4
+ - en
5
+ base_model:
6
+ - sbintuitions/sarashina2.2-3b-instruct-v0.1
7
+ license: mit
8
+ tags:
9
+ - multimodal
10
+ - vision-language
11
+ pipeline_tag: image-to-text
12
+ library_name: transformers
13
+ ---
14
+
15
+ # Sarashina2.2-Vision-3B
16
+ **Sarashina2.2-Vision-3B** is a Japanese Large Vision Language Model trained by [SB Intuitions](https://www.sbintuitions.co.jp).
17
+
18
+ This model is based on [Sarashina2.2-3B-Instruct](https://huggingface.co/sbintuitions/sarashina2.2-3b-instruct-v0.1) and Image Encoder of [SigLIP](https://huggingface.co/google/siglip-so400m-patch14-384).
19
+
20
+ ## Model Performance
21
+ ### Japanese Performance
22
+
23
+ |Model|Params(B)|[BussinessSlide VQA](https://github.com/stockmarkteam/business-slide-questions)<sup>*1</sup>|[Heron-Bench](https://arxiv.org/abs/2404.07824)<sup>*1</sup>|[JDocQA](https://arxiv.org/abs/2403.19454)<sup>*1</sup>|[JMMMU](https://arxiv.org/abs/2410.17250)|
24
+ |-|-|-|-|-|-|
25
+ |[Sarashina2.2-Vision-3B](https://huggingface.co/sbintuitions/sarashina2.2-vision-3b)|3.8|3.932|**3.214**|<u>3.327</u>|<u>0.486</u>|
26
+ |[Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)|3.8|3.516|2.000|3.019|0.450|
27
+ |[Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)|4.4|**4.105**|2.330|**3.596**|**0.493**|
28
+ |[InternVL3_5-4B](https://huggingface.co/OpenGVLab/InternVL3_5-4B)|4.7|3.311|1.893|2.626|0.437|
29
+ |[Sarashina2-Vision-14B](https://huggingface.co/sbintuitions/sarashina2-vision-14b)|14.4|3.110|2.184|-<sup>*2</sup>|0.432|
30
+ |[Stockmark-2-VL-100B-beta](https://huggingface.co/stockmark/Stockmark-2-VL-100B-beta)|96.5|<u>3.973</u>|<u>2.563</u>|3.168|-<sup>*2</sup>|
31
+
32
+ *1. [gpt-oss-120b](https://huggingface.co/openai/gpt-oss-120b) was used for LLM-as-a-Judge.
33
+
34
+ *2. Score cannot be measured because some input data exceeds the model's `max_position_embeddings`.
35
+
36
+ ### English Performance
37
+
38
+ |Model|Params(B)|[DocVQA](https://arxiv.org/abs/2007.00398)|[InfoVQA](https://arxiv.org/abs/2104.12756)|[RealworldQA](https://huggingface.co/datasets/xai-org/RealworldQA)
39
+ |-|-|-|-|-|
40
+ |[Sarashina2.2-Vision-3B](https://huggingface.co/sbintuitions/sarashina2.2-vision-3b)|3.8|0.831|0.567|<u>0.625</u>|
41
+ |[Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)|3.8|<u>0.924</u>|<u>0.750</u>|0.586|
42
+ |[Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)|4.4|**0.948**|**0.798**|**0.712**|
43
+ |[InternVL3_5-4B](https://huggingface.co/OpenGVLab/InternVL3_5-4B)|4.7|0.823|0.541|0.553|
44
+ |[Sarashina2-Vision-14B](https://huggingface.co/sbintuitions/sarashina2-vision-14b)|14.4|0.729|0.490|0.519||
45
+
46
+ ## How to use
47
+ ### 1. Install dependencies
48
+
49
+ ```sh
50
+ pip install transformers==4.57.1 torch torchvision pillow protobuf sentencepiece accelerate
51
+ ```
52
+
53
+ ### 2. Inference
54
+ The following script loads the model and allows inference.
55
+ ```python
56
+ import requests
57
+ from PIL import Image
58
+ from transformers import AutoModelForCausalLM, AutoProcessor, set_seed
59
+
60
+ # Define model path
61
+ model_path = "sbintuitions/sarashina2.2-vision-3b"
62
+
63
+ # Load model and processor
64
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ model_path,
67
+ device_map="cuda",
68
+ torch_dtype="auto",
69
+ trust_remote_code=True,
70
+ )
71
+ set_seed(42)
72
+
73
+ image_url = "https://huggingface.co/sbintuitions/sarashina2.2-vision-3b/resolve/main/sample.jpg"
74
+ message = [
75
+ {
76
+ "role": "user",
77
+ "content": [
78
+ {
79
+ "type": "image",
80
+ "image": image_url,
81
+ },
82
+ {
83
+ "type": "text",
84
+ "text": "これはどこで撮った写真ですか?",
85
+ },
86
+ ],
87
+ }
88
+ ]
89
+ text_prompt = processor.apply_chat_template(message, add_generation_prompt=True)
90
+ """text_prompt: <|user|><|prefix|><|file|><|suffix|>これはどこで撮った写真ですか?</s><|assistant|>"""
91
+
92
+ image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
93
+ inputs = processor(
94
+ text=[text_prompt],
95
+ images=[image],
96
+ padding=True,
97
+ return_tensors="pt",
98
+ )
99
+ inputs = inputs.to(model.device)
100
+
101
+ # Inference: Generation of the output
102
+ output_ids = model.generate(
103
+ **inputs,
104
+ max_new_tokens=512,
105
+ temperature=0.7,
106
+ top_p=0.95,
107
+ repetition_penalty=1.2,
108
+ )
109
+ generated_ids = [
110
+ output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs.input_ids, output_ids)
111
+ ]
112
+ output_text = processor.batch_decode(
113
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
114
+ )
115
+ print(output_text[0])
116
+ """
117
+ この写真は、**道後温泉本館(どうごおんせんほんかん)** の入り口を夜景で撮影した写真です。
118
+
119
+ ---
120
+ 場所の詳細:
121
+ - **名称**:道後温泉本館(Dogo Onsen Honkan)
122
+ - **所在地**:〒790-0842 愛媛県松山市道後湯之町1丁目3番5号
123
+ - **アクセス**:JR松山駅から市内電車「道後温泉駅」下車すぐ
124
+ - **特徴**:日本最古の温泉の一つとして知られる「道後温泉」の中心��な施設。国の重要文化財にも指定されています。
125
+
126
+ ---
127
+ 写真の特徴から判断した理由:
128
+ - 建物の屋根や装飾が伝統的な和風建築で、「道後温泉」の看板が目立つ。
129
+ - 入口の垂れ幕には「道後」「道後」と書かれており、白い鳳凰の模様が描かれている → 道後温泉の象徴的デザイン。
130
+ - 夜の照明と石灯籠、提灯風の灯りが日本の温泉地らしい雰囲気を醸し出している。
131
+ - 看板に「道後温泉」の文字が明確に表示されている。
132
+
133
+ ---
134
+ 補足情報:
135
+ 道後温泉本館は、夏目漱石の小説『坊っちゃん』の舞台としても有名で、多くの観光客が訪れる人気スポットです。また、2020年にリニューアルされ、現代的な設備も導入されていますが、外観は伝統を残しています。
136
+
137
+ ---
138
+ よって、この写真は **愛媛県松山市にある「道後温泉本館」の夜景** です。
139
+ """
140
+ ```
141
+
142
+ ## Training
143
+ **Sarashina2.2-Vision-3B** is created through the following five-stage training process:
144
+
145
+ ### PreTrain
146
+ 1. Projector Warmup: To bridge the gap between the text and image embedding spaces within the LLM
147
+ 2. Vision Encoder Pretraining: To enhance image comprehension, especially for understanding Japan-specific images and text
148
+ 3. Full Model Pretraining: To enhance the model's unified understanding of images and language using interleaved data
149
+
150
+ ### PostTrain
151
+ 1. Supervised Fine-Tuning(SFT): To improve the model's ability to follow instructions and respond appropriately to user prompts
152
+ 2. Mixed Preference Optimization(MPO): To align the model's outputs with user preferences, ensuring it generates more desirable responses
153
+
154
+ ## Limitations
155
+ This model has limited safety training. Therefore, it might generate some meaningless sequences, some inaccurate instances, or biased/objectionable outputs. Before using it, we would like developers to tune models based on human preferences and safety considerations.
156
+
157
+ ## LICENSE
158
+ [MIT License](./LICENSE)
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{% set image_count = namespace(value=0) %}{% for message in messages %}{% if message['content'] is string %}{% if message['role'] == 'user' %}{{ '<|user|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% endif %}{% else %}{% if message['role'] == 'user' %}{{ '<|user|>' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' }}{% endif %}{% for content in message['content'] %}{% if content['type'] == 'image' or content.get('image') or content.get('image_url') %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %} Picture {{ image_count.value }}: {% endif %}{{ '<|prefix|><|file|><|suffix|>' }}{% endif %}{% endfor %}{% for content in message['content'] %}{% if content.get('text') %}{{ content['text'] }}{% endif %}{% endfor %}{{ eos_token }}{% endif %}{% endfor %}{% if messages[-1]['role'] == 'user' %}{{ '<|assistant|>' }}{% endif %}"
3
+ }
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Sarashina2VisionForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_sarashina2_vision.Sarashina2VisionConfig",
7
+ "AutoModelForCausalLM": "modeling_sarashina2_vision.Sarashina2VisionForCausalLM"
8
+ },
9
+ "end_image_token_index": 102398,
10
+ "hidden_act": "silu",
11
+ "image_token_index": 14,
12
+ "model_type": "sarashina2_vision",
13
+ "start_image_token_index": 102397,
14
+ "text_config": {
15
+ "_name_or_path": "sbintuitions/sarashina2.2-3b-instruct-v0.1",
16
+ "architectures": [
17
+ "LlamaForCausalLM"
18
+ ],
19
+ "attention_bias": false,
20
+ "attention_dropout": 0.0,
21
+ "head_dim": 160,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 2560,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 8960,
26
+ "max_position_embeddings": 8192,
27
+ "mlp_bias": false,
28
+ "model_type": "llama",
29
+ "num_attention_heads": 16,
30
+ "num_hidden_layers": 32,
31
+ "num_key_value_heads": 8,
32
+ "pretraining_tp": 1,
33
+ "rms_norm_eps": 1e-05,
34
+ "rope_scaling": null,
35
+ "rope_theta": 500000,
36
+ "torch_dtype": "bfloat16",
37
+ "use_cache": false,
38
+ "vocab_size": 102400
39
+ },
40
+ "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.51.3",
42
+ "vision_config": {
43
+ "depth": 27,
44
+ "embed_dim": 1152,
45
+ "hidden_act": "gelu_pytorch_tanh",
46
+ "hidden_size": 2560,
47
+ "in_channels": 3,
48
+ "initializer_range": 0.02,
49
+ "mlp_ratio": 3.7362,
50
+ "model_type": "qwen2_vl",
51
+ "num_heads": 16,
52
+ "patch_size": 14,
53
+ "spatial_merge_size": 2,
54
+ "temporal_patch_size": 2
55
+ },
56
+ "vocab_size": 102400
57
+ }
configuration_sarashina2_vision.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the SB Intuitions.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Sarashina2Vision model configuration"""
16
+
17
+ from typing import Any, Optional
18
+
19
+ from transformers import LlamaConfig, PretrainedConfig
20
+ from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
21
+ from transformers.utils import logging
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Sarashina2VisionConfig(PretrainedConfig):
27
+ """
28
+ This is the configuration class to store the configuration of a [`Sarashina2VisionModel`]. It is used to instantiate a
29
+ Sarashina2Vision model according to the specified arguments, defining the model architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vision_config (`Dict`, *optional*):
36
+ The config for the visual encoder initialization.
37
+ text_config (`Dict`, *optional*):
38
+ The config for the text decoder initialization.
39
+ image_token_index (`int`):
40
+ image token id.
41
+ start_image_token_index (`int`):
42
+ start image token id.
43
+ end_image_token_index (`int`):
44
+ end image token id.
45
+ """
46
+
47
+ model_type = "sarashina2_vision"
48
+
49
+ def __init__(
50
+ self,
51
+ vision_config: Optional[dict[str, Any]] = None,
52
+ text_config: Optional[dict[str, Any]] = None,
53
+ image_token_index: int = 14,
54
+ start_image_token_index: int = 102397,
55
+ end_image_token_index: int = 102398,
56
+ **kwargs,
57
+ ):
58
+ if isinstance(text_config, dict):
59
+ self.text_config = LlamaConfig(**text_config)
60
+ elif isinstance(text_config, LlamaConfig):
61
+ self.text_config = text_config
62
+ elif text_config is None:
63
+ self.text_config = LlamaConfig()
64
+
65
+ if isinstance(vision_config, dict):
66
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
67
+ elif isinstance(vision_config, Qwen2VLVisionConfig):
68
+ self.vision_config = vision_config
69
+ elif vision_config is None:
70
+ self.vision_config = Qwen2VLVisionConfig()
71
+
72
+ self.image_token_index = image_token_index
73
+ self.start_image_token_index = start_image_token_index
74
+ self.end_image_token_index = end_image_token_index
75
+
76
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.57.1",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91fdaf16a99ee00f43801b26f6167aee0d205477303dfc72a37242e660da7ac4
3
+ size 7603021272
modeling_sarashina2_vision.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the SB Intuitions.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.nn import CrossEntropyLoss
21
+ from transformers import (
22
+ AutoConfig,
23
+ AutoModelForCausalLM,
24
+ GenerationMixin,
25
+ LlamaForCausalLM,
26
+ PreTrainedModel,
27
+ )
28
+ from transformers.modeling_outputs import CausalLMOutputWithPast
29
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
30
+ from transformers.utils import logging, replace_return_docstrings
31
+
32
+ from .configuration_sarashina2_vision import Sarashina2VisionConfig
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ _CONFIG_FOR_DOC = "Sarashina2VisionConfig"
37
+
38
+
39
+ class Sarashina2VisionPreTrainedModel(PreTrainedModel):
40
+ config_class = Sarashina2VisionConfig
41
+ base_model_prefix = "model"
42
+ _supports_flash_attn_2 = True
43
+ _supports_sdpa = True
44
+ _supports_cache_class = True
45
+ _supports_static_cache = True
46
+
47
+ def _init_weights(self, module):
48
+ std = (
49
+ self.config.initializer_range
50
+ if hasattr(self.config, "initializer_range")
51
+ else self.config.text_config.initializer_range
52
+ )
53
+
54
+ if hasattr(module, "class_embedding"):
55
+ module.class_embedding.data.normal_(mean=0.0, std=std)
56
+
57
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
58
+ module.weight.data.normal_(mean=0.0, std=std)
59
+ if module.bias is not None:
60
+ module.bias.data.zero_()
61
+ elif isinstance(module, nn.Embedding):
62
+ module.weight.data.normal_(mean=0.0, std=std)
63
+ if module.padding_idx is not None:
64
+ module.weight.data[module.padding_idx].zero_()
65
+
66
+
67
+ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMixin):
68
+ def __init__(self, config: Sarashina2VisionConfig):
69
+ super().__init__(config)
70
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
71
+ self.norm = nn.LayerNorm(config.text_config.hidden_size)
72
+ self.llm = LlamaForCausalLM._from_config(config.text_config)
73
+
74
+ # Initialize weights and apply final processing
75
+ self.post_init()
76
+
77
+ def get_input_embeddings(self):
78
+ return self.llm.get_input_embeddings()
79
+
80
+ def get_image_embeds(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ grid_thw: torch.Tensor,
84
+ ) -> torch.Tensor:
85
+ rotary_pos_emb = self.visual.rot_pos_emb(grid_thw)
86
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
87
+ position_embeddings = (emb.cos(), emb.sin())
88
+ hidden_states = self.visual.patch_embed(hidden_states)
89
+
90
+ cu_seqlens = torch.repeat_interleave(
91
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
92
+ ).cumsum(dim=0, dtype=torch.int32)
93
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
94
+
95
+ for blk in self.visual.blocks:
96
+ hidden_states = blk(
97
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings
98
+ )
99
+ return self.norm(self.visual.merger(hidden_states))
100
+
101
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
102
+ def forward(
103
+ self,
104
+ input_ids: torch.LongTensor = None,
105
+ attention_mask: Optional[torch.Tensor] = None,
106
+ position_ids: Optional[torch.LongTensor] = None,
107
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
108
+ inputs_embeds: Optional[torch.FloatTensor] = None,
109
+ labels: Optional[torch.LongTensor] = None,
110
+ use_cache: Optional[bool] = None,
111
+ output_attentions: Optional[bool] = None,
112
+ output_hidden_states: Optional[bool] = None,
113
+ return_dict: Optional[bool] = None,
114
+ pixel_values: torch.FloatTensor = None,
115
+ image_grid_thw: Optional[torch.LongTensor] = None,
116
+ cache_position: Optional[torch.LongTensor] = None,
117
+ logits_to_keep: Union[int, torch.Tensor] = 0,
118
+ **lm_kwargs,
119
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
120
+ """
121
+ Args:
122
+ input_ids (torch.LongTensor, optional): Indices of input sequence tokens in the vocabulary. Defaults to None.
123
+ attention_mask (Optional[torch.Tensor], optional): Mask to avoid performing attention on padding token indices. Defaults to None.
124
+ position_ids (Optional[torch.LongTensor], optional): Indices of positions of each input sequence tokens in the position embeddings. Defaults to None.
125
+ past_key_values (Optional[List[torch.FloatTensor]], optional): _description_. Defaults to None.
126
+ inputs_embeds (Optional[torch.FloatTensor], optional): Instead of passing `input_ids` you can choose to directly pass an embedded representation. Defaults to None.
127
+ labels (Optional[torch.LongTensor], optional): Labels for computing the masked language modeling loss. Defaults to None.
128
+ use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding. Defaults to None.
129
+ output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers. Defaults to None.
130
+ output_hidden_states (Optional[bool], optional): Whether or not to return the hidden states of all layers. Defaults to None.
131
+ return_dict (Optional[bool], optional): Whether or not to return a `CausalLMOutputWithPast` instead of a plain tuple. Defaults to None.
132
+ pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None.
133
+ image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None.
134
+ cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
135
+ logits_to_keep (Union[int, torch.Tensor]): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
136
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
137
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
138
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
139
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
140
+ Returns:
141
+ CausalLMOutputWithPast: The output of the model.
142
+ """
143
+ output_attentions = (
144
+ output_attentions if output_attentions is not None else self.config.output_attentions
145
+ )
146
+ output_hidden_states = (
147
+ output_hidden_states
148
+ if output_hidden_states is not None
149
+ else self.config.output_hidden_states
150
+ )
151
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
152
+
153
+ if inputs_embeds is None:
154
+ inputs_embeds = self.get_input_embeddings()(input_ids)
155
+ if pixel_values is not None:
156
+ pixel_values = pixel_values.type(self.visual.get_dtype())
157
+ image_embeds = self.get_image_embeds(pixel_values, image_grid_thw)
158
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
159
+ n_image_features = image_embeds.shape[0]
160
+ if n_image_tokens != n_image_features:
161
+ raise ValueError(
162
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
163
+ )
164
+ image_mask = (
165
+ (input_ids == self.config.image_token_index)
166
+ .unsqueeze(-1)
167
+ .expand_as(inputs_embeds)
168
+ .to(inputs_embeds.device)
169
+ )
170
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
171
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
172
+
173
+ outputs = self.llm(
174
+ attention_mask=attention_mask,
175
+ position_ids=position_ids,
176
+ past_key_values=past_key_values,
177
+ inputs_embeds=inputs_embeds,
178
+ use_cache=use_cache,
179
+ output_attentions=output_attentions,
180
+ output_hidden_states=output_hidden_states,
181
+ return_dict=return_dict,
182
+ cache_position=cache_position,
183
+ logits_to_keep=logits_to_keep,
184
+ **lm_kwargs,
185
+ )
186
+
187
+ logits = outputs[0]
188
+
189
+ loss = None
190
+ if labels is not None:
191
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
192
+ logits = logits.float()
193
+ # Shift so that tokens < n predict n
194
+ shift_logits = logits[..., :-1, :].contiguous()
195
+ shift_labels = labels[..., 1:].contiguous()
196
+ # Flatten the tokens
197
+ loss_fct = CrossEntropyLoss()
198
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
199
+ shift_labels = shift_labels.view(-1)
200
+ # Enable model parallelism
201
+ shift_labels = shift_labels.to(shift_logits.device)
202
+ loss = loss_fct(shift_logits, shift_labels)
203
+
204
+ if not return_dict:
205
+ output = (logits,) + outputs[1:]
206
+ return (loss,) + output if loss is not None else output
207
+
208
+ return CausalLMOutputWithPast(
209
+ loss=loss,
210
+ logits=logits,
211
+ past_key_values=outputs.past_key_values,
212
+ hidden_states=outputs.hidden_states,
213
+ attentions=outputs.attentions,
214
+ )
215
+
216
+ def prepare_inputs_for_generation(
217
+ self,
218
+ input_ids,
219
+ past_key_values=None,
220
+ inputs_embeds=None,
221
+ pixel_values=None,
222
+ attention_mask=None,
223
+ cache_position=None,
224
+ logits_to_keep=None,
225
+ image_grid_thw=None,
226
+ **kwargs,
227
+ ):
228
+ model_inputs = self.llm.prepare_inputs_for_generation(
229
+ input_ids,
230
+ past_key_values=past_key_values,
231
+ inputs_embeds=inputs_embeds,
232
+ attention_mask=attention_mask,
233
+ cache_position=cache_position,
234
+ logits_to_keep=logits_to_keep,
235
+ **kwargs,
236
+ )
237
+
238
+ if cache_position[0] == 0:
239
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
240
+ # Otherwise we need pixel values to be passed to model
241
+ model_inputs["pixel_values"] = pixel_values
242
+ model_inputs["image_grid_thw"] = image_grid_thw
243
+
244
+ return model_inputs
245
+
246
+
247
+ AutoConfig.register("sarashina2_vision", Sarashina2VisionConfig)
248
+ AutoModelForCausalLM.register(Sarashina2VisionConfig, Sarashina2VisionForCausalLM)
249
+ Sarashina2VisionConfig.register_for_auto_class()
250
+ Sarashina2VisionForCausalLM.register_for_auto_class("AutoModelForCausalLM")
preprocessor_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_sarashina2_vision.Sarashina2VisionProcessor"
4
+ },
5
+ "do_convert_rgb": true,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_processor_type": "Sarashina2VisionImageProcessor",
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "max_pixels": 1016064,
21
+ "merge_size": 2,
22
+ "min_pixels": 3136,
23
+ "patch_size": 14,
24
+ "processor_class": "Sarashina2VisionProcessor",
25
+ "resample": 2,
26
+ "rescale_factor": 0.00392156862745098,
27
+ "size": {
28
+ "longest_edge": 1016064,
29
+ "shortest_edge": 3136
30
+ },
31
+ "temporal_patch_size": 2
32
+ }
processing_sarashina2_vision.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the SB Intuitions.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Sarashina2Vision.
17
+ """
18
+
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from transformers import (
25
+ AutoImageProcessor,
26
+ BaseImageProcessor,
27
+ )
28
+ from transformers.feature_extraction_utils import BatchFeature
29
+ from transformers.image_transforms import (
30
+ convert_to_rgb,
31
+ to_channel_dimension_format,
32
+ )
33
+ from transformers.image_utils import (
34
+ OPENAI_CLIP_MEAN,
35
+ OPENAI_CLIP_STD,
36
+ ChannelDimension,
37
+ ImageInput,
38
+ get_image_size,
39
+ infer_channel_dimension_format,
40
+ is_scaled_image,
41
+ make_flat_list_of_images,
42
+ make_list_of_images,
43
+ to_numpy_array,
44
+ valid_images,
45
+ )
46
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
47
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
48
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
49
+ from transformers.utils import TensorType, logging
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ class Sarashina2VisionImageProcessor(BaseImageProcessor):
55
+ r"""
56
+ Constructs a Sarashina2Vision image processor that dynamically resizes images based on the original images.
57
+
58
+ Args:
59
+ do_resize (`bool`, *optional*, defaults to `True`):
60
+ Whether to resize the image's (height, width) dimensions.
61
+ do_rescale (`bool`, *optional*, defaults to `True`):
62
+ Whether to rescale the image by the specified scale `rescale_factor`.
63
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
64
+ Scale factor to use if rescaling the image.
65
+ do_normalize (`bool`, *optional*, defaults to `True`):
66
+ Whether to normalize the image.
67
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
68
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
69
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
70
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
71
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
72
+ Whether to convert the image to RGB.
73
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
74
+ The min pixels of the image to resize the image.
75
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
76
+ The max pixels of the image to resize the image.
77
+ patch_size (`int`, *optional*, defaults to 14):
78
+ The spacial patch size of the vision encoder.
79
+ temporal_patch_size (`int`, *optional*, defaults to 2):
80
+ The temporal patch size of the vision encoder.
81
+ merge_size (`int`, *optional*, defaults to 2):
82
+ The merge size of the vision encoder to llm encoder.
83
+ """
84
+
85
+ model_input_names = ["pixel_values", "image_grid_thw"]
86
+
87
+ def __init__(
88
+ self,
89
+ do_resize: bool = True,
90
+ do_rescale: bool = True,
91
+ rescale_factor: Union[int, float] = 1 / 255,
92
+ do_normalize: bool = True,
93
+ image_mean: Optional[Union[float, List[float]]] = None,
94
+ image_std: Optional[Union[float, List[float]]] = None,
95
+ do_convert_rgb: bool = True,
96
+ min_pixels: int = 56 * 56,
97
+ max_pixels: int = 28 * 28 * 1280,
98
+ patch_size: int = 14,
99
+ temporal_patch_size: int = 2,
100
+ merge_size: int = 2,
101
+ **kwargs,
102
+ ) -> None:
103
+ super().__init__(**kwargs)
104
+ self.do_resize = do_resize
105
+ self.do_rescale = do_rescale
106
+ self.rescale_factor = rescale_factor
107
+ self.do_normalize = do_normalize
108
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
109
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
110
+ self.min_pixels = min_pixels
111
+ self.max_pixels = max_pixels
112
+ self.patch_size = patch_size
113
+ self.temporal_patch_size = temporal_patch_size
114
+ self.merge_size = merge_size
115
+ self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
116
+ self.do_convert_rgb = do_convert_rgb
117
+
118
+ def _preprocess(
119
+ self,
120
+ images: ImageInput,
121
+ do_resize: bool = None,
122
+ do_rescale: bool = None,
123
+ rescale_factor: float = None,
124
+ do_normalize: bool = None,
125
+ image_mean: Optional[Union[float, List[float]]] = None,
126
+ image_std: Optional[Union[float, List[float]]] = None,
127
+ do_convert_rgb: bool = None,
128
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
129
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
130
+ ):
131
+ """
132
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `Sarashina2Vision`.
133
+
134
+ Args:
135
+ images (`ImageInput`):
136
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
137
+ vision_info (`List[Dict]`, *optional*):
138
+ Optional list of dictionaries containing additional information about vision inputs.
139
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
140
+ Whether to resize the image.
141
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
142
+ Whether to rescale the image.
143
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
144
+ Scale factor to use if rescaling the image.
145
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
146
+ Whether to normalize the image.
147
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
148
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
149
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
150
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
151
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
152
+ Whether to convert the image to RGB.
153
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
154
+ The channel dimension format for the output image. Can be one of:
155
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
156
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
157
+ - Unset: Use the channel dimension format of the input image.
158
+ input_data_format (`ChannelDimension` or `str`, *optional*):
159
+ The channel dimension format for the input image. Can be one of:
160
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
161
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
162
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
163
+ """
164
+ images = make_list_of_images(images)
165
+
166
+ if do_convert_rgb:
167
+ images = [convert_to_rgb(image) for image in images]
168
+
169
+ # All transformations expect numpy arrays.
170
+ images = [to_numpy_array(image) for image in images]
171
+
172
+ if do_rescale and is_scaled_image(images[0]):
173
+ logger.warning_once(
174
+ "It looks like you are trying to rescale already rescaled images. If the input"
175
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
176
+ )
177
+ if input_data_format is None:
178
+ # We assume that all images have the same channel dimension format.
179
+ input_data_format = infer_channel_dimension_format(images[0])
180
+
181
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
182
+ resized_height, resized_width = height, width
183
+ processed_images = []
184
+ for image in images:
185
+ if do_rescale:
186
+ image = self.rescale(
187
+ image, scale=rescale_factor, input_data_format=input_data_format
188
+ )
189
+
190
+ if do_normalize:
191
+ image = self.normalize(
192
+ image=image,
193
+ mean=image_mean,
194
+ std=image_std,
195
+ input_data_format=input_data_format,
196
+ )
197
+
198
+ image = to_channel_dimension_format(
199
+ image, data_format, input_channel_dim=input_data_format
200
+ )
201
+
202
+ if do_resize:
203
+ resized_height, resized_width = smart_resize(
204
+ height,
205
+ width,
206
+ factor=self.patch_size * self.merge_size,
207
+ min_pixels=self.min_pixels,
208
+ max_pixels=self.max_pixels,
209
+ )
210
+ image = (
211
+ F.interpolate(
212
+ torch.from_numpy(image).unsqueeze(0),
213
+ size=(resized_height, resized_width),
214
+ mode="bicubic",
215
+ )
216
+ .squeeze(0)
217
+ .numpy()
218
+ )
219
+
220
+ processed_images.append(image)
221
+
222
+ patches = np.array(processed_images)
223
+ if data_format == ChannelDimension.LAST:
224
+ patches = patches.transpose(0, 3, 1, 2)
225
+ if patches.shape[0] % self.temporal_patch_size != 0:
226
+ repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0)
227
+ patches = np.concatenate([patches, repeats], axis=0)
228
+ channel = patches.shape[1]
229
+ grid_t = patches.shape[0] // self.temporal_patch_size
230
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
231
+ patches = patches.reshape(
232
+ grid_t,
233
+ self.temporal_patch_size,
234
+ channel,
235
+ grid_h // self.merge_size,
236
+ self.merge_size,
237
+ self.patch_size,
238
+ grid_w // self.merge_size,
239
+ self.merge_size,
240
+ self.patch_size,
241
+ )
242
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
243
+ flatten_patches = patches.reshape(
244
+ grid_t * grid_h * grid_w,
245
+ channel * self.temporal_patch_size * self.patch_size * self.patch_size,
246
+ )
247
+
248
+ return flatten_patches, (grid_t, grid_h, grid_w)
249
+
250
+ def preprocess(
251
+ self,
252
+ images: ImageInput,
253
+ do_resize: bool = None,
254
+ size: Dict[str, int] = None,
255
+ do_rescale: bool = None,
256
+ rescale_factor: float = None,
257
+ do_normalize: bool = None,
258
+ image_mean: Optional[Union[float, List[float]]] = None,
259
+ image_std: Optional[Union[float, List[float]]] = None,
260
+ do_convert_rgb: bool = None,
261
+ return_tensors: Optional[Union[str, TensorType]] = None,
262
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
263
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
264
+ ):
265
+ """
266
+ Args:
267
+ images (`ImageInput`):
268
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
269
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
270
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
271
+ Whether to resize the image.
272
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
273
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
274
+ the longest edge resized to keep the input aspect ratio.
275
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
276
+ Whether to rescale the image.
277
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
278
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
279
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
280
+ Whether to normalize the image.
281
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
282
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
283
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
284
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
285
+ `True`.
286
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
287
+ Whether to convert the image to RGB.
288
+ return_tensors (`str` or `TensorType`, *optional*):
289
+ The type of tensors to return. Can be one of:
290
+ - Unset: Return a list of `np.ndarray`.
291
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
292
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
293
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
294
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
295
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
296
+ The channel dimension format for the output image. Can be one of:
297
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
298
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
299
+ - Unset: Use the channel dimension format of the input image.
300
+ input_data_format (`ChannelDimension` or `str`, *optional*):
301
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
302
+ from the input image. Can be one of:
303
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
304
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
305
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
306
+
307
+ """
308
+ do_resize = do_resize if do_resize is not None else self.do_resize
309
+ size = size if size is not None else self.size
310
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
311
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
312
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
313
+ image_mean = image_mean if image_mean is not None else self.image_mean
314
+ image_std = image_std if image_std is not None else self.image_std
315
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
316
+
317
+ if images is not None:
318
+ images = make_flat_list_of_images(images)
319
+
320
+ if images is not None and not valid_images(images):
321
+ raise ValueError(
322
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
323
+ "torch.Tensor, tf.Tensor or jax.ndarray."
324
+ )
325
+
326
+ if images is not None:
327
+ pixel_values, vision_grid_thws = [], []
328
+ for image in images:
329
+ patches, image_grid_thw = self._preprocess(
330
+ image,
331
+ do_resize=do_resize,
332
+ do_rescale=do_rescale,
333
+ rescale_factor=rescale_factor,
334
+ do_normalize=do_normalize,
335
+ image_mean=image_mean,
336
+ image_std=image_std,
337
+ data_format=data_format,
338
+ do_convert_rgb=do_convert_rgb,
339
+ input_data_format=input_data_format,
340
+ )
341
+ pixel_values.extend(patches)
342
+ vision_grid_thws.append(image_grid_thw)
343
+ pixel_values = np.array(pixel_values)
344
+ vision_grid_thws = np.array(vision_grid_thws)
345
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
346
+
347
+ return BatchFeature(data=data, tensor_type=return_tensors)
348
+
349
+
350
+ class Sarashina2VisionProcessorKwargs(ProcessingKwargs, total=False):
351
+ _defaults = {
352
+ "text_kwargs": {
353
+ "padding": False,
354
+ },
355
+ }
356
+
357
+
358
+ class Sarashina2VisionProcessor(ProcessorMixin):
359
+ r"""
360
+ Constructs Sarashina2Vision processor which wraps a Sarashina2Vision image processor and a LLama tokenizer into a single processor.
361
+ [`Sarashina2VisionProcessor`] offers all the functionalities of [`Sarashina2VisionImageProcessor`] and [`LlamaTokenizerFast`]. See the
362
+ [`~Sarashina2VisionProcessor.__call__`] and [`~Sarashina2VisionProcessor.decode`] for more information.
363
+ Args:
364
+ image_processor ([`Sarashina2VisionImageProcessor`], *optional*):
365
+ The image processor is a required input.
366
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
367
+ The tokenizer is a required input.
368
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
369
+ in a chat into a tokenizable string.
370
+ """
371
+
372
+ attributes = ["image_processor", "tokenizer"]
373
+ valid_kwargs = ["chat_template"]
374
+ image_processor_class = "AutoImageProcessor"
375
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
376
+
377
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
378
+ self.image_token = (
379
+ "<|file|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
380
+ )
381
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
382
+
383
+ def __call__(
384
+ self,
385
+ images: ImageInput = None,
386
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
387
+ **kwargs: Unpack[Sarashina2VisionProcessorKwargs],
388
+ ) -> BatchFeature:
389
+ """
390
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
391
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
392
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
393
+ Sarashina2VisionImageProcessor's [`~Sarashina2VisionImageProcessor.__call__`] if `vision_infos` is not `None`.
394
+
395
+ Args:
396
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
397
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
398
+ tensor. Both channels-first and channels-last formats are supported.
399
+ text (`str`, `List[str]`, `List[List[str]]`):
400
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
401
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
402
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
403
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
404
+ If set, will return tensors of a particular framework. Acceptable values are:
405
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
406
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
407
+ - `'np'`: Return NumPy `np.ndarray` objects.
408
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
409
+
410
+ Returns:
411
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
412
+
413
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
414
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
415
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
416
+ `None`).
417
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
418
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
419
+ """
420
+ output_kwargs = self._merge_kwargs(
421
+ Sarashina2VisionProcessorKwargs,
422
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
423
+ **kwargs,
424
+ )
425
+ if images is not None:
426
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
427
+ image_grid_thw = image_inputs["image_grid_thw"]
428
+ else:
429
+ image_inputs = {}
430
+ image_grid_thw = None
431
+
432
+ if not isinstance(text, list):
433
+ text = [text]
434
+
435
+ if image_grid_thw is not None:
436
+ merge_length = self.image_processor.merge_size**2
437
+ index = 0
438
+ for i in range(len(text)):
439
+ while self.image_token in text[i]:
440
+ text[i] = text[i].replace(
441
+ self.image_token,
442
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length),
443
+ 1,
444
+ )
445
+ index += 1
446
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
447
+
448
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
449
+
450
+ return BatchFeature(data={**text_inputs, **image_inputs})
451
+
452
+ def batch_decode(self, *args, **kwargs):
453
+ """
454
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`].
455
+ """
456
+ return self.tokenizer.batch_decode(*args, **kwargs)
457
+
458
+ def decode(self, *args, **kwargs):
459
+ """
460
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`].
461
+ """
462
+ return self.tokenizer.decode(*args, **kwargs)
463
+
464
+ def post_process_image_text_to_text(self, generated_outputs):
465
+ """
466
+ Post-process the output of the model to decode the text.
467
+
468
+ Args:
469
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
470
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
471
+ or `(sequence_length,)`.
472
+
473
+ Returns:
474
+ `List[str]`: The decoded text.
475
+ """
476
+ return self.tokenizer.batch_decode(
477
+ generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
478
+ )
479
+
480
+ @property
481
+ def model_input_names(self):
482
+ tokenizer_input_names = self.tokenizer.model_input_names
483
+ image_processor_input_names = self.image_processor.model_input_names
484
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
485
+
486
+
487
+ Sarashina2VisionProcessor.register_for_auto_class("AutoProcessor")
488
+ AutoImageProcessor.register("Sarashina2VisionImageProcessor", Sarashina2VisionImageProcessor)
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_sarashina2_vision.Sarashina2VisionProcessor"
4
+ },
5
+ "processor_class": "Sarashina2VisionProcessor"
6
+ }
sample.jpg ADDED

Git LFS Details

  • SHA256: 0d59c72e12d79c7add1cfdd1d044d0f47dca95b0ce75160f740e8c9e69de5473
  • Pointer size: 131 Bytes
  • Size of remote file: 819 kB
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<cls>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "<sep>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:008293028e1a9d9a1038d9b63d989a2319797dfeaa03f171093a57b33a3a8277
3
+ size 1831879
tokenizer_config.json ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_dummy_prefix_space": false,
4
+ "add_eos_token": false,
5
+ "add_prefix_space": false,
6
+ "added_tokens_decoder": {
7
+ "0": {
8
+ "content": "<unk>",
9
+ "lstrip": false,
10
+ "normalized": false,
11
+ "rstrip": false,
12
+ "single_word": false,
13
+ "special": true
14
+ },
15
+ "1": {
16
+ "content": "<s>",
17
+ "lstrip": false,
18
+ "normalized": false,
19
+ "rstrip": false,
20
+ "single_word": false,
21
+ "special": true
22
+ },
23
+ "2": {
24
+ "content": "</s>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false,
29
+ "special": true
30
+ },
31
+ "3": {
32
+ "content": "<pad>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false,
37
+ "special": true
38
+ },
39
+ "4": {
40
+ "content": "<sep>",
41
+ "lstrip": false,
42
+ "normalized": false,
43
+ "rstrip": false,
44
+ "single_word": false,
45
+ "special": true
46
+ },
47
+ "5": {
48
+ "content": "<mask>",
49
+ "lstrip": false,
50
+ "normalized": false,
51
+ "rstrip": false,
52
+ "single_word": false,
53
+ "special": true
54
+ },
55
+ "6": {
56
+ "content": "<cls>",
57
+ "lstrip": false,
58
+ "normalized": false,
59
+ "rstrip": false,
60
+ "single_word": false,
61
+ "special": true
62
+ },
63
+ "7": {
64
+ "content": "<|system|>",
65
+ "lstrip": false,
66
+ "normalized": false,
67
+ "rstrip": false,
68
+ "single_word": false,
69
+ "special": false
70
+ },
71
+ "8": {
72
+ "content": "<|assistant|>",
73
+ "lstrip": false,
74
+ "normalized": false,
75
+ "rstrip": false,
76
+ "single_word": false,
77
+ "special": false
78
+ },
79
+ "9": {
80
+ "content": "<|user|>",
81
+ "lstrip": false,
82
+ "normalized": false,
83
+ "rstrip": false,
84
+ "single_word": false,
85
+ "special": false
86
+ },
87
+ "10": {
88
+ "content": "<|available_tools|>",
89
+ "lstrip": false,
90
+ "normalized": false,
91
+ "rstrip": false,
92
+ "single_word": false,
93
+ "special": false
94
+ },
95
+ "11": {
96
+ "content": "<|tool_calls|>",
97
+ "lstrip": false,
98
+ "normalized": false,
99
+ "rstrip": false,
100
+ "single_word": false,
101
+ "special": false
102
+ },
103
+ "12": {
104
+ "content": "<|tool_results|>",
105
+ "lstrip": false,
106
+ "normalized": false,
107
+ "rstrip": false,
108
+ "single_word": false,
109
+ "special": false
110
+ },
111
+ "13": {
112
+ "content": "<|code|>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false,
117
+ "special": false
118
+ },
119
+ "14": {
120
+ "content": "<|file|>",
121
+ "lstrip": false,
122
+ "normalized": false,
123
+ "rstrip": false,
124
+ "single_word": false,
125
+ "special": false
126
+ },
127
+ "102397": {
128
+ "content": "<|prefix|>",
129
+ "lstrip": false,
130
+ "normalized": false,
131
+ "rstrip": false,
132
+ "single_word": false,
133
+ "special": false
134
+ },
135
+ "102398": {
136
+ "content": "<|suffix|>",
137
+ "lstrip": false,
138
+ "normalized": false,
139
+ "rstrip": false,
140
+ "single_word": false,
141
+ "special": false
142
+ },
143
+ "102399": {
144
+ "content": "<|middle|>",
145
+ "lstrip": false,
146
+ "normalized": false,
147
+ "rstrip": false,
148
+ "single_word": false,
149
+ "special": false
150
+ }
151
+ },
152
+ "auto_map": {
153
+ "AutoProcessor": "processing_sarashina2_vision.Sarashina2VisionProcessor"
154
+ },
155
+ "bos_token": "<s>",
156
+ "chat_template": "{% set image_count = namespace(value=0) %}{% for message in messages %}{% if message['content'] is string %}{% if message['role'] == 'user' %}{{ '<|user|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% endif %}{% else %}{% if message['role'] == 'user' %}{{ '<|user|>' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' }}{% endif %}{% for content in message['content'] %}{% if content['type'] == 'image' or content.get('image') or content.get('image_url') %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %} Picture {{ image_count.value }}: {% endif %}{{ '<|prefix|><|file|><|suffix|>' }}{% endif %}{% endfor %}{% for content in message['content'] %}{% if content.get('text') %}{{ content['text'] }}{% endif %}{% endfor %}{{ eos_token }}{% endif %}{% endfor %}{% if messages[-1]['role'] == 'user' %}{{ '<|assistant|>' }}{% endif %}",
157
+ "clean_up_tokenization_spaces": false,
158
+ "cls_token": "<cls>",
159
+ "do_lower_case": false,
160
+ "eos_token": "</s>",
161
+ "extra_ids": 0,
162
+ "extra_special_tokens": {},
163
+ "keep_accents": true,
164
+ "legacy": false,
165
+ "mask_token": "<mask>",
166
+ "model_max_length": 8192,
167
+ "pad_token": "<pad>",
168
+ "padding_side": "right",
169
+ "processor_class": "Sarashina2VisionProcessor",
170
+ "sep_token": "<sep>",
171
+ "sp_model_kwargs": {},
172
+ "spaces_between_special_tokens": false,
173
+ "tokenizer_class": "LlamaTokenizer",
174
+ "unk_token": "<unk>",
175
+ "use_default_system_prompt": false
176
+ }