-
Notifications
You must be signed in to change notification settings - Fork 3
Adding cell list to neighborlist.py #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…h cell list neighbor list
…es to make it more clear.
Hi @suadou, many thanks for the contribution! This is a feature I've been meaning to get around to for ages and I'm very happy that someone has tackled it. I'll take a closer look at the code (and the errors) as soon as I can. The Out of curiosity, did you benchmark this against the naive implementation? I'm really curious when the crossover occurs between "stupid O(N^2) but very simple" and "clever O(N) but complicated". |
Hi @sirmarcel , thank you for your message and for taking the time to review my contribution. I’m glad to hear that you’re interested in the feature I’ve been working on. As for the benchmarking against the naive implementation, I did run some preliminary tests. Here are the results: From these initial results, it appears that the new cell list implementation begins to outperform the old one at around 8000 atoms, even without a comprehensive benchmark. Again, thank you for your feedback. |
Very cool!! Thank you. Very nice. The memory usage for the quadratic list should be easy to mitigate by replacing the I've started working on the review, I'll hopefully finish it today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, here's my first pass at a review! Overall, a very nice contribution, thanks again. 👍
However, there's a few things we need to sort out before this can be merged. Roughly in order of priority (and the order we should try to solve them):
Correctness/e 10000 rrors:
- I believe there's a correctness issue with the invalidation of the
cell_list
(see below). - This broke the "standard" neighborlist in some subtle way that makes
test_lj.py
fail -- this is not an unrelated error, this only fails in this branch. From the error message, something in the overflow logic has been broken somehow.
Design/architecture considerations:
- I see why it makes sense to merge the
cell_list
into thequadratic_neighbor_list
function. There's a lot of shared logic that comes after the candidate function. But we at least need to rename that function as a result, and currently I find it very hard to follow. I've left some comments in the code to start fixing this. - I don't really understand the design of the
cell_list
. I don't need to know the details, but I want to have a high-level understanding of how it works at least. What assumptions are built into this? Which shapes are fixed? How are the particle IDs stored? What exactly can be done in a JIT and what can't? - Closely related, I don't understand what happens during the jittable update of the
cell_list
and why there is some mysteriouspure_callback
...
Code formatting etc:
- Code should be formatted with
black
. (This is on me, I didn't expect contributions and didn't write any guidelines.) - There needs to be a small but non-zero amount of explanation in the code for future debugging/reading.
- Some other minor things have been highlighted in the review itself.
I don't think any of these are super terrible and can be solved with a relatively small amount of work. If you don't have the bandwidth to work on this, I can also take a look -- it'll take me some time though, and we should coordinate so we don't do the same work twice. I'd suggest that we proceed by first fixing the "correctness" issues and then work our way down the list to the more minor considerations. Once everything looks good, I'll go in myself ... 🤓
|
||
sq_distances = square_distances(positions[centers], positions[others]) | ||
mask = sq_distances <= (cutoff**2) | ||
mask = mask * ((centers != others) & (others<N)) # remove self-interactions and neighbors repetitions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind explaining a bit why the others<N
filter is needed? Do you pad the candidates with N
? I assume there needs to be padding because of the usual fixed shape struggle. I can't immediately find the code where N
is set as default value, so I'm a bit confused here ... ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the initial allocation, the script assigns memory to each bin based on the sum of bin_capacity
and a tolerance value (extra_capacity
). This means that some bins may have more indices than there are atoms. The indices for the particles range from 0 to N-1, where N is the total number of atoms. The Nth index is used as a placeholder for an empty value. This placeholder must be masked to prevent duplicate indices.
This is the line in which the placeholder are introduced:
bin_id = N * jnp.ones((bin_count * bin_capacity, 1), dtype=jnp.int32)
|
||
|
||
def cell_list(cell, cutoff, buffer_size_multiplier=1.25, bin_size_multiplier=1): | ||
"""Implementation of cell list neighborlist in pbc.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not the world's biggest fan of verbose docstrings or comments, but somewhere around here, I'd really appreciate a brief high-level exposition on how this all works. Even just explaining what you mean by bin_id
, hash
, etc would be extremely helpful to get a gist of how this works.
@@ -8,7 +8,7 @@ | |||
from ase.build import bulk | |||
|
|||
|
|||
from glp.neighborlist import quadratic_neighbor_list, neighbor_list | |||
from glp.neighborlist import quadratic_neighbor_list, neighbor_list, check_reallocation_cell_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be an unused import?
@@ -43,10 +43,11 @@ def get_distances(graph): | |||
|
|||
|
|||
class TestNeighborList(TestCase): | |||
# Checking cell list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably check both cell_list
and the regular implementation and loop over use_cell_list=True
and False
. I can do this later, though, no worries.
|
||
allowed_movement = (skin * cast(0.5)) ** cast(2.0) | ||
|
||
if cell is not None and use_cell_list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this block can be avoided entirely... we should (a) move the check for a "too small" cell into cell_too_small
and (b) remove the ability to just "turn off" the use of cell_list
dynamically in quadratic_neighbor_list
.
I think if use_cell_list=True
is passed, we should always use it. I don't like the inline if cl_allocate is not None else None
checks below -- the code is already pretty complicated to read, and having even more on-the-fly sthuff that can happen should be avoided! 🙏
The decision whether to use the cell_list
needs to be made either by the user ahead of time or (more likely) at a higher-level interface. I think it'd be best to decide automatically in neighbor_list
at the top based on the number of atoms, whether we're in a periodic case, and whether the cell is large enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I'll implement it.
@@ -65,14 +78,15 @@ def _update(system, neighbors): | |||
return neighbors, _update | |||
|
|||
|
|||
def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, debug=False): | |||
def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, use_cell_list=False, debug=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this function needs to be renamed to something else 😁 ... I can do that later though.
bin_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32) | ||
bin_id = bin_id.at[sorted_bin_id].set(sorted_id) | ||
|
||
# This is not jitable, we're changing shapes. It's a fix for the unflatten_bin_buffer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what happens here -- shouldn't this just be a case where we have to recompile? If shapes are changing...?! I'm not sure it's a good idea to allow changing the number of bins dynamically... I think this what's happening here? Not sure about the performance impact of doing pure_callback
either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is something that is based on jax-md. The problem is that in order to get the bin_id
, the indices of the particles in each bin, the easiest way is to sort a list of all the values, so that they are sorted according to the bin they belong to. However, in order to use the cell_list_candidate_fn
function, we need a 3-dimensional array to search the neighbours.
I have tried to avoid having to use unflatten_bin_buffer
, which is the reshape to get the final bin_id
array. However, to keep the shape, I would have to use at and set to update the values or some other kind of apporach but I was not able to find a way. Using pure_callback
(not-jitable function inside a jit function) I managed to get a workaround without compromising performance. If you can think of another way, that would be great.
However, the shape of final bin_id
is kept according to old_cell_list.id
.
# So, after each update we need to check if we need to reallocate | ||
N = positions.shape[0] | ||
|
||
bin_size = jnp.where(old_cell_list.bins_per_side != 0, cell / old_cell_list.bins_per_side, cell) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure that this will work with non-orthorhombic cells? this looks like it assumes a diagonal cell
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. Solution is:
bin_size = jnp.where(old_cell_list.bins_per_side != 0, (cell.T / old_cell_list.bins_per_side).T, cell)
cl_allocate = cl_update = None | ||
else: | ||
cl_allocate = cl_update = None | ||
|
||
def need_update_fn(neighbors, new_positions, new_cell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function needs some changes for cell_list
, at least if we want to allow the cell
to change in that case: If the cell gets bigger, each bin becomes larger and more "movement" can be allowed. If it becomes smaller, less is allowed, and in the extreme case no further movement is allowed and we have to reallocate. (Or fail entirely if we can't fit enough boxes in.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case you are right, but is not this lines of the code.
if cell is not None and use_cell_list:
if jnp.all(jnp.any(cutoff < get_heights(cell) / 3., axis=0)):
cl_allocate, cl_update = cell_list(cell, cutoff)
else:
cl_allocate = cl_update = None
else:
cl_allocate = cl_update = None
This block is essentially checking if it’s possible to divide the cell into at least 3 bins along any of the cell vectors. If it’s not possible, it doesn’t make sense to use the cell list. It might be beneficial to use a cutoff plus some tolerance (Now it is cutoff+skin
). The need for reallocation of the cell list arises not when the cell gets larger, but when it gets smaller such that the bin is not large enough to guarantee that it can search for each neighbor (i.e., the height in one or more directions is less than the cutoff).
def update_fn(positions, old_cell_list, new_cell):
# this is jittable,
# CellList tells us all the shapes we need
# If bin size is lower than cutoff, we need to reallocate cell list
# Rellocate is not jitable, we're changing shapes
# So, after each update we need to check if we need to reallocate
N = positions.shape[0]
bin_size = jnp.where(old_cell_list.bins_per_side != 0, (cell.T / old_cell_list.bins_per_side).T, cell)
max_occupancy = estimate_bin_capacity(positions, new_cell, bin_size, 1)
# Checking if update or reallocate
reallocate = jnp.all(get_heights(bin_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity)
Here, the code updates the bin_size
according to the new cell keeping number of bins per side. Afterwards, it checks the occupancy (if there are more atoms in one bin than bin_capacity) and if the height is equal to or greater than the cutoff.
return arr | ||
|
||
# I had to include it for reallocate cell list, calling cell_list allocate not jitable | ||
def check_reallocation_cell_list(cell_list, allocate_fn, positions): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function seems to be unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. i'll delete it.
EDIT: Okay, sorry, I've thought about this more and I think what I said is largely wrong. Let me start from the top. I think there's some issues we need to discuss around invalidation of the For the neighbour list itself, the current def test_edge_case_cell_change(self):
# testing whether we catch *only* the cell boundary
# changing such that a new neighbor appears
# let | denote the boundary and c the cutoff
#
# | c c |
# x . . . . y . . |
#
# to:
# | c c . |
# x . . . . y . |
#
# since y hasn't moved explicitly, this is missed by a naive check
from ase import Atoms
cutoff = 1.0
skin = 0.1
L = 10.0
epsilon = 1e-6 # distance by which we're just outside skin+cutoff initially
total_cutoff = skin + cutoff
atoms = Atoms(
positions=np.array([[0, 0, 0], [L - total_cutoff - epsilon, 0, 0]]),
cell=np.array([[L, 0, 0], [0, 100, 0], [0, 0, 100]]),
pbc=True
)
system = atoms_to_system(atoms)
allocate, update, need_update = quadratic_neighbor_list(
system.cell,
cutoff,
skin,
debug=True,
capacity_multiplier=1.5,
use_cell_list=True,
)
neighbors = allocate(system.R, new_cell=system.cell)
# no neighbors -- just padding
assert neighbors.centers[0] == 2
assert len(neighbors.centers) == 1
# now we shift the boundary left
atoms.set_cell(np.array([[L - skin, 0, 0], [0, 100, 0], [0, 0, 100]]), scale_atoms=False)
system = atoms_to_system(atoms)
assert need_update(neighbors, system.R, system.cell) I think this can be fixed by adding the norm of the total change in In addition to this, there's the question of when to recompute the Sorry for the wall of text, but I think we should try to get this right, or (which would also be fine) just explicitly refuse to deal with the case of a changing cell with the |
Okay, I've thought about it a bit more: We just need to figure out the worst-case movement of two atoms, which is For the Anyhow, sorry for derailing this PR with this. I would suggest we set this invalidation business aside for later. For now, what I'd like to see is (a) Then, we can tackle the other stuff together! |
This pull request introduces a new method for computing the cell list to identify neighboring atoms. Here's how it works:
This approach is jittable, making it efficient for large systems. However, the initial allocation is not jittable. The implementation follows the cell list method used in
jax-md
.The existing function
quadratic_neighbor_list
has been updated with a new parameteruse_cell_list=False
. To use the new cell list method, setcell_list=True
. Please note that for this to work, thecut_off
should be at least three times smaller than thecell
size in one or more of the components (x,y,z).While the cell-list implementation is more efficient for large systems, it is more complex and requires the simulation cell to be periodic. On the other hand, the N-square implementation is simpler and can be used for non-periodic systems.
Important Note: If the cell size changes and the bins become smaller than the cutoff, the cell list must be reallocated. The
CellList
class has areallocated
attribute. If it's set to true, the cell list must be recomputed.Test: The new code was tested using unit tests. It passes
neighborlist_test.py
(updated to increase the cell size to fulfill cell size). However, it failstest_lj.py
andtest_mlff.py
due to unrelated errors: