Grad-CAM (Gradient-weighted Class Activation Mapping) 是一种可视化深度神经网络中哪些部分对于预测结果贡献最大的技术。它能够定位到特定的图像区域,从而使得神经网络的决策过程更加可解释和可视化。

Grad-CAM 的基本思想是,在神经网络中,最后一个卷积层的输出特征图对于分类结果的影响最大,因此我们可以通过对最后一个卷积层的梯度进行全局平均池化来计算每个通道的权重。这些权重可以用来加权特征图,生成一个 Class Activation Map (CAM),其中每个像素都代表了该像素区域对于分类结果的重要性。


相比于传统的 CAM 方法,Grad-CAM 能够处理任意种类的神经网络,因为它不需要修改网络结构或使用特定的层结构。此外,Grad-CAM 还可以用于对特征的可视化,以及对网络中的一些特定层或单元进行分析。

在Pytorch中,我们可以使用钩子 (hook) 技术,在网络中注册前向钩子和反向钩子。前向钩子用于记录目标层的输出特征图,反向钩子用于记录目标层的梯度。在本篇文章中,我们将详细介绍如何在Pytorch中实现Grad-CAM。



model_path = "your/model/path/"  # instantiate your model model = XRayClassifier()  # load your model. Here we"re loading on CPU since we"re not going to do # large amounts of inference model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))  # put it in evaluation mode for inference model.eval()


import torch import torch.nn as nn import torch.nn.functional as F  # hyperparameters nc = 3 # number of channels nf = 64 # number of features to begin with dropout = 0.2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # setup a resnet block and its forward function class ResNetBlock(nn.Module):     def __init__(self, in_channels, out_channels, stride=1):         super(ResNetBlock, self).__init__()         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)         self.bn1 = nn.BatchNorm2d(out_channels)         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)         self.bn2 = nn.BatchNorm2d(out_channels)                  self.shortcut = nn.Sequential()         if stride != 1 or in_channels != out_channels:             self.shortcut = nn.Sequential(                 nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),                 nn.BatchNorm2d(out_channels)            )              def forward(self, x):         out = F.relu(self.bn1(self.conv1(x)))         out = self.bn2(self.conv2(out))         out += self.shortcut(x)         out = F.relu(out)         return out  # setup the final model structure class XRayClassifier(nn.Module):     def __init__(self, nc=nc, nf=nf, dropout=dropout):         super(XRayClassifier, self).__init__()          self.resnet_blocks = nn.Sequential(             ResNetBlock(nc,   nf,    stride=2), # (B, C, H, W) -> (B, NF, H/2, W/2), i.e., (64,64,128,128)             ResNetBlock(nf,   nf*2,  stride=2), # (64,128,64,64)             ResNetBlock(nf*2, nf*4,  stride=2), # (64,256,32,32)             ResNetBlock(nf*4, nf*8,  stride=2), # (64,512,16,16)             ResNetBlock(nf*8, nf*16, stride=2), # (64,1024,8,8)        )          self.classifier = nn.Sequential(             nn.Conv2d(nf*16, 1, 8, 1, 0, bias=False),             nn.Dropout(p=dropout),             nn.Sigmoid(),        )      def forward(self, input):         output = self.resnet_blocks(input.to(device))         output = self.classifier(output)         return output

模型3通道接收256x256的图片。它期望输入为[batch size, 3,256,256]。每个ResNet块以一个ReLU激活函数结束。对于我们的目标,我们需要选择最后一个ResNet块。

XRayClassifier(  (resnet_blocks): Sequential(    (0): ResNetBlock(      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (1): ResNetBlock(      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (2): ResNetBlock(      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (3): ResNetBlock(      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (4): ResNetBlock(      (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )  )  (classifier): Sequential(    (0): Conv2d(1024, 1, kernel_size=(8, 8), stride=(1, 1), bias=False)    (1): Dropout(p=0.2, inplace=False)    (2): Sigmoid()  ) )


model.resnet_blocks[-1] #ResNetBlock( # (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) # (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # (shortcut): Sequential( #   (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False) #   (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # ) #)



register_full_backward_hook(hook, prepend=False)


hook(module, grad_input, grad_output) -> tuple(Tensor) or None


register_forward_hook(hook, *, prepend=False, with_kwargs=False)


hook(module, args, output) -> None or modified output




# defines two global scope variables to store our gradients and activations gradients = None activations = None  def backward_hook(module, grad_input, grad_output):   global gradients # refers to the variable in the global scope   print("Backward hook running...")   gradients = grad_output   # In this case, we expect it to be torch.Size([batch size, 1024, 8, 8])   print(f"Gradients size: {gradients[0].size()}")   # We need the 0 index because the tensor containing the gradients comes   # inside a one element tuple.  def forward_hook(module, args, output):   global activations # refers to the variable in the global scope   print("Forward hook running...")   activations = output   # In this case, we expect it to be torch.Size([batch size, 1024, 8, 8])   print(f"Activations size: {activations.size()}")


backward_hook = model.resnet_blocks[-1].register_full_backward_hook(backward_hook, prepend=False) forward_hook = model.resnet_blocks[-1].register_forward_hook(forward_hook, prepend=False)


from PIL import Image  img_path = "/your/image/path/" image = Image.open(img_path).convert("RGB")


from torchvision import transforms from torchvision.transforms import ToTensor  image_size = 256 transform = transforms.Compose([                                transforms.Resize(image_size, antialias=True),                                transforms.CenterCrop(image_size),                                transforms.ToTensor(),                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),                            ])  img_tensor = transform(image) # stores the tensor that represents the image




Forward hook running... Activations size: torch.Size([1, 1024, 8, 8]) Backward hook running... Gradients size: torch.Size([1, 1024, 8, 8])




pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
import torch.nn.functional as F import matplotlib.pyplot as plt  # weight the channels by corresponding gradients for i in range(activations.size()[1]):     activations[:, i, :, :] *= pooled_gradients[i]  # average the channels of the activations heatmap = torch.mean(activations, dim=1).squeeze()  # relu on top of the heatmap heatmap = F.relu(heatmap)  # normalize the heatmap heatmap /= torch.max(heatmap)  # draw the heatmap plt.matshow(heatmap.detach())






from torchvision.transforms.functional import to_pil_image from matplotlib import colormaps import numpy as np import PIL  # Create a figure and plot the first image fig, ax = plt.subplots() ax.axis("off") # removes the axis markers  # First plot the original image ax.imshow(to_pil_image(img_tensor, mode="RGB"))  # Resize the heatmap to the same size as the input image and defines # a resample algorithm for increasing image resolution # we need heatmap.detach() because it can"t be converted to numpy array while # requiring gradients overlay = to_pil_image(heatmap.detach(), mode="F")                      .resize((256,256), resample=PIL.Image.BICUBIC)  # Apply any colormap you want cmap = colormaps["jet"] overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)  # Plot the heatmap on the same axes, # but with alpha < 1 (this defines the transparency of the heatmap) ax.imshow(overlay, alpha=0.4, interpolation="nearest", extent=extent)  # Show the plot plt.show()





backward_hook.remove() forward_hook.remove()

这篇文章可以帮助你理清Grad-CAM 是如何工作的,以及如何用Pytorch实现它。因为Pytorch包含了强大的钩子函数,所以我们可以在任何模型中使用本文的代码。


