| | from torch.utils.data import dataset |
| | from tqdm import tqdm |
| | import network |
| | import utils |
| | import os |
| | import random |
| | import argparse |
| | import numpy as np |
| |
|
| | from torch.utils import data |
| | from datasets import VOCSegmentation, Cityscapes, cityscapes |
| | from torchvision import transforms as T |
| | from metrics import StreamSegMetrics |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from PIL import Image |
| | import matplotlib |
| | import matplotlib.pyplot as plt |
| | from glob import glob |
| |
|
| | def get_argparser(): |
| | parser = argparse.ArgumentParser() |
| |
|
| | |
| | parser.add_argument("--input", type=str, required=True, |
| | help="path to a single image or image directory") |
| | parser.add_argument("--dataset", type=str, default='voc', |
| | choices=['voc', 'cityscapes'], help='Name of training set') |
| |
|
| | |
| | available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \ |
| | not (name.startswith("__") or name.startswith('_')) and callable( |
| | network.modeling.__dict__[name]) |
| | ) |
| |
|
| | parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet', |
| | choices=available_models, help='model name') |
| | parser.add_argument("--separable_conv", action='store_true', default=False, |
| | help="apply separable conv to decoder and aspp") |
| | parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16]) |
| |
|
| | |
| | parser.add_argument("--save_val_results_to", default=None, |
| | help="save segmentation results to the specified dir") |
| |
|
| | parser.add_argument("--crop_val", action='store_true', default=False, |
| | help='crop validation (default: False)') |
| | parser.add_argument("--val_batch_size", type=int, default=4, |
| | help='batch size for validation (default: 4)') |
| | parser.add_argument("--crop_size", type=int, default=513) |
| |
|
| | |
| | parser.add_argument("--ckpt", default=None, type=str, |
| | help="resume from checkpoint") |
| | parser.add_argument("--gpu_id", type=str, default='0', |
| | help="GPU ID") |
| | return parser |
| |
|
| | def main(): |
| | opts = get_argparser().parse_args() |
| | if opts.dataset.lower() == 'voc': |
| | opts.num_classes = 21 |
| | decode_fn = VOCSegmentation.decode_target |
| | elif opts.dataset.lower() == 'cityscapes': |
| | opts.num_classes = 19 |
| | decode_fn = Cityscapes.decode_target |
| |
|
| | os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print("Device: %s" % device) |
| |
|
| | |
| | image_files = [] |
| | if os.path.isdir(opts.input): |
| | for ext in ['png', 'jpeg', 'jpg', 'JPEG']: |
| | files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True) |
| | if len(files)>0: |
| | image_files.extend(files) |
| | elif os.path.isfile(opts.input): |
| | image_files.append(opts.input) |
| | |
| | |
| | model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) |
| | if opts.separable_conv and 'plus' in opts.model: |
| | network.convert_to_separable_conv(model.classifier) |
| | utils.set_bn_momentum(model.backbone, momentum=0.01) |
| | |
| | if opts.ckpt is not None and os.path.isfile(opts.ckpt): |
| | |
| | checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) |
| | model.load_state_dict(checkpoint["model_state"]) |
| | model = nn.DataParallel(model) |
| | model.to(device) |
| | print("Resume model from %s" % opts.ckpt) |
| | del checkpoint |
| | else: |
| | print("[!] Retrain") |
| | model = nn.DataParallel(model) |
| | model.to(device) |
| |
|
| | |
| |
|
| | if opts.crop_val: |
| | transform = T.Compose([ |
| | T.Resize(opts.crop_size), |
| | T.CenterCrop(opts.crop_size), |
| | T.ToTensor(), |
| | T.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| | else: |
| | transform = T.Compose([ |
| | T.ToTensor(), |
| | T.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| | if opts.save_val_results_to is not None: |
| | os.makedirs(opts.save_val_results_to, exist_ok=True) |
| | with torch.no_grad(): |
| | model = model.eval() |
| | for img_path in tqdm(image_files): |
| | ext = os.path.basename(img_path).split('.')[-1] |
| | img_name = os.path.basename(img_path)[:-len(ext)-1] |
| | img = Image.open(img_path).convert('RGB') |
| | img = transform(img).unsqueeze(0) |
| | img = img.to(device) |
| | |
| | pred = model(img).max(1)[1].cpu().numpy()[0] |
| | colorized_preds = decode_fn(pred).astype('uint8') |
| | colorized_preds = Image.fromarray(colorized_preds) |
| | if opts.save_val_results_to: |
| | colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png')) |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|