8000 compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior · Issue #22 · AdamCobb/hamiltorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior #22

New issue

Have a 9DF6 question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
neuronphysics opened this issue Mar 3, 2023 · 0 comments

Comments

@neuronphysics
Copy link
neuronphysics commented Mar 3, 2023

Hi,

I am trying to compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior. It is analytically intractable unless doing some approximation. However, it is also possible to compute it via Monte Carlo Sampling. I was wondering how do you suggest implementing it with your library?

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import MultivariateNormal, OneHotCategorical, MixtureSameFamily, Categorical
from torch.distributions.independent import Independent
class VGMM(nn.Module): 
     def __init__(self,
                  u_dim,
                  h_dim,
                  z_dim,
                  n_mixtures,
                  device,
                  batch_norm=False,
                  ):
        super(VGMM, self).__init__()
        self.n_mixtures =n_mixtures
        self.u_dim= u_dim
        self.h_dim=h_dim
        self.z_dim=z_dim
        self.device=device
        self.batch_norm=  batch_norm
        encoder_layers=[nn.Linear(self.u_dim , self.h_dim)]
        if self.batch_norm:
            encoder_layers.append(torch.nn.BatchNorm1d(self.h_dim))
        encoder_layers=encoder_layers+[
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
        ]
        if self.batch_norm:
            encoder_layers= encoder_layers+[nn.BatchNorm1d(self.h_dim)]

        encoder_layers  = encoder_layers+[nn.ReLU()]

        self.enc        = torch.nn.Sequential(*encoder_layers)

        self.enc_mean   = nn.Linear(self.h_dim, self.z_dim)

        self.enc_logvar = nn.Linear(self.h_dim, self.z_dim)
        self.dist = MixtureSameFamily
        self.comp = Normal
        self.mix = Categorical

        layers_prior = [nn.Linear(self.u_dim, self.h_dim)]
        if self.batch_norm:
            layers_prior.append(torch.nn.BatchNorm1d(self.h_dim))
        layers_prior = layers_prior + [
            nn.ReLU(),
        ]

        self.prior = torch.nn.Sequential(*layers_prior)

        self.prior_mean = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )

        self.prior_logvar = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )
        self.prior_weights = nn.Linear(self.h_dim, n_mixtures) 
     def forward(self, u):
        encoder_input = self.enc(u)
        enc_mean   = self.enc_mean(encoder_input)
        enc_logvar = self.enc_logvar(encoder_input)
        enc_logvar = nn.Softplus()(enc_logvar)
        prior_input =self.prior(u)
        prior_mean  = torch.cat([ self.prior_mean[n](prior_input).unsqueeze(1) for n in range(self.n_mixtures)],dim=1,)
        prior_logvar = torch.cat([self.prior_logvar[n](prior_input).unsqueeze(1)for n in range(self.n_mixtures)],dim=1,)
        prior_w     = self.prior_weights(prior_input)
        prior_sigma = prior_logvar.exp().sqrt()
        prior_dist = self.dist(self.mix(logits=prior_w), Independent(self.comp(prior_mean, prior_sigma), 1))
        post_dist = self.comp(enc_mean, enc_logvar.exp().sqrt())
        z_t      = self.reparametrization(enc_mean, enc_logvar)
        return prior_dist, post_dist, z_t
     def reparametrization(self, mu, log_var):
        var = torch.exp(log_var* 0.5)
        eps = torch.FloatTensor(var.size()).normal_(mean=0, std=1).to(self.device)
        eps = torch.autograd.Variable(eps)
        return eps.mul(var).add_(mu).add_(1e-7)     

How do you suggest I can use library to compute the KL term? Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant
0