PromptIR: Prompting for All-in-One Blind Image Restoration

本文最后更新于:2023年9月19日 晚上

PromptIR: Prompting for All-in-One Blind Image Restoration

导读

图像恢复是从其受损版本中恢复高质量清晰图像的过程。基于深度学习的方法显著提升了图像恢复性能,然而,它们在不同类型和级别的退化上的泛化能力有限。这限制了它们在实际应用中的使用,因为需要针对每种具体的退化进行单独训练模型,并了解输入图像的退化类型才能应用相应的模型。本文介绍了一种基于提示的学习方法,称为PromptIR,用于全能图像恢复,可以有效地从各种类型和级别的退化中恢复图像。具体而言,本文方法使用提示(Prompt)来编码退化特定信息,并动态引导恢复网络。 这使得本文的方法能够推广到不同类型和级别的退化,并在图像去噪、去雨和去雾方面取得了最先进的结果。总的来说,PromptIR提供了一个通用且高效的插件模块,只需少量轻量级提示即可用于恢复各种类型和级别的受损图像,无需事先了解图像中存在的损坏信息。

引言

最近的一种方法AirNet,采用了基于对比学习的方法来提取特征,其中同一张退化图片中截取出来的patch作为正样本(退化相同),而不同的退化图片中的patch作为负样本(退化不同),通过对比学习的方法学习表示,这样学得的表示就能包含图像的退化信息。这涉及训练一个额外的编码器来区分各种类型的图像退化。尽管AirNet取得了最先进的结果,但它在建模不同污染类型的完全解耦表示方面存在困难。此外,使用附加的编码器进行对比学习会导致更高的训练负担,因为需要两阶段的训练方法。

PromptIR提出了一个即插即用的提示模块,该模块隐式预测与退化条件相关的提示,以指导未知退化的输入图像的恢复过程。来自提示的引导被注入到网络的多个解码阶段,具有少量的可学习参数。这样可以学习一个一体化的统一模型,该模型可以很好地执行多个图像恢复任务(例如,排水、去雾和去噪)。

即插即用的提示模块

下图中显示了PromptIR和最先进的AirNet中使用的退化嵌入的tSNE图。不同的颜色表示不同的退化类型。在PromptIR的例子中,每个任务的嵌入被更好地聚在一起,显示了提示标记学习具有区分退化上下文的有效性,从而有助于恢复过程。

tSNE

本文的主要工作:

  • 本文提出了一个基于提示的一体化(blind)恢复框架PromptIR,它仅依赖输入图像来恢复一个清晰的图像,而不需要对图像中出现的退化有任何先验知识。
  • 本文的提示块是一个插件模块,可以很容易地集成到任何现有的恢复网络。它包括提示生成模块PGM (prompt generation module)和提示交互模块PIM (prompt interaction module)。提示块的目标是生成具有输入条件的提示(通过PGM),这些提示配备了有用的上下文信息,以指导恢复网络(使用PIM)有效地从输入图像中删除损坏。
  • 本文的综合实验展示了PromptIR的动态适应行为,通过在各种图像恢复任务中实现最先进的性能,包括使用统一的PromptIR模型进行图像去噪、去雾。

方法

在“All-in-one”图像恢复中,本文的目标是学习单个模型,从退化的图像恢复到清晰的图像,同时没有关于退化的先验信息。通过提供关于退化类型的隐含上下文信息,可以增强其恢复清晰图像的性能。本文提出了基于提示学习的图像恢复框架PromptIR,如下图所示。PromptIR的关键元素是提示块(Prompt Block),这些提示块首先生成可学习的提示参数,然后在恢复过程中使用这些提示来指导模型。接下来详细描述PromptIR框架及其组件的总体流程。

PromptIR方法概述

本文使用U-Net风格的网络,在编码和解码阶段使用Transformer块。该框架的主要组件,即提示块,由两个模块组成,提示生成模块(PGM)和提示交互模块(PIM)。提示生成模块使用输入特性$F_l$和提示组件生成具有输入条件的提示符$P$。然后,提示交互模块通过Transformer块使用生成的提示符动态地调整输入特性。提示与解码器特征在多个级别交互,以丰富特定于退化的上下文信息。

PromptIR

