8000 GitHub - n-kall/POTNet: Implementation of Generative Modeling with Penalized Optimal Transport Network (POTNet)
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
/ POTNet Public
forked from sophial05/POTNet

Implementation of Generative Modeling with Penalized Optimal Transport Network (POTNet)

License

Notifications You must be signed in to change notification settings

n-kall/POTNet

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generative Modeling for Tabular Data via Penalized Optimal Transport Network


Paper URL

Overview

POTNet is a deep generative model for generating mixed-type tabular data (continuous/discrete/categorical) based on the marginally-penalized Wasserstein loss.

This repository contains implementation of the POTNet model described in the paper:

Generative Modeling for Tabular Data via Penalized Optimal Transport Network

Installation

Python version: 3.10

Intall dependences:

pip install -r requirements.txt

Usage

  • Input: When using POTNet, continuous and discrete features should be represented as float. Categorical columns can be represented as either str, int, or float.

  • Categorical features: You will need to specify which columns correspond to categorical features via the categorical_cols argument using either column names (e.g. ['SEX', 'EDUCATION']) or column indices ([1, 2]).

  • Numeric output:

    • If all numeric features are discrete, then specify numeric_output_data_type = "integer".
    • If numeric features consist of both discrete and continuous types, then you need to specify each numeric feature type using a dict:
        {'integer': [col1, col2], 
        'continuous': [col3, col4]}
    • By default, all numeric features are assumed to be continuous. You can also manually specify this via numeric_output_data_type = "continuous"
  • Conditional POTNet: You can condition on a selection of numeric features by setting conditional = True in model initialization, and subsequently passing the features you wish to condition on via conditioning_data to POTNet during the fitting stage. For a concrete example, please see ./example/demo_lfi.ipynb under section Conditional POTNet.

  • Missing values: The data should not contain any missing values.

Example

We provide two examples for using POTNet, both deposited in the ./examples folder:

  1. Likelihood-free inference (./examples/demo_lfi.ipynb): We generate 3,000 samples consisting of all continous features from the model

    • $\theta_i \sim \mathrm{Unif}[-3, 3$ for $i = 1, \dots, 5$
    • $\mu = (\theta_1, \theta_2)$
    • $\sigma_1 = \theta_3^2$
    • $\sigma_2 = \theta_4^2$
    • $\rho = \tanh(\theta_5)$
    • $\Sigma = ((\sigma_1^2, ~\rho \sigma_1 \sigma_2 ), (\rho \sigma_1 \sigma_2, ~\sigma_2^2))$
    • $X_j \sim \mathcal{N}(\mu, \Sigma)$ for $j = 1, \dots, 4$
  2. Default of credit card clients (./examples/demo_credit.ipynb): For an example illustrating usage of POTNet for mixed-data types, We subsampled 1,000 samples from the credit default dataset, deposited in ./data/credit_card.csv.

Below, we provide a simple template for using POTNet:

from potnet import *

# load data
data = ...

# specify categorical columns
cat_cols = [col1, col2, col3]

# initialize POTNet
potnet_model= POTNet(embedding_dim=data.shape[1],
                      categorical_cols=cat_cols,
                      numeric_output_data_type = 'continuous', # continuous data
                      epochs=500,
                      batch_size=256)

# fit POTNet
potnet_model.fit(data)

# generate 1000 synthetic samples
gen_data = potnet_model.generate(1000)

# save model
potnet_model.save('potnet_model.pt')

Reference

If you use POTNet in your research, please cite the following paper:

@article{lu2024generative,
  title={Generative Modeling for Tabular Data via Penalized Optimal Transport Network},
  author={Lu, Wenhui Sophia and Zhong, Chenyang and Wong, Wing Hung},
  journal={arXiv preprint arXiv:2402.10456},
  year={2024}
}

About

Implementation of Generative Modeling with Penalized Optimal Transport Network (POTNet)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%
0