code review

Virtual Class Enhanced Discriminative Embedding Learning. NIPS. 2018.

sungwool 2021. 7. 28. 15:54

Virtual Softmax


Conceptual visualization of vritual softmax

 

1. Method

  이 논문에서 제안하는 Virtual Softmax는 Softmax에서 파생된 classifier입니다. 먼저 Softmax의 수식은 아래와 같습니다. 

$$ \frac{\mathcal{e}^{W^{T}_{y_i}X_i}}{\sum^C_{j=1}\mathcal{e}^{W^{T}_{y_j}X_i}} $$

Softmax는 수식에서 볼 수 있듯이 weight와 입력의 feature의 inner product \(Wx\)를 기반으로 연산됩니다. 학습 중에는 \( W^T_{y_i}X_i > \max_{j \in c, j \neq y_i}(W^T_jX_i)\)를 충족하도록 penalty를 부여함으로써 입력이 올바른 class로 분류되도록 합니다. 하지만, Softmax는 일정 수준까지 학습되면 더 이상 학습이 진행되지 않는 단점이 있습니다.

먼저, \( \|W_j\| = l \)이라고 가정합니다. 주어진 가정에 맞추어 앞 수식을 다시 전개해봅니다.

$$ W^T_{y_i}X_i > \max_{j \in c, j \neq y_i}(W^T_jX_i) $$

$$ = l \|X_i\| \cos{\theta_{y_i}} > \max_{j \in c, j \neq y_i}{ l \|X_i\|\cos{\theta_j}}$$

$$ = \cos{\theta_{y_i}} > \max_{j\in c, j \neq y_i}\cos{\theta_j} $$

위 식에서 \(\theta_j\)는 feature vector \(X_i\)와 class vector \(W_j\)가 이루는 각을 의미합니다. 이제 class \(y_i\)의 decision boundary \(\theta\)는 아래와 같이 표현할 수 있습니다.

$$ \theta = \theta_{y_i} = \theta_{\text{argmax}_{j, j \neq y_i}(W^T_jX)} = \frac{\boldsymbol{\Phi}}{2} $$

위 식에서 \(\boldsymbol{\phi}\)는 feature space 상에서 class vector가 고르게 분포했을 때, 모든 class vector가 이루는 same vectorial angle을 의미하며, \(\boldsymbol{\Phi} = \frac{2\phi}{c}\)의 값을 가집니다. 결과적으로 앞서 가정한 상황을 고려한다면 Softmax는 class 개수에 의존적인 vectorial angle보다 좁아지는 방향으로 학습이 이루어지지 않습니다.

 

본 논문에서 제안하는 Virtual Softmax는 아래와 같은 수식을 가집니다.

$$  \frac{\mathcal{e}^{W^{T}_{y_i}X_i}}{\sum^C_{j=1}\mathcal{e}^{W^{T}_{y_j}X_i} + \mathcal{e}^{W^T_{virt}X_i}}  $$

식에서 확인할 수 있듯이, Virtual Softmax는 가상의 class인 \( virt \)를 함께 고려하여 총 \(C+1\)개의 class를 사용합니다. 수식에서 \( W^{T}_{virt}X_i = \| W_{y_i} \| \| X_i \| \)이기 때문에, \( W^T_{y_i}X_i \ge \max_{j \in C+1}{W^{T}_{j}X_i} = W^{T}_{virt}X_i \)가 됩니다. 즉, 이는 Virtual Softmax는 가상의 class를 통해 \( \theta_{y_i}s = 0 \)인 최적점을 갖게 됩니다. 이는 Softmax의 최적점인 \( \theta = \frac{\boldsymbol{\Phi}}{2}\)와 대비되는 특성입니다. 이러한 특성은 아래와 그림과 같이 학습하는 model이 매우 discriminative한 feature를 갖도록 유도합니다.

Visualization of the learned features optimized by original Softmax vs. Virtual Softmax

 

 실제 PyTorch로는 아래와 같이 구현될 수 있습니다.

$$ \text{Virtual Softmax} \rightarrow \frac{\mathcal{e}^{W^{T}_{y_i}X_i}}{\sum^C_{j=1}\mathcal{e}^{W^{T}_{y_j}X_i} + \mathcal{e}^{W^T_{virt}X_i}}  $$

from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn as nn
import math
import torch


class VirtualSoftmax(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super(VirtualSoftmax, self).__init__()
        self.weight = nn.Parameter(torch.ones(feat_dim, num_classes), requires_grad=True)

    def forward(self, x, y):
        ''' x's shape : [bs, feat_dim] '''

        Wy        = self.weight[:, y]
        Wy_norm   = torch.norm(Wy, dim=1, p=2, keepdim=True)
        Wy_norm_T = Wy_norm.T

        x_norm    = torch.norm(x, dim=1, p=2, keepdim=True)
        
        Wx_vert   = torch.mul(x_norm, Wy_norm_T)                # [bs, feat_dim]
        Wx_vert   = torch.clip(Wx_vert, min=1e-10, max=15.0)    
        
        Wx        = torch.mm(x, self.weight)                    # [bs, feat_dim]
        Wx_new    = torch.cat([Wx, Wx_vert], dim=1)

        return Wx_new