8000 Adding cell list to neighborlist.py by suadou · Pull Request #3 · sirmarcel/glp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Adding cell list to neighborlist.py #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

suadou
Copy link
@suadou suadou commented Apr 9, 2024

This pull request introduces a new method for computing the cell list to identify neighboring atoms. Here's how it works:

  1. The simulation cell is divided into a grid of bins.
  2. Each atom is assigned to a bin.
  3. For each atom, only the atoms in the same bin and its neighboring bins are considered.

This approach is jittable, making it efficient for large systems. However, the initial allocation is not jittable. The implementation follows the cell list method used in jax-md.

The existing function quadratic_neighbor_list has been updated with a new parameter use_cell_list=False. To use the new cell list method, set cell_list=True. Please note that for this to work, the cut_off should be at least three times smaller than the cell size in one or more of the components (x,y,z).

While the cell-list implementation is more efficient for large systems, it is more complex and requires the simulation cell to be periodic. On the other hand, the N-square implementation is simpler and can be used for non-periodic systems.

Important Note: If the cell size changes and the bins become smaller than the cutoff, the cell list must be reallocated. The CellList class has a reallocated attribute. If it's set to true, the cell list must be recomputed.

Test: The new code was tested using unit tests. It passes neighborlist_test.py (updated to increase the cell size to fulfill cell size). However, it fails test_lj.py and test_mlff.py due to unrelated errors:

=========================== short test summary info ============================
FAILED test_lj.py::TestCalculatorWithLJ::test_mega - RuntimeError
FAILED test_mlff.py::TestMLFF::test_mega - ValueError: max() arg is an empty sequence

8000
@sirmarcel
Copy link
Owner

Hi @suadou, many thanks for the contribution! This is a feature I've been meaning to get around to for ages and I'm very happy that someone has tackled it. I'll take a closer look at the code (and the errors) as soon as I can. The mlff test can be safely ignored for now, but test_lj.py should not be failing.

Out of curiosity, did you benchmark this against the naive implementation? I'm really curious when the crossover occurs between "stupid O(N^2) but very simple" and "clever O(N) but complicated".

@suadou
Copy link
Author
suadou commented Apr 16, 2024

Hi @sirmarcel , thank you for your message and for taking the time to review my contribution. I’m glad to hear that you’re interested in the feature I’ve been working on.

As for the benchmarking against the naive implementation, I did run some preliminary tests. Here are the results:

image

image

From these initial results, it appears that the new cell list implementation begins to outperform the old one at around 8000 atoms, even without a comprehensive benchmark.

Again, thank you for your feedback.

@sirmarcel
Copy link
Owner

Very cool!! Thank you. Very nice. The memory usage for the quadratic list should be easy to mitigate by replacing the vmap over the N**2 distances with a scan, but I never bothered to implement it... 😁

I've started working on the review, I'll hopefully finish it today.

Copy link
Owner
@sirmarcel sirmarcel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, here's my first pass at a review! Overall, a very nice contribution, thanks again. 👍

However, there's a few things we need to sort out before this can be merged. Roughly in order of priority (and the order we should try to solve them):

Correctness/e 10000 rrors:

  • I believe there's a correctness issue with the invalidation of the cell_list (see below).
  • This broke the "standard" neighborlist in some subtle way that makes test_lj.py fail -- this is not an unrelated error, this only fails in this branch. From the error message, something in the overflow logic has been broken somehow.

Design/architecture considerations:

  • I see why it makes sense to merge the cell_list into the quadratic_neighbor_list function. There's a lot of shared logic that comes after the candidate function. But we at least need to rename that function as a result, and currently I find it very hard to follow. I've left some comments in the code to start fixing this.
  • I don't really understand the design of the cell_list. I don't need to know the details, but I want to have a high-level understanding of how it works at least. What assumptions are built into this? Which shapes are fixed? How are the particle IDs stored? What exactly can be done in a JIT and what can't?
  • Closely related, I don't understand what happens during the jittable update of the cell_list and why there is some mysterious pure_callback...

Code formatting etc:

  • Code should be formatted with black. (This is on me, I didn't expect contributions and didn't write any guidelines.)
  • There needs to be a small but non-zero amount of explanation in the code for future debugging/reading.
  • Some other minor things have been highlighted in the review itself.

I don't think any of these are super terrible and can be solved with a relatively small amount of work. If you don't have the bandwidth to work on this, I can also take a look -- it'll take me some time though, and we should coordinate so we don't do the same work twice. I'd suggest that we proceed by first fixing the "correctness" issues and then work our way down the list to the more minor considerations. Once everything looks good, I'll go in myself ... 🤓


sq_distances = square_distances(positions[centers], positions[others])
mask = sq_distances <= (cutoff**2)
mask = mask * ((centers != others) & (others<N)) # remove self-interactions and neighbors repetitions
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind explaining a bit why the others<N filter is needed? Do you pad the candidates with N? I assume there needs to be padding because of the usual fixed shape struggle. I can't immediately find the code where N is set as default value, so I'm a bit confused here ... ;)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the initial allocation, the script assigns memory to each bin based on the sum of bin_capacity and a tolerance value (extra_capacity). This means that some bins may have more indices than there are atoms. The indices for the particles range from 0 to N-1, where N is the total number of atoms. The Nth index is used as a placeholder for an empty value. This placeholder must be masked to prevent duplicate indices.
This is the line in which the placeholder are introduced:
bin_id = N * jnp.ones((bin_count * bin_capacity, 1), dtype=jnp.int32)



