Contrastive Test-Time Adaptation (CoTTA)

 

들어가며 ..

DNN은 training과 test의 data의 distribution은 따라서 학습이 되어지지만 unseen data에 대해서는 domain shift가 되어지는 단점을 가지고 있다. Dsitribution shift는 labeled source 데이터가 새로운 target domain에 대해서 transfer knowlege를 하는것이 목적이다.

최근에는 test-time, source free라는 조건으로 Domain adpataion(DA)를 접근을 하고 있으며 source data는 test data에 대해서는 더이상 사용하지 않는 개념을 가지고 있다.

TTA는 source model에 대해서만 접근을 할 수 있으면 된다.

 

TTA는 2가지의 조건이 있는데

1. 어떻게 target domain representation을 GT annoitation없이 어떻게 학습을 할 것인가?

2. 어떻게 traget domian calssifier를 적절한 soure domain classifiter만 가지고 만들어야되는가?

이러한 해법을 위해서 image/featuer generation, class prototypes, entropy minimization, SSL, Pesudo lable, SSL with auxiliary task가 있다.

 

Genteration의 경우는 capaicity가 많이 필요하며

Entropy minimization (EM)의 경우는 competitve하지만 entropy의 disrupt를 야기하는 단점이 있다.

Pesudo labeling은 promise한 result지만 noisy한 label를 만듬.

SSL의 방법은 auxiliary rotate prediction task를 source와 target에 training을 하게 되어지며 이 방법의 단점은 source 에 traninig protocol이 추가적으로 들어각 ㅔ되는점이다.

Contrastiive learning의 paradigm을 적용하여 학습 하는 방안의 경우는 transferalbe rotate의 비교하여 하는 방법으로 최근에 pre-training stage에서 ssl를 학습을 하는 방법을 제안한다.

저자는 이 이전의 방법의 경우 SSL을 최대한 활용하지 못한다고 말한다. 특히나 adaptation stage에 대해서 ...

 

이번 저자는 TTA를 SSL 사용하며 더 좋은 representation을 학습을 하게 pesudo labeling을 사용하여 향상 시켰으며 auxiliary contrastive learnin을 통해서 online pesudo label를 만들면서 정확도를 올렸다.

이를 활용하기 위해서 MOCO에서 사용했던 memory bank르 사용하였따. 기존의 방식보다 hyperparemter의 tuning의 폭이 커졌다.

 

 

 

TENT SHOT의 단점은

the entropy minimization objective does not model the relation among different samples and more importantly, disrupts the model calibration on target data due to direct entropy optimization.

the pseudo labels are updated only on a per-epoch basis, which fails to reflect the most recent model improvement during an epoch.

 

 

Method

Source 모델을 얻기 위해서 Stand cross entropy를 사용하여서 softmax를 이용하여서 학습을 하였으며

Online pseudo label refinement방법을 사용하였음.

target data에 대해서 pesudo label을 만들기 위해서 Source model의 weight를 가져왔었으며 한 epoch이 끝날때마다 peusudo label를 update를 하였으며 , refinement nearest -heighbor soft voting으로 update되어진다.

soft k nearest neighbors : https://github.com/DianCh/AdaContrast/blob/c3c8b880131f2658d6fd0d5ed14d71f326174d57/target.py#L123

위의 첨부한 그림을 기반으로 설명을 하자면 $x_t$라는 Test case에 weak augmenation ($t_w$)를 적용을 하며 Random distribution $T_w$ 를 이러한 Feature들을 vector화를 시켜 $w=F_t(t_w(x_t))$ 로 뽑아내게 되어지게 되어지면 feature들에서 target feature space와 유사한 nearest neighbors가 나오게 되며 argmax를 사용하여 peusdo label를 뽑아지게 되어진다. 이때 나온 값을 $\hat{y}$라고 한다.

code link : https://github.com/DianCh/AdaContrast/blob/c3c8b880131f2658d6fd0d5ed14d71f326174d57/target.py#L240 

 

class AdaMoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a memory bank
    https://arxiv.org/abs/1911.05722
    """

    def __init__(
        self,
        src_model,
        momentum_model,
        K=16384,
        m=0.999,
        T_moco=0.07,
        checkpoint_path=None,
    ):
        """
        dim: feature dimension (default: 128)
        K: buffer size; number of keys
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(AdaMoCo, self).__init__()

        self.K = K
        self.m = m
        self.T_moco = T_moco
        self.queue_ptr = 0

        # create the encoders
        self.src_model = src_model
        self.momentum_model = momentum_model

        # create the fc heads
        feature_dim = src_model.output_dim

        # freeze key model
        self.momentum_model.requires_grad_(False)

        # create the memory bank
        self.register_buffer("mem_feat", torch.randn(feature_dim, K))
        self.register_buffer(
            "mem_labels", torch.randint(0, src_model.num_classes, (K,))
        )
        self.mem_feat = F.normalize(self.mem_feat, dim=0)
        
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        # encoder_q -> encoder_k
        for param_q, param_k in zip(
            self.src_model.parameters(), self.momentum_model.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

 

그 이후에서는 

Memory Queue

 

memory quete : $$ Q_w $$ 를 통해서

weak augmented target sample과 currnet mini batc에서 같은 source weight에서 momentum을 추가하여준다.

 

Joint self supervised contrastirve learning

contrastive learning 을 적용하기 위해서 pair-wise 하게 접근을 하는것이 기반이 되어야 한다. 저자는 self supervised랑 test data간의 contrastive learning을 합쳐서 적용하기 위해서 같은 영상에 다른 view끼리 positivie pair로 하게 하면서 다른 이미지간에서는 밀어내어 학습을 하게 한다.

이를 위해서 target image는 $x_t$로 정의를 하며 strong augmentation의 경우는 $t_x, t'_x$로 정의가 되어지며 최종적으로 $t_x(s_t), t'_s(x_t)$간의 contrastive로 학습을 하게 되어진다. 

 

encoder initialization by source

MOCO의 모델과 유사하게 target encoder $f_t$를 source model wieght로 적용을 하며 momentum encoder $f'_t$로 initialize하게 된다. 이 momentum encoder로 $Q_w$를 updating하면서 학습을 하게 되어진다. 

 

Exclusion of same-class negative pairs

앞전에 init하였던 encoder들의 바탕으로 augmentation을 달리하여 q와 k라는 feature를 뽑아내어 infoNCE loss를 사용하여 cosine distnace를 minimize를 하여준다. 

이 loss를 바탕으로 postivie간에는 push하지 않고 다른 image간에는 negtaive로 밀어내게 되어진다.

 

Additional regularization

FixMatch에서 영감을 받아서 psudo label를 만들어내었으며 weakly-augmentation 과 strong augmentationd의 몇가지 중요한 distinctions를 하여 만들었다.

1. 절대 GT를 접근을 하지 않으며, 이전에 사용했던 Pesudo label를 refine하며, confidence한 threshold를 만들지 않으며 source initialization을 한 바탕으로 시작을 하는것을 전제로 하였다. 이를 정규화를 하기 위해서 Cross enotropy를 사용한다.

 

 

 

 

 

 

반응형

introduction

이번 논문에는 Neurlps 2019에 나왔던 weakly supervised instance segmentation에 관한 논문이다. 

아마 bbox annotation만을 가지고 instance segmentation을 하는 방법에서 multi instance learning(MIL)을 사용했다는 점에서 새로운 방식으로 접근을 하였다. 

논문의 코드는 https://github.com/chengchunhsu/WSIS_BBTP. 해당 링크에 있다. 

 

Method 

기본적인 instance segmentation의 pipeline의 경우는 대표적 two stage instance segmentation의 방식 중에 하나로 Mask-RCNN으로 구성이 되어있는 것을 볼 수 있다. 

밑의 있는 그림에서 나와있듯이 첮번째로는 ResNet101에서 RPN을 통해서 ROI를 추출한 다음 ROI aline을 한 후 Bbox를 추출하는 건 일반적인 object detection의 방식과 같은 형식으로 진행이 된다. 

저자가 제안한 방법은 segmentation branch의 방법으로 이때 MIL을 사용하여 진행을 하게 되어지며 이때 tightness prior를 주게 된다. 

그렇다면 MIL을 어떻게 하였을까? 

  • 일반적으로 MIL의 경우는 postive, negative  bag의 설정이 중요하다.
  • MIL의 경우는 image내에서 해당하는 class가 있을때를 postivie bag로 설정하며 없을 때는 negative bag로 설정하여 bag들을 섞어서 postivie bag일 때 해당 class를 classification을 하는 방법을 말한다. 
  • 이를 통해서 image classification에 해당 class에 localization을 잘할수 있는 장점이 있다.
  • weakly supervised object detection중 대표적인 예시로  "C-MIL: Continuation Multiple Instance Learning for Weakly Supervised Object Detection"가 있으며
  • 이 논문에서는 저자는 밑의 그림과같이 한 개의 이미지 내에서 bbox를 무작위로 추출한 다음 IOU기준으로 높은 곳을 positive box로 하며 그 외의 것을 negative로 하여 MIL을 적용하며 좀 더 개선하기 위해서 subset의 개념들을 들고 와서 문제를 풀었다 

  • 이번 논문에서는 ROI에서 추출된 bbox끼리 MIL을 적용하는 방식을 사용하게 되는데 
  • Bbox에서 나온 여러개의 bbox를 horizontal & vertical으로 sample들을 무작위로 추출하게 된다. 
  • 이때 해당 label bbox에서 나온 sample을 반듯이 instance기준으로 pixel들이 반듯이 들어가기 때문에 positve bag으로 설정하며 다른 bbox에서 나온 sample들은 해당하는 pixel들이 들어가지 않음으로 negative bag로 설정을 한다.

 

  • 이로 인해서 training set의 경우 $D = \left\{I_n, B_n \right\}^B_{n=1}$으로 되며 MIL을 적용하게 되면 $D = \left\{I_n, B_n,\tilde {B}_n \right\}^B_{n=1}$으로 되며 $tilde {B}_n$에서는 k개의 positive & negative bags들이 포함이 되어있다. 

 

  • 이를 수식적으로 풀어보면 밑의 수식으로 되어지는데 
  • segmentation branch에서는 instance score map 나오게 되며 이를 $S$로 부르며 positive & negative bags는 $\hat{B}^+ , \hat {B}^-$와 함께 MIL을 적용하게 된다. 
  • 총 2가의 term으로 loss가 구성이 되어져있있으며 전자를 unary term과 pariwise term으로 구성이 되어있다.

unary term loss

  • positive & negative bags에서 tigthness constraints를 $S$를 바탕으로 loss를 적용하게 되는데 이는 mask를 좀 더 tight 하게 만드는 효과가 있다고 제안한다. 
  • positive bags의 경우는 $S$에서 반듯이 instance pixel이 있어야 함으로 maximal prediction score를 뽑아지도록 Loss를 주게 되며 반면에 negative box의 경우에서는 $S$에서는 pixel이 없어야 하므로 모든 pixel들이 minimize 되도록 하게 하는 형식이다. 
  • 이때 $P\left( \hat{b} \right)=max_{p\in \hat {b}} S\left ( p \right )$로 되며 probability의 경우 maximum이 된 경우 positive bag로 되어지면 반대의 경우는 negative bag로 예측이 되게 된다. 

Pairwise term

  • 전체의 obejct내에서 CAM을 이용한 결과와 argumentation의 하기전을 다시 적용하여 2개의 regulization을 했다고 보면 된다. 
  • 우리가 아는 consistency regulization을 적용했다고 보면 되겠다. 

 

  • 밑의 표와 같이 다양한 논문에 비교를 하였으며 해당 loss에 대해서도 ablation study도 실시한 것을 볼 수 있다. 
  • 그림은 상당히 준수한 수준으로 나왔다. 

 

하지만 이러한 좋은 방법에도 불구하고 같은 instance에 대해서는 아쉬운 성능을 보였으며 이는 추후의 개선점으로 보인다. 

반응형

Abstract

  • One-stage instance sementation의 새로운 방식으로 제안한 방법
  • FCOS base로 구성이 되어있으며 BlendMask가 나오기 이전에 제일 좋은 성능을 내었던 방식임
  • 2가지의 문제점을 가지고 접근을 하였으며 object instance differentiation 과 pixel-wise feature alignment에 대해서 문제점을 지적하였음. 
  • Local shape와 global saliency map을 사용해서 instance간의 구별을 하였으며 기본적으로 scratch로 학습을 하였을경우 mask AP 34.5가 나온다고 하고 있음.
  • 밑의 그림처럼 2개의 feature map을 잘  ensemble하여 mask를 획득하였음.

 

Method 

  • abstract에서 말했듯이 저자는 2가지의 문제점을 지적을 하였다.
    • 첫번째로는 instance끼리 어떻게 분리를 할것인지 (기존의 여러 방식이 있지만 좋지는 않음) 
    • feature map에서 어떻게 pixel-wise location을 aligment를 할것인지에 대해서 해결책을 놓을려고 하였다. (이방법도 마찬가지로 TensorMask처럼 4D로 풀기도하였고 contour를 맞췄지만 정확하지 않음) 
  • 따라서 제안한 방법은 local shape에 대한 방법뿐만 아니라 global saliency map을 사용하여 aligment부분과 instane끼리 분리할수 있도록 제안하였다. 

Local shape prediction

  • 첫번째로 재안한 방법은 다른 위치에 있는 instance를 구별하기 위한 방법을 제안하였다. 
  • 저자는 다른 Instance까리는 shape와 size가 다르기 때문에 2개의 branch(shape,size)를 사용하여 multi task 문제로 접근하였다.
  • shape의 branch의 경우 그림4와 같이 $H X W X S^2$로 만들었으며 ($S^2$는 shape 는size에 대한 vector를 의미함 ) size의 branchr경우에서는 $H X W X 2$로 만들었다.(2가 의미하는 바는 bbox의 Width,Height 를 의미한다.)
  • 이 두 branch의 합하여 local shape를 뽑아낼수 있는데 이는 그림 4에 첨부된것처럼 shape branch에 해당하는 한개의 center point의 Pixel에는 local에 대한 shape를 가지고 있으므로 S X S로 reshape를 하며 같은 위치의 pixel은 height 와 width를 가지고 있으므로 reshape된 SxS를 HxW로 resize하여 combination을 하게 되면 local shape에 대한 정보를 가진 featuer가 나오게 되어진다. ( 의문점은 shape의 vector를 어떻게 만들었는지?? )

Global Saliency Generation 

  • 이렇게 만든 instance mask는 coarse한 문제점이 있는데 이를 어떻게 좀더 정확한 mask로 만들기 위해서  globa saliency map을 이용하여 풀었다. 
  • 저자는 이를 sigmoid로 class-agnistic하게 하여 feature map을 뽑아내어 전체를 obejctness에 해당하는 map을 얻어내었다. 

Mask Assembly

  • 최종적으로 위의 2개의 feature map을 조합하는것이 필요하며 shape branch에서 나온 결과에 해당하는 pixel에 대한 위치를 global saliency map에 해당하는 위치로 crop을 하여 combination을 하게되는데 이를 multiple하여 정확한 instance를 뽑게 된다. 
  • 이때 사용되어지는 mask loss function의 경우 objectness만 사용하여 loss를 구하게 되어진다.

  • 전체적인 학습의 pipeline은 centernet과 똑같이 진행이 되어진다. heatmap에 대한 $L_p$의 부분은 regression으로 focal loss처럼 주었으며 $L_{size}$의 경우는 L1 norm으로 주었으며 $L_{offset}$의 경우는 stride의 주어서 featuer map에 맞춰서 L1 norm을 마찬가로 주어서 학습을 진행하였다. 
  • 최종적으로는 다른 constant를 주어서 sum을 하여 loss를 적용하였다.

 

 

Result

  • 역시 ablation study를 많이 진행하였으며 shape를 변형, backbone, local shape branch의 사용 global branch의 사용등 다양한 방법을 사용하기 전후에 대한 비교를 진행을 하였으며 
  • 그림 5에 보이는것처럼 shape의 branch만 사용했을때에는 global에 대한 featuer가 없으니 instance끼리는 분리를 하더라도 coarse mask가 나오는것을 볼수 있으며 반대인 경우에서는 shape는 정확할지라도 instance에 대한 구분이 없는것을 볼수 있다. 

 

반응형

+ Recent posts