论文笔记 Segment Anything

本文最后更新于:2023年7月19日 上午

论文笔记 Segment Anything

论文链接:Segment Anything (arxiv.org)

代码链接:facebookresearch/segment-anything (github.com)

Demo:Segment Anything | Meta AI (segment-anything.com)

Segment Anything是Meta AI发布的非常火的图像分割相关论文,提出了Segment Anything Model(SAM)模型,开启了图像分割领域的新范式。

SAM的贡献为上图所示的三点:TaskModelData

  1. Task:该文章提出了promptable segmentation任务,即输入模型图像和prompt,输出一个和提示相关的有效mask。其中,prompt可以是point、box、mask、text,“有效mask”指的是当prompt的语义较为模糊,无法确定进行segment的粒度,有多种正确的segment mask时,任意一个mask都是有效的。
  2. Model:SAM模型采用了简洁且灵活的设计,其包含一个较大的图像编码器和一个轻量的解码器和prompt编码器。prompt编码器可以接纳多种prompt格式的输入。而使用“头重脚轻”的架构可以只提取一次特征来进行多次不同prompt的解码任务。
  3. Data:SAM的训练策略是先训练、再利用训练好的模型辅助标注、再利用标注的数据进行训练。这种循环的训练策略可以在节约人力的情况下,获得高质量的标注。目前有许多方法利用了这种训练策略,关键词为Bootstrap或者self-training。最终SAM通过这种方式构建了一个具有11M图像、1B mask的分割数据集(SA-1B)。

SAM的效果图就不放了,大家都知道它效果多好,接下来介绍其技术细节。

Promptable Segmentation Task

image-20230718103517959

如图,对于一个模糊的point prompt,SAM生成3种粒度的mask,在训练时,模型的输出与最接近的粒度计算损失并回传。

Segment Anything Model

Image Encoder

图像编码器选用了MAE-Huge/16,输入1024×10241024\times 1024分辨率的图片,得到64×6464\times64个patch embedding,并通过1×11\times13×33\times3的卷积降低到256通道维度以免数据量过大。

Prompt Encoder

Prompt分为了密集的(mask)和稀疏的(point、box、text)。

密集的mask prompt先通过卷积降低分辨率并匹配通道维度为256,然后再与patch embedding相加。假如没有mask,则一个特殊的[NoMask]embedding会与patch embedding相加。

point prompt分为坐标和前景/背景两部分,坐标通过position encoding得到256维的向量,然后前景/背景则分别设置一个可学习的embedding相加。

box prompt表示为两个point prompt,为左上角和右下角。

text prompt是CLIP提取的token。

然而,这篇文章没有过多的关注text prompt,重点在于geometric prompt。

Mask Decoder

Mask Decoder将image embeddings和prompt embeddings映射到多个粒度的输出mask。

