code review
Learning deep representations by mutual information estimation and maximization. ICLR. 2019
sungwool
2021. 7. 8. 11:14
Deep InfoMax
only utilize the "local mutual information maximization".

1. Model
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.block1 = ENB( 3, 64, 3, 1)
self.block2 = ENB( 64, 64, 3, 1)
self.block3 = ENB( 64, 128, 3, 1, True)
self.block4 = ENB(128, 256, 3, 1)
self.block5 = ENB(256, 512, 3, 1)
self.block6 = ENB(512, 64, 1, 0)
self.gap = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
h = self.block1(x)
h = self.block2(h)
features = self.block3(h)
h = self.block4(features)
h = self.block5(h)
h = self.gap(h)
encoded = self.block6(h)
return encoded, features
class ENB(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, is_pooling=False):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.is_pooling = is_pooling
if is_pooling:
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
if self.is_pooling:
x = self.pool(x)
return x
2. Local Discriminator
class LocalDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(192, 512, kernel_size=1)
self.conv2 = nn.Conv2d(512, 512, kernel_size=1)
self.conv3 = nn.Conv2d(512, 1, kernel_size=1)
self.relu = nn.ReLU()
def forward(self, x):
h = self.conv1(x)
h = self.relu(h)
h = self.conv2(x)
h = self.relu(h)
out = self.conv3(x)
return out
3. Loss Function
$$ \text{argmax}_{\omega, \psi} \frac{1}{M^2}\sum^{M^2}_{i=1}\hat{I}_{\omega, \psi}(C^{(i)}_{\psi}(X); \ E_{\psi}(X)) $$
$$ \text{where}\ C_{\psi}(x) := \{C^{(i)}_{\psi}\}^{M\times M}_{i=1},\ E_{\psi}(x) = f_{\psi} \circ C_{\psi}(x) $$
class DeepInfoMaxLoss(nn.Module):
def __init__(self, beta=1.0):
super().__init__()
self.local_d = LocalDiscriminator()
self.beta = beta
def forward(self, y, M, M_prime):
# M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
y_exp = y
y_exp = y_exp.expand(-1, -1, 32, 32)
y_M = torch.cat((M, y_exp), dim=1)
y_M_prime = torch.cat((M_prime, y_exp), dim=1)
Ej = -F.softplus(-self.local_d(y_M)).mean()
Em = F.softplus(self.local_d(y_M_prime)).mean()
LOCAL = (Em - Ej) * self.beta
return LOCAL
code from : https://github.com/DuaneNielsen/DeepInfomaxPytorch