Joseph Pollack commited on
Commit
6b4a0c8
Β·
unverified Β·
1 Parent(s): a0c936d

attempts to add an annotated image component with bounding boxes

Browse files
Files changed (1) hide show
  1. app.py +160 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
  import json
5
  import os
6
  from transformers import AutoProcessor, AutoModelForImageTextToText
@@ -23,6 +23,107 @@ if not HF_TOKEN:
23
  logger.warning("HF_TOKEN not found in environment variables. Model access may be restricted.")
24
  logger.warning("Please set HF_TOKEN in your environment variables or Spaces secrets.")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class LOperatorDemo:
27
  def __init__(self):
28
  self.model = None
@@ -160,16 +261,16 @@ demo_instance = LOperatorDemo()
160
  def process_input(image, goal, step_instructions):
161
  """Process the input and generate action"""
162
  if image is None:
163
- return "❌ Please upload an Android screenshot image."
164
 
165
  if not goal.strip():
166
- return "❌ Please provide a goal."
167
 
168
  if not step_instructions.strip():
169
- return "❌ Please provide step instructions."
170
 
171
  if not demo_instance.is_loaded:
172
- return "❌ Model not loaded. Please wait for it to load automatically."
173
 
174
  try:
175
  # Handle different image formats
@@ -183,10 +284,10 @@ def process_input(image, goal, step_instructions):
183
  # Handle Gradio file object
184
  pil_image = Image.open(image.name)
185
  else:
186
- return "❌ Invalid image format. Please upload a valid image."
187
 
188
  if pil_image is None:
189
- return "❌ Failed to process image. Please try again."
190
 
191
  # Convert image to RGB if needed
192
  if pil_image.mode != "RGB":
@@ -194,11 +295,33 @@ def process_input(image, goal, step_instructions):
194
 
195
  # Generate action using goal and step instructions
196
  response = demo_instance.generate_action(pil_image, goal, step_instructions)
197
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  except Exception as e:
200
  logger.error(f"Error processing input: {str(e)}")
201
- return f"❌ Error: {str(e)}"
 
 
 
 
 
 
 
202
 
203
 
204
  def load_example_episodes():
@@ -281,6 +404,12 @@ def create_demo():
281
  .output-container {
282
  min-height: 200px;
283
  }
 
 
 
 
 
 
284
  """
285
  ) as demo:
286
 
@@ -303,7 +432,7 @@ def create_demo():
303
  The model generates JSON actions in the following format:
304
  ```json
305
  {
306
- "action_type": "tap",
307
  "x": 540,
308
  "y": 1200,
309
  "text": "Settings",
@@ -312,6 +441,8 @@ def create_demo():
312
  }
313
  ```
314
 
 
 
315
  ---
316
  """)
317
 
@@ -342,6 +473,16 @@ def create_demo():
342
  process_btn = gr.Button("πŸš€ Generate Action", variant="primary", size="lg")
343
 
344
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
345
  gr.Markdown("### πŸ“Š Generated Action")
346
  output_text = gr.Textbox(
347
  label="JSON Action Output",
@@ -350,12 +491,16 @@ def create_demo():
350
  interactive=False,
351
  elem_classes=["output-container"]
352
  )
353
-
354
  # Connect the process button
355
  process_btn.click(
356
  fn=process_input,
357
  inputs=[image_input, goal_input, step_instructions_input],
358
- outputs=output_text
 
 
 
 
359
  )
360
 
361
  # Load examples
@@ -395,6 +540,9 @@ def create_demo():
395
  fn=lambda img, g, s: (img, g, s),
396
  inputs=[example_image, example_goal, example_step_instruction],
397
  outputs=[image_input, goal_input, step_instructions_input]
 
 
 
398
  )
399
  except Exception as e:
400
  logger.warning(f"Failed to load examples: {str(e)}")
 
1
  import gradio as gr
2
  import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
  import json
5
  import os
6
  from transformers import AutoProcessor, AutoModelForImageTextToText
 
23
  logger.warning("HF_TOKEN not found in environment variables. Model access may be restricted.")
24
  logger.warning("Please set HF_TOKEN in your environment variables or Spaces secrets.")
25
 
