8000 Add Sample acquisition function to GP bandit. by copybara-service[bot] · Pull Request #1141 · google/vizier · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add Sample acquisition function to GP bandit. #1141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from vizier._src.jax import types
from vizier._src.jax.models import continuous_only_kernel


tfd = tfp.distributions
tfp_bo = tfp.experimental.bayesopt
tfpke = tfp.experimental.psd_kernels
Expand Down Expand Up @@ -274,6 +273,22 @@ def __call__(
)()


@struct.dataclass
class Sample(AcquisitionFunction):
"""Sample AcquisitionFunction."""

num_samples: int = 1000

def __call__(
self,
dist: tfd.Distribution,
seed: Optional[jax.Array] = None,
) -> jax.Array:
if seed is None:
seed = jax.random.PRNGKey(0)
return dist.sample([self.num_samples], seed=seed)


class MaxValueEntropySearch(eqx.Module):
"""MES score function. Implements the `ScoreFunction` protocol."""

Expand Down
29 changes: 29 additions & 0 deletions vizier/_src/algorithms/designers/gp/acquisitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ def test_ehvi_approximation(self):
rtol=1e-1,
)

def test_ehvi_mcmc(self):
num_obj = 2
num_scalarizations = 1000
weights = jnp.abs(
jax.random.normal(
jax.random.PRNGKey(0), shape=(num_scalarizations, num_obj)
)
)
weights = weights / jnp.linalg.norm(weights, axis=1, keepdims=True)

scalarizer = scalarization.HyperVolumeScalarization(
weights,
)

# Tests that the scalarizer gives the approximate hypervolume with mean
# and uses constant rescaling of pi/4 for num_objs=2.
hypervolume = acquisitions.ScalarizedAcquisition(
acquisitions.Sample(num_samples=100),
scalarizer,
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x)),
)
# Expected hypervolume should be close to 2 * 1.5 = 3.0.
stddev = 0.01 * jnp.ones(num_obj)
np.testing.assert_allclose(
hypervolume(tfd.Normal(jnp.array([2, 1.5]), stddev)) * (3.1415) / 4.0,
jnp.array([3.0]),
rtol=1e-1,
)

def test_pi(self):
labels = types.PaddedArray.as_padded(jnp.array([[0.2]]))
best_labels = acquisitions.get_best_labels(labels)
Expand Down
46 changes: 35 additions & 11 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def from_problem(
seed: Optional[int] = None,
num_scalarizations: int = 1000,
reference_scaling: float = 0.01,
num_samples: int | None = None,
**kwargs,
) -> 'VizierGPBandit':
rng = jax.random.PRNGKey(seed or 0)
Expand All @@ -595,19 +596,42 @@ def from_problem(
)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

def _scalarized_ucb(data: types.ModelData) -> acq_lib.AcquisitionFunction:
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, scale=reference_scaling),
)
return acq_lib.ScalarizedAcquisition(
acq_lib.UCB(),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
)
if num_samples is None:

def _scalarized_ucb(
data: types.ModelData,
) -> acq_lib.AcquisitionFunction:
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, reference_scaling),
)
return acq_lib.ScalarizedAcquisition(
acq_lib.UCB(),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
)

acq_fn_factory = _scalarized_ucb
else:

def _scalarized_sample_ehvi(
data: types.ModelData,
) -> acq_lib.AcquisitionFunction:
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, reference_scaling),
)
return acq_lib.ScalarizedAcquisition(
acq_lib.Sample(num_samples),
scalarizer,
# We need to reduce across the scalarization and sample axes.
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1]),
)

acq_fn_factory = _scalarized_sample_ehvi

scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
_scalarized_ucb
acq_fn_factory
)
return cls(
problem,
Expand Down
10 changes: 8 additions & 2 deletions vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,11 @@ def _qei_factory(data: types.ModelData) -> acquisitions.AcquisitionFunction:
iters * n_parallel,
)

def test_multi_metrics(self):
@parameterized.parameters(
dict(num_samples=10),
dict(num_samples=None),
)
def test_multi_metrics(self, num_samples: int | None):
search_space = vz.SearchSpace()
search_space.root.add_float_param('x0', -5.0, 5.0)
problem = vz.ProblemStatement(
Expand All @@ -483,7 +487,9 @@ def test_multi_metrics(self):
)

iters = 2
designer = gp_bandit.VizierGPBandit.from_problem(problem)
designer = gp_bandit.VizierGPBandit.from_problem(
problem, num_samples=num_samples
)
self.assertLen(
test_runners.RandomMetricsRunner(
problem,
Expand Down
Loading
0