ocr / app.py
vithacocf's picture
Update app.py
2364e8e verified
# =========================
# CAMEL-DOC-OCR (HF Spaces SAFE)
# Single-file – NO CUDA init at global scope
# =========================
import os
import gc
import torch
import fitz
import gradio as gr
import spaces
from PIL import Image
from transformers import AutoProcessor, BitsAndBytesConfig
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
# =========================
# CONFIG
# =========================
MODEL_ID = "prithivMLmods/Camel-Doc-OCR-062825"
DPI = 150
MAX_IMAGE_SIZE = 2048
# =========================
# TORCH FLAGS (SAFE FOR SPACES)
# =========================
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# =========================
# LOAD MODEL (NO CUDA INIT HERE)
# =========================
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
quantization_config=bnb,
device_map="auto", # HF Spaces will inject GPU here
torch_dtype=torch.float16,
trust_remote_code=True
).eval()
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
# =========================
# PDF → IMAGE (FAST & SAFE)
# =========================
def pdf_to_images(pdf_bytes):
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
images = []
scale = DPI / 72.0
mat = fitz.Matrix(scale, scale)
for page in doc:
pix = page.get_pixmap(matrix=mat, colorspace=fitz.csRGB)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
if max(img.size) > MAX_IMAGE_SIZE:
img.thumbnail((MAX_IMAGE_SIZE, MAX_IMAGE_SIZE), Image.Resampling.LANCZOS)
images.append(img)
return images
# =========================
# OCR INFERENCE (CUDA ONLY HERE)
# =========================
@spaces.GPU
def run_inference(image, prompt, max_new_tokens):
if image.mode != "RGB":
image = image.convert("RGB")
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}]
text_prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = processor(
text=[text_prompt],
images=[image],
return_tensors="pt",
truncation=False, # 🔴 BẮT BUỘC
padding="longest" # 🔴 BẮT BUỘC
).to(model.device)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
eos_token_id=processor.tokenizer.eos_token_id
)
outputs = outputs[:, inputs["input_ids"].shape[1]:]
return processor.tokenizer.decode(
outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
).strip()
# =========================
# FILE HANDLER
# =========================
def handle_file(file, prompt, max_new_tokens, progress=gr.Progress()):
file_path = file.name
ext = file_path.lower().split(".")[-1]
prompt = prompt.strip()
if ext == "pdf":
with open(file_path, "rb") as f:
images = pdf_to_images(f.read())
results = []
for i, img in enumerate(images):
text = run_inference(img, prompt, max_new_tokens)
results.append(text)
progress((i + 1) / len(images), desc=f"Page {i+1}/{len(images)}")
return "\n\n--- PAGE BREAK ---\n\n".join(results)
else:
img = Image.open(file_path)
return run_inference(img, prompt, max_new_tokens)
# =========================
# DEFAULT PROMPT (CAMEL OCR)
# =========================
DEFAULT_PROMPT = """
You are an OCR + Information Extraction engine.
Extract data strictly from the document.
Return JSON ONLY. NO explanation.
OUTPUT FORMAT:
{
"price": "",
"vat": "",
"invoiceNo": "",
"invoiceDate": "",
"billingToTaxCode": "",
"accountingObjectTaxCode": "",
"description": ""
}
""".strip()
# =========================
# GRADIO UI
# =========================
with gr.Blocks(title="Camel-Doc-OCR") as demo:
gr.Markdown("## 🧾 Camel-Doc-OCR (Qwen2.5-VL – 4bit, HF Spaces Safe)")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload Image / PDF",
file_types=[".jpg", ".jpeg", ".png", ".pdf"]
)
prompt_input = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
lines=10
)
max_tokens = gr.Radio(
[256, 512, 1024, 2048],
value=512,
label="Max new tokens"
)
run_btn = gr.Button("🚀 Run OCR", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(
label="Result",
lines=20
)
run_btn.click(
fn=handle_file,
inputs=[file_input, prompt_input, max_tokens],
outputs=output
)
# =========================
# CLEANUP
# =========================
def cleanup():
torch.cuda.empty_cache()
gc.collect()
# =========================
# LAUNCH
# =========================
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)