| | import cv2 |
| | import numpy as np |
| | from torch.nn import functional as F |
| | import torch |
| |
|
| | class ActivationsAndGradients: |
| | """ Class for extracting activations and |
| | registering gradients from targeted intermediate layers """ |
| |
|
| | def __init__(self, model, target_layers, reshape_transform): |
| | self.model = model |
| | self.gradients = [] |
| | self.activations = [] |
| | self.reshape_transform = reshape_transform |
| | self.handles = [] |
| | for target_layer in target_layers: |
| | self.handles.append( |
| | target_layer.register_forward_hook( |
| | self.save_activation)) |
| | |
| | if hasattr(target_layer, 'register_full_backward_hook'): |
| | self.handles.append( |
| | target_layer.register_full_backward_hook( |
| | self.save_gradient)) |
| | else: |
| | self.handles.append( |
| | target_layer.register_backward_hook( |
| | self.save_gradient)) |
| |
|
| | def save_activation(self, module, input, output): |
| | activation = output |
| | if self.reshape_transform is not None: |
| | activation = self.reshape_transform(activation) |
| | self.activations.append(activation.cpu().detach()) |
| |
|
| | def save_gradient(self, module, grad_input, grad_output): |
| | |
| | grad = grad_output[0] |
| | if self.reshape_transform is not None: |
| | grad = self.reshape_transform(grad) |
| | self.gradients = [grad.cpu().detach()] + self.gradients |
| |
|
| | def __call__(self, x, y): |
| | self.gradients = [] |
| | self.activations = [] |
| | return self.model(x, y) |
| |
|
| | def release(self): |
| | for handle in self.handles: |
| | handle.remove() |
| |
|
| |
|
| | class GradCAM: |
| | def __init__(self, |
| | cfg, |
| | model, |
| | target_layers, |
| | reshape_transform=None, |
| | use_cuda=False): |
| | self.cfg = cfg |
| | self.model = model.eval() |
| | self.target_layers = target_layers |
| | self.reshape_transform = reshape_transform |
| | self.cuda = use_cuda |
| | if self.cuda: |
| | self.model = model.cuda() |
| | self.activations_and_grads = ActivationsAndGradients( |
| | self.model, target_layers, reshape_transform) |
| |
|
| | """ Get a vector of weights for every channel in the target layer. |
| | Methods that return weights channels, |
| | will typically need to only implement this function. """ |
| |
|
| | @staticmethod |
| | def get_cam_weights(grads): |
| | return np.mean(grads, axis=(2, 3), keepdims=True) |
| |
|
| | @staticmethod |
| | def get_loss(output, target_category): |
| | loss = 0 |
| | for i in range(len(target_category)): |
| | loss = loss + output[i] |
| | return loss |
| |
|
| | def get_cam_image(self, activations, grads): |
| | weights = self.get_cam_weights(grads) |
| | weighted_activations = weights * activations |
| | cam = weighted_activations.sum(axis=1) |
| |
|
| | return cam |
| |
|
| | @staticmethod |
| | def get_target_width_height(input_tensor): |
| | width, height = input_tensor.size(-1), input_tensor.size(-2) |
| | return width, height |
| |
|
| | def compute_cam_per_layer(self, input_tensor): |
| | activations_list = [a.cpu().data.numpy() |
| | for a in self.activations_and_grads.activations] |
| | grads_list = [g.cpu().data.numpy() |
| | for g in self.activations_and_grads.gradients] |
| | target_size = self.get_target_width_height(input_tensor) |
| |
|
| | cam_per_target_layer = [] |
| | |
| |
|
| | for layer_activations, layer_grads in zip(activations_list, grads_list): |
| | cam = self.get_cam_image(layer_activations, layer_grads) |
| | cam[cam < 0] = 0 |
| | scaled = self.scale_cam_image(cam, target_size) |
| | cam_per_target_layer.append(scaled[:, None, :]) |
| |
|
| | return cam_per_target_layer |
| |
|
| | def aggregate_multi_layers(self, cam_per_target_layer): |
| | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) |
| | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) |
| | result = np.mean(cam_per_target_layer, axis=1) |
| | return self.scale_cam_image(result) |
| |
|
| | @staticmethod |
| | def scale_cam_image(cam, target_size=None): |
| | result = [] |
| | for img in cam: |
| | img = img - np.min(img) |
| | img = img / (1e-7 + np.max(img)) |
| | if target_size is not None: |
| | img = cv2.resize(img, target_size) |
| | result.append(img) |
| | result = np.float32(result) |
| |
|
| | return result |
| |
|
| | def __call__(self, input_tensor, target_category=None): |
| | x, y = input_tensor |
| | if self.cuda: |
| | x = x.cuda() |
| | y = y.cuda() |
| |
|
| | |
| | if self.cfg.net == 'cdmask': |
| | o, outputs = self.activations_and_grads(x, y) |
| | mask_cls_results = outputs["pred_logits"] |
| | mask_pred_results = outputs["pred_masks"] |
| | mask_pred_results = F.interpolate( |
| | mask_pred_results, |
| | scale_factor=(4,4), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | mask_cls = F.softmax(mask_cls_results, dim=-1)[...,1:] |
| | mask_pred = mask_pred_results.sigmoid() |
| | output = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) |
| | else: |
| | output = self.activations_and_grads(x, y) |
| |
|
| | if isinstance(target_category, int): |
| | target_category = [target_category] * x.size(0) |
| |
|
| | if target_category is None: |
| | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) |
| | print(f"category id: {target_category}") |
| | else: |
| | assert (len(target_category) == x.size(0)) |
| |
|
| | self.model.zero_grad() |
| | loss = self.get_loss(output, target_category).sum() |
| | loss.backward(retain_graph=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | cam_per_layer = self.compute_cam_per_layer(x) |
| | return self.aggregate_multi_layers(cam_per_layer) |
| |
|
| | def __del__(self): |
| | self.activations_and_grads.release() |
| |
|
| | def __enter__(self): |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_value, exc_tb): |
| | self.activations_and_grads.release() |
| | if isinstance(exc_value, IndexError): |
| | |
| | print( |
| | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") |
| | return True |
| |
|
| |
|
| | def show_cam_on_image(img: np.ndarray, |
| | mask: np.ndarray, |
| | use_rgb: bool = False, |
| | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: |
| | """ This function overlays the cam mask on the image as an heatmap. |
| | By default the heatmap is in BGR format. |
| | |
| | :param img: The base image in RGB or BGR format. |
| | :param mask: The cam mask. |
| | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. |
| | :param colormap: The OpenCV colormap to be used. |
| | :returns: The default image with the cam overlay. |
| | """ |
| |
|
| | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) |
| | if use_rgb: |
| | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| | heatmap = np.float32(heatmap) / 255 |
| |
|
| | if np.max(img) > 1: |
| | raise Exception( |
| | "The input image should np.float32 in the range [0, 1]") |
| |
|
| | cam = heatmap + img |
| | cam = cam / np.max(cam) |
| | return np.uint8(255 * cam) |
| |
|
| |
|
| | def center_crop_img(img: np.ndarray, size: int): |
| | h, w, c = img.shape |
| |
|
| | if w == h == size: |
| | return img |
| |
|
| | if w < h: |
| | ratio = size / w |
| | new_w = size |
| | new_h = int(h * ratio) |
| | else: |
| | ratio = size / h |
| | new_h = size |
| | new_w = int(w * ratio) |
| |
|
| | img = cv2.resize(img, dsize=(new_w, new_h)) |
| |
|
| | if new_w == size: |
| | h = (new_h - size) // 2 |
| | img = img[h: h+size] |
| | else: |
| | w = (new_w - size) // 2 |
| | img = img[:, w: w+size] |
| |
|
| | return img |
| |
|