10000 Internal by copybara-service[bot] · Pull Request #1219 · google/vizier · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Internal #1219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by 8000 extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ def __call__(
)()


def create_hv_scalarization(
num_scalarizations: int, labels: types.PaddedArray, rng: jax.Array
):
"""Creates a HyperVolumeScalarization with random weights.

Args:
num_scalarizations: The number of scalarizations to create.
labels: The labels used to create the reference point.
rng: The random key to use for sampling the weights.

Returns:
A HyperVolumeScalarization with random weights.
"""
weights = jax.random.normal(
rng,
shape=(num_scalarizations, labels.shape[1]),
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
ref_point = (
get_reference_point(labels, scale=0.01) if labels.shape[0] > 0 else None
)
return scalarization.HyperVolumeScalarization(weights, ref_point)


# TODO: What do we end up jitting? If we end up directly jitting this call
# then we should make it `eqx.Module` and set
# `reduction_fn=eqx.field(static=True)` instead.
Expand Down
18 changes: 4 additions & 14 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.designers import scalarization
from vizier._src.algorithms.designers.gp import acquisitions as acq_lib
from vizier._src.algorithms.designers.gp import gp_models
from vizier._src.algorithms.designers.gp import output_warpers
Expand Down Expand Up @@ -202,27 +201,18 @@ def __attrs_post_init__(self):
# Multi-objective overrides.
m_info = self._problem.metric_information
if not m_info.is_single_objective:
num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE))

# Create scalarization weights.
self._rng, weights_rng = jax.random.split(self._rng)
weights = jax.random.normal(
weights_rng, shape=(self._num_scalarizations, num_obj)
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
ref_point = (
acq_lib.get_reference_point(data.labels, self._ref_scaling)
if has_labels
else None
scalarizer = acq_lib.create_hv_scalarization(
self._num_scalarizations, data.labels, weights_rng
)
scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)

labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
max_scalarized = None
if has_labels:
max_scalarized = jnp.max(scalarizer(labels_array), axis=-1)
Expand Down
Loading
Loading
0