8000 Add accessor and comments about CausalModel's CausalEstimator cache. by drawlinson · Pull Request #1113 · py-why/dowhy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add accessor and comments about CausalModel's CausalEstimator cache. #1113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions dowhy/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,25 @@ def __init__(
logger=self.logger,
)

def get_estimator(self, method_name):
"""
Retrieves an existing CausalEstimator object matching the given `method_name`.

CausalEstimator objects are created in `estimate_effect()` and stored in a cache for reuse.
Different instances can be created for different methods.
They may be reused multiple times on different data with `estimate_effect(fit_estimator=False)`.
This is useful for e.g. estimating effects on different samples of the same dataset.

The `CausalEstimate` object returned by `estimate_effect()` also has a reference to the `CausalEstimator` object used to produce it:

`effect = model.estimate_effect(...)`
`effect.estimator # returns the fitted CausalEstimator estimator object`

:param method_name: name of the estimation method to be used.
:returns: An instance of CausalEstimator for the given method, if it exists, or None.
"""
return self._estimator_cache.get(method_name)

def init_graph(self, graph, identify_vars):
"""
Initialize self._graph using graph provided by the user.
Expand Down Expand Up @@ -320,9 +339,13 @@ def estimate_effect(

identified_estimand.set_identifier_method(identifier_name)

if not fit_estimator and method_name in self._estimator_cache:
causal_estimator = self._estimator_cache[method_name]
else:
# If not fit_estimator, attempt to retrieve existing estimator.
# Keep original behaviour to create new estimator if none found.
causal_estimator = None
if not fit_estimator:
causal_estimator = self.get_estimator(method_name)

if causal_estimator is None:
causal_estimator = causal_estimator_class(
identified_estimand,
test_significance=test_significance,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,49 @@ def test_incorrect_graph_format(self):
graph=nx.Graph([("X", "Y"), ("Y", "Z")]),
)

def test_causal_estimator_cache(self):
"""
Tests that CausalEstimator objects can be consistently retrieved from CausalEstimate and CausalModel objects.
"""
beta = 10
num_samples = 100
num_treatments = 1
num_common_causes = 5
data = dowhy.datasets.linear_dataset(
beta=beta,
num_common_causes=num_common_causes,
num_samples=num_samples,
num_treatments=num_treatments,
treatment_is_binary=True,
)

model = CausalModel(
data=data["df"],
treatment=data["treatment_name"],
outcome=data["outcome_name"],
graph=data["gml_graph"],
proceed_when_unidentifiable=True,
test_significance=None,
)

identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
methods = [
"backdoor.linear_regression",
"backdoor.propensity_score_matching",
]
estimates = []
estimates.append(
model.estimate_effect(identified_estimand, method_name=methods[0], control_value=0, treatment_value=1)
)
estimates.append(
model.estimate_effect(identified_estimand, method_name=methods[1], control_value=0, treatment_value=1)
)

# Default == operator tests if same object. If same object, don't need to check type.
assert (estimates[0].estimator) == model.get_estimator(methods[0])
assert (estimates[1].estimator) == model.get_estimator(methods[1])
assert (estimates[0].estimator) != model.get_estimator(methods[1]) # check not same object


if __name__ == "__main__":
pytest.main([__file__])
0