8000 Improve the speed of creating TensorMap and accessing data inside · Issue #818 · metatensor/metatensor · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Improve the speed of creating TensorMap and accessing data inside #818
Open
@frostedoyster

Description

@frostedoyster

The attached script compares the performance of sphericart with and without its metatensor wrapper. It seems that wrapping and unwrapping arrays is slow compared to the calculation of the spherical harmonics

import numpy as np
import torch
import time

import sphericart
import metatensor
import sphericart.metatensor
import sphericart.torch
import metatensor.torch
import sphericart.torch.metatensor


def benchmark_speed(calculator, xyz):
    need_to_sync = (isinstance(xyz, torch.Tensor) and xyz.is_cuda)

    for _ in range(100):
        calculator.compute_with_gradients(xyz)

    if need_to_sync:
        torch.cuda.synchronize()
    start = time.time()
    for _ in range(1000):
        calculator.compute_with_gradients(xyz)
    if need_to_sync:
        torch.cuda.synchronize()
    end = time.time()

    return (end - start)/1000.0


def get_tensormap_from_xyz(module, xyz):
    return module.TensorMap(
        keys=module.Labels.single(),
        blocks=[
            module.TensorBlock(
                values=xyz[:, :, None],
                samples=module.Labels.range("sample", xyz.shape[0]),
                components=[module.Labels.range("xyz", xyz.shape[1])],
                properties=module.Labels.single(),
            )
        ],
    )

all_timings_no_mts = []
all_timings_mts = []
for backend in ["numpy", "torch"]:
    if backend == "numpy":
        xyz = np.random.randn(10000, 3)
    else:
        xyz = torch.randn(10000, 3)
    
    for device in ["cpu", "cuda"]:
        if backend == "numpy" and device == "cuda":
            continue
        xyz_tensor_map = get_tensormap_from_xyz((metatensor if backend == "numpy" else metatensor.torch), xyz)
        if device == "cuda":
            xyz = xyz.to(device)
            xyz_tensor_map = xyz_tensor_map.to(device)
        all_timings_no_mts.append([])
        all_timings_mts.append([])
        for l_max in range(11):
            calculator = (sphericart.SphericalHarmonics(l_max) if backend == "numpy" else sphericart.torch.SphericalHarmonics(l_max))
            all_timings_no_mts[-1].append(benchmark_speed(calculator, xyz))
            calculator_metatensor = (sphericart.metatensor.SphericalHarmonics(l_max) if backend == "numpy" else sphericart.torch.metatensor.SphericalHarmonics(l_max))
            all_timings_mts[-1].append(benchmark_speed(calculator_metatensor, xyz_tensor_map))
            print(f"{backend} {device} {l_max} {all_timings_no_mts[-1][-1]} {all_timings_mts[-1][-1]}")


import matplotlib.pyplot as plt

plt.plot(range(11), all_timings_no_mts[0], label="numpy cpu", color="blue")
plt.plot(range(11), all_timings_mts[0], label="numpy cpu metatensor", color="blue", linestyle="--")
plt.plot(range(11), all_timings_no_mts[1], label="torch cpu", color="orange")
plt.plot(range(11), all_timings_mts[1], label="torch cpu metatensor", color="orange", linestyle="--")
plt.plot(range(11), all_timings_no_mts[2], label="torch cuda", color="green")
plt.plot(range(11), all_timings_mts[2], label="torch cuda metatensor", color="green", linestyle="--")
plt.yscale("log")
plt.legend()
plt.savefig("speed.pdf")

Metadata

Metadata

Assignees

No one assigned

    Labels

    PerformancePerformance and optimization issuesPython-APIRelated to the Python bindings to "core"Torch-APIRelated to the (Py)Torch interfacecoreRelated to the core implementation, in Rust

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0