code review

Learning deep representations by mutual information estimation and maximization. ICLR. 2019

sungwool 2021. 7. 8. 11:14

Deep InfoMax


only utilize the "local mutual information maximization".

Maximizing mutual information between local features and global features.

1. Model

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ENB(  3,  64, 3, 1)
        self.block2 = ENB( 64,  64, 3, 1)        
        self.block3 = ENB( 64, 128, 3, 1, True)
        self.block4 = ENB(128, 256, 3, 1)
        self.block5 = ENB(256, 512, 3, 1)
        self.block6 = ENB(512,  64, 1, 0)
        
        self.gap = nn.AdaptiveAvgPool2d(1)


    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        features = self.block3(h)

        h = self.block4(features)
        h = self.block5(h)

        h = self.gap(h)
        encoded = self.block6(h)

        return encoded, features

 

class ENB(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, is_pooling=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding)
        self.bn   = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        self.is_pooling = is_pooling
        if is_pooling:
            self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        if self.is_pooling:
            x = self.pool(x)

        return x

 

2. Local Discriminator

class LocalDiscriminator(nn.Module):
	def __init__(self):
    	super().__init__()
        self.conv1 = nn.Conv2d(192, 512, kernel_size=1)
        self.conv2 = nn.Conv2d(512, 512, kernel_size=1)
        self.conv3 = nn.Conv2d(512,   1, kernel_size=1)

  	self.relu  = nn.ReLU()
        
    def forward(self, x):
    	h = self.conv1(x)
        h = self.relu(h)
        
        h = self.conv2(x)
        h = self.relu(h)
        
        out = self.conv3(x)
        
        return out

 

3. Loss Function

$$ \text{argmax}_{\omega, \psi} \frac{1}{M^2}\sum^{M^2}_{i=1}\hat{I}_{\omega, \psi}(C^{(i)}_{\psi}(X); \ E_{\psi}(X)) $$

$$ \text{where}\ C_{\psi}(x) := \{C^{(i)}_{\psi}\}^{M\times M}_{i=1},\ E_{\psi}(x) = f_{\psi} \circ C_{\psi}(x) $$

class DeepInfoMaxLoss(nn.Module):
    def __init__(self, beta=1.0):
        super().__init__()
        self.local_d = LocalDiscriminator()
        self.beta = beta

    def forward(self, y, M, M_prime):
    	# M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
        y_exp = y
        y_exp = y_exp.expand(-1, -1, 32, 32)

        y_M = torch.cat((M, y_exp), dim=1)
        y_M_prime = torch.cat((M_prime, y_exp), dim=1)

        Ej = -F.softplus(-self.local_d(y_M)).mean()
        Em = F.softplus(self.local_d(y_M_prime)).mean()
        LOCAL = (Em - Ej) * self.beta
        
        return LOCAL

 

 

code from : https://github.com/DuaneNielsen/DeepInfomaxPytorch