8000 Adding cell list to neighborlist.py by suadou · Pull Request #3 · sirmarcel/glp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Adding cell list to neighborlist.py #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 251 additions & 41 deletions glp/neighborlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@

The role of a neighborlist implementation is to reduce the amount of pairs
of atoms to consider from N*N to something linear in N by removing pairs
that are further away than a given cutoff radius.
that are farther away than a given cutoff radius.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx :)


This file implements this in a naive way: We first generate all N*N combinations,
and then trim down these candidates into a fixed number of pairs. Once that number
is decided, this procedure is jittable. The initial allocation is not.
This file implements this in two ways:
- N-square: We first generate all N*N combinations, and then trim down these
candidates into a fixed number of pairs. Once that number is
decided, this procedure is jittable. The initial allocation is not.

- Cell-list: We first divide the simulation cell into a grid of bins, and
assign each atom to a bin. Then, for each atom, we only consider
the atoms in the same bin and its neighbors. This is jittable.
The initial allocation is not.

The cell-list implementation is more efficient for large systems, but it is
more complex and requires the simulation cell to be periodic. The N-square
implementation is simpler and can be used for non-periodic systems.

The general data format consists of two index arrays, such that the indices in
`centers` contains the atom from which atom-pair vectors originate, while the
Expand All @@ -17,20 +27,23 @@
"""

from collections import namedtuple
from jax import jit, vmap
import jax
from jax import vmap, ops
import jax.numpy as jnp
from jax.lax import stop_gradient, cond
from functools import partial
from typing import Callable
from jax.lax import stop_gradient, cond, iota
import numpy as np

from glp import comms
from .periodic import displacement, get_heights
from .periodic import displacement, get_heights, wrap, inverse
from .utils import boolean_mask_1d, cast, squared_distance

Neighbors = namedtuple(
"Neighbors", ("centers", "others", "overflow", "reference_positions")
"Neighbors", ("centers", "others", "overflow", "reference_positions", "cell_list")
)

CellList = namedtuple(
"CellList", ("id", "reallocate", "capacity", "size", "bins_per_side")
)

def neighbor_list(system, cutoff, skin, capacity_multiplier=1.25):
# convenience interface: we often don't need explicit access to an allocate_fn, since
Expand Down Expand Up @@ -65,14 +78,15 @@ def _update(system, neighbors):
return neighbors, _update


def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, debug=False):
def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, use_cell_list=False, debug=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this function needs to be renamed to something else 😁 ... I can do that later though.

"""Implementation of neighborlist in pbc using cell list."""

assert capacity_multiplier >= 1.0

cell = stop_gradient(cell)

cutoff = cast(stop_gradient(cutoff))
skin = cast(stop_gradient(skin))

if cell is not None:

def cell_too_small(new_cell):
Expand All @@ -91,11 +105,23 @@ def cell_too_small(new_cell):
)
comms.warn("this will yield incorrect results!")

squared_cutoff = (cutoff + skin) ** cast(2.0)
cutoff = (cutoff + skin)

allowed_movement = (skin * cast(0.5)) ** cast(2.0)

if cell is not None and use_cell_list:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this block can be avoided entirely... we should (a) move the check for a "too small" cell into cell_too_small and (b) remove the ability to just "turn off" the use of cell_list dynamically in quadratic_neighbor_list.

I think if use_cell_list=True is passed, we should always use it. I don't like the inline if cl_allocate is not None else None checks below -- the code is already pretty complicated to read, and having even more on-the-fly sthuff that can happen should be avoided! 🙏

The decision whether to use the cell_list needs to be made either by the user ahead of time or (more likely) at a higher-level interface. I think it'd be best to decide automatically in neighbor_list at the top based on the number of atoms, whether we're in a periodic case, and whether the cell is large enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I'll implement it.

if jnp.all(jnp.any(cutoff < get_heights(cell) / 3., axis=0)):
cl_allocate, cl_update = cell_list(cell, cutoff)
else:
cl_allocate = cl_update = None
else:
cl_allocate = cl_update = None

def need_update_fn(neighbors, new_positions, new_cell):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function needs some changes for cell_list, at least if we want to allow the cell to change in that case: If the cell gets bigger, each bin becomes larger and more "movement" can be allowed. If it becomes smaller, less is allowed, and in the extreme case no further movement is allowed and we have to reallocate. (Or fail entirely if we can't fit enough boxes in.)

Copy link
Author
@suadou suadou Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case you are right, but is not this lines of the code.

    if cell is not None and use_cell_list:
        if jnp.all(jnp.any(cutoff < get_heights(cell) / 3., axis=0)):
            cl_allocate, cl_update = cell_list(cell, cutoff)
        else:
            cl_allocate = cl_update = None
    else:
        cl_allocate = cl_update = None

This block is essentially checking if it’s possible to divide the cell into at least 3 bins along any of the cell vectors. If it’s not possible, it doesn’t make sense to use the cell list. It might be beneficial to use a cutoff plus some tolerance (Now it is cutoff+skin). The need for reallocation of the cell list arises not when the cell gets larger, but when it gets smaller such that the bin is not large enough to guarantee that it can search for each neighbor (i.e., the height in one or more directions is less than the cutoff).

    def update_fn(positions, old_cell_list, new_cell):
        # this is jittable,
        # CellList tells us all the shapes we need
        # If bin size is lower than cutoff, we need to reallocate cell list
        # Rellocate is not jitable, we're changing shapes
        # So, after each update we need to check if we need to reallocate
        N = positions.shape[0]

        bin_size = jnp.where(old_cell_list.bins_per_side != 0, (cell.T / old_cell_list.bins_per_side).T, cell)
        max_occupancy = estimate_bin_capacity(positions, new_cell, bin_size, 1)
        # Checking if update or reallocate
        reallocate = jnp.all(get_heights(bin_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity)

Here, the code updates the bin_size according to the new cell keeping number of bins per side. Afterwards, it checks the occupancy (if there are more atoms in one bin than bin_capacity) and if the height is equal to or greater than the cutoff.

# question: how to deal with changes in new_cell?
# we will invalidate if atoms move too much, but not if
# the new cell is too small for the cutoff...

movement = make_squared_distance(new_cell)(
neighbors.reference_positions, new_positions
)
Expand All @@ -104,7 +130,7 @@ def need_update_fn(neighbors, new_positions, new_cell):
# if we have an overflow, we don't need to update -- results are
# already invalid. instead, we will simply return previous invalid
# neighbors and hope someone up the chain catches it!
should_update = (max_movement > allowed_movement) & (~neighbors.overflow)
should_update = (max_movement > allowed_movement) & (~neighbors.overflow)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whitespace! 🤡 (don't worry about it, this will be fixed by black at the end anyhow.)


# if the cell is now *too small*, we need to update for sure to set overflow flag
return should_update | cell_too_small(new_cell)
Expand All @@ -119,52 +145,49 @@ def allocate_fn(positions, new_cell=None, padding_mask=None):
new_cell = stop_gradient(new_cell)

N = positions.shape[0]

# Check if we are using cell list
cl = cl_allocate(positions) if cl_allocate is not None else None
centers, others, sq_distances, mask, hits = get_neighbors(
positions,
new_cell,
make_squared_distance(new_cell),
squared_cutoff,
cutoff,
padding_mask=padding_mask,
cl=cl
)

size = int(hits.item() * capacity_multiplier + 1)
centers, _ = boolean_mask_1d(centers, mask, size, N)
others, overflow = boolean_mask_1d(others, mask, size, N)

overflow = overflow | cell_too_small(new_cell)

# print("done with neighbors=None branch")
return Neighbors(centers, others, overflow, positions)

def update_fn(
positions, neighbors, new_cell=None, padding_mask=None, force_update=False
):
return Neighbors(centers, others, overflow, positions, cl)

def update_fn(positions, neighbors, new_cell=None, padding_mask=None, force_update=False):
# this is jittable,
# neighbors tells us all the shapes we need

if new_cell is None:
new_cell = cell
else:
new_cell = stop_gradient(new_cell)



N = positions.shape[0]
dim = positions.shape[1]
size = neighbors.centers.shape[0]

def update(positions, cell, padding_mask):
def update(positions, cell, padding_mask, cl=neighbors.cell_list):
# Check if we are using cell list
cl = cl_update(positions, cl, new_cell) if cl_update is not None else None
centers, others, sq_distances, mask, hits = get_neighbors(
positions,
cell,
make_squared_distance(cell),
squared_cutoff,
cutoff,
padding_mask=padding_mask,
cl = cl
)
centers, _ = boolean_mask_1d(centers, mask, size, N)
others, overflow = boolean_mask_1d(others, mask, size, N)

overflow = overflow | cell_too_small(cell)

return Neighbors(centers, others, overflow, positions)
return Neighbors(centers, others, overflow, positions, cl)

# if we need an update, call update(), else do a no-op and return input
return cond(
Expand All @@ -189,23 +212,210 @@ def candidates_fn(n):
others = jnp.reshape(square, (-1,))
return centers, others

def cell_list_candidate_fn(bin_id, N, dim):
"""Get the candidates for the cell list neighbor list. It is implemented following the jax-md implementation."""
idx = bin_id
bin_idx = [idx] * (dim**3)
# Get the neighboring bins from the current bin
for i, dindex in enumerate(neighboring_bins(dim)):
bin_idx[i] = shift_array(idx, dindex)

bin_idx = jnp.concatenate(bin_idx, axis=-2)
bin_idx = bin_idx[..., jnp.newaxis, :, :]
bin_idx = jnp.broadcast_to(bin_idx, idx.shape[:-1] + bin_idx.shape[-2:])
def copy_values_from_bin(value, bin_value, bin_id):
scatter_indices = jnp.reshape(bin_id, (-1,))
bin_value = jnp.reshape(bin_value, (-1,) + bin_value.shape[-2:])
return value.at[scatter_indices].set(bin_value)
neighbor_idx = jnp.zeros((N + 1,) + bin_idx.shape[-2:], jnp.int32)
neighbor_idx = copy_values_from_bin(neighbor_idx, bin_idx, idx)
# Reshape the neighbor_idx to get the centers and others
others = jnp.reshape(neighbor_idx[:-1, :, 0], (-1,))
centers = jnp.repeat(jnp.arange(0, N), others.shape[0] // N)
return centers, others

def get_neighbors(positions, cell, square_distances, cutoff, padding_mask=None):
centers, others = candidates_fn(positions.shape[0])
sq_distances = square_distances(positions[centers], positions[others])

mask = sq_distances <= cutoff
mask = mask * (centers != others) # remove self-interactions
def get_neighbors(positions, square_distances, cutoff, padding_mask=None, cl=None):
N, dim = positions.shape
if cl is not None:
centers, others = cell_list_candidate_fn(cl.id, N, dim)
else:
centers, others = candidates_fn(N)

# mask out fake positions

sq_distances = square_distances(positions[centers], positions[others])
mask = sq_distances <= (cutoff**2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency points with the rest of the code, this should use cast(2).

mask = mask * ((centers != others) & (others<N)) # remove self-interactions and neighbors repetitions
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind explaining a bit why the others<N filter is needed? Do you pad the candidates with N? I assume there needs to be padding because of the usual fixed shape struggle. I can't immediately find the code where N is set as default value, so I'm a bit confused here ... ;)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the initial allocation, the script assigns memory to each bin based on the sum of bin_capacity and a tolerance value (extra_capacity). This means that some bins may have more indices than there are atoms. The indices for the particles range from 0 to N-1, where N is the total number of atoms. The Nth index is used as a placeholder for an empty value. This placeholder must be masked to prevent duplicate indices.
This is the line in which the placeholder are introduced:
bin_id = N * jnp.ones((bin_count * bin_capacity, 1), dtype=jnp.int32)

if padding_mask is not None:
mask = mask * padding_mask[centers]
mask = mask * padding_mask[others]

hits = jnp.sum(mask)

return centers, others, sq_distances, mask, hits


def make_squared_distance(cell):
return vmap(lambda Ra, Rb: squared_distance(displacement(cell, Ra, Rb)))


def cell_list(cell, cutoff, buffer_size_multiplier=1.25, bin_size_multiplier=1):
"""Implementation of cell list neighborlist in pbc."""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not the world's biggest fan of verbose docstrings or comments, but somewhere around here, I'd really appreciate a brief high-level exposition on how this all works. Even just explaining what you mean by bin_id, hash, etc would be extremely helpful to get a gist of how this works.


cell = jnp.array(cell)
cutoff *= bin_size_multiplier
def allocate_fn(positions, extra_capacity=0):
# This function is not jittable, we're determining shapes
N = positions.shape[0]
_, bin_size, bins_per_side, bin_count = bin_dimensions(cell, cutoff)
bin_capacity = estimate_bin_capacity(positions, cell, bin_size,
buffer_size_multiplier)
# Computing the maximum number of atoms per bin (capacity)
# extra_capacity is used to increase the capacity of the bins to avoid reallocation
bin_capacity += extra_capacity

overflow = False
bin_id = N * jnp.ones((bin_count * bin_capacity, 1), dtype=jnp.int32)

# Compute the hash of each particle to assign it to a bin
hash_multipliers = compute_hash_constants(bins_per_side)
particle_id = iota(jnp.int32, N)
indices = jnp.array(jnp.floor(positions @ inverse(bin_size).T), dtype=jnp.int32)
# Some particles are in the edge and might have negative indices or larger than bins_per_side
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make more sense to wrap positions before the whole procedure to avoid dealing with this here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried wrapping the positions instead of the indices. However, when I did the wrap I had negative values (probably due to the number of decimals) which made it have negative indices, which broke the implementation. So, the simplest way was just to wrap the indices having as 'cell vectors' the bins_per_side.

# We need to correct them wrapping into bin per side vector
indices = wrap(jnp.diag(bins_per_side), indices).astype(jnp.int32)
hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32)

sort_map = jnp.argsort(hashes)
sorted_hash = hashes[sort_map]
sorted_id = particle_id[sort_map]

sorted_bin_id = jnp.mod(iota(jnp.int32, N), bin_capacity)
sorted_bin_id = sorted_hash * bin_capacity + sorted_bin_id
sorted_id = jnp.reshape(sorted_id, (N, 1))
bin_id = bin_id.at[sorted_bin_id].set(sorted_id)
bin_id = unflatten_bin_buffer(bin_id, bins_per_side)
occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, bin_count)
max_occupancy = jnp.max(occupancy)
overflow = overflow | (max_occupancy > bin_capacity)
return CellList(bin_id, overflow, bin_capacity, bin_size, bins_per_side)

def update_fn(positions, old_cell_list, new_cell):
# this is jittable,
# CellList tells us all the shapes we need
# If bin size is lower than cutoff, we need to reallocate cell list
# Rellocate is not jitable, we're changing shapes
# So, after each update we need to check if we need to reallocate
N = positions.shape[0]

bin_size = jnp.where(old_cell_list.bins_per_side != 0, cell / old_cell_list.bins_per_side, cell)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure that this will work with non-orthorhombic cells? this looks like it assumes a diagonal cell?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. Solution is:
bin_size = jnp.where(old_cell_list.bins_per_side != 0, (cell.T / old_cell_list.bins_per_side).T, cell)

max_occupancy = estimate_bin_capacity(positions, new_cell, bin_size, 1)
# Checking if update or reallocate
reallocate = jnp.all(get_heights(bin_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity)

def update(positions, old_cell_list):
hash_multipliers = compute_hash_constants(old_cell_list.bins_per_side)
indices = jnp.array(jnp.floor(positions @ inverse(bin_size).T), dtype=jnp.int32)
# Some particles are in the edge and might have negative indices or larger than bins_per_side
# We need to correct them wrapping into bin per side vector
indices = wrap(jnp.diag(old_cell_list.bins_per_side), indices).astype(jnp.int32)

hashes = jnp.sum(indices * hash_multipliers, axis=1, dtype=jnp.int32)
particle_id = iota(jnp.int32, N)
sort_map = jnp.argsort(hashes)
sorted_hash = hashes[sort_map]
sorted_id = particle_id[sort_map]

sorted_bin_id = jnp.mod(iota(jnp.int32, N), old_cell_list.capacity)
sorted_bin_id = sorted_hash * old_cell_list.capacity + sorted_bin_id
sorted_id = jnp.reshape(sorted_id, (N, 1))
bin_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32)
bin_id = bin_id.at[sorted_bin_id].set(sorted_id)

# This is not jitable, we're changing shapes. It's a fix for the unflatten_bin_buffer.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what happens here -- shouldn't this just be a case where we have to recompile? If shapes are changing...?! I'm not sure it's a good idea to allow changing the number of bins dynamically... I think this what's happening here? Not sure about the performance impact of doing pure_callback either.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is something that is based on jax-md. The problem is that in order to get the bin_id, the indices of the particles in each bin, the easiest way is to sort a list of all the values, so that they are sorted according to the bin they belong to. However, in order to use the cell_list_candidate_fn function, we need a 3-dimensional array to search the neighbours.

I have tried to avoid having to use unflatten_bin_buffer, which is the reshape to get the final bin_id array. However, to keep the shape, I would have to use at and set to update the values or some other kind of apporach but I was not able to find a way. Using pure_callback (not-jitable function inside a jit function) I managed to get a workaround without compromising performance. If you can think of another way, that would be great.

However, the shape of final bin_id is kept according to old_cell_list.id.


bin_id = jax.pure_callback(unflatten_bin_buffer, old_cell_list.id, bin_id, old_cell_list.bins_per_side)
return CellList(bin_id, old_cell_list.reallocate, old_cell_list.capacity, bin_size, old_cell_list.bins_per_side)

# In case bin size is lower than cutoff, we need to reallocate
def need_reallocate(_ ,old_cell_list):
return CellList(old_cell_list.id, True, old_cell_list.capacity, bin_size, old_cell_list.bins_per_side)
return cond(reallocate, need_reallocate, update, positions, old_cell_list)
return allocate_fn, update_fn


def estimate_bin_capacity(positions, cell, bin_size, buffer_size_multiplier):
minimum_bin_size = jnp.min(get_heights(bin_size))
bin_capacity = jnp.max(count_bin_filling(positions, cell, minimum_bin_size))
return (bin_capacity * buffer_size_multiplier).astype(jnp.int32)

def bin_dimensions(cell, cutoff):
"""Compute the number of bins-per-side and total number of bins in a box."""
# Considering cell is 3x3 array
# Transform into reciprocal space to get the cell size whatever cell is
face_dist = get_heights(cell)
bins_per_side = jnp.floor(face_dist / cutoff).astype(jnp.int32)
bin_size = jnp.where(bins_per_side != 0, cell / bins_per_side, cell)
bin_count = jnp.prod(bins_per_side)
return cell, bin_size, bins_per_side, bin_count.astype(jnp.int32)


def count_bin_filling(position, cell, minimum_bin_size):
"""
Counts the number of particles per-bin in a spatial partitioning scheme.
"""
cell, bin_size, bins_per_side, _ = bin_dimensions(cell, minimum_bin_size)
hash_multipliers = compute_hash_constants(bins_per_side)
particle_index = jnp.array(jnp.floor(jnp.dot(position, inverse(bin_size.T).T)), dtype=jnp.int32)
particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1)
filling = jnp.zeros_like(particle_hash, dtype=jnp.int32)
filling = filling.at[particle_hash].add(1)
return filling

def compute_hash_constants(bins_per_side):
"""Compute the hash constants for a given number of bins per side."""
>
bins_per_side = jnp.concatenate((one, bins_per_side[:-1]), axis=0)
return jnp.array(jnp.cumprod(bins_per_side), dtype=jnp.int32)

def unflatten_bin_buffer(arr, bins_per_side):
"""Unflatten the bin buffer to get the bin_id."""
bins_per_side = tuple(bins_per_side)
return jnp.reshape(arr, bins_per_side + (-1,) + arr.shape[1:])

def neighboring_bins(dimension):
"""Generate the indices of the neighboring bins of a given bin"""
for dindex in np.ndindex(*([3] * dimension)):
yield jnp.array(dindex) - 1



def shift_array(arr, dindex):
"""Shift an array by a given index. It is used for the bin's neighbor list."""
dx, dy, dz = tuple(dindex) + (0,) * (3 - len(dindex))
arr = cond(dx < 0,
lambda x: jnp.concatenate((x[1:], x[:1])),
lambda x: cond(dx > 0,
lambda x: jnp.concatenate((x[-1:], x[:-1])),
lambda x: x, arr),
arr)

arr = cond(dy < 0,
lambda x: jnp.concatenate((x[:, 1:], x[:, :1]), axis=1),
lambda x: cond(dy > 0,
lambda x: jnp.concatenate((x[:, -1:], x[:, :-1]), axis=1),
lambda x: x, arr),
arr)

arr = cond(dz < 0,
lambda x: jnp.concatenate((x[:, :, 1:], x[:, :, :1]), axis=2),
lambda x: cond(dz > 0,
lambda x: jnp.concatenate((x[:, :, -1:], x[:, :, :-1]), axis=2),
lambda x: x, arr),
arr)

return arr

# I had to include it for reallocate cell list, calling cell_list allocate not jitable
def check_reallocation_cell_list(cell_list, allocate_fn, positions):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to be unused?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. i'll delete it.

if cell_list.reallocate:
return allocate_fn(positions)
return cell_list
Loading
0