llm-compressor-my-repo / quantize_huihui_fara.py
n00b001's picture
save
c2bdc87 unverified
#!/usr/bin/env python
"""
Script to quantize the huihui-ai/Huihui-Fara-7B-abliterated model with Qwen2.5-VL architecture support
Uses sequential onloading for memory efficiency.
"""
import base64
from io import BytesIO
import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.utils import dispatch_for_generation
def create_qwen2_5_vl_data_collator():
"""Create a data collator for Qwen2.5-VL models that handles multimodal inputs."""
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) if isinstance(value, (list, int, float)) else value
for key, value in batch[0].items()}
return data_collator
def create_qwen2_5_vl_preprocessing_fn(processor, max_sequence_length: int = 2048):
"""Create a preprocessing function for Qwen2.5-VL datasets."""
def preprocess_and_tokenize(example):
# Handle different image formats
if 'image' in example:
# Process image
if hasattr(example['image'], 'save'):
# PIL Image object
buffered = BytesIO()
example["image"].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
else:
# Already a string or other format
base64_qwen = str(example["image"])
else:
# If there's no image field, try 'img' or similar
img_key = None
for key in example.keys():
if 'image' in key.lower() or 'img' in key.lower():
img_key = key
break
if img_key:
if hasattr(example[img_key], 'save'):
buffered = BytesIO()
example[img_key].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded_image_text = encoded_image.decode("utf-8")
base64_qwen = f"data:image;base64,{encoded_image_text}"
else:
base64_qwen = str(example[img_key])
else:
# If no image, create a simple text-only example
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": example.get('text', example.get('content', 'What can you tell me about this?'))},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return processor(
text=[text],
padding=False,
max_length=max_sequence_length,
truncation=True,
)
# Create message with image
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_qwen},
{"type": "text", "text": "What does the image show?"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
# tokenize
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=False,
max_length=max_sequence_length,
truncation=True,
)
return preprocess_and_tokenize
def get_qwen2_5_vl_quantization_recipe(method: str, scheme: str = "W4A16"):
"""
Creates the appropriate quantization recipe for Qwen2.5-VL models.
Args:
method: Quantization method ("GPTQ", "AWQ", or "FP8")
scheme: Quantization scheme (e.g., "W4A16", "W8A8", "FP8")
Returns:
List of modifiers for the quantization recipe
"""
if method == "GPTQ":
return [
GPTQModifier(
targets="Linear",
scheme=scheme,
ignore=["lm_head", "re:visual.*", "re:model.visual.*"],
sequential_targets=["Qwen2_5_VLDecoderLayer"], # This enables sequential onloading
),
]
elif method == "AWQ":
# Create AWQ mappings for Qwen2.5-VL architecture
mappings = [
AWQMapping(
"re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
),
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
AWQMapping(
"re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"]
),
AWQMapping("re:.*up_proj", ["re:.*down_proj"]),
]
return [
AWQModifier(
ignore=["lm_head", "re:visual.*", "re:model.visual.*"],
scheme="W4A16_ASYM" if scheme == "W4A16" else scheme,
targets=["Linear"],
mappings=mappings,
sequential_targets=["Qwen2_5_VLDecoderLayer"], # Sequential onloading for memory efficiency
),
]
elif method == "FP8":
return [
QuantizationModifier(
scheme="FP8",
targets="Linear",
ignore=["lm_head", "re:visual.*", "re:model.visual.*"]
)
]
else:
raise ValueError(f"Unsupported quantization method: {method}")
def quantize_huihui_fara_model(
model_id: str = "huihui-ai/Huihui-Fara-7B-abliterated",
quantization_method: str = "GPTQ",
output_dir: str = None,
dataset_id: str = "wikitext",
dataset_config: str = "wikitext-2-raw-v1",
dataset_split: str = "train[:1%]",
num_calibration_samples: int = 64,
max_sequence_length: int = 512,
scheme: str = "W4A16",
trust_remote_code: bool = True,
):
"""
Quantizes the huihui-ai/Huihui-Fara-7B-abliterated model with proper Qwen2.5-VL architecture support.
Args:
model_id: Hugging Face model ID to quantize
quantization_method: Method to use ("GPTQ", "AWQ", or "FP8")
output_dir: Directory to save the quantized model
dataset_id: Dataset ID for calibration
dataset_config: Dataset config for calibration
dataset_split: Dataset split for calibration
num_calibration_samples: Number of samples to use for calibration
max_sequence_length: Maximum sequence length for processing
scheme: Quantization scheme (e.g., "W4A16", "W8A8")
trust_remote_code: Whether to trust remote code in model loading
Returns:
Quantized model
"""
print(f"Loading model: {model_id}")
# Handle different device scenarios properly
if torch.cuda.is_available():
try:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 to save memory
device_map="auto", # Auto device mapping for memory efficiency
trust_remote_code=trust_remote_code
)
except RuntimeError as e:
if "out of memory" in str(e).lower() or "offload_dir" in str(e):
print(f"Memory issue detected, using offloading: {e}")
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
offload_folder=temp_dir,
max_memory={0: "24GB", "cpu": "48GB"},
trust_remote_code=trust_remote_code
)
else:
raise
else:
# CPU only
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32, # Use float32 on CPU
device_map="cpu",
trust_remote_code=trust_remote_code
)
print(f"Loading processor for: {model_id}")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code)
# If output directory not specified, create one based on model and method
if not output_dir:
model_name = model_id.rstrip("/").split("/")[-1]
output_dir = f"{model_name}-{scheme.replace(':', '-')}-{quantization_method}"
print(f"Output directory: {output_dir}")
# Load dataset and preprocess
print(f"Loading dataset: {dataset_id}")
try:
# Try to load a multimodal dataset first
ds = load_dataset("lmms-lab/flickr30k", split="test[:64]")
print("Using multimodal dataset for calibration")
preprocess_fn = create_qwen2_5_vl_preprocessing_fn(processor, max_sequence_length)
ds = ds.map(preprocess_fn, remove_columns=ds.column_names)
except Exception as e:
print(f"Failed to load multimodal dataset: {e}, falling back to text-only dataset")
# If multimodal dataset fails, use text-only
ds = load_dataset(dataset_id, dataset_config, split=dataset_split)
ds = ds.shuffle(seed=42)
# Text-only preprocessing
def text_only_preprocess(example):
text = example.get('text', example.get('content', str(example)))
if not isinstance(text, str):
text = str(text)
# Limit text length to avoid exceeding max sequence length
text = text[:500] + "..." if len(text) > 500 else text
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return processor(text=[prompt], padding=False, max_length=max_sequence_length, truncation=True)
ds = ds.map(text_only_preprocess, remove_columns=ds.column_names)
# Define data collator
data_collator = create_qwen2_5_vl_data_collator()
# Create recipe
recipe = get_qwen2_5_vl_quantization_recipe(quantization_method, scheme)
print(f"Starting quantization with method: {quantization_method}")
print(f"Using recipe: {recipe}")
print(f"Using sequential targets: {[mod.sequential_targets if hasattr(mod, 'sequential_targets') else 'N/A' for mod in recipe]}")
# Perform oneshot quantization with sequential onloading for memory efficiency
oneshot(
model=model,
tokenizer=processor, # Use processor as tokenizer for Qwen2.5-VL
dataset=ds,
recipe=recipe,
max_seq_length=max_sequence_length,
num_calibration_samples=num_calibration_samples,
trust_remote_code_model=trust_remote_code,
data_collator=data_collator,
save_compressed=True,
output_dir=output_dir,
)
print(f"Quantization completed! Model saved to: {output_dir}")
# Save the processor as well
processor.save_pretrained(output_dir)
return model
def test_quantized_model(model, processor, max_sequence_length: int = 2048):
"""
Tests the quantized model with a sample generation.
"""
print("========== SAMPLE GENERATION ==============")
try:
dispatch_for_generation(model)
# Simple text-only test first
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello, how are you today?"},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[prompt],
padding=False,
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
).to(model.device)
output = model.generate(**inputs, max_new_tokens=50)
result = processor.decode(output[0], skip_special_tokens=True)
print(result)
print("==========================================")
return result
except Exception as e:
print(f"Test generation failed: {e}")
import traceback
traceback.print_exc()
return None
def main():
"""
Main function to quantize the Huihui-Fara model.
"""
import argparse
parser = argparse.ArgumentParser(description="Quantize huihui-ai/Huihui-Fara-7B-abliterated model")
parser.add_argument("--model_id", type=str, default="huihui-ai/Huihui-Fara-7B-abliterated",
help="Model ID to quantize")
parser.add_argument("--method", type=str, choices=["GPTQ", "AWQ", "FP8"],
default="GPTQ", help="Quantization method to use")
parser.add_argument("--output_dir", type=str, default=None,
help="Output directory for quantized model")
parser.add_argument("--dataset_id", type=str, default="wikitext",
help="Dataset for calibration (default: wikitext)")
parser.add_argument("--scheme", type=str, default="W4A16",
help="Quantization scheme (e.g., W4A16, W8A8)")
parser.add_argument("--num_samples", type=int, default=64,
help="Number of calibration samples")
args = parser.parse_args()
print(f"Starting quantization of {args.model_id} using {args.method}")
print("Note: This may take a while and will use sequential onloading for memory efficiency...")
try:
# Quantize the model
quantized_model = quantize_huihui_fara_model(
model_id=args.model_id,
quantization_method=args.method,
output_dir=args.output_dir,
dataset_id=args.dataset_id,
num_calibration_samples=args.num_samples,
scheme=args.scheme
)
# Test the model
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
test_quantized_model(quantized_model, processor)
print(f"✅ Successfully quantized {args.model_id} with {args.method}")
print(f"Model saved to: {args.output_dir or args.model_id.split('/')[-1] + f'-{args.scheme}-{args.method}'}")
except Exception as e:
print(f"❌ Quantization failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()