Virtual Class Enhanced Discriminative Embedding Learning. NIPS. 2018.
Virtual Softmax

1. Method
이 논문에서 제안하는 Virtual Softmax는 Softmax에서 파생된 classifier입니다. 먼저 Softmax의 수식은 아래와 같습니다.
$$ \frac{\mathcal{e}^{W^{T}_{y_i}X_i}}{\sum^C_{j=1}\mathcal{e}^{W^{T}_{y_j}X_i}} $$
Softmax는 수식에서 볼 수 있듯이 weight와 입력의 feature의 inner product \(Wx\)를 기반으로 연산됩니다. 학습 중에는 \( W^T_{y_i}X_i > \max_{j \in c, j \neq y_i}(W^T_jX_i)\)를 충족하도록 penalty를 부여함으로써 입력이 올바른 class로 분류되도록 합니다. 하지만, Softmax는 일정 수준까지 학습되면 더 이상 학습이 진행되지 않는 단점이 있습니다.
먼저, \( \|W_j\| = l \)이라고 가정합니다. 주어진 가정에 맞추어 앞 수식을 다시 전개해봅니다.
$$ W^T_{y_i}X_i > \max_{j \in c, j \neq y_i}(W^T_jX_i) $$
$$ = l \|X_i\| \cos{\theta_{y_i}} > \max_{j \in c, j \neq y_i}{ l \|X_i\|\cos{\theta_j}}$$
$$ = \cos{\theta_{y_i}} > \max_{j\in c, j \neq y_i}\cos{\theta_j} $$
위 식에서 \(\theta_j\)는 feature vector \(X_i\)와 class vector \(W_j\)가 이루는 각을 의미합니다. 이제 class \(y_i\)의 decision boundary \(\theta\)는 아래와 같이 표현할 수 있습니다.
$$ \theta = \theta_{y_i} = \theta_{\text{argmax}_{j, j \neq y_i}(W^T_jX)} = \frac{\boldsymbol{\Phi}}{2} $$
위 식에서 \(\boldsymbol{\phi}\)는 feature space 상에서 class vector가 고르게 분포했을 때, 모든 class vector가 이루는 same vectorial angle을 의미하며, \(\boldsymbol{\Phi} = \frac{2\phi}{c}\)의 값을 가집니다. 결과적으로 앞서 가정한 상황을 고려한다면 Softmax는 class 개수에 의존적인 vectorial angle보다 좁아지는 방향으로 학습이 이루어지지 않습니다.
본 논문에서 제안하는 Virtual Softmax는 아래와 같은 수식을 가집니다.
$$ \frac{\mathcal{e}^{W^{T}_{y_i}X_i}}{\sum^C_{j=1}\mathcal{e}^{W^{T}_{y_j}X_i} + \mathcal{e}^{W^T_{virt}X_i}} $$
식에서 확인할 수 있듯이, Virtual Softmax는 가상의 class인 \( virt \)를 함께 고려하여 총 \(C+1\)개의 class를 사용합니다. 수식에서 \( W^{T}_{virt}X_i = \| W_{y_i} \| \| X_i \| \)이기 때문에, \( W^T_{y_i}X_i \ge \max_{j \in C+1}{W^{T}_{j}X_i} = W^{T}_{virt}X_i \)가 됩니다. 즉, 이는 Virtual Softmax는 가상의 class를 통해 \( \theta_{y_i}s = 0 \)인 최적점을 갖게 됩니다. 이는 Softmax의 최적점인 \( \theta = \frac{\boldsymbol{\Phi}}{2}\)와 대비되는 특성입니다. 이러한 특성은 아래와 그림과 같이 학습하는 model이 매우 discriminative한 feature를 갖도록 유도합니다.

실제 PyTorch로는 아래와 같이 구현될 수 있습니다.
$$ \text{Virtual Softmax} \rightarrow \frac{\mathcal{e}^{W^{T}_{y_i}X_i}}{\sum^C_{j=1}\mathcal{e}^{W^{T}_{y_j}X_i} + \mathcal{e}^{W^T_{virt}X_i}} $$
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn as nn
import math
import torch
class VirtualSoftmax(nn.Module):
def __init__(self, feat_dim, num_classes):
super(VirtualSoftmax, self).__init__()
self.weight = nn.Parameter(torch.ones(feat_dim, num_classes), requires_grad=True)
def forward(self, x, y):
''' x's shape : [bs, feat_dim] '''
Wy = self.weight[:, y]
Wy_norm = torch.norm(Wy, dim=1, p=2, keepdim=True)
Wy_norm_T = Wy_norm.T
x_norm = torch.norm(x, dim=1, p=2, keepdim=True)
Wx_vert = torch.mul(x_norm, Wy_norm_T) # [bs, feat_dim]
Wx_vert = torch.clip(Wx_vert, min=1e-10, max=15.0)
Wx = torch.mm(x, self.weight) # [bs, feat_dim]
Wx_new = torch.cat([Wx, Wx_vert], dim=1)
return Wx_new