对于给定的退化输入图像${I} \in {R}^{H \times W \times 3}$, PromptIR首先通过卷积运算提取底层特征${F}_0 \in {R}^{H \times W \times C}$;式中,$H \times W$为空间分辨率,$C$为通道数。接下来,特征嵌入$F_0$经过4级分级编解码,转化为深层特征${F}_r \in {R}^{H \times W \times 2C}$。编码器-解码器的每一层都使用几个Transformer块,块的数量从顶层逐渐增加到底层,以保持计算效率。从高分辨率输入开始,编码器的目标是在增加信道容量的同时逐步降低空间分辨率,从而产生低分辨率的潜在表示${F}_l \in {R}^{\frac{H}{8} \times \frac{W}{8} \times 8C}$。解码器的目标是从低分辨率的潜在特征$F_l$逐步恢复高分辨率的清晰输出。为了帮助解码过程,本文在PromptIR框架中加入了提示块。提示块是适配器模块,按顺序连接解码器的每两级。在每个解码器级别上,提示块隐式地用退化类型的信息丰富输入特征,以进行引导恢复。接下来,本文详细描述了提出的提示模块及其核心构建模块。

Prompt Block

本文提出的PromptIR方法借鉴了在自然语言处理和计算机视觉任务中使用的基于提示的技术。在这些任务中,基于提示的技术已经被用于对在源任务上训练的大型固定模型进行参数高效微调,以适应目标任务。基于提示的技术之所以有效,是因为它们能够有效地将任务特定的上下文信息编码到提示组件中。在PromptIR中,提示组件是可学习的参数,与输入特征进行交互,以丰富它们的退化类型信息。提示块由两个关键组件组成:提示生成模块(PGM)和提示交互模块(PIM)。

给定N个提示组件${P}_c \in {R}^{N \times \hat{H} \times \hat{W} \times \hat{C} }$,输入特性${F}_l \in {R}^{\hat{H} \times \hat{W} \times \hat{C} }$,提示块的整体过程定义为:

$\hat{F}_l = {PIM}({PGM}({P_c,{F}_l}),{F_l})$

提示生成模块(PGM)

提示组件${P_c}$是一组可学习的参数,与输入特征交互,嵌入了退化信息。一种直接的特征-提示交互方法是直接使用学习到的提示来校准特征。然而,这种静态方法可能会产生次优结果,因为它对输入内容是无知的。因此,本文提出了提示生成模块(PGM),它从输入特征中动态预测基于注意力的权重,并将这些权重应用于提示组件,生成与输入条件相关的提示 ${P}$。此外,PGM创建了一个共享空间,促进了提示组件之间的相关知识共享。

为了从输入特征${F}_l$生成提示权重,PGM首先对空间维度进行全局平均池化(GAP),生成特征向量$v \in {R}^{\hat{C}}$。接下来将$v$通过通道缩减的卷积层,得到一个紧凑的特征向量,然后进行softmax操作,从而得到提示权重$w \in {R}^N$。最后使用这些权重对提示组件进行调整,接着应用一个$3 \times3$的卷积层。总体而言,PGM的过程可以概括为:

$ {P}={Conv}{3 \times 3}(\sum{c=1}^{N} w_{i} {P}{c}), \quad w{i}={Softmax}({Conv}{1 \times 1}({GAP}({F}{l})))$

由于在推理阶段,恢复网络需要能够处理不同分辨率的图像,不能使用具有固定尺寸的提示组件${P}_c$。因此,作者对提示组件进行双线性插值操作,将其放大到与输入特征相同的尺寸。

提示交互模块(PIM)

PIM的主要目标是实现输入特征${F}_l$和提示${P}$之间的交互,以实现有指导的恢复过程。

在PIM中,沿着通道维度将生成的提示与输入特征进行拼接。接下来将拼接后的表示通过一个Transformer块进行处理,该块利用提示中编码的退化信息来转换输入特征。

本文的主要贡献是提示块,它是一个插件模块,与具体的架构无关。因此,在提出的PromptIR框架中,作者使用了现有的Transformer块,而不是开发一个新的块。Transformer块由两个顺序连接的子模块组成:多转置卷积头转置注意力(MDTA)和门控转置卷积前馈网络(GDFN)。MDTA在通道而不是空间维度上应用自注意操作,并具有线性复杂度。GDFN的目标是以可控的方式转换特征,即抑制信息较少的特征,只允许有用的特征在网络中传播。PIM的整体过程为:

