8000 Prepare branch for TFP 0.12.1 release by jburnim · Pull Request #1205 · tensorflow/probability · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Prepare branch for TFP 0.12.1 release #1205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tensorflow_probability/python/bijectors/bijector.py
< 10000 /tr>
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def __init__(self,
name = name_util.strip_invalid_chars(name)
super(Bijector, self).__init__(name=name)
self._name = name
# TODO(b/176242804): Infer `parameters` if not specified by the child class.
self._parameters = self._no_dependency(parameters)

self._graph_parents = self._no_dependency(graph_parents or [])
Expand Down Expand Up @@ -648,6 +649,8 @@ def parameters(self):
# Remove "self", "__class__", or other special variables. These can appear
# if the subclass used:
# `parameters = dict(locals())`.
if self._parameters is None:
return None
return {k: v for k, v in self._parameters.items()
if not k.startswith('__') and k != 'self'}

Expand Down Expand Up @@ -689,6 +692,11 @@ def __eq__(self, other):
return True

def _get_parameterization(self):
if self.parameters is None:
# If a user-written bijector doesn't specify `parameters`, we must assume
# that all instances are unique.
# TODO(b/176242804): this can be removed if we always infer `parameters`.
return id(self)
return self.parameters

def __call__(self, value, name=None, **kwargs):
Expand Down
30 changes: 30 additions & 0 deletions tensorflow_probability/python/bijectors/bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,24 @@ def _get_parameterization(self):
return id(self)


class UnspecifiedParameters(tfb.Bijector):
"""A bijector that fails to pass `parameters` to the base class."""

def __init__(self, loc):
self._loc = loc
super(UnspecifiedParameters, self).__init__(
validate_args=False,
is_constant_jacobian=True,
forward_min_event_ndims=0,
name='unspecified_parameters')

def _forward(self, x):
return x + self._loc

def _forward_log_det_jacobian(self, x):
return tf.constant(0., x.dtype)


@test_util.test_all_tf_execution_regimes
class BijectorTestEventNdims(test_util.TestCase):

Expand Down Expand Up @@ -440,6 +458,18 @@ def testUniqueCacheKey(self):
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)

def testBijectorsWithUnspecifiedParametersDoNotShareCache(self):
bijector_1 = UnspecifiedParameters(loc=tf.constant(1., dtype=tf.float32))
bijector_2 = UnspecifiedParameters(loc=tf.constant(2., dtype=tf.float32))

x = tf.constant(3., dtype=tf.float32)
y_1 = bijector_1.forward(x)
y_2 = bijector_2.forward(x)

self.assertIsNot(y_1, y_2)
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)

def testInstanceCache(self):
instance_cache_bijector = tfb.Exp()
instance_cache_bijector._cache = cache_util.BijectorCache(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '0'
_MINOR_VERSION = '12'
_PATCH_VERSION = '0'
_PATCH_VERSION = '1'

# When building releases, we can update this value on the release branch to
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official
Expand Down
0