Spaces:
Runtime error
Runtime error
Shanshan Wang
commited on
Commit
Β·
bcfef20
1
Parent(s):
d6bfd67
Track binary files with Git LFS
Browse files- .gitattributes +1 -0
- app.py +69 -4
- assets/rental_application.png +3 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/handwritten-note-example.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/handwritten-note-example.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -30,6 +30,29 @@ example_prompts = [
|
|
| 30 |
]
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def load_model_and_set_image_function(model_name):
|
| 34 |
# Get the model path from the model_paths dictionary
|
| 35 |
model_path = model_paths[model_name]
|
|
@@ -245,10 +268,34 @@ def regenerate_response(chatbot,
|
|
| 245 |
def clear_all():
|
| 246 |
return [], None, None, "" # Clear chatbot, state, reset image_input
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
# Build the Gradio interface
|
| 249 |
with gr.Blocks() as demo:
|
| 250 |
-
gr.
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
state= gr.State()
|
| 253 |
model_state = gr.State()
|
| 254 |
|
|
@@ -258,7 +305,12 @@ with gr.Blocks() as demo:
|
|
| 258 |
label="Select Model",
|
| 259 |
value="H2OVL-Mississippi-2B"
|
| 260 |
)
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
with gr.Row(equal_height=True):
|
| 264 |
# First column with image input
|
|
@@ -293,7 +345,7 @@ with gr.Blocks() as demo:
|
|
| 293 |
inputs=None,
|
| 294 |
outputs=[chatbot, state]
|
| 295 |
)
|
| 296 |
-
|
| 297 |
|
| 298 |
# Reset chatbot and state when image input changes
|
| 299 |
image_input.change(
|
|
@@ -343,6 +395,18 @@ with gr.Blocks() as demo:
|
|
| 343 |
label="Tile Number (default: 6)"
|
| 344 |
)
|
| 345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
with gr.Row():
|
| 347 |
submit_button = gr.Button("Submit")
|
| 348 |
regenerate_button = gr.Button("Regenerate")
|
|
@@ -394,6 +458,7 @@ with gr.Blocks() as demo:
|
|
| 394 |
gr.Examples(
|
| 395 |
examples=[
|
| 396 |
["assets/handwritten-note-example.jpg", "Read the text on the image"],
|
|
|
|
| 397 |
["assets/receipt.jpg", "Extract the text from the image."],
|
| 398 |
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
|
| 399 |
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
|
|
|
|
| 30 |
]
|
| 31 |
|
| 32 |
|
| 33 |
+
# Function to handle task type logic
|
| 34 |
+
def handle_task_type(task_type, model_name):
|
| 35 |
+
max_new_tokens = 1024 # Default value
|
| 36 |
+
if task_type == "OCR":
|
| 37 |
+
max_new_tokens = 3072 # Adjust for OCR
|
| 38 |
+
return max_new_tokens
|
| 39 |
+
|
| 40 |
+
# Function to handle task type logic and default question
|
| 41 |
+
def handle_task_type_and_prompt(task_type, model_name):
|
| 42 |
+
max_new_tokens = handle_task_type(task_type, model_name)
|
| 43 |
+
default_question = example_prompts[0] if task_type == "OCR" else None
|
| 44 |
+
return max_new_tokens, default_question
|
| 45 |
+
|
| 46 |
+
def update_task_type_on_model_change(model_name):
|
| 47 |
+
# Set default task type and max_new_tokens based on the model
|
| 48 |
+
if '2b' in model_name.lower():
|
| 49 |
+
return "Document extractor", handle_task_type("Document extractor", model_name)
|
| 50 |
+
elif '0.8b' in model_name.lower():
|
| 51 |
+
return "OCR", handle_task_type("OCR", model_name)
|
| 52 |
+
else:
|
| 53 |
+
return "Chat", handle_task_type("Chat", model_name)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
def load_model_and_set_image_function(model_name):
|
| 57 |
# Get the model path from the model_paths dictionary
|
| 58 |
model_path = model_paths[model_name]
|
|
|
|
| 268 |
def clear_all():
|
| 269 |
return [], None, None, "" # Clear chatbot, state, reset image_input
|
| 270 |
|
| 271 |
+
|
| 272 |
+
title_html = """
|
| 273 |
+
<h1> <span class="gradient-text" id="text">H2OVL-Mississippi</span><span class="plain-text">: Lightweight Vision Language Models for OCR and Doc AI tasks</span></h1>
|
| 274 |
+
<a href="https://huggingface.co/collections/h2oai/h2ovl-mississippi-66e492da45da0a1b7ea7cf39">[π Hugging Face]</a>
|
| 275 |
+
<a href="https://arxiv.org/abs/2410.13611">[π Paper]</a>
|
| 276 |
+
<a href="https://huggingface.co/spaces/h2oai/h2ovl-mississippi-benchmarks">[π Benchmarks]</a>
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
# Build the Gradio interface
|
| 282 |
with gr.Blocks() as demo:
|
| 283 |
+
gr.HTML(title_html)
|
| 284 |
+
gr.HTML("""
|
| 285 |
+
<style>
|
| 286 |
+
.gradient-text {
|
| 287 |
+
font-size: 36px !important;
|
| 288 |
+
font-weight: bold !important;
|
| 289 |
+
}
|
| 290 |
+
.plain-text {
|
| 291 |
+
font-size: 32px !important;
|
| 292 |
+
}
|
| 293 |
+
h1 {
|
| 294 |
+
margin-bottom: 20px !important;
|
| 295 |
+
}
|
| 296 |
+
</style>
|
| 297 |
+
""")
|
| 298 |
+
|
| 299 |
state= gr.State()
|
| 300 |
model_state = gr.State()
|
| 301 |
|
|
|
|
| 305 |
label="Select Model",
|
| 306 |
value="H2OVL-Mississippi-2B"
|
| 307 |
)
|
| 308 |
+
|
| 309 |
+
task_type_dropdown = gr.Dropdown(
|
| 310 |
+
choices=["OCR", "Document extractor", "Chat"],
|
| 311 |
+
label="Select Task Type",
|
| 312 |
+
value="Document extractor"
|
| 313 |
+
)
|
| 314 |
|
| 315 |
with gr.Row(equal_height=True):
|
| 316 |
# First column with image input
|
|
|
|
| 345 |
inputs=None,
|
| 346 |
outputs=[chatbot, state]
|
| 347 |
)
|
| 348 |
+
|
| 349 |
|
| 350 |
# Reset chatbot and state when image input changes
|
| 351 |
image_input.change(
|
|
|
|
| 395 |
label="Tile Number (default: 6)"
|
| 396 |
)
|
| 397 |
|
| 398 |
+
model_dropdown.change(
|
| 399 |
+
fn=update_task_type_on_model_change,
|
| 400 |
+
inputs=[model_dropdown],
|
| 401 |
+
outputs=[task_type_dropdown, max_new_tokens_input]
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
task_type_dropdown.change(
|
| 405 |
+
fn=handle_task_type_and_prompt,
|
| 406 |
+
inputs=[task_type_dropdown, model_dropdown],
|
| 407 |
+
outputs=[max_new_tokens_input, user_input]
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
with gr.Row():
|
| 411 |
submit_button = gr.Button("Submit")
|
| 412 |
regenerate_button = gr.Button("Regenerate")
|
|
|
|
| 458 |
gr.Examples(
|
| 459 |
examples=[
|
| 460 |
["assets/handwritten-note-example.jpg", "Read the text on the image"],
|
| 461 |
+
["assets/rental_application.png", "Read the text and provide word by word ocr for the document. <doc>"],
|
| 462 |
["assets/receipt.jpg", "Extract the text from the image."],
|
| 463 |
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
|
| 464 |
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
|
assets/rental_application.png
ADDED
|
Git LFS Details
|