def cell_list(cell, cutoff, buffer_size_multiplier=1.25, bin_size_multiplier=1):
"""Implementation of cell list neighborlist in pbc."""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not the world's biggest fan of verbose docstrings or comments, but somewhere around here, I'd really appreciate a brief high-level exposition on how this all works. Even just explaining what you mean by bin_id, hash, etc would be extremely helpful to get a gist of how this works.

@@ -8,7 +8,7 @@
from ase.build import bulk


from glp.neighborlist import quadratic_neighbor_list, neighbor_list
from glp.neighborlist import quadratic_neighbor_list, neighbor_list, check_reallocation_cell_list
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be an unused import?

@@ -43,10 +43,11 @@ def get_distances(graph):


class TestNeighborList(TestCase):
# Checking cell list
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably check both cell_list and the regular implementation and loop over use_cell_list=True and False. I can do this later, though, no worries.


allowed_movement = (skin * cast(0.5)) ** cast(2.0)

if cell is not None and use_cell_list:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this block can be avoided entirely... we should (a) move the check for a "too small" cell into cell_too_small and (b) remove the ability to just "turn off" the use of cell_list dynamically in quadratic_neighbor_list.

I think if use_cell_list=True is passed, we should always use it. I don't like the inline if cl_allocate is not None else None checks below -- the code is already pretty complicated to read, and having even more on-the-fly sthuff that can happen should be avoided! 🙏

The decision whether to use the cell_list needs to be made either by the user ahead of time or (more likely) at a higher-level interface. I think it'd be best to decide automatically in neighbor_list at the top based on the number of atoms, whether we're in a periodic case, and whether the cell is large enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I'll implement it.

@@ -65,14 +78,15 @@ def _update(system, neighbors):
return neighbors, _update


