8000 Bias Mitigation and Direction Methods by ArjunSubramonian · Pull Request #5130 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Bias Mitigation and Direction Methods #5130

Merged
merged 25 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
79c6c33
added linear and hard debiasers
Apr 13, 2021
e23057c
worked on documentation
Apr 14, 2021
fcc3d34
committing changes before branch switch
Apr 14, 2021
7d00910
committing changes before switching branch
Apr 15, 2021
668a513
finished bias direction, linear and hard debiasers, need to write tests
Apr 15, 2021
91029ef
finished bias direction test
Apr 15, 2021
396b245
Commiting changes before switching branch
Apr 16, 2021
a8c22a1
finished hard and linear debiasers
Apr 16, 2021
ef6a062
finished OSCaR
Apr 17, 2021
2c873cb
bias mitigators tests and bias metrics remaining
Apr 17, 2021
d97a526
added bias mitigator tests
Apr 18, 2021
8460281
added bias mitigator tests
Apr 18, 2021
5a76922
finished tests for bias mitigation methods
Apr 19, 2021
85cb107
Merge remote-tracking branch 'origin/main' into arjuns/post-processin…
Apr 19, 2021
8e55f28
fixed gpu issues
Apr 19, 2021
b42b73a
fixed gpu issues
Apr 19, 2021
37d8e33
fixed gpu issues
Apr 20, 2021
31b1d2c
resolve issue with count_nonzero not being differentiable
Apr 20, 2021
a1f4f2a
merged main into post-processing-debiasing
Apr 21, 2021
8000 36cebe3
added more references
Apr 21, 2021
88c083b
Merge branch 'main' of https://github.com/allenai/allennlp into arjun…
Apr 28, 2021
7269c1d
Merge branch 'main' into arjuns/post-processing-debiasing
schmmd May 6, 2021
24ce58f
Merge branch 'main' into arjuns/post-processing-debiasing
AkshitaB May 7, 2021
4495627
responded to Akshita's comments
May 9, 2021
1182b10
Merge branch 'arjuns/post-processing-debiasing' of https://github.com…
May 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars.
- The test for distributed metrics now takes a parameter specifying how often you want to run it.
- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, `Sufficiency`, and `DemographicParityWithoutGroundTruth`.
- Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`).

### Changed

Expand Down
12 changes: 12 additions & 0 deletions allennlp/fairness/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,15 @@
Sufficiency,
DemographicParityWithoutGroundTruth,
)
from allennlp.fairness.bias_direction import (
PCABiasDirection,
PairedPCABiasDirection,
ClassificationNormalBiasDirection,
TwoMeansBiasDirection,
)
from allennlp.fairness.bias_mitigators import (
LinearBiasMitigator,
HardBiasMitigator,
INLPBiasMitigator,
OSCaRBiasMitigator,
)
301 changes: 301 additions & 0 deletions allennlp/fairness/bias_direction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
"""
A suite of differentiable methods to compute the bias direction
or concept subspace representing binary protected variables.
"""

import torch
import sklearn
import numpy as np

from allennlp.common.checks import ConfigurationError


class BiasDirection:
"""
Parent class for bias direction classes.

# Parameters

requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation.
"""

def __init__(self, requires_grad: bool = False):
self.requires_grad = requires_grad

def _normalize_bias_direction(self, bias_direction: torch.Tensor):
return bias_direction / torch.linalg.norm(bias_direction)


class PCABiasDirection(BiasDirection):
"""
PCA-based bias direction. Computes one-dimensional subspace that is the span
of a specific concept (e.g. gender) using PCA. This subspace minimizes the sum of
squared distances from all seed word embeddings.

!!! Note
It is uncommon to utilize more than one direction to represent a concept.

Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar,
V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021).
[VERB: Visualizing and Interpreting Bias Mitigation Techniques for
Word Representations](https://api.semanticscholar.org/CorpusID:233168618).
ArXiv, abs/2104.02797.
"""

def __call__(self, seed_embeddings: torch.Tensor):
"""

# Parameters

!!! Note
In the examples below, we treat gender identity as binary, which does not accurately
characterize gender in real life.

seed_embeddings : `torch.Tensor`
A tensor of size (batch_size, ..., dim) containing seed word embeddings related to
a concept. For example, if the concept is gender, seed_embeddings could contain embeddings
for words like "man", "king", "brother", "woman", "queen", "sister", etc.

# Returns

bias_direction : `torch.Tensor`
A unit tensor of size (dim, ) representing the concept subspace.
"""

# Some sanity checks
if seed_embeddings.ndim < 2:
raise ConfigurationError("seed_embeddings1 must have at least two dimensions.")

with torch.set_grad_enabled(self.requires_grad):
# pca_lowrank centers the embeddings by default
# There will be two dimensions when applying PCA to
# definitionally-gendered words: 1) the gender direction,
# 2) all other directions, with the gender direction being principal.
_, _, V = torch.pca_lowrank(seed_embeddings, q=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we set q=2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the VERB implementation + paper. I think the intuition behind this is that there will be two dimensions when applying PCA to definitionally-gendered words: 1) the gender direction, 2) all other directions, with the gender direction being principal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment in the file itself

# get top principal component
bias_direction = V[:, 0]
return self._normalize_bias_direction(bias_direction)


class PairedPCABiasDirection(BiasDirection):
"""
Paired-PCA-based bias direction. Computes one-dimensional subspace that is the span
of a specific concept (e.g. gender) as the first principle component of the
difference vectors between seed word embedding pairs.

!!! Note
It is uncommon to utilize more than one direction to represent a concept.

Based on: T. Bolukbasi, K. W. Chang, J. Zou, V. Saligrama, and A. Kalai. [Man is to
computer programmer as woman is to homemaker? debiasing word embeddings]
(https://api.semanticscholar.org/CorpusID:1704893).
In ACM Transactions of Information Systems, 2016.

Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar,
V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021).
[VERB: Visualizing and Interpreting Bias Mitigation Techniques for
Word Representations](https://api.semanticscholar.org/CorpusID:233168618).
ArXiv, abs/2104.02797.
"""

def __call__(self, seed_embeddings1: torch.Tensor, seed_embeddings2: torch.Tensor):
"""

# Parameters

!!! Note
In the examples below, we treat gender identity as binary, which does not accurately
characterize gender in real life.

seed_embeddings1 : `torch.Tensor`
A tensor of size (batch_size, ..., dim) containing seed word
embeddings related to a concept group. For example, if the concept is gender,
seed_embeddings1 could contain embeddings for linguistically masculine words, e.g.
"man", "king", "brother", etc.

seed_embeddings2: `torch.Tensor`
A tensor of the same size as seed_embeddings1 containing seed word
embeddings related to a different group for the same concept. For example,
seed_embeddings2 could contain embeddings for linguistically feminine words, e.g.
"woman", "queen", "sister", etc.

!!! Note
For Paired-PCA, the embeddings at the same positions in each of seed_embeddings1 and
seed_embeddings2 are expected to form seed word pairs. For example, if the concept
is gender, the embeddings for ("man", "woman"), ("king", "queen"), ("brother", "sister"), etc.
should be at the same positions in seed_embeddings1 and seed_embeddings2.

!!! Note
All tensors are expected to be on the same device.

# Returns

bias_direction : `torch.Tensor`
A unit tensor of size (dim, ) representing the concept subspace.
"""

# Some sanity checks
if seed_embeddings1.size() != seed_embeddings2.size():
raise ConfigurationError("seed_embeddings1 and seed_embeddings2 must be the same size.")
if seed_embeddings1.ndim < 2:
raise ConfigurationError(
"seed_embeddings1 and seed_embeddings2 must have at least two dimensions."
)

with torch.set_grad_enabled(self.requires_grad):
paired_embeddings = seed_embeddings1 - seed_embeddings2
_, _, V = torch.pca_lowrank(
paired_embeddings,
q=min(paired_embeddings.size(0), paired_embeddings.size(1)) - 1,
)
bias_direction = V[:, 0]
return self._normalize_bias_direction(bias_direction)


class TwoMeansBiasDirection(BiasDirection):
"""
Two-means bias direction. Computes one-dimensional subspace that is the span
of a specific concept (e.g. gender) as the normalized difference vector of the
averages of seed word embedding sets.

!!! Note
It is uncommon to utilize more than one direction to represent a concept.

Based on: Dev, S., & Phillips, J.M. (2019). [Attenuating Bias in Word Vectors]
(https://api.semanticscholar.org/CorpusID:59158788). AISTATS.

Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar,
V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021).
[VERB: Visualizing and Interpreting Bias Mitigation Techniques for
Word Representations](https://api.semanticscholar.org/CorpusID:233168618).
ArXiv, abs/2104.02797.
"""

def __call__(self, seed_embeddings1: torch.Tensor, seed_embeddings2: torch.Tensor):
"""

# Parameters

!!! Note
In the examples below, we treat gender identity as binary, which does not accurately
characterize gender in real life.

seed_embeddings1 : `torch.Tensor`
A tensor of size (embeddings1_batch_size, ..., dim) containing seed word
embeddings related to a specific concept group. For example, if the concept is gender,
seed_embeddings1 could contain embeddings for linguistically masculine words, e.g.
"man", "king", "brother", etc.
seed_embeddings2: `torch.Tensor`
A tensor of size (embeddings2_batch_size, ..., dim) containing seed word
embeddings related to a different group for the same concept. For example,
seed_embeddings2 could contain embeddings for linguistically feminine words, , e.g.
"woman", "queen", "sister", etc.

!!! Note
seed_embeddings1 and seed_embeddings2 need NOT be the same size. Furthermore,
the embeddings at the same positions in each of seed_embeddings1 and seed_embeddings2
are NOT expected to form seed word pairs.

!!! Note
All tensors are expected to be on the same device.

# Returns

bias_direction : `torch.Tensor`
A unit tensor of size (dim, ) representing the concept subspace.
"""
# Some sanity checks
if seed_embeddings1.ndim < 2 or seed_embeddings2.ndim < 2:
raise ConfigurationError(
"seed_embeddings1 and seed_embeddings2 must have at least two dimensions."
)
if seed_embeddings1.size(-1) != seed_embeddings2.size(-1):
raise ConfigurationError("All seed embeddings must have same dimensionality.")

with torch.set_grad_enabled(self.requires_grad):
seed_embeddings1_mean = torch.mean(seed_embeddings1, dim=0)
seed_embeddings2_mean = torch.mean(seed_embeddings2, dim=0)
bias_direction = seed_embeddings1_mean - seed_embeddings2_mean
return self._normalize_bias_direction(bias_direction)


class ClassificationNormalBiasDirection(BiasDirection):
"""
Classification normal bias direction. Computes one-dimensional subspace that is the span
of a specific concept (e.g. gender) as the direction perpendicular to the classification
boundary of a linear support vector machine fit to classify seed word embedding sets.

!!! Note
It is uncommon to utilize more than one direction to represent a concept.

Based on: Ravfogel, S., Elazar, Y., Gonen, H., Twiton, M., & Goldberg, Y. (2020).
[Null It Out: Guarding Protected Attributes by Iterative Nullspace Projection]
(https://api.semanticscholar.org/CorpusID:215786522). ArXiv, abs/2004.07667.

Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar,
V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021).
[VERB: Visualizing and Interpreting Bias Mitigation Techniques for
Word Representations](https://api.semanticscholar.org/CorpusID:233168618).
ArXiv, abs/2104.02797.
"""

def __init__(self):
super().__init__()

def __call__(self, seed_embeddings1: torch.Tensor, seed_embeddings2: torch.Tensor):
"""

# Parameters

!!! Note
In the examples below, we treat gender identity as binary, which does not accurately
characterize gender in real life.

seed_embeddings1 : `torch.Tensor`
A tensor of size (embeddings1_batch_size, ..., dim) containing seed word
embeddings related to a specific concept group. For example, if the concept is gender,
seed_embeddings1 could contain embeddings for linguistically masculine words, e.g.
"man", "king", "brother", etc.
seed_embeddings2: `torch.Tensor`
A tensor of size (embeddings2_batch_size, ..., dim) containing seed word
embeddings related to a different group for the same concept. For example,
seed_embeddings2 could contain embeddings for linguistically feminine words, , e.g.
"woman", "queen", "sister", etc.

!!! Note
seed_embeddings1 and seed_embeddings2 need NOT be the same size. Furthermore,
the embeddings at the same positions in each of seed_embeddings1 and seed_embeddings2
are NOT expected to form seed word pairs.

!!! Note
All tensors are expected to be on the same device.

!!! Note
This bias direction method is NOT differentiable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we intend to allow users to specify bias direction (and mitigator) methods in config, perhaps we should make "is_differentiable" a field, so that the list of methods which can be used can be obtained programmatically?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is part of the bias mitigators and direction wrappers PR - this PR is just the functional API.


# Returns

bias_direction : `torch.Tensor`
A unit tensor of size (dim, ) representing the concept subspace.
"""

# Some sanity checks
if seed_embeddings1.ndim < 2 or seed_embeddings2.ndim < 2:
raise ConfigurationError(
"seed_embeddings1 and seed_embeddings2 must have at least two dimensions."
)
if seed_embeddings1.size(-1) != seed_embeddings2.size(-1):
raise ConfigurationError("All seed embeddings must have same dimensionality.")

device = seed_embeddings1.device
seed_embeddings1 = seed_embeddings1.flatten(end_dim=-2).detach().cpu().numpy()
seed_embeddings2 = seed_embeddings2.flatten(end_dim=-2).detach().cpu().numpy()

X = np.vstack([seed_embeddings1, seed_embeddings2])
Y = np.concatenate([[0] * seed_embeddings1.shape[0], [1] * seed_embeddings2.shape[0]])

classifier = sklearn.svm.SVC(kernel="linear").fit(X, Y)
bias_direction = torch.Tensor(classifier.coef_[0]).to(device)

return self._normalize_bias_direction(bias_direction)
Loading
0