$\hatl = {Conv{3\times3}}({GDFN}({MDTA}[{F_l};{P}]))$

其中$[;]$表示拼接操作。MDTA的公式为${Y}=W_p{V} \cdot {Softmax}({K}\cdot{Q}/\alpha)+{X} $,其中${X}$和${Y}$分别表示输入和输出特征。${Q}$、${K}$和${V}$分别表示通过应用$1\times1$PW卷积后跟随$3\times3$DW卷积在层归一化的输入特征图上获得的查询、键和值的投影。 $W_p$是PW卷积,$\alpha$是可学习的缩放参数,$(\cdot)$表示点积交互。GDFN的过程定义为${Z}=W_p^0\left(\phi(W_d^1W_p^1({LN}({Y})))\odot W_d^2W_p^2({LN}({Y}))\right)$。其中,$W_d^{( \dot )}$是 3×3 的DW卷积, $\odot$表示逐元素乘法, $\phi$是GELU非线性激活函数,${LN}$是层归一化。

Transformer block

上图详细阐述了PromptIR框架中所使用的Transformer模块的细节。首先,输入特性${X} \in {R}^{H_l×W_l×C_l}$是通过MDTA模块传递的。在这个模块中,特性最初是使用Layer规范化的。然后,结合$1\times 1$卷积和$3 \times 3$深度卷积,将特征投射到$Query (Q)$、$Key (K)$和$Value (V)$张量中。MDTA模块的一个基本特征是它计算的注意力跨越通道维度,而不是空间维度。这有效地减少了计算开销。为了实现这种通道式的注意计算,在计算点积之前,将$Q$投影和$K$投影分别从$H_l \times W_l \times C_l$变换为$H_l W_l \times C_l$和$C_l \times H_l W_l$,从而得到与形状为$C_l×C_l$变换后的注意图。在这个子模块中使用了无偏差卷积。此外,注意力以多头并行方式计算。MDTA模块之后,通过GDFN模块处理特征。在GDFN模块中,输入特征通过$1 \times 1$卷积的因子$\gamma$扩展,然后通过$3 \times 3$卷积。这些操作通过两条并行路径执行,其中一条路径的输出使用GeLU非线性激活。然后直接相加将这个激活的特征图与其他路径的输出结合起来。

实验

为了证明所提出的PromptIR的有效性,我们对三个代表性的图像恢复任务进行了评估:图像去雾、图像去雾和图像去噪。我们在两种不同的实验设置下进行实验:(a) All-in-One, (b) Single-task。

All-in-one Results

对比实验

对来自SOTS数据集的图像上的一体化方法进行脱雾比较。由我们的PromptIR产生的结果的图像质量在视觉上比以前的最先进的方法AirNet更好。

对来自Rain100L数据集的图像进行一体化方法的图像删除比较。我们的方法可以有效地去除有雨的条纹,从而生成无雨的图像。

对一体化方法的去噪结果。

Single-task Results

去雾结果会导致SOTS基准数据集上的单任务设置。我们的提示IR比AirNet显著提高了8.13 dB。

在Rain100L上进行单任务设置。与AirNet算法相比,该方法提高了2.13 dB的PSNR。

在BSD68和Urban100数据集上进行单任务设置的去噪比较。对于Urban100上σ = 50具有挑战性的噪声水平,我们与AirNet相比,mritIR获得0.51 dB增益。

消融实验

表9显示随着退化类型的增加,网络恢复图像变得越来越困难,从而导致性能下降。具体来说,在合并数据集中出现模糊图像似乎会对模型产生负面影响。有趣的是,结合雨和噪声图像训练的模型获得了很好的性能,表明去噪和去噪任务之间存在正相关关系。这种现象在AirNet工作中也可以观察到。

消融实验

消融实验


PromptIR: Prompting for All-in-One Blind Image Restoration
https://jialiangz.github.io/2023/09/10/PromptIR/
作者
爱吃菠萝
发布于
2023年9月10日
更新于
2023年9月19日
许可协议