在 PyTorch 中,钩子(Hook) 是一种机制,用于在模型的前向传播或反向传播过程中执行用户定义的操作。它允许我们在不改变模型结构的情况下访问中间计算结果(如特征图或梯度)或对它们进行修改。
钩子通常被应用于以下场景:
特征提取:从某些特定层获取激活值(前向传播的输出)。
梯度获取:从某些层获取反向传播时的梯度。
调试:检查中间层的值或诊断训练问题。
模型解释:如 Grad-CAM,需要使用钩子获取特定层的梯度和特征图。
钩子的类型
1. 前向钩子(Forward Hook)
在层的 前向传播完成后 执行。
常用于捕获特定层的激活值(即该层的输出)。
注册方式:register_forward_hook
示例:
def forward_hook(module, input, output):
print(f"Input: {input}")
print(f"Output: {output}")
layer = model.features[10] # 假设是某个卷积层
handle = layer.register_forward_hook(forward_hook)
2. 反向钩子(Backward Hook)
在 反向传播完成后 执行。
常用于捕获某些层的梯度信息。
注册方式