Contrastive Test-Time Adaptation (CoTTA)
들어가며 ..
DNN은 training과 test의 data의 distribution은 따라서 학습이 되어지지만 unseen data에 대해서는 domain shift가 되어지는 단점을 가지고 있다. Dsitribution shift는 labeled source 데이터가 새로운 target domain에 대해서 transfer knowlege를 하는것이 목적이다.
최근에는 test-time, source free라는 조건으로 Domain adpataion(DA)를 접근을 하고 있으며 source data는 test data에 대해서는 더이상 사용하지 않는 개념을 가지고 있다.
TTA는 source model에 대해서만 접근을 할 수 있으면 된다.
TTA는 2가지의 조건이 있는데
1. 어떻게 target domain representation을 GT annoitation없이 어떻게 학습을 할 것인가?
2. 어떻게 traget domian calssifier를 적절한 soure domain classifiter만 가지고 만들어야되는가?
이러한 해법을 위해서 image/featuer generation, class prototypes, entropy minimization, SSL, Pesudo lable, SSL with auxiliary task가 있다.
Genteration의 경우는 capaicity가 많이 필요하며
Entropy minimization (EM)의 경우는 competitve하지만 entropy의 disrupt를 야기하는 단점이 있다.
Pesudo labeling은 promise한 result지만 noisy한 label를 만듬.
SSL의 방법은 auxiliary rotate prediction task를 source와 target에 training을 하게 되어지며 이 방법의 단점은 source 에 traninig protocol이 추가적으로 들어각 ㅔ되는점이다.
Contrastiive learning의 paradigm을 적용하여 학습 하는 방안의 경우는 transferalbe rotate의 비교하여 하는 방법으로 최근에 pre-training stage에서 ssl를 학습을 하는 방법을 제안한다.
저자는 이 이전의 방법의 경우 SSL을 최대한 활용하지 못한다고 말한다. 특히나 adaptation stage에 대해서 ...
이번 저자는 TTA를 SSL 사용하며 더 좋은 representation을 학습을 하게 pesudo labeling을 사용하여 향상 시켰으며 auxiliary contrastive learnin을 통해서 online pesudo label를 만들면서 정확도를 올렸다.
이를 활용하기 위해서 MOCO에서 사용했던 memory bank르 사용하였따. 기존의 방식보다 hyperparemter의 tuning의 폭이 커졌다.
TENT SHOT의 단점은
the entropy minimization objective does not model the relation among different samples and more importantly, disrupts the model calibration on target data due to direct entropy optimization.
the pseudo labels are updated only on a per-epoch basis, which fails to reflect the most recent model improvement during an epoch.
Method
Source 모델을 얻기 위해서 Stand cross entropy를 사용하여서 softmax를 이용하여서 학습을 하였으며
Online pseudo label refinement방법을 사용하였음.
target data에 대해서 pesudo label을 만들기 위해서 Source model의 weight를 가져왔었으며 한 epoch이 끝날때마다 peusudo label를 update를 하였으며 , refinement nearest -heighbor soft voting으로 update되어진다.
soft k nearest neighbors : https://github.com/DianCh/AdaContrast/blob/c3c8b880131f2658d6fd0d5ed14d71f326174d57/target.py#L123
위의 첨부한 그림을 기반으로 설명을 하자면 $x_t$라는 Test case에 weak augmenation ($t_w$)를 적용을 하며 Random distribution $T_w$ 를 이러한 Feature들을 vector화를 시켜 $w=F_t(t_w(x_t))$ 로 뽑아내게 되어지게 되어지면 feature들에서 target feature space와 유사한 nearest neighbors가 나오게 되며 argmax를 사용하여 peusdo label를 뽑아지게 되어진다. 이때 나온 값을 $\hat{y}$라고 한다.
code link : https://github.com/DianCh/AdaContrast/blob/c3c8b880131f2658d6fd0d5ed14d71f326174d57/target.py#L240
class AdaMoCo(nn.Module):
"""
Build a MoCo model with: a query encoder, a key encoder, and a memory bank
https://arxiv.org/abs/1911.05722
"""
def __init__(
self,
src_model,
momentum_model,
K=16384,
m=0.999,
T_moco=0.07,
checkpoint_path=None,
):
"""
dim: feature dimension (default: 128)
K: buffer size; number of keys
m: moco momentum of updating key encoder (default: 0.999)
T: softmax temperature (default: 0.07)
"""
super(AdaMoCo, self).__init__()
self.K = K
self.m = m
self.T_moco = T_moco
self.queue_ptr = 0
# create the encoders
self.src_model = src_model
self.momentum_model = momentum_model
# create the fc heads
feature_dim = src_model.output_dim
# freeze key model
self.momentum_model.requires_grad_(False)
# create the memory bank
self.register_buffer("mem_feat", torch.randn(feature_dim, K))
self.register_buffer(
"mem_labels", torch.randint(0, src_model.num_classes, (K,))
)
self.mem_feat = F.normalize(self.mem_feat, dim=0)
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
# encoder_q -> encoder_k
for param_q, param_k in zip(
self.src_model.parameters(), self.momentum_model.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
그 이후에서는
Memory Queue
memory quete : $$ Q_w $$ 를 통해서
weak augmented target sample과 currnet mini batc에서 같은 source weight에서 momentum을 추가하여준다.
Joint self supervised contrastirve learning
contrastive learning 을 적용하기 위해서 pair-wise 하게 접근을 하는것이 기반이 되어야 한다. 저자는 self supervised랑 test data간의 contrastive learning을 합쳐서 적용하기 위해서 같은 영상에 다른 view끼리 positivie pair로 하게 하면서 다른 이미지간에서는 밀어내어 학습을 하게 한다.
이를 위해서 target image는 $x_t$로 정의를 하며 strong augmentation의 경우는 $t_x, t'_x$로 정의가 되어지며 최종적으로 $t_x(s_t), t'_s(x_t)$간의 contrastive로 학습을 하게 되어진다.
encoder initialization by source
MOCO의 모델과 유사하게 target encoder $f_t$를 source model wieght로 적용을 하며 momentum encoder $f'_t$로 initialize하게 된다. 이 momentum encoder로 $Q_w$를 updating하면서 학습을 하게 되어진다.
Exclusion of same-class negative pairs
앞전에 init하였던 encoder들의 바탕으로 augmentation을 달리하여 q와 k라는 feature를 뽑아내어 infoNCE loss를 사용하여 cosine distnace를 minimize를 하여준다.
이 loss를 바탕으로 postivie간에는 push하지 않고 다른 image간에는 negtaive로 밀어내게 되어진다.
Additional regularization
FixMatch에서 영감을 받아서 psudo label를 만들어내었으며 weakly-augmentation 과 strong augmentationd의 몇가지 중요한 distinctions를 하여 만들었다.
1. 절대 GT를 접근을 하지 않으며, 이전에 사용했던 Pesudo label를 refine하며, confidence한 threshold를 만들지 않으며 source initialization을 한 바탕으로 시작을 하는것을 전제로 하였다. 이를 정규화를 하기 위해서 Cross enotropy를 사용한다.