def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, debug=False):
def quadratic_neighbor_list(cell, cutoff, skin, capacity_multiplier=1.25, use_cell_list=False, debug=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this function needs to be renamed to something else 😁 ... I can do that later though.

bin_id = N * jnp.ones((old_cell_list.id.reshape(-1, 1).shape), dtype=jnp.int32)
bin_id = bin_id.at[sorted_bin_id].set(sorted_id)

# This is not jitable, we're changing shapes. It's a fix for the unflatten_bin_buffer.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what happens here -- shouldn't this just be a case where we have to recompile? If shapes are changing...?! I'm not sure it's a good idea to allow changing the number of bins dynamically... I think this what's happening here? Not sure about the performance impact of doing pure_callback either.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is something that is based on jax-md. The problem is that in order to get the bin_id, the indices of the particles in each bin, the easiest way is to sort a list of all the values, so that they are sorted according to the bin they belong to. However, in order to use the cell_list_candidate_fn function, we need a 3-dimensional array to search the neighbours.

I have tried to avoid having to use unflatten_bin_buffer, which is the reshape to get the final bin_id array. However, to keep the shape, I would have to use at and set to update the values or some other kind of apporach but I was not able to find a way. Using pure_callback (not-jitable function inside a jit function) I managed to get a workaround without compromising performance. If you can think of another way, that would be great.

However, the shape of final bin_id is kept according to old_cell_list.id.

# So, after each update we need to check if we need to reallocate
N = positions.shape[0]

bin_size = jnp.where(old_cell_list.bins_per_side != 0, cell / old_cell_list.bins_per_side, cell)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure that this will work with non-orthorhombic cells? this looks like it assumes a diagonal cell?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. Solution is:
bin_size = jnp.where(old_cell_list.bins_per_side != 0, (cell.T / old_cell_list.bins_per_side).T, cell)

cl_allocate = cl_update = None
else:
cl_allocate = cl_update = None

def need_update_fn(neighbors, new_positions, new_cell):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function needs some changes for cell_list, at least if we want to allow the cell to change in that case: If the cell gets bigger, each bin becomes larger and more "movement" can be allowed. If it becomes smaller, less is allowed, and in the extreme case no further movement is allowed and we have to reallocate. (Or fail entirely if we can't fit enough boxes in.)

Copy link
Author
@suadou suadou Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case you are right, but is not this lines of the code.

    if cell is not None and use_cell_list:
        if jnp.all(jnp.any(cutoff < get_heights(cell) / 3., axis=0)):
            cl_allocate, cl_update = cell_list(cell, cutoff)
        else:
            cl_allocate = cl_update = None
    else:
        cl_allocate = cl_update = None

This block is essentially checking if it’s possible to divide the cell into at least 3 bins along any of the cell vectors. If it’s not possible, it doesn’t make sense to use the cell list. It might be beneficial to use a cutoff plus some tolerance (Now it is cutoff+skin). The need for reallocation of the cell list arises not when the cell gets larger, but when it gets smaller such that the bin is not large enough to guarantee that it can search for each neighbor (i.e., the height in one or more directions is less than the cutoff).

    def update_fn(positions, old_cell_list, new_cell):
        # this is jittable,
        # CellList tells us all the shapes we need
        # If bin size is lower than cutoff, we need to reallocate cell list
        # Rellocate is not jitable, we're changing shapes
        # So, after each update we need to check if we need to reallocate
        N = positions.shape[0]

        bin_size = jnp.where(old_cell_list.bins_per_side != 0, (cell.T / old_cell_list.bins_per_side).T, cell)
        max_occupancy = estimate_bin_capacity(positions, new_cell, bin_size, 1)
        # Checking if update or reallocate
        reallocate = jnp.all(get_heights(bin_size).any() >= cutoff) & (max_occupancy <= old_cell_list.capacity)

Here, the code updates the bin_size according to the new cell keeping number of bins per side. Afterwards, it checks the occupancy (if there are more atoms in one bin than bin_capacity) and if the height is equal to or greater than the cutoff.

return arr

# I had to include it for reallocate cell list, calling cell_list allocate not jitable
def check_reallocation_cell_list(cell_list, allocate_fn, positions):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to be unused?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. i'll delete it.

@sirmarcel
Copy link
Owner
sirmarcel commented Apr 16, 2024

EDIT: Okay, sorry, I've thought about this more and I think what I said is largely wrong. Let me start from the top.

I think there's some issues we need to discuss around invalidation of the neighbor_list and the cell_list if cell has changed since they were initialised.

For the neighbour list itself, the current needs_update_fn is actually not quite sufficient: it doesn't catch the (admittedly somewhat obscure) case of the cell moving without positions changing. This test (which could be added to test_neighborlist.py, for example, currently fails:

    def test_edge_case_cell_change(self):
        # testing whether we catch *only* the cell boundary
        # changing such that a new neighbor appears
        # let | denote the boundary and c the cutoff
        #
        # |   c       c   |
        # x . . . . y . . |
        #
        # to:
        # |   c     c . |
        # x . . . . y . |
        #
        # since y hasn't moved explicitly, this is missed by a naive check


        from ase import Atoms

        cutoff = 1.0
        skin = 0.1
        L = 10.0
        epsilon = 1e-6  # distance by which we're just outside skin+cutoff initially
        total_cutoff = skin + cutoff

        atoms = Atoms(
            positions=np.array([[0, 0, 0], [L - total_cutoff - epsilon, 0, 0]]),
            cell=np.array([[L, 0, 0], [0, 100, 0], [0, 0, 100]]),
            pbc=True
        )
        system = atoms_to_system(atoms)

        allocate, update, need_update = quadratic_neighbor_list(
            system.cell,
            cutoff,
            skin,
            debug=True,
            capacity_multiplier=1.5,
            use_cell_list=True,
        )
        neighbors = allocate(system.R, new_cell=system.cell)

        # no neighbors -- just padding
        assert neighbors.centers[0] == 2
        assert len(neighbors.centers) == 1

        # now we shift the boundary left
        atoms.set_cell(np.array([[L - skin, 0, 0], [0, 100, 0], [0, 0, 100]]), scale_atoms=False)
        system = atoms_to_system(atoms)

        assert need_update(neighbors, system.R, system.cell)

I think this can be fixed by adding the norm of the total change in cell vectors to the max_movement, but I'll think about it a bit more, maybe there's something even more obscure I'm overlooking.

In addition to this, there's the question of when to recompute the CellList, which as far as I understand stores the assignment of each atom into a spatial bin. I think here, there's a bit of a subtle issue: since you have to divide the box into bins evenly, the bins are typically bigger than cutoff+skin and so any potential bugs my be masked by the additional "bonus tolerance". On the plus side, we can maybe also exploit this tolerance to reduce recomputations. I'm not yet quite sure what the "invalidation criterion" should be for the CellList.

Sorry for the wall of text, but I think we should try to get this right, or (which would also be fine) just explicitly refuse to deal with the case of a changing cell with the CellList/fall back on just re-assigning to bins if the cell changes.

@sirmarcel
Copy link
Owner

Okay, I've thought about it a bit more: We just need to figure out the worst-case movement of two atoms, which is max_movement for one atom in the cell and max_movement + np.sum(np.linalg.norm(cell.T, axis=1)) for the other, so the total is 2*max_movement + (cell stuff). If this is above skin, the Neighbors need to be recomputed.

For the CellList, it's exactly the same, except that skin 76C0 is cutoff - min(heights / bins_per_side). I'm not sure if this layer of caching is even needed. If sorting atoms into their bins is jittable and reasonably fast it's probably not necessary. I think at the moment the only check is whether the bins are too small for the cutoff, which provokes something like an overflow?


Anyhow, sorry for derailing this PR with this. I would suggest we set this invalidation business aside for later. For now, what I'd like to see is

(a) test_lj passing with and without the new cell_list,
(b) test_neighborlist passing with and without,
(c) and enough high-level explanations that I can actually discuss the rest of the code!

Then, we can tackle the other stuff together!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0