首先,将多个[class]token与prompt embeddings组合,[class]`之后会被用来得到不同粒度的mask输出。

Decoder分成四步

  1. token自注意力
  2. token对image embedding交叉注意力
  3. token FFN
  4. image embedding对token交叉注意力

这种类似bi-attention的方式进行两层,最后再来一次额外的token对image embedding交叉注意力。

由于分割任务与位置关系十分密切,所以在每个attention层,都会将positional encoding加在image embedding上,同时原始的token也会加到更新后的token上。

在decoder处理之后,就要开始生成mask。首先把image embeddings通过卷积上采样4倍,然后将[class]token通过3层MLP得到一个张量,这个张量与image embeddings按元素相乘,得到mask的预测。

SAM的使用3个[class]输出3个mask(whole, part, subpart),训练的时候,选择与Ground Truth最接近的粒度计算损失并回传。

在推理的时候,3个mask会通过预测的IoU进行排序来得到最终输出的mask。IoU是预测面积和真实面积的交并比,模型会通过[class]+MLP预测得到一个IoU分数,可以认为是模型输出的对于自己预测的置信度。

Training

训练使用focal loss和dice loss以20:1的比例进行,同时模型预测IoU分数还会通过MSE与实际的IoU分数计算损失。

训练策略采用了interactive segmentation setup:

  1. 在标注的mask中,选择一个前景的point或者box,选择box时会给坐标增加一些噪声来模拟人的划分误差。
  2. 在拥有第一轮的输出后,之后的point会从之前预测错误的区域进行采样,假如预测错误的情况是false negative(该像素应该要被选中但是没有)则为前景点,假如是false positive(改像素不应该被选中但是选上了)则为背景点。同时,前一轮的预测结果将作为mask prompt进行后面轮次的训练。这样的中间轮次持续8轮。
  3. 8轮过后,剩余2轮不添加新的prompt,仅使用已有的prompt进行训练。

这种交互式的训练策略总共有11轮,论文中说后续可以尝试更多的数据和更多的轮次,因为image embedding只用提取1次,而light weight的decoder可以便捷地运行多次。

Segment Anything Dataset

数据集包含11M超高质量的图片(3300×49503300\times 4950平均分辨率,发布的分辨率降低到短边1500)。

除此以外包含1.1B的mask,mask的收集过程见下文。

如下图所示,每张图片的mask数量多、精细程度高,并且图像来自世界的各个地方,对AI的公平性有贡献。

Segment Anything Data Engine

如此厉害的数据集,标注过程分成三部分:模型辅助人工、半自动、全自动。

Assisted-manual stage

首先,作者先使用一些公开的数据集训练一个SAM出来,然后通过SAM“头重脚轻”的架构,设计了交互式的实时的标注平台。标注者被要求标注他们能够描述出来的任何东西。

在这个过程中,一边标记,用来辅助的SAM不断进化,使用标记出来的新数据进行训练,图像编码器也从ViT-B扩展到ViT-H,这样重新训练了6次,最后标注者平均14s可以标注出一张图像。

这一阶段总共收集4.3M mask和120k图像。

Semi-automatic stage

半自动阶段,旨在提升数据集的多样性。所以这个阶段作者给SAM添加了一个目标检测的head,用来检测图中所有的物体,然后让标注者去标记那些没有被框出来的物体。

同样,这一阶段也重新训练了5次,由于难度上升,标记时间为34s,但每个图像的mask数量上升。

这一阶段总共收集5.9M mask和180k图像。

Fully automatic stage

全自动阶段,旨在提升数据集的量。作者给SAM一个32×3232\times 32的网格point prompt,每个点都能预测出3个mask。这一阶段利用一些手动的算法来选择可信的mask。

  1. 模型输出的IoU分数当作置信度,通过阈值选出可信的mask。
  2. 选择出stable的mask:mask像素的概率,卡阈值为0.5δ0.5-\delta0.5+δ0.5+\delta,假如两个阈值生成的mask类似,那么这就是一个stable的mask。
  3. 使用非极大值抑制(NMS)来过滤重复的mask。

这一阶段总共收集1.1B mask和11M图像。

实验

zero-shot point mask

SAM主要与一个单点交互式分割网络RITM进行比较,上图(a)是在23个数据集上相对RITM的定量分析,大部分更好,少部分数据集更差一些。(a)中还有橙色的点,表示“oracle”模型的评价,因为SAM会输出3种mask,非oracle的情况就是选择置信度最高的mask与GT进行比较,而oracle则是将3中mask与GT进行比较,取最相似的。在oracle的情况下,SAM在所有数据集上表现都更好。

(b)则是在几个数据集上的人工评价, (c)(d)则是加入了其它baseline的比较。

zero-shot edge detection

16×1616\times16网格获得768个mask,NMS之后使用Sobel得到边缘,感觉非常不错!

zero-shot object proposal

64×6464\times64的网格+NMS来获得非常多的mask,并通过执行度和稳定度获取top 1000个proposal,在LVIS数据集上进行评价,效果也很好。

zero-shot instance segmentation

这里使用了目标检测模型ViTDet来检测框,然后用框给SAM prompt。

这里和全监督的方法对比,发现在小数据集上有一定差距,在更精细的LVIS数据集上差距减少,而在人工评测上效果更好。

zero-shot text-to-mask

这个训练方法非常有趣,SAM的常规训练没有使用text prompt,数据集也不包含文本,他们zero-shot text-to-mask的方法是对于数据集中大于100×100100\times100像素的区域,提取其CLIP image embedding作为prompt。

注意这里是image embedding而不是text embedding,因为image和text是在同一个语义空间的,在训练的时候无需额外的文本标注来训练,而预测的时候就可以直接用文本来进行推理。

上图是定性分析,发现用文本的效果不是非常好,因为有模糊语义的存在,但是结合point一起作为prompt的话就更好了。

总之,本文只是对文本进行了初步的尝试,只是展现了潜力。

消融实验

在数据上,最后使用的数据集只包含自动生成的部分,左图证明加上人工标注部分效果只有微弱增加,所以自动生成部分的数据已经很好了。

中间图证明使用10%的图片作为训练数据效果与100%相当,可以是一个practical setting。

右图说明使用小一点的编码器也可以。

总结

总之,SAM是一个为了“通用”而设计的模型,在各个细小的任务上不一定能超过专精的模型,但SAM任务和数据集的提出为更好的AI创造了可能性。