From 32bddfb0f3e827667ee786c9ae96bc610e764c72 Mon Sep 17 00:00:00 2001 From: suadou Date: Thu, 21 Mar 2024 13:31:40 +0100 Subject: [PATCH 1/6] Adding cell list to neighbor list --- glp/neighborlist.py | 279 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 241 insertions(+), 38 deletions(-) diff --git a/glp/neighborlist.py b/glp/neighborlist.py index cb57c5b..2398fd4 100644 --- a/glp/neighborlist.py +++ b/glp/neighborlist.py @@ -2,11 +2,21 @@ The role of a neighborlist implementation is to reduce the amount of pairs of atoms to consider from N*N to something linear in N by removing pairs -that are further away than a given cutoff radius. +that are farther away than a given cutoff radius. -This file implements this in a naive way: We first generate all N*N combinations, -and then trim down these candidates into a fixed number of pairs. Once that number -is decided, this procedure is jittable. The initial allocation is not. +This file implements this in two ways: +- N-square: We first generate all N*N combinations, and then trim down these + candidates into a fixed number of pairs. Once that number is + decided, this procedure is jittable. The initial allocation is not. + +- Cell-list: We first divide the simulation cell into a grid of cells, and + assign each atom to a cell. Then, for each atom, we only consider + the atoms in the same cell and its neighbors. This is jittable. + The initial allocation is not. + +The cell-list implementation is more efficient for large systems, but it is +more complex and requires the simulation cell to be periodic. The N-square +implementation is simpler and can be used for non-periodic systems. The general data format consists of two index arrays, such that the indices in `centers` contains the atom from which atom-pair vectors originate, while the @@ -17,20 +27,23 @@ """ from collections import namedtuple -from jax import jit, vmap +import jax +from jax import vmap, ops import jax.numpy as jnp -from jax.lax import stop_gradient, cond -from functools import partial -from typing import Callable +from jax.lax import stop_gradient, cond, iota +import numpy as np from glp import comms from .periodic import displacement, get_heights from .utils import boolean_mask_1d, cast, squared_distance Neighbors = namedtuple( - "Neighbors", ("centers", "others", "overflow", "reference_positions") + "Neighbors", ("centers", "others", "overflow", "reference_positions", "cell_list") ) +CellList = namedtuple( + "CellList", ("id", "overflow", "capacity", "size", "cells_per_side") +) def neighbor_list(system, cutoff, skin, capacity_multiplier=1.25): # convenience interface: we often don't need explicit access to an allocate_fn, since @@ -65,14 +78,15 @@ def _update(system, neighbors): return neighbors, _update -def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, debug=False): +def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, use_cell_list=False, debug=False): + """Implementation of neighborlist in pbc using cell list.""" + assert capacity_multiplier >= 1.0 cell = stop_gradient(cell) cutoff = cast(stop_gradient(cutoff)) skin = cast(stop_gradient(skin)) - if cell is not None: def cell_too_small(new_cell): @@ -91,11 +105,23 @@ def cell_too_small(new_cell): ) comms.warn("this will yield incorrect results!") - squared_cutoff = (cutoff + skin) ** cast(2.0) + cutoff = (cutoff + skin) allowed_movement = (skin * cast(0.5)) ** cast(2.0) + if cell is not None and use_cell_list: + if jnp.all(jnp.any(cutoff < cell / 3., axis=1)): + cl_allocate, cl_update = cell_list(cell, cutoff) + else: + cl_allocate = cl_update = None + else: + cl_allocate = cl_update = None + def need_update_fn(neighbors, new_positions, new_cell): + # question: how to deal with changes in new_cell? + # we will invalidate if atoms move too much, but not if + # the new cell is too small for the cutoff... + movement = make_squared_distance(new_cell)( neighbors.reference_positions, new_positions ) @@ -104,7 +130,7 @@ def need_update_fn(neighbors, new_positions, new_cell): # if we have an overflow, we don't need to update -- results are # already invalid. instead, we will simply return previous invalid # neighbors and hope someone up the chain catches it! - should_update = (max_movement > allowed_movement) & (~neighbors.overflow) + should_update = (max_movement > allowed_movement) & (~neighbors.overflow) # if the cell is now *too small*, we need to update for sure to set overflow flag return should_update | cell_too_small(new_cell) @@ -120,51 +146,52 @@ def allocate_fn(positions, new_cell=None, padding_mask=None): N = positions.shape[0] + cl = cl_allocate(positions) if cl_allocate is not None else None + centers, others, sq_distances, mask, hits = get_neighbors( positions, - new_cell, make_squared_distance(new_cell), - squared_cutoff, + cutoff, padding_mask=padding_mask, + cl=cl ) - + # Hits estan mal. el numero es demasiado grande size = int(hits.item() * capacity_multiplier + 1) centers, _ = boolean_mask_1d(centers, mask, size, N) others, overflow = boolean_mask_1d(others, mask, size, N) - overflow = overflow | cell_too_small(new_cell) - # print("done with neighbors=None branch") - return Neighbors(centers, others, overflow, positions) - - def update_fn( - positions, neighbors, new_cell=None, padding_mask=None, force_update=False - ): + return Neighbors(centers, others, overflow, positions, cl) + + def update_fn(positions, neighbors, new_cell=None, padding_mask=None, force_update=False): # this is jittable, # neighbors tells us all the shapes we need - if new_cell is None: new_cell = cell else: new_cell = stop_gradient(new_cell) - + + N = positions.shape[0] + dim = positions.shape[1] size = neighbors.centers.shape[0] - - def update(positions, cell, padding_mask): + def update(positions, cell, padding_mask, cl=neighbors.cell_list): + # if cl: + # cell_position, cell_id, overflow, cell_capacity, cell_size, cells_per_side = cl + cl = cl_update(positions, cl, cell) if cl_update is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, - cell, make_squared_distance(cell), - squared_cutoff, + cutoff, padding_mask=padding_mask, + cl = cl ) centers, _ = boolean_mask_1d(centers, mask, size, N) others, overflow = boolean_mask_1d(others, mask, size, N) overflow = overflow | cell_too_small(cell) - return Neighbors(centers, others, overflow, positions) + return Neighbors(centers, others, overflow, positions, cl) # if we need an update, call update(), else do a no-op and return input return cond( @@ -189,23 +216,199 @@ def candidates_fn(n): others = jnp.reshape(square, (-1,)) return centers, others +def cell_list_candidate_fn(cell_id, N, dim, _): + idx = cell_id + + cell_idx = [idx] * (dim**3) + for i, dindex in enumerate(neighboring_cells(dim)): + cell_idx[i] = shift_array(idx, dindex) + + cell_idx = jnp.concatenate(cell_idx, axis=-2) + cell_idx = cell_idx[..., jnp.newaxis, :, :] + cell_idx = jnp.broadcast_to(cell_idx, idx.shape[:-1] + cell_idx.shape[-2:]) + def copy_values_from_cell(value, cell_value, cell_id): + scatter_indices = jnp.reshape(cell_id, (-1,)) + cell_value = jnp.reshape(cell_value, (-1,) + cell_value.shape[-2:]) + return value.at[scatter_indices].set(cell_value) + neighbor_idx = jnp.zeros((N + 1,) + cell_idx.shape[-2:], jnp.int32) + neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) + others = jnp.reshape(neighbor_idx[:-1, :, 0], (-1,)) + centers = jnp.repeat(jnp.arange(0, N), others.shape[0] // N) + return centers, others -def get_neighbors(positions, cell, square_distances, cutoff, padding_mask=None): - centers, others = candidates_fn(positions.shape[0]) +def get_neighbors(positions, square_distances, cutoff, padding_mask=None, cl=None): + N, dim = positions.shape + if cl is not None: + centers, others = cell_list_candidate_fn(cl.id, N, dim, cl.cells_per_side) + else: + centers, others = candidates_fn(N) + sq_distances = square_distances(positions[centers], positions[others]) - - mask = sq_distances <= cutoff - mask = mask * (centers != others) # remove self-interactions - - # mask out fake positions + mask = sq_distances <= (cutoff**2) + mask = mask * ((centers != others) & (others cell_capacity) + + return CellList(cell_id, overflow, cell_capacity, cell_size, cells_per_side) + + def update_fn(positions, old_cell_list, new_cell): + # this is jittable, + # CellList tells us all the shapes we need + N = positions.shape[0] + dim = positions.shape[1] + + _, cell_size, cells_per_side, _ = cell_dimensions(new_cell, cutoff) + max_occupancy = estimate_cell_capacity(positions, new_cell, cell_size, 1) + def update(positions, old_cell_list, new_cell): + cell_capacity = old_cell_list.capacity + overflow = old_cell_list.overflow + hash_multipliers = compute_hash_constants(cells_per_side) + indices = jnp.array(jnp.floor(jnp.dot(move_to_cell(positions, new_cell), jnp.linalg.inv(cell_size).T)), dtype=jnp.int32) + hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) + particle_id = iota(jnp.int32, N) + sort_map = jnp.argsort(hashes) + sorted_hash = hashes[sort_map] + sorted_id = particle_id[sort_map] + + sorted_cell_id = jnp.mod(iota(jnp.int32, N), cell_capacity) + sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id + sorted_id = jnp.reshape(sorted_id, (N, 1)) + cell_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32) + cell_id = cell_id.at[sorted_cell_id].set(sorted_id) + + # This is not jitable, we're changing shapes. It's a fix for the unflatten_cell_buffer. + def get_cell_id(input): + cell_id, cells_per_side, dim = input + return unflatten_cell_buffer(cell_id, cells_per_side, dim) + + cell_id = jax.pure_callback(get_cell_id, old_cell_list.id, [cell_id, cells_per_side, dim]) + overflow = jnp.array(False) + + return CellList(cell_id, overflow, cell_capacity, cell_size, cells_per_side) + + # In case cell size is lower than cutoff, we need to reallocate + def need_reallocate(_, old_cell_list, __): + return CellList(old_cell_list.id, True, old_cell_list.capacity, old_cell_list.size, old_cell_list.cells_per_side) + return cond(((jnp.all(cells_per_side == old_cell_list.cells_per_side)) & (max_occupancy <= old_cell_list.capacity)), + update, + need_reallocate, + positions, old_cell_list, new_cell) + return allocate_fn, update_fn + + +def estimate_cell_capacity(positions, cell, cell_size, buffer_size_multiplier): + cell_capacity = jnp.max(count_cell_filling(positions, cell, cell_size)) + return (cell_capacity * buffer_size_multiplier).astype(jnp.int32) + +def cell_dimensions(cell, cutoff): + """Compute the number of cells-per-side and total number of cells in a box.""" + # Considering cell is 3x3 array + cells_per_side = jnp.floor(cell / cutoff) + cells_per_side = jnp.nan_to_num(cells_per_side) + cell_size = jnp.where(cells_per_side != 0,cell / cells_per_side, cell) + cells_per_side = jnp.amax(cells_per_side, axis=1, keepdims=True) + flat_cells_per_side = jnp.reshape(cells_per_side, (-1,)) + cell_count = 1 + for cells in flat_cells_per_side: + cell_count *= cells + return cell, cell_size, cells_per_side, cell_count.astype(jnp.int32) + +def count_cell_filling(position, cell, minimum_cell_size): + """ + Counts the number of particles per-cell in a spatial partitioning scheme. + """ + cell, cell_size, cells_per_side, _ = cell_dimensions(cell, minimum_cell_size) + hash_multipliers = compute_hash_constants(cells_per_side) + particle_index = jnp.array(jnp.floor(jnp.dot(position, jnp.linalg.inv(cell_size.T).T)), dtype=jnp.int32) + particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1) + filling = jnp.zeros_like(particle_hash, dtype=jnp.int32) # jnp.zeros((cell_count,), dtype=jnp.int32) + filling = filling.at[particle_hash].add(1) + return filling + +def compute_hash_constants(cells_per_side): + one = jnp.array([[1]]) + cells_per_side = jnp.concatenate((one, cells_per_side[:-1].T), axis=1) + return jnp.array(jnp.cumprod(cells_per_side), dtype=jnp.int32) + +def unflatten_cell_buffer(arr, cells_per_side, dim): + if not cells_per_side.shape: + cells_per_side = (int(cells_per_side),) * dim + elif len(cells_per_side.shape) <= 2: + cells_per_side = tuple([x.astype(jnp.int32)for x in cells_per_side.flatten()]) + return jnp.reshape(arr, cells_per_side + (-1,) + arr.shape[1:]) + +def neighboring_cells(dimension): + for dindex in np.ndindex(*([3] * dimension)): + yield jnp.array(dindex) - 1 + + +def shift_array(arr, dindex): + dx, dy, dz = tuple(dindex) + (0,) * (3 - len(dindex)) + arr = cond(dx < 0, + lambda x: jnp.concatenate((x[1:], x[:1])), + lambda x: cond(dx > 0, + lambda x: jnp.concatenate((x[-1:], x[:-1])), + lambda x: x, arr), + arr) + + arr = cond(dy < 0, + lambda x: jnp.concatenate((x[:, 1:], x[:, :1]), axis=1), + lambda x: cond(dy > 0, + lambda x: jnp.concatenate((x[:, -1:], x[:, :-1]), axis=1), + lambda x: x, arr), + arr) + + arr = cond(dz < 0, + lambda x: jnp.concatenate((x[:, :, 1:], x[:, :, :1]), axis=2), + lambda x: cond(dz > 0, + lambda x: jnp.concatenate((x[:, :, -1:], x[:, :, :-1]), axis=2), + lambda x: x, arr), + arr) + + return arr + +def move_to_cell(positions, cell): + return jnp.dot(jnp.mod(jnp.dot(positions, jnp.linalg.inv(cell).T), 1.0), cell) \ No newline at end of file From 8ad80eb7040f775112ba46c584e841fd558e7e9b Mon Sep 17 00:00:00 2001 From: suadou Date: Fri, 22 Mar 2024 17:47:10 +0100 Subject: [PATCH 2/6] Improving neighbor list to work with glp periodic functions --- glp/neighborlist.py | 86 ++++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/glp/neighborlist.py b/glp/neighborlist.py index 2398fd4..4287c57 100644 --- a/glp/neighborlist.py +++ b/glp/neighborlist.py @@ -34,7 +34,7 @@ import numpy as np from glp import comms -from .periodic import displacement, get_heights +from .periodic import displacement, get_heights, wrap, inverse from .utils import boolean_mask_1d, cast, squared_distance Neighbors = namedtuple( @@ -42,7 +42,7 @@ ) CellList = namedtuple( - "CellList", ("id", "overflow", "capacity", "size", "cells_per_side") + "CellList", ("id", "reallocate", "capacity", "size", "cells_per_side") ) def neighbor_list(system, cutoff, skin, capacity_multiplier=1.25): @@ -110,7 +110,7 @@ def cell_too_small(new_cell): allowed_movement = (skin * cast(0.5)) ** cast(2.0) if cell is not None and use_cell_list: - if jnp.all(jnp.any(cutoff < cell / 3., axis=1)): + if jnp.all(jnp.any(cutoff < get_heights(cell) / 3., axis=0)): cl_allocate, cl_update = cell_list(cell, cutoff) else: cl_allocate = cl_update = None @@ -176,8 +176,6 @@ def update_fn(positions, neighbors, new_cell=None, padding_mask=None, force_upda dim = positions.shape[1] size = neighbors.centers.shape[0] def update(positions, cell, padding_mask, cl=neighbors.cell_list): - # if cl: - # cell_position, cell_id, overflow, cell_capacity, cell_size, cells_per_side = cl cl = cl_update(positions, cl, cell) if cl_update is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, @@ -216,7 +214,7 @@ def candidates_fn(n): others = jnp.reshape(square, (-1,)) return centers, others -def cell_list_candidate_fn(cell_id, N, dim, _): +def cell_list_candidate_fn(cell_id, N, dim): idx = cell_id cell_idx = [idx] * (dim**3) @@ -234,14 +232,16 @@ def copy_values_from_cell(value, cell_value, cell_id): neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) others = jnp.reshape(neighbor_idx[:-1, :, 0], (-1,)) centers = jnp.repeat(jnp.arange(0, N), others.shape[0] // N) + print(others.shape, centers.shape) return centers, others def get_neighbors(positions, square_distances, cutoff, padding_mask=None, cl=None): N, dim = positions.shape if cl is not None: - centers, others = cell_list_candidate_fn(cl.id, N, dim, cl.cells_per_side) + centers, others = cell_list_candidate_fn(cl.id, N, dim) else: centers, others = candidates_fn(N) + sq_distances = square_distances(positions[centers], positions[others]) mask = sq_distances <= (cutoff**2) @@ -258,9 +258,10 @@ def make_squared_distance(cell): return vmap(lambda Ra, Rb: squared_distance(displacement(cell, Ra, Rb))) -def cell_list(cell, cutoff, buffer_size_multiplier=1.25): +def cell_list(cell, cutoff, buffer_size_multiplier=1.25, bin_size_multiplier=1): cell = jnp.array(cell) + cutoff *= bin_size_multiplier def allocate_fn(positions, extra_capacity=0): # This function is not jittable, we're determining shapes N = positions.shape[0] @@ -276,7 +277,7 @@ def allocate_fn(positions, extra_capacity=0): hash_multipliers = compute_hash_constants(cells_per_side) particle_id = iota(jnp.int32, N) - indices = jnp.array(jnp.floor(jnp.dot(move_to_cell(positions, cell), jnp.linalg.inv(cell_size).T)), dtype=jnp.int32) + indices = jnp.array(jnp.floor(jnp.dot(wrap(cell, positions), inverse(cell_size).T)), dtype=jnp.int32) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) sort_map = jnp.argsort(hashes) @@ -297,24 +298,28 @@ def allocate_fn(positions, extra_capacity=0): def update_fn(positions, old_cell_list, new_cell): # this is jittable, # CellList tells us all the shapes we need + # If cell size is lower than cutoff, we need to reallocate cell list + # Rellocate is not jitable, we're changing shapes + # So, after each update we need to check if we need to reallocate N = positions.shape[0] dim = positions.shape[1] - - _, cell_size, cells_per_side, _ = cell_dimensions(new_cell, cutoff) + cell_size = jnp.divide(new_cell, cutoff) max_occupancy = estimate_cell_capacity(positions, new_cell, cell_size, 1) + + # Checking if update or reallocate + reallocate = jnp.all(get_heights(cell_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity) + def update(positions, old_cell_list, new_cell): - cell_capacity = old_cell_list.capacity - overflow = old_cell_list.overflow - hash_multipliers = compute_hash_constants(cells_per_side) - indices = jnp.array(jnp.floor(jnp.dot(move_to_cell(positions, new_cell), jnp.linalg.inv(cell_size).T)), dtype=jnp.int32) + hash_multipliers = compute_hash_constants(old_cell_list.cells_per_side) + indices = jnp.array(jnp.floor(jnp.dot(wrap(new_cell, positions), inverse(cell_size).T)), dtype=jnp.int32) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) particle_id = iota(jnp.int32, N) sort_map = jnp.argsort(hashes) sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] - sorted_cell_id = jnp.mod(iota(jnp.int32, N), cell_capacity) - sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id + sorted_cell_id = jnp.mod(iota(jnp.int32, N), old_cell_list.capacity) + sorted_cell_id = sorted_hash * old_cell_list.capacity + sorted_cell_id sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) @@ -324,37 +329,32 @@ def get_cell_id(input): cell_id, cells_per_side, dim = input return unflatten_cell_buffer(cell_id, cells_per_side, dim) - cell_id = jax.pure_callback(get_cell_id, old_cell_list.id, [cell_id, cells_per_side, dim]) - overflow = jnp.array(False) + cell_id = jax.pure_callback(get_cell_id, old_cell_list.id, [cell_id, old_cell_list.cells_per_side, dim]) - return CellList(cell_id, overflow, cell_capacity, cell_size, cells_per_side) + return CellList(cell_id, False, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) # In case cell size is lower than cutoff, we need to reallocate - def need_reallocate(_, old_cell_list, __): - return CellList(old_cell_list.id, True, old_cell_list.capacity, old_cell_list.size, old_cell_list.cells_per_side) - return cond(((jnp.all(cells_per_side == old_cell_list.cells_per_side)) & (max_occupancy <= old_cell_list.capacity)), - update, - need_reallocate, - positions, old_cell_list, new_cell) + def need_reallocate(_ ,old_cell_list, __): + return CellList(old_cell_list.id, True, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) + return cond(reallocate, need_reallocate, update, positions, old_cell_list, new_cell) return allocate_fn, update_fn def estimate_cell_capacity(positions, cell, cell_size, buffer_size_multiplier): - cell_capacity = jnp.max(count_cell_filling(positions, cell, cell_size)) + minimum_cell_size = jnp.min(get_heights(cell_size)) + cell_capacity = jnp.max(count_cell_filling(positions, cell, minimum_cell_size)) return (cell_capacity * buffer_size_multiplier).astype(jnp.int32) def cell_dimensions(cell, cutoff): """Compute the number of cells-per-side and total number of cells in a box.""" # Considering cell is 3x3 array - cells_per_side = jnp.floor(cell / cutoff) - cells_per_side = jnp.nan_to_num(cells_per_side) - cell_size = jnp.where(cells_per_side != 0,cell / cells_per_side, cell) - cells_per_side = jnp.amax(cells_per_side, axis=1, keepdims=True) - flat_cells_per_side = jnp.reshape(cells_per_side, (-1,)) - cell_count = 1 - for cells in flat_cells_per_side: - cell_count *= cells + # Transform into reciprocal space to get the cell size whatever cell is + face_dist = get_heights(cell) + cells_per_side = jnp.floor(face_dist / cutoff).astype(jnp.int32) + cell_size = jnp.where(cells_per_side != 0, cell / cells_per_side, cell) + cell_count = jnp.prod(cells_per_side) return cell, cell_size, cells_per_side, cell_count.astype(jnp.int32) + def count_cell_filling(position, cell, minimum_cell_size): """ @@ -362,22 +362,22 @@ def count_cell_filling(position, cell, minimum_cell_size): """ cell, cell_size, cells_per_side, _ = cell_dimensions(cell, minimum_cell_size) hash_multipliers = compute_hash_constants(cells_per_side) - particle_index = jnp.array(jnp.floor(jnp.dot(position, jnp.linalg.inv(cell_size.T).T)), dtype=jnp.int32) + particle_index = jnp.array(jnp.floor(jnp.dot(position, inverse(cell_size.T).T)), dtype=jnp.int32) particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1) filling = jnp.zeros_like(particle_hash, dtype=jnp.int32) # jnp.zeros((cell_count,), dtype=jnp.int32) filling = filling.at[particle_hash].add(1) return filling def compute_hash_constants(cells_per_side): - one = jnp.array([[1]]) - cells_per_side = jnp.concatenate((one, cells_per_side[:-1].T), axis=1) + one = jnp.array([1]) + cells_per_side = jnp.concatenate((one, cells_per_side[:-1]), axis=0) return jnp.array(jnp.cumprod(cells_per_side), dtype=jnp.int32) def unflatten_cell_buffer(arr, cells_per_side, dim): if not cells_per_side.shape: cells_per_side = (int(cells_per_side),) * dim elif len(cells_per_side.shape) <= 2: - cells_per_side = tuple([x.astype(jnp.int32)for x in cells_per_side.flatten()]) + cells_per_side = tuple(cells_per_side) return jnp.reshape(arr, cells_per_side + (-1,) + arr.shape[1:]) def neighboring_cells(dimension): @@ -410,5 +410,9 @@ def shift_array(arr, dindex): return arr -def move_to_cell(positions, cell): - return jnp.dot(jnp.mod(jnp.dot(positions, jnp.linalg.inv(cell).T), 1.0), cell) \ No newline at end of file + +# I had to include it for reallocate cell list, calling cell_list allocate not jitable +def check_reallocation_cell_list(cell_list, allocate_fn, positions): + if cell_list.reallocate: + return allocate_fn(positions) + return cell_list \ No newline at end of file From 38b82cbf33c5bab466d694a824154761066fb916 Mon Sep 17 00:00:00 2001 From: suadou Date: Sat, 23 Mar 2024 16:56:27 +0100 Subject: [PATCH 3/6] Improving cell list for non-cubic cells. Including neighbors test with cell list neighbor list --- glp/neighborlist.py | 32 +++++++++++++++++--------------- tests/unit/test_neighborlist.py | 16 ++++++++-------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/glp/neighborlist.py b/glp/neighborlist.py index 4287c57..fe39447 100644 --- a/glp/neighborlist.py +++ b/glp/neighborlist.py @@ -145,9 +145,8 @@ def allocate_fn(positions, new_cell=None, padding_mask=None): new_cell = stop_gradient(new_cell) N = positions.shape[0] - - cl = cl_allocate(positions) if cl_allocate is not None else None - + positions = wrap(new_cell, positions) if new_cell is not None else positions + cl = cl_allocate(positions) if cl_allocate is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, make_squared_distance(new_cell), @@ -155,7 +154,6 @@ def allocate_fn(positions, new_cell=None, padding_mask=None): padding_mask=padding_mask, cl=cl ) - # Hits estan mal. el numero es demasiado grande size = int(hits.item() * capacity_multiplier + 1) centers, _ = boolean_mask_1d(centers, mask, size, N) others, overflow = boolean_mask_1d(others, mask, size, N) @@ -174,9 +172,10 @@ def update_fn(positions, neighbors, new_cell=None, padding_mask=None, force_upda N = positions.shape[0] dim = positions.shape[1] + positions = wrap(new_cell, positions) if new_cell is not None else positions size = neighbors.centers.shape[0] def update(positions, cell, padding_mask, cl=neighbors.cell_list): - cl = cl_update(positions, cl, cell) if cl_update is not None else None + cl = cl_update(positions, cl, new_cell) if cl_update is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, make_squared_distance(cell), @@ -186,7 +185,6 @@ def update(positions, cell, padding_mask, cl=neighbors.cell_list): ) centers, _ = boolean_mask_1d(centers, mask, size, N) others, overflow = boolean_mask_1d(others, mask, size, N) - overflow = overflow | cell_too_small(cell) return Neighbors(centers, others, overflow, positions, cl) @@ -216,7 +214,6 @@ def candidates_fn(n): def cell_list_candidate_fn(cell_id, N, dim): idx = cell_id - cell_idx = [idx] * (dim**3) for i, dindex in enumerate(neighboring_cells(dim)): cell_idx[i] = shift_array(idx, dindex) @@ -232,7 +229,6 @@ def copy_values_from_cell(value, cell_value, cell_id): neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) others = jnp.reshape(neighbor_idx[:-1, :, 0], (-1,)) centers = jnp.repeat(jnp.arange(0, N), others.shape[0] // N) - print(others.shape, centers.shape) return centers, others def get_neighbors(positions, square_distances, cutoff, padding_mask=None, cl=None): @@ -277,7 +273,10 @@ def allocate_fn(positions, extra_capacity=0): hash_multipliers = compute_hash_constants(cells_per_side) particle_id = iota(jnp.int32, N) - indices = jnp.array(jnp.floor(jnp.dot(wrap(cell, positions), inverse(cell_size).T)), dtype=jnp.int32) + indices = jnp.array(jnp.floor(positions @ inverse(cell_size).T), dtype=jnp.int32) + # Some particles are in the edge and might have negative indices or larger than cells_per_side + # We need to correct them wrapping into cell per side vector + indices = wrap(jnp.diag(cells_per_side), indices) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) sort_map = jnp.argsort(hashes) @@ -309,9 +308,12 @@ def update_fn(positions, old_cell_list, new_cell): # Checking if update or reallocate reallocate = jnp.all(get_heights(cell_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity) - def update(positions, old_cell_list, new_cell): + def update(positions, old_cell_list): hash_multipliers = compute_hash_constants(old_cell_list.cells_per_side) - indices = jnp.array(jnp.floor(jnp.dot(wrap(new_cell, positions), inverse(cell_size).T)), dtype=jnp.int32) + indices = jnp.array(jnp.floor(positions @ inverse(cell_size).T), dtype=jnp.int32) + # Some particles are in the edge and might have negative indices or larger than cells_per_side + # We need to correct them wrapping into cell per side vector + indices = wrap(jnp.diag(old_cell_list.cells_per_side), indices) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) particle_id = iota(jnp.int32, N) sort_map = jnp.argsort(hashes) @@ -331,12 +333,12 @@ def get_cell_id(input): cell_id = jax.pure_callback(get_cell_id, old_cell_list.id, [cell_id, old_cell_list.cells_per_side, dim]) - return CellList(cell_id, False, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) + return CellList(cell_id, old_cell_list.reallocate, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) # In case cell size is lower than cutoff, we need to reallocate - def need_reallocate(_ ,old_cell_list, __): + def need_reallocate(_ ,old_cell_list): return CellList(old_cell_list.id, True, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) - return cond(reallocate, need_reallocate, update, positions, old_cell_list, new_cell) + return cond(reallocate, need_reallocate, update, positions, old_cell_list) return allocate_fn, update_fn @@ -377,7 +379,7 @@ def unflatten_cell_buffer(arr, cells_per_side, dim): if not cells_per_side.shape: cells_per_side = (int(cells_per_side),) * dim elif len(cells_per_side.shape) <= 2: - cells_per_side = tuple(cells_per_side) + cells_per_side = tuple([x for x in cells_per_side]) return jnp.reshape(arr, cells_per_side + (-1,) + arr.shape[1:]) def neighboring_cells(dimension): diff --git a/tests/unit/test_neighborlist.py b/tests/unit/test_neighborlist.py index 730f50c..f73981a 100644 --- a/tests/unit/test_neighborlist.py +++ b/tests/unit/test_neighborlist.py @@ -8,7 +8,7 @@ from ase.build import bulk -from glp.neighborlist import quadratic_neighbor_list, neighbor_list +from glp.neighborlist import quadratic_neighbor_list, neighbor_list, check_reallocation_cell_list from glp.system import atoms_to_system, unfold_system from glp.graph import system_to_graph from glp.unfold import unfolder @@ -43,10 +43,11 @@ def get_distances(graph): class TestNeighborList(TestCase): + # Checking cell list def test_basic(self): cutoff = 5.0 skin = 0.5 - atoms = bulk("Ar", cubic=False) * [5, 5, 5] + atoms = bulk("Ar", cubic=False) * [7, 7, 7] system = atoms_to_system(atoms) @@ -134,12 +135,11 @@ def test_variable_cell(self): cutoff = 5.0 skin = 0.5 spread = 0.5 - atoms = bulk("Ar", cubic=False) * [5, 5, 5] - + atoms = bulk("Ar", cubic=False) * [7, 7, 7] system = atoms_to_system(atoms) allocate, update, need_update = quadratic_neighbor_list( - system.cell, cutoff, skin, debug=True, capacity_multiplier=1.5 + system.cell, cutoff, skin, debug=True, capacity_multiplier=1.5, use_cell_list=True ) neighbors = allocate(system.R) neighbors = update(system.R, neighbors) @@ -172,7 +172,7 @@ def test_variable_cell(self): def test_jit_high_level_interface(self): cutoff = 5.0 skin = 0.5 - atoms = bulk("Ar", cubic=True) * [5, 5, 5] + atoms = bulk("Ar", cubic=True) * [7, 7, 7] system = atoms_to_system(atoms) @@ -209,7 +209,7 @@ def test_unfolding(self): cutoff = 5.0 skin = 0.5 - atoms = bulk("Ar", cubic=True) * [5, 5, 5] + atoms = bulk("Ar", cubic=True) * [7, 7, 7] system = atoms_to_system(atoms) @@ -240,4 +240,4 @@ def test_unfolding(self): for i in range(len(atoms)): compare_distances( ase_distances[i], distances[neighbors.centers == i], cutoff, atol=1e-5 - ) + ) \ No newline at end of file From d3d307cac7e50ac62c0ba4c94840ac4ff0d47776 Mon Sep 17 00:00:00 2001 From: suadou Date: Sat, 23 Mar 2024 18:12:57 +0100 Subject: [PATCH 4/6] Removing unnecessary wrap function --- glp/neighborlist.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/glp/neighborlist.py b/glp/neighborlist.py index fe39447..1e9f67e 100644 --- a/glp/neighborlist.py +++ b/glp/neighborlist.py @@ -145,7 +145,6 @@ def allocate_fn(positions, new_cell=None, padding_mask=None): new_cell = stop_gradient(new_cell) N = positions.shape[0] - positions = wrap(new_cell, positions) if new_cell is not None else positions cl = cl_allocate(positions) if cl_allocate is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, @@ -172,7 +171,6 @@ def update_fn(positions, neighbors, new_cell=None, padding_mask=None, force_upda N = positions.shape[0] dim = positions.shape[1] - positions = wrap(new_cell, positions) if new_cell is not None else positions size = neighbors.centers.shape[0] def update(positions, cell, padding_mask, cl=neighbors.cell_list): cl = cl_update(positions, cl, new_cell) if cl_update is not None else None From cf247c910105207782d04cfc4e19a9a8cbd0b124 Mon Sep 17 00:00:00 2001 From: suadou Date: Sat, 23 Mar 2024 19:49:51 +0100 Subject: [PATCH 5/6] Updating cell list update --- glp/neighborlist.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/glp/neighborlist.py b/glp/neighborlist.py index 1e9f67e..0085ab6 100644 --- a/glp/neighborlist.py +++ b/glp/neighborlist.py @@ -259,7 +259,6 @@ def cell_list(cell, cutoff, buffer_size_multiplier=1.25, bin_size_multiplier=1): def allocate_fn(positions, extra_capacity=0): # This function is not jittable, we're determining shapes N = positions.shape[0] - dim = positions.shape[1] _, cell_size, cells_per_side, cell_count = cell_dimensions(cell, cutoff) cell_capacity = estimate_cell_capacity(positions, cell, cell_size, buffer_size_multiplier) @@ -274,7 +273,7 @@ def allocate_fn(positions, extra_capacity=0): indices = jnp.array(jnp.floor(positions @ inverse(cell_size).T), dtype=jnp.int32) # Some particles are in the edge and might have negative indices or larger than cells_per_side # We need to correct them wrapping into cell per side vector - indices = wrap(jnp.diag(cells_per_side), indices) + indices = wrap(jnp.diag(cells_per_side), indices).astype(jnp.int32) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) sort_map = jnp.argsort(hashes) @@ -285,11 +284,10 @@ def allocate_fn(positions, extra_capacity=0): sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) - cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim) + cell_id = unflatten_cell_buffer(cell_id, cells_per_side) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) max_occupancy = jnp.max(occupancy) overflow = overflow | (max_occupancy > cell_capacity) - return CellList(cell_id, overflow, cell_capacity, cell_size, cells_per_side) def update_fn(positions, old_cell_list, new_cell): @@ -299,10 +297,9 @@ def update_fn(positions, old_cell_list, new_cell): # Rellocate is not jitable, we're changing shapes # So, after each update we need to check if we need to reallocate N = positions.shape[0] - dim = positions.shape[1] - cell_size = jnp.divide(new_cell, cutoff) - max_occupancy = estimate_cell_capacity(positions, new_cell, cell_size, 1) + cell_size = jnp.where(old_cell_list.cells_per_side != 0, cell / old_cell_list.cells_per_side, cell) + max_occupancy = estimate_cell_capacity(positions, new_cell, cell_size, 1) # Checking if update or reallocate reallocate = jnp.all(get_heights(cell_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity) @@ -311,7 +308,8 @@ def update(positions, old_cell_list): indices = jnp.array(jnp.floor(positions @ inverse(cell_size).T), dtype=jnp.int32) # Some particles are in the edge and might have negative indices or larger than cells_per_side # We need to correct them wrapping into cell per side vector - indices = wrap(jnp.diag(old_cell_list.cells_per_side), indices) + indices = wrap(jnp.diag(old_cell_list.cells_per_side), indices).astype(jnp.int32) + hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) particle_id = iota(jnp.int32, N) sort_map = jnp.argsort(hashes) @@ -325,12 +323,8 @@ def update(positions, old_cell_list): cell_id = cell_id.at[sorted_cell_id].set(sorted_id) # This is not jitable, we're changing shapes. It's a fix for the unflatten_cell_buffer. - def get_cell_id(input): - cell_id, cells_per_side, dim = input - return unflatten_cell_buffer(cell_id, cells_per_side, dim) - - cell_id = jax.pure_callback(get_cell_id, old_cell_list.id, [cell_id, old_cell_list.cells_per_side, dim]) - + + cell_id = jax.pure_callback(unflatten_cell_buffer, old_cell_list.id, cell_id, old_cell_list.cells_per_side) return CellList(cell_id, old_cell_list.reallocate, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) # In case cell size is lower than cutoff, we need to reallocate @@ -373,18 +367,16 @@ def compute_hash_constants(cells_per_side): cells_per_side = jnp.concatenate((one, cells_per_side[:-1]), axis=0) return jnp.array(jnp.cumprod(cells_per_side), dtype=jnp.int32) -def unflatten_cell_buffer(arr, cells_per_side, dim): - if not cells_per_side.shape: - cells_per_side = (int(cells_per_side),) * dim - elif len(cells_per_side.shape) <= 2: - cells_per_side = tuple([x for x in cells_per_side]) +def unflatten_cell_buffer(arr, cells_per_side): + cells_per_side = tuple(cells_per_side) return jnp.reshape(arr, cells_per_side + (-1,) + arr.shape[1:]) def neighboring_cells(dimension): for dindex in np.ndindex(*([3] * dimension)): yield jnp.array(dindex) - 1 - + + def shift_array(arr, dindex): dx, dy, dz = tuple(dindex) + (0,) * (3 - len(dindex)) arr = cond(dx < 0, From d5bcbcac20528be57b4ee51516c475baf19f736a Mon Sep 17 00:00:00 2001 From: suadou Date: Mon, 8 Apr 2024 16:41:50 +0200 Subject: [PATCH 6/6] Updating neighborlist.py. Adding comments and change names of variables to make it more clear. --- glp/neighborlist.py | 167 +++++++++++++++++++++++--------------------- 1 file changed, 89 insertions(+), 78 deletions(-) diff --git a/glp/neighborlist.py b/glp/neighborlist.py index 0085ab6..f8cd193 100644 --- a/glp/neighborlist.py +++ b/glp/neighborlist.py @@ -9,9 +9,9 @@ candidates into a fixed number of pairs. Once that number is decided, this procedure is jittable. The initial allocation is not. -- Cell-list: We first divide the simulation cell into a grid of cells, and - assign each atom to a cell. Then, for each atom, we only consider - the atoms in the same cell and its neighbors. This is jittable. +- Cell-list: We first divide the simulation cell into a grid of bins, and + assign each atom to a bin. Then, for each atom, we only consider + the atoms in the same bin and its neighbors. This is jittable. The initial allocation is not. The cell-list implementation is more efficient for large systems, but it is @@ -42,7 +42,7 @@ ) CellList = namedtuple( - "CellList", ("id", "reallocate", "capacity", "size", "cells_per_side") + "CellList", ("id", "reallocate", "capacity", "size", "bins_per_side") ) def neighbor_list(system, cutoff, skin, capacity_multiplier=1.25): @@ -145,6 +145,7 @@ def allocate_fn(positions, new_cell=None, padding_mask=None): new_cell = stop_gradient(new_cell) N = positions.shape[0] + # Check if we are using cell list cl = cl_allocate(positions) if cl_allocate is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, @@ -173,6 +174,7 @@ def update_fn(positions, neighbors, new_cell=None, padding_mask=None, force_upda dim = positions.shape[1] size = neighbors.centers.shape[0] def update(positions, cell, padding_mask, cl=neighbors.cell_list): + # Check if we are using cell list cl = cl_update(positions, cl, new_cell) if cl_update is not None else None centers, others, sq_distances, mask, hits = get_neighbors( positions, @@ -210,21 +212,24 @@ def candidates_fn(n): others = jnp.reshape(square, (-1,)) return centers, others -def cell_list_candidate_fn(cell_id, N, dim): - idx = cell_id - cell_idx = [idx] * (dim**3) - for i, dindex in enumerate(neighboring_cells(dim)): - cell_idx[i] = shift_array(idx, dindex) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) - cell_idx = cell_idx[..., jnp.newaxis, :, :] - cell_idx = jnp.broadcast_to(cell_idx, idx.shape[:-1] + cell_idx.shape[-2:]) - def copy_values_from_cell(value, cell_value, cell_id): - scatter_indices = jnp.reshape(cell_id, (-1,)) - cell_value = jnp.reshape(cell_value, (-1,) + cell_value.shape[-2:]) - return value.at[scatter_indices].set(cell_value) - neighbor_idx = jnp.zeros((N + 1,) + cell_idx.shape[-2:], jnp.int32) - neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) +def cell_list_candidate_fn(bin_id, N, dim): + """Get the candidates for the cell list neighbor list. It is implemented following the jax-md implementation.""" + idx = bin_id + bin_idx = [idx] * (dim**3) + # Get the neighboring bins from the current bin + for i, dindex in enumerate(neighboring_bins(dim)): + bin_idx[i] = shift_array(idx, dindex) + + bin_idx = jnp.concatenate(bin_idx, axis=-2) + bin_idx = bin_idx[..., jnp.newaxis, :, :] + bin_idx = jnp.broadcast_to(bin_idx, idx.shape[:-1] + bin_idx.shape[-2:]) + def copy_values_from_bin(value, bin_value, bin_id): + scatter_indices = jnp.reshape(bin_id, (-1,)) + bin_value = jnp.reshape(bin_value, (-1,) + bin_value.shape[-2:]) + return value.at[scatter_indices].set(bin_value) + neighbor_idx = jnp.zeros((N + 1,) + bin_idx.shape[-2:], jnp.int32) + neighbor_idx = copy_values_from_bin(neighbor_idx, bin_idx, idx) + # Reshape the neighbor_idx to get the centers and others others = jnp.reshape(neighbor_idx[:-1, :, 0], (-1,)) centers = jnp.repeat(jnp.arange(0, N), others.shape[0] // N) return centers, others @@ -253,62 +258,65 @@ def make_squared_distance(cell): def cell_list(cell, cutoff, buffer_size_multiplier=1.25, bin_size_multiplier=1): + """Implementation of cell list neighborlist in pbc.""" cell = jnp.array(cell) cutoff *= bin_size_multiplier def allocate_fn(positions, extra_capacity=0): # This function is not jittable, we're determining shapes N = positions.shape[0] - _, cell_size, cells_per_side, cell_count = cell_dimensions(cell, cutoff) - cell_capacity = estimate_cell_capacity(positions, cell, cell_size, + _, bin_size, bins_per_side, bin_count = bin_dimensions(cell, cutoff) + bin_capacity = estimate_bin_capacity(positions, cell, bin_size, buffer_size_multiplier) - cell_capacity += extra_capacity + # Computing the maximum number of atoms per bin (capacity) + # extra_capacity is used to increase the capacity of the bins to avoid reallocation + bin_capacity += extra_capacity overflow = False - cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=jnp.int32) + bin_id = N * jnp.ones((bin_count * bin_capacity, 1), dtype=jnp.int32) - - hash_multipliers = compute_hash_constants(cells_per_side) + # Compute the hash of each particle to assign it to a bin + hash_multipliers = compute_hash_constants(bins_per_side) particle_id = iota(jnp.int32, N) - indices = jnp.array(jnp.floor(positions @ inverse(cell_size).T), dtype=jnp.int32) - # Some particles are in the edge and might have negative indices or larger than cells_per_side - # We need to correct them wrapping into cell per side vector - indices = wrap(jnp.diag(cells_per_side), indices).astype(jnp.int32) + indices = jnp.array(jnp.floor(positions @ inverse(bin_size).T), dtype=jnp.int32) + # Some particles are in the edge and might have negative indices or larger than bins_per_side + # We need to correct them wrapping into bin per side vector + indices = wrap(jnp.diag(bins_per_side), indices).astype(jnp.int32) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) sort_map = jnp.argsort(hashes) sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] - sorted_cell_id = jnp.mod(iota(jnp.int32, N), cell_capacity) - sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id + sorted_bin_id = jnp.mod(iota(jnp.int32, N), bin_capacity) + sorted_bin_id = sorted_hash * bin_capacity + sorted_bin_id sorted_id = jnp.reshape(sorted_id, (N, 1)) - cell_id = cell_id.at[sorted_cell_id].set(sorted_id) - cell_id = unflatten_cell_buffer(cell_id, cells_per_side) - occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) + bin_id = bin_id.at[sorted_bin_id].set(sorted_id) + bin_id = unflatten_bin_buffer(bin_id, bins_per_side) + occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, bin_count) max_occupancy = jnp.max(occupancy) - overflow = overflow | (max_occupancy > cell_capacity) - return CellList(cell_id, overflow, cell_capacity, cell_size, cells_per_side) + overflow = overflow | (max_occupancy > bin_capacity) + return CellList(bin_id, overflow, bin_capacity, bin_size, bins_per_side) def update_fn(positions, old_cell_list, new_cell): # this is jittable, # CellList tells us all the shapes we need - # If cell size is lower than cutoff, we need to reallocate cell list + # If bin size is lower than cutoff, we need to reallocate cell list # Rellocate is not jitable, we're changing shapes # So, after each update we need to check if we need to reallocate N = positions.shape[0] - cell_size = jnp.where(old_cell_list.cells_per_side != 0, cell / old_cell_list.cells_per_side, cell) - max_occupancy = estimate_cell_capacity(positions, new_cell, cell_size, 1) + bin_size = jnp.where(old_cell_list.bins_per_side != 0, cell / old_cell_list.bins_per_side, cell) + max_occupancy = estimate_bin_capacity(positions, new_cell, bin_size, 1) # Checking if update or reallocate - reallocate = jnp.all(get_heights(cell_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity) + reallocate = jnp.all(get_heights(bin_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity) def update(positions, old_cell_list): - hash_multipliers = compute_hash_constants(old_cell_list.cells_per_side) - indices = jnp.array(jnp.floor(positions @ inverse(cell_size).T), dtype=jnp.int32) - # Some particles are in the edge and might have negative indices or larger than cells_per_side - # We need to correct them wrapping into cell per side vector - indices = wrap(jnp.diag(old_cell_list.cells_per_side), indices).astype(jnp.int32) + hash_multipliers = compute_hash_constants(old_cell_list.bins_per_side) + indices = jnp.array(jnp.floor(positions @ inverse(bin_size).T), dtype=jnp.int32) + # Some particles are in the edge and might have negative indices or larger than bins_per_side + # We need to correct them wrapping into bin per side vector + indices = wrap(jnp.diag(old_cell_list.bins_per_side), indices).astype(jnp.int32) hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32) particle_id = iota(jnp.int32, N) @@ -316,68 +324,72 @@ def update(positions, old_cell_list): sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] - sorted_cell_id = jnp.mod(iota(jnp.int32, N), old_cell_list.capacity) - sorted_cell_id = sorted_hash * old_cell_list.capacity + sorted_cell_id + sorted_bin_id = jnp.mod(iota(jnp.int32, N), old_cell_list.capacity) + sorted_bin_id = sorted_hash * old_cell_list.capacity + sorted_bin_id sorted_id = jnp.reshape(sorted_id, (N, 1)) - cell_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32) - cell_id = cell_id.at[sorted_cell_id].set(sorted_id) + bin_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32) + bin_id = bin_id.at[sorted_bin_id].set(sorted_id) - # This is not jitable, we're changing shapes. It's a fix for the unflatten_cell_buffer. + # This is not jitable, we're changing shapes. It's a fix for the unflatten_bin_buffer. - cell_id = jax.pure_callback(unflatten_cell_buffer, old_cell_list.id, cell_id, old_cell_list.cells_per_side) - return CellList(cell_id, old_cell_list.reallocate, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) + bin_id = jax.pure_callback(unflatten_bin_buffer, old_cell_list.id, bin_id, old_cell_list.bins_per_side) + return CellList(bin_id, old_cell_list.reallocate, old_cell_list.capacity, bin_size, old_cell_list.bins_per_side) - # In case cell size is lower than cutoff, we need to reallocate + # In case bin size is lower than cutoff, we need to reallocate def need_reallocate(_ ,old_cell_list): - return CellList(old_cell_list.id, True, old_cell_list.capacity, cell_size, old_cell_list.cells_per_side) + return CellList(old_cell_list.id, True, old_cell_list.capacity, bin_size, old_cell_list.bins_per_side) return cond(reallocate, need_reallocate, update, positions, old_cell_list) return allocate_fn, update_fn -def estimate_cell_capacity(positions, cell, cell_size, buffer_size_multiplier): - minimum_cell_size = jnp.min(get_heights(cell_size)) - cell_capacity = jnp.max(count_cell_filling(positions, cell, minimum_cell_size)) - return (cell_capacity * buffer_size_multiplier).astype(jnp.int32) +def estimate_bin_capacity(positions, cell, bin_size, buffer_size_multiplier): + minimum_bin_size = jnp.min(get_heights(bin_size)) + bin_capacity = jnp.max(count_bin_filling(positions, cell, minimum_bin_size)) + return (bin_capacity * buffer_size_multiplier).astype(jnp.int32) -def cell_dimensions(cell, cutoff): - """Compute the number of cells-per-side and total number of cells in a box.""" +def bin_dimensions(cell, cutoff): + """Compute the number of bins-per-side and total number of bins in a box.""" # Considering cell is 3x3 array # Transform into reciprocal space to get the cell size whatever cell is face_dist = get_heights(cell) - cells_per_side = jnp.floor(face_dist / cutoff).astype(jnp.int32) - cell_size = jnp.where(cells_per_side != 0, cell / cells_per_side, cell) - cell_count = jnp.prod(cells_per_side) - return cell, cell_size, cells_per_side, cell_count.astype(jnp.int32) + bins_per_side = jnp.floor(face_dist / cutoff).astype(jnp.int32) + bin_size = jnp.where(bins_per_side != 0, cell / bins_per_side, cell) + bin_count = jnp.prod(bins_per_side) + return cell, bin_size, bins_per_side, bin_count.astype(jnp.int32) -def count_cell_filling(position, cell, minimum_cell_size): +def count_bin_filling(position, cell, minimum_bin_size): """ - Counts the number of particles per-cell in a spatial partitioning scheme. + Counts the number of particles per-bin in a spatial partitioning scheme. """ - cell, cell_size, cells_per_side, _ = cell_dimensions(cell, minimum_cell_size) - hash_multipliers = compute_hash_constants(cells_per_side) - particle_index = jnp.array(jnp.floor(jnp.dot(position, inverse(cell_size.T).T)), dtype=jnp.int32) + cell, bin_size, bins_per_side, _ = bin_dimensions(cell, minimum_bin_size) + hash_multipliers = compute_hash_constants(bins_per_side) + particle_index = jnp.array(jnp.floor(jnp.dot(position, inverse(bin_size.T).T)), dtype=jnp.int32) particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1) - filling = jnp.zeros_like(particle_hash, dtype=jnp.int32) # jnp.zeros((cell_count,), dtype=jnp.int32) + filling = jnp.zeros_like(particle_hash, dtype=jnp.int32) filling = filling.at[particle_hash].add(1) return filling -def compute_hash_constants(cells_per_side): +def compute_hash_constants(bins_per_side): + """Compute the hash constants for a given number of bins per side.""" one = jnp.array([1]) - cells_per_side = jnp.concatenate((one, cells_per_side[:-1]), axis=0) - return jnp.array(jnp.cumprod(cells_per_side), dtype=jnp.int32) + bins_per_side = jnp.concatenate((one, bins_per_side[:-1]), axis=0) + return jnp.array(jnp.cumprod(bins_per_side), dtype=jnp.int32) -def unflatten_cell_buffer(arr, cells_per_side): - cells_per_side = tuple(cells_per_side) - return jnp.reshape(arr, cells_per_side + (-1,) + arr.shape[1:]) +def unflatten_bin_buffer(arr, bins_per_side): + """Unflatten the bin buffer to get the bin_id.""" + bins_per_side = tuple(bins_per_side) + return jnp.reshape(arr, bins_per_side + (-1,) + arr.shape[1:]) -def neighboring_cells(dimension): +def neighboring_bins(dimension): + """Generate the indices of the neighboring bins of a given bin""" for dindex in np.ndindex(*([3] * dimension)): yield jnp.array(dindex) - 1 def shift_array(arr, dindex): + """Shift an array by a given index. It is used for the bin's neighbor list.""" dx, dy, dz = tuple(dindex) + (0,) * (3 - len(dindex)) arr = cond(dx < 0, lambda x: jnp.concatenate((x[1:], x[:1])), @@ -402,7 +414,6 @@ def shift_array(arr, dindex): return arr - # I had to include it for reallocate cell list, calling cell_list allocate not jitable def check_reallocation_cell_list(cell_list, allocate_fn, positions): if cell_list.reallocate: