8000 compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior · Issue #22 · AdamCobb/hamiltorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior #22
Open
@neuronphysics

Description

@neuronphysics

Hi,

I am trying to compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior. It is analytically intractable unless doing some approximation. However, it is also possible to compute it via Monte Carlo Sampling. I was wondering how do you suggest implementing it with your library?

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import MultivariateNormal, OneHotCategorical, MixtureSameFamily, Categorical
from torch.distributions.independent import Independent
class VGMM(nn.Module): 
     def __init__(self,
                  u_dim,
                  h_dim,
                  z_dim,
                  n_mixtures,
                  device,
                  batch_norm=False,
                  ):
        super(VGMM, self).__init__()
        self.n_mixtures =n_mixtures
        self.u_dim= u_dim
        self.h_dim=h_dim
        self.z_dim=z_dim
        self.device=device
        self.batch_norm=  batch_norm
        encoder_layers=[nn.Linear(self.u_dim , self.h_dim)]
        if self.batch_norm:
            encoder_layers.append(torch.nn.BatchNorm1d(self.h_dim))
        encoder_layers=encoder_layers+[
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
        ]
        if self.batch_norm:
            encoder_layers= encoder_layers+[nn.BatchNorm1d(self.h_dim)]

        encoder_layers  = encoder_layers+[nn.ReLU()]

        self.enc        = torch.nn.Sequential(*encoder_layers)

        self.enc_mean   = nn.Linear(self.h_dim, self.z_dim)

        self.enc_logvar = nn.Linear(self.h_dim, self.z_dim)
        self.dist = MixtureSameFamily
        self.comp = Normal
        self.mix = Categorical

        layers_prior = [nn.Linear(self.u_dim, self.h_dim)]
        if self.batch_norm:
            layers_prior.append(torch.nn.BatchNorm1d(self.h_dim))
        layers_prior = layers_prior + [
            nn.ReLU(),
        ]

        self.prior = torch.nn.Sequential(*layers_prior)

        self.prior_mean = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
     
565D
   )

        self.prior_logvar = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )
        self.prior_weights = nn.Linear(self.h_dim, n_mixtures) 
     def forward(self, u):
        encoder_input = self.enc(u)
        enc_mean   = self.enc_mean(encoder_input)
        enc_logvar = self.enc_logvar(encoder_input)
        enc_logvar = nn.Softplus()(enc_logvar)
        prior_input =self.prior(u)
        prior_mean  = torch.cat([ self.prior_mean[n](prior_input).unsqueeze(1) for n in range(self.n_mixtures)],dim=1,)
        prior_logvar = torch.cat([self.prior_logvar[n](prior_input).unsqueeze(1)for n in range(self.n_mixtures)],dim=1,)
        prior_w     = self.prior_weights(prior_input)
        prior_sigma = prior_logvar.exp().sqrt()
        prior_dist = self.dist(self.mix(logits=prior_w), Independent(self.comp(prior_mean, prior_sigma), 1))
        post_dist = self.comp(enc_mean, enc_logvar.exp().sqrt())
        z_t      = self.reparametrization(enc_mean, enc_logvar)
        return prior_dist, post_dist, z_t
     def reparametrization(self, mu, log_var):
        var = torch.exp(log_var* 0.5)
        eps = torch.FloatTensor(var.size()).normal_(mean=0, std=1).to(self.device)
        eps = torch.autograd.Variable(eps)
        return eps.mul(var).add_(mu).add_(1e-7)     

How do you suggest I can use library to compute the KL term? Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0