Spaces:
Build error
Build error
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
| from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| import gc | |
| models = { | |
| 'vit_b': './checkpoints/sam_vit_b_01ec64.pth', | |
| 'vit_l': './checkpoints/sam_vit_l_0b3195.pth', | |
| 'vit_h': './checkpoints/sam_vit_h_4b8939.pth' | |
| } | |
| image_examples = [ | |
| [os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"), 0, []], | |
| [os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"), 1, []], | |
| [os.path.join(os.path.dirname(__file__), "./images/1.jpg"),2,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/2.jpg"),3,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/3.jpg"),4,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/4.jpg"),5,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/5.jpg"),6,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/6.jpg"),7,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/7.jpg"),8,[]], | |
| [os.path.join(os.path.dirname(__file__), "./images/8.jpg"),9,[]] | |
| ] | |
| def plot_boxes(img, boxes): | |
| img_pil = Image.fromarray(np.uint8(img * 255)).convert('RGB') | |
| draw = ImageDraw.Draw(img_pil) | |
| for box in boxes: | |
| color = tuple(np.random.randint(0, 255, size=3).tolist()) | |
| x0, y0, x1, y1 = box | |
| x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) | |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=6) | |
| return img_pil | |
| def segment_one(img, mask_generator, seed=None): | |
| if seed is not None: | |
| np.random.seed(seed) | |
| masks = mask_generator.generate(img) | |
| sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) | |
| mask_all = np.ones((img.shape[0], img.shape[1], 3)) | |
| for ann in sorted_anns: | |
| m = ann['segmentation'] | |
| color_mask = np.random.random((1, 3)).tolist()[0] | |
| for i in range(3): | |
| mask_all[m == True, i] = color_mask[i] | |
| result = img / 255 * 0.3 + mask_all * 0.7 | |
| return result, mask_all | |
| def generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, | |
| min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, | |
| input_x, progress=gr.Progress()): | |
| # sam model | |
| sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) | |
| mask_generator = SamAutomaticMaskGenerator( | |
| sam, | |
| points_per_side=points_per_side, | |
| pred_iou_thresh=pred_iou_thresh, | |
| stability_score_thresh=stability_score_thresh, | |
| stability_score_offset=stability_score_offset, | |
| box_nms_thresh=box_nms_thresh, | |
| crop_n_layers=crop_n_layers, | |
| crop_nms_thresh=crop_nms_thresh, | |
| crop_overlap_ratio=512 / 1500, | |
| crop_n_points_downscale_factor=1, | |
| point_grids=None, | |
| min_mask_region_area=min_mask_region_area, | |
| output_mode='binary_mask' | |
| ) | |
| # input is image, type: numpy | |
| if type(input_x) == np.ndarray: | |
| result, mask_all = segment_one(input_x, mask_generator) | |
| return result, mask_all | |
| elif isinstance(input_x, str): # input is video, type: path (str) | |
| cap = cv2.VideoCapture(input_x) # read video | |
| frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True) | |
| for _ in progress.tqdm(range(int(frames_num)), | |
| desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)): | |
| ret, frame = cap.read() # read a frame | |
| result, mask_all = segment_one(frame, mask_generator, seed=2023) | |
| result = (result * 255).astype(np.uint8) | |
| out.write(result) | |
| out.release() | |
| cap.release() | |
| return 'output.mp4' | |
| def predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold=0.1): | |
| # sam model | |
| sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) | |
| predictor = SamPredictor(sam) | |
| predictor.set_image(input_x) # Process the image to produce an image embedding | |
| if input_text != '': | |
| # split input text | |
| input_text = [input_text.split(',')] | |
| print(input_text) | |
| # OWL-ViT model | |
| processor = OwlViTProcessor.from_pretrained('./checkpoints/models--google--owlvit-base-patch32') | |
| owlvit_model = OwlViTForObjectDetection.from_pretrained("./checkpoints/models--google--owlvit-base-patch32").to(device) | |
| # get outputs | |
| input_text = processor(text=input_text, images=input_x, return_tensors="pt").to(device) | |
| outputs = owlvit_model(**input_text) | |
| target_size = torch.Tensor([input_x.shape[:2]]).to(device) | |
| results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_size, | |
| threshold=owl_vit_threshold) | |
| # get the box with best score | |
| scores = torch.sigmoid(outputs.logits) | |
| # best_scores, best_idxs = torch.topk(scores, k=1, dim=1) | |
| # best_idxs = best_idxs.squeeze(1).tolist() | |
| i = 0 # Retrieve predictions for the first image for the corresponding text queries | |
| boxes_tensor = results[i]["boxes"] # [best_idxs] | |
| boxes = boxes_tensor.cpu().detach().numpy() | |
| # boxes = boxes[np.newaxis, :, :] | |
| transformed_boxes = predictor.transform.apply_boxes_torch(torch.Tensor(boxes).to(device), | |
| input_x.shape[:2]) # apply transform to original boxes | |
| # transformed_boxes = transformed_boxes.unsqueeze(0) | |
| print(transformed_boxes.size(), boxes.shape) | |
| else: | |
| transformed_boxes = None | |
| # points | |
| if len(selected_points) != 0: | |
| points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1) | |
| labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1) | |
| transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2]) | |
| print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points) | |
| else: | |
| transformed_points, labels = None, None | |
| # predict segmentation according to the boxes | |
| masks, scores, logits = predictor.predict_torch( | |
| point_coords=transformed_points, | |
| point_labels=labels, | |
| boxes=transformed_boxes, # only one box | |
| multimask_output=False, | |
| ) | |
| masks = masks.cpu().detach().numpy() | |
| mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3)) | |
| for ann in masks: | |
| color_mask = np.random.random((1, 3)).tolist()[0] | |
| for i in range(3): | |
| mask_all[ann[0] == True, i] = color_mask[i] | |
| img = input_x / 255 * 0.3 + mask_all * 0.7 | |
| if input_text != '': | |
| img = plot_boxes(img, boxes_tensor) # image + mask + boxes | |
| # free the memory | |
| if input_text != '': | |
| owlvit_model.cpu() | |
| del owlvit_model | |
| del input_text | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return img, mask_all | |
| def run_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area, | |
| stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, owl_vit_threshold, input_x, | |
| input_text, selected_points): | |
| # if input_x is int, the image is selected from examples | |
| if isinstance(input_x, int): | |
| input_x = cv2.imread(image_examples[input_x][0]) | |
| input_x = cv2.cvtColor(input_x, cv2.COLOR_BGR2RGB) | |
| if (input_text != '' and not isinstance(input_x, str)) or len(selected_points) != 0: # user input text or points | |
| print('use predictor_inference') | |
| print('prompt text: ', input_text) | |
| print('prompt points length: ', len(selected_points)) | |
| return predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold) | |
| else: | |
| print('use generator_inference') | |
| return generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, | |
| min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers, | |
| crop_nms_thresh, input_x) | |