26
+ def create_annotated_image(image: Image.Image, x: int, y: int, action_type: str = "click") -> Image.Image:
27
+ """Create an image with a bounding box around the specified coordinates"""
28
+ try:
29
+ # Create a copy of the original image
30
+ annotated_image = image.copy()
31
+ draw = ImageDraw.Draw(annotated_image)
32
+
33
+ # Define bounding box parameters - make it generous as requested
34
+ box_size = 120 # Increased size for more generous bounding box
35
+ box_color = (255, 0, 0) # Red color
36
+ line_width = 4 # Thicker line for better visibility
37
+
38
+ # Calculate bounding box coordinates
39
+ left = max(0, x - box_size // 2)
40
+ top = max(0, y - box_size // 2)
41
+ right = min(image.width, x + box_size // 2)
42
+ bottom = min(image.height, y + box_size // 2)
43
+
44
+ # Draw the bounding box with rounded corners effect
45
+ draw.rectangle([left, top, right, bottom], outline=box_color, width=line_width)
46
+
47
+ # Draw corner indicators for better visibility
48
+ corner_size = 15
49
+ # Top-left corner
50
+ draw.line([left, top, left + corner_size, top], fill=box_color, width=line_width)
51
+ draw.line([left, top, left, top + corner_size], fill=box_color, width=line_width)
52
+ # Top-right corner
53
+ draw.line([right - corner_size, top, right, top], fill=box_color, width=line_width)
54
+ draw.line([right, top, right, top + corner_size], fill=box_color, width=line_width)
55
+ # Bottom-left corner
56
+ draw.line([left, bottom - corner_size, left, bottom], fill=box_color, width=line_width)
57
+ draw.line([left, bottom, left + corner_size, bottom], fill=box_color, width=line_width)
58
+ # Bottom-right corner
59
+ draw.line([right - corner_size, bottom, right, bottom], fill=box_color, width=line_width)
60
+ draw.line([right, bottom - corner_size, right, bottom], fill=box_color, width=line_width)
61
+
62
+ # Draw a crosshair at the exact point
63
+ crosshair_size = 15
64
+ crosshair_color = (255, 255, 0) # Yellow crosshair for contrast
65
+ draw.line([x - crosshair_size, y, x + crosshair_size, y], fill=crosshair_color, width=3)
66
+ draw.line([x, y - crosshair_size, x, y + crosshair_size], fill=crosshair_color, width=3)
67
+
68
+ # Add a small circle at the center
69
+ circle_radius = 4
70
+ draw.ellipse([x - circle_radius, y - circle_radius, x + circle_radius, y + circle_radius],
71
+ fill=crosshair_color, outline=box_color, width=2)
72
+
73
+ # Add text label with better positioning
74
+ try:
75
+ font = ImageFont.load_default()
76
+ except:
77
+ font = ImageFont.load_default()
78
+
79
+ label_text = f"{action_type.upper()}: ({x}, {y})"
80
+ text_bbox = draw.textbbox((0, 0), label_text, font=font)
81
+ text_width = text_bbox[2] - text_bbox[0]
82
+ text_height = text_bbox[3] - text_bbox[1]
83
+
84
+ # Position text above the bounding box, but ensure it's visible
85
+ text_x = max(5, left)
86
+ text_y = max(5, top - text_height - 10)
87
+
88
+ # If text would go off the top, position it below the box
89
+ if text_y < 5:
90
+ text_y = min(image.height - text_height - 5, bottom + 10)
91
+
92
+ # Draw text background with better contrast
93
+ draw.rectangle([text_x - 4, text_y - 4, text_x + text_width + 4, text_y + text_height + 4],
94
+ fill=(0, 0, 0, 180))
95
+
96
+ # Draw text
97
+ draw.text((text_x, text_y), label_text, fill=(255, 255, 255), font=font)
98
+
99
+ return annotated_image
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error creating annotated image: {str(e)}")
103
+ return image # Return original image if annotation fails
104
+
105
+ def parse_action_response(response: str) -> tuple:
106
+ """Parse the action response and extract coordinates if present"""
107
+ try:
108
+ # Try to parse as JSON
109
+ if response.strip().startswith('{'):
110
+ action_data = json.loads(response)
111
+
112
+ # Check if it's a click action with coordinates
113
+ if (action_data.get('action_type') == 'click' and
114
+ 'x' in action_data and 'y' in action_data):
115
+ return action_data, True
116
+ else:
117
+ return action_data, False
118
+ else:
119
+ return response, False
120
+
121
+ except json.JSONDecodeError:
122
+ return response, False
123
+ except Exception as e:
124
+ logger.error(f"Error parsing action response: {str(e)}")
125
+ return response, False
126
+
127
  class LOperatorDemo:
128
  def __init__(self):
129
  self.model = None
 
261
  def process_input(image, goal, step_instructions):
262
  """Process the input and generate action"""
263
  if image is None:
264
+ return "❌ Please upload an Android screenshot image.", None
265
 
266
  if not goal.strip():
267
+ return "❌ Please provide a goal.", None
268
 
269
  if not step_instructions.strip():
270
+ return "❌ Please provide step instructions.", None
271
 
272
  if not demo_instance.is_loaded:
273
+ return "❌ Model not loaded. Please wait for it to load automatically.", None
274
 
275
  try:
276
  # Handle different image formats
 
284
  # Handle Gradio file object
285
  pil_image = Image.open(image.name)
286
  else:
287
+ return "❌ Invalid image format. Please upload a valid image.", None
288
 
289
  if pil_image is None:
290
+ return "❌ Failed to process image. Please try again.", None
291
 
292
  # Convert image to RGB if needed
293
  if pil_image.mode != "RGB":
 
295
 
296
  # Generate action using goal and step instructions
297
  response = demo_instance.generate_action(pil_image, goal, step_instructions)
298
+
299
+ # Parse the response to check for coordinates
300
+ action_data, has_coordinates = parse_action_response(response)
301
+
302
+ # If coordinates are found, create annotated image
303
+ annotated_image = None
304
+ if has_coordinates and isinstance(action_data, dict):
305
+ x = action_data.get('x')
306
+ y = action_data.get('y')
307
+ action_type = action_data.get('action_type', 'click')
308
+
309
+ if x is not None and y is not None:
310
+ annotated_image = create_annotated_image(pil_image, x, y, action_type)
311
+ logger.info(f"Created annotated image for coordinates ({x}, {y})")
312
+
313
+ return response, annotated_image
314
 
315
  except Exception as e:
316
  logger.error(f"Error processing input: {str(e)}")
317
+ return f"❌ Error: {str(e)}", None
318
+
319
+ def update_annotated_image_visibility(response, annotated_image):
320
+ """Update the visibility of the annotated image based on whether coordinates are present"""
321
+ if annotated_image is not None:
322
+ return gr.update(visible=True, value=annotated_image)
323
+ else:
324
+ return gr.update(visible=False, value=None)
325
 
326
 
327
  def load_example_episodes():
 
404
  .output-container {
405
  min-height: 200px;
406
  }
407
+ .annotated-image-container {
408
+ border: 2px solid #e0e0e0;
409
+ border-radius: 8px;
410
+ padding: 10px;
411
+ margin-top: 10px;
412
+ }
413
  """
414
  ) as demo:
415
 
 
432
  The model generates JSON actions in the following format:
433
  ```json
434
  {
435
+ "action_type": "click",
436
  "x": 540,
437
  "y": 1200,
438
  "text": "Settings",
 
441
  }
442
  ```
443
 
444
+ **🎯 Visual Feedback**: When the model returns coordinates (x, y), an annotated screenshot will be displayed showing the exact click location with a red bounding box and crosshair.
445
+
446
  ---
447
  """)
448
 
 
473
  process_btn = gr.Button("πŸš€ Generate Action", variant="primary", size="lg")
474
 
475
  with gr.Column(scale=1):
476
+
477
+ gr.Markdown("### 🎯 Annotated Screenshot")
478
+ annotated_image_output = gr.Image(
479
+ label="Click Location Highlighted",
480
+ height=400,
481
+ visible=False,
482
+ interactive=False,
483
+ elem_classes=["annotated-image-container"]
484
+ )
485
+
486
  gr.Markdown("### πŸ“Š Generated Action")
487
  output_text = gr.Textbox(
488
  label="JSON Action Output",
 
491
  interactive=False,
492
  elem_classes=["output-container"]
493
  )
494
+
495
  # Connect the process button
496
  process_btn.click(
497
  fn=process_input,
498
  inputs=[image_input, goal_input, step_instructions_input],
499
+ outputs=[output_text, annotated_image_output]
500
+ ).then(
501
+ fn=update_annotated_image_visibility,
502
+ inputs=[output_text, annotated_image_output],
503
+ outputs=annotated_image_output
504
  )
505
 
506
  # Load examples
 
540
  fn=lambda img, g, s: (img, g, s),
541
  inputs=[example_image, example_goal, example_step_instruction],
542
  outputs=[image_input, goal_input, step_instructions_input]
543
+ ).then(
544
+ fn=lambda: (None, gr.update(visible=False)),
545
+ outputs=[output_text, annotated_image_output]
546
  )
547
  except Exception as e:
548
  logger.warning(f"Failed to load examples: {str(e)}")