8000 [Feature] Molecule Meida Object3D Support. by xj63 · Pull Request #920 · SwanHubX/SwanLab · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Feature] Molecule Meida Object3D Support. #920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements-media.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
soundfile
pillow
matplotlib
numpy
numpy
rdkit
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
swankit==0.1.6
swankit==0.1.7
urllib3>=1.26.0
requests>=2.25.0
setuptools
Expand Down
1 change: 1 addition & 0 deletions swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"Audio",
"Image",
"Object3D",
"Molecule",
"Text",
"Run",
"State",
Expand Down
1 change: 1 addition & 0 deletions swanlab/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Image,
Text,
Object3D,
Molecule,
)
from .run import (
SwanLabRun as Run,
Expand Down
3 changes: 2 additions & 1 deletion swanlab/data/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .audio import Audio
from .image import Image
from .line import FloatConvertible, Line
from .object3d import Model3D, Object3D, PointCloud
from .object3d import Model3D, Object3D, PointCloud, Molecule
from .text import Text
from .wrapper import DataWrapper

Expand All @@ -26,4 +26,5 @@
"Object3D",
"PointCloud",
"Model3D",
"Molecule",
]
3 changes: 3 additions & 0 deletions swanlab/data/modules/object3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Object3D: Main dispatcher class for handling different types of 3D data
PointCloud: Class for handling point cloud data with XYZ, XYZC, and XYZRGB formats
Model3D: Class for handling 3D model files like GLB
Molecule: Class for handling molecule data from various formats by RDKit

Examples:
# Create point cloud from XYZ coordinates
Expand Down Expand Up @@ -33,11 +34,13 @@
"""

from .model3d import Model3D
from .molecule import Molecule
from .object3d import Object3D
from .point_cloud import PointCloud

__all__ = [
'Object3D',
'PointCloud',
'Model3D',
'Molecule',
]
186 changes: 186 additions & 0 deletions swanlab/data/modules/object3d/molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

from swankit.core.data import DataSuite as D
from swankit.core.data import MediaBuffer, MediaType

# Attempt to import RDKit; if unavailable, set Chem and Mol to None
try:
from rdkit import Chem
from rdkit.Chem import AllChem, Mol

_has_rdkit = True
except ImportError:
Chem = None
Mol = None
AllChem = None
_has_rdkit = False


@dataclass()
class Molecule(MediaType):
pdb_data: str
caption: Optional[str] = None

def __post_init__(self):
"""Validates input data after initialization."""
if not isinstance(self.pdb_data, str):
raise TypeError("pdb_data must be a string, use RDKit.Chem.MolToPDBBlock to convert.")

@staticmethod
def check_is_available():
"""Check if RDKit is available."""
if not _has_rdkit:
raise ImportError("RDKit is not available. You can install it by running 'pip install rdkit'.")

@classmethod
def from_mol(cls, mol: Mol, *, caption: Optional[str] = None, **kwargs) -> "Molecule":
"""Creates a Molecule instance from an RDKit Mol object.

Args:
mol: The RDKit Mol object.
caption: Optional descriptive text.

Returns:
Molecule: A new Molecule instance.

Raises:
ImportError: If RDKit is not available.

Examples:
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles("CCO")
>>> molecule = Molecule.from_mol(mol, caption="Ethanol")
"""
cls.check_is_available()
pdb_block = cls._convert_to_pdb_block(mol)
return cls(pdb_block, caption=caption, **kwargs)

@staticmethod
def _convert_to_pdb_block(mol: Mol) -> str:
"""将分子转换为 PDB 字符串"""
Molecule.check_is_available()
if mol.GetNumConformers() == 0:
AllChem.EmbedMolecule(mol) # 生成 3D 坐标
return Chem.MolToPDBBlock(mol)

@classmethod
def from_pdb_file(cls, pdb_file: Union[Path, str], *, caption: Optional[str] = None, **kwargs) -> "Molecule":
"""Creates a Molecule instance from a PDB file by reading the file content directly.

Args:
pdb_file: Path to the PDB file.
caption: Optional descriptive text.

Returns:
Molecule: A new Molecule instance.

Raises:
ValueError: If RDKit is not available or the file cannot be read.
"""
cls.check_is_available()
try:
with open(pdb_file) as f:
pdb_data = f.read()
except FileNotFoundError as e:
raise FileNotFoundError(f"PDB file not found: {pdb_file}") from e
except Exception as e:
raise ValueError(f"Could not read PDB file: {pdb_file}. Error: {e}") from e

# Directly create the Molecule instance with the pdb_data, skipping Mol object
return cls(pdb_data, caption=caption, **kwargs)

@classmethod
def from_sdf_file(cls, sdf_file: Path, *, caption: Optional[str] = None, **kwargs) -> "Molecule":
"""Creates a Molecule instance from an SDF file.

Args:
sdf_file: Path to the SDF file.
caption: Optional descriptive text.

Returns:
Molecule: A new Molecule instance.

Raises:
ImportError: If RDKit is not available.
"""
cls.check_is_available()
suppl = Chem.SDMolSupplier(str(sdf_file))
mol = next(suppl) # Assuming only one molecule in the SDF file, you can iterate if needed.
if mol is None:
raise ValueError(f"Could not read molecule from SDF file: {sdf_file}")
return cls.from_mol(mol, caption=caption, **kwargs)

@classmethod
def from_smiles(cls, smiles: str, *, caption: Optional[str] = None, **kwargs) -> "Molecule":
"""Creates a Molecule instance from a SMILES string.

Args:
smiles: The SMILES string.
caption: Optional descriptive text.

Returns:
Molecule: A new Molecule instance.

Raises:
ValueError: If RDKit is not available.
"""
cls.check_is_available()
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Could not read molecule from SMILES string: {smiles}")
return cls.from_mol(mol, caption=caption, **kwargs)

@classmethod
def from_mol_file(cls, mol_file: Path, *, caption: Optional[str] = None, **kwargs) -> "Molecule":
"""Creates a Molecule instance from a Mol file.

Args:
mol_file: Path to the Mol file.
caption: Optional descriptive text.

Returns:
Molecule: A new Molecule instance.

Raises:
ValueError: If RDKit is not available or the file cannot be read.
"""
cls.check_is_available()
mol = Chem.MolFromMolFile(str(mol_file))
if mol is None:
raise ValueError(f"Could not read molecule from Mol file: {mol_file}")
return cls.from_mol(mol, caption=caption, **kwargs)

# ---------------------------------- override ----------------------------------

def parse(self) -> Tuple[str, MediaBuffer]:
"""Convert Molecule PDB to buffer for transmission.

Returns:
Tuple containing:
- File name with format: molecule-step{step}-{hash}.pdb
- MediaBuffer containing the molecule pdb data
"""

data = self.pdb_data.encode()

buffer = MediaBuffer()
buffer.write(data)

hash_name = D.get_hash_by_bytes(data)[:16]
save_name = f"molecule-step{self.step}-{hash_name}.pdb"

return save_name, buffer

def get_chart(self) -> MediaType.Chart:
"""Return chart type for visualization"""
return MediaType.Chart.MOLECULE

def get_section(self) -> str:
"""Return section name for organization"""
return "Molecule"

def get_more(self) -> Optional[Dict[str, str]]:
"""Return additional information (caption)"""
return {"caption": self.caption} if self.caption else None
24 changes: 23 additions & 1 deletion swanlab/data/modules/object3d/object3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from swankit.core.data import MediaType

from .model3d import Model3D
from .molecule import Molecule
from .point_cloud import Box, PointCloud

try:
Expand All @@ -12,6 +13,11 @@
except ImportError:
np = None

try:
from rdkit.Chem import Mol
except ImportError:
Mol = None


class Object3D:
"""A dispatcher class that converts different types of 3D data to MediaType objects.
Expand Down Expand Up @@ -58,10 +64,17 @@ class Object3D:
... {"points": points_xyz, "boxes": list(Box)}
... )

5. Creating from Molecule:
>>> from rdkit import Chem
>>> mol = Chem.MolFromSmiles("CCO")
>>> molecule = Molecule.from_mol(mol, caption="Ethanol")
>>> obj8 = Object3D(molecule)

Args:
data: Input data, can be:
- numpy.ndarray: Point cloud data with shape (N, C) where C is 3,4 or 6
- str/Path: Path to a 3D file (.glb or .swanlab.pts.json)
- Molecule: A Molecule object
caption: Optional description text
**kwargs: Additional keyword arguments passed to specific handlers

Expand All @@ -77,7 +90,7 @@ class Object3D:

def __new__(
cls,
data: Union[np.ndarray, str, Path, Dict],
data: Union[np.ndarray, str, Path, Dict, Mol],
*,
caption: Optional[str] = None,
**kwargs,
Expand All @@ -92,6 +105,9 @@ def __new__(
if isinstance(data, (str, Path)):
return cls._handle_file(Path(data), **kwargs)

if isinstance(data, Mol):
return Molecule.from_mol(data, **kwargs)

return cls._handle_data(data, **kwargs)

@staticmethod
Expand All @@ -117,7 +133,13 @@ def _handle_ndarray(cls, data: np.ndarray, **kwargs) -> MediaType:

_FILE_HANDLERS: Dict[str, List[Callable]] = {
'.swanlab.pts.json': [PointCloud.from_swanlab_pts_json_file],

'.glb': [Model3D.from_glb_file],

'.sd': [Molecule.from_sdf_file],
'.sdf': [Molecule.from_sdf_file],
'.mol': [Molecule.from_mol_file],
'.pdb': [Molecule.from_pdb_file],
}

@classmethod
Expand Down
25 changes: 25 additions & 0 deletions test/metrics/assets/molecule.example.pdb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
HETATM 1 C1 UNL 1 2.860 -0.100 -0.085 1.00 0.00 C
HETATM 2 C2 UNL 1 1.399 -0.404 0.007 1.00 0.00 C
HETATM 3 C3 UNL 1 0.446 0.316 -0.606 1.00 0.00 C
HETATM 4 C4 UNL 1 -1.049 0.080 -0.575 1.00 0.00 C
HETATM 5 C5 UNL 1 -1.746 1.288 0.055 1.00 0.00 C
HETATM 6 C6 UNL 1 -1.498 -1.204 0.129 1.00 0.00 C
HETATM 7 H1 UNL 1 3.392 -0.955 -0.513 1.00 0.00 H
HETATM 8 H2 UNL 1 3.063 0.776 -0.709 1.00 0.00 H
HETATM 9 H3 UNL 1 3.264 0.092 0.914 1.00 0.00 H
HETATM 10 H4 UNL 1 1.138 -1.265 0.618 1.00 0.00 H
HETATM 11 H5 UNL 1 0.756 1.170 -1.210 1.00 0.00 H
HETATM 12 H6 UNL 1 -1.379 0.009 -1.619 1.00 0.00 H
HETATM 13 H7 UNL 1 -1.493 2.213 -0.477 1.00 0.00 H
HETATM 14 H8 UNL 1 -2.834 1.173 0.017 1.00 0.00 H
HETATM 15 H9 UNL 1 -1.456 1.414 1.104 1.00 0.00 H
HETATM 16 H10 UNL 1 -2.584 -1.324 0.051 1.00 0.00 H
HETATM 17 H11 UNL 1 -1.037 -2.086 -0.328 1.00 0.00 H
HETATM 18 H12 UNL 1 -1.242 -1.193 1.194 1.00 0.00 H
CONECT 1 2 7 8 9
CONECT 2 3 3 10
CONECT 3 4 11
CONECT 4 5 6 12
CONECT 5 13 14 15
CONECT 6 16 17 18
END
30 changes: 30 additions & 0 deletions test/metrics/molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
@author: xj63
@file: molecule.py
@time: 2025/3/13 13:30
@description: 测试上传分子对象

NOTE 你需要下载下面的文件放到当前文件目录的assets文件夹下,才能运行这个测试
- 3D分子pdb文件: 此文件存放于 https://github.com/SwanHubX/SwanLab/pull/477 ,可通过 https://github.com/user-attachments/files/19605387/molecule.example.pdb.zip 下载
"""

# noinspection PyPackageRequirements
from rdkit import Chem

import swanlab

swanlab.init(project="molecule", public=True)

# from rdkit.Chem.Mol
chem = Chem.MolFromSmiles("CCO")
cco = swanlab.Molecule.from_mol(chem, caption="cco")

# from file path
file = swanlab.Molecule.from_pdb_file("./assets/molecule.example.pdb", caption="file")

# from pdb data
with open("./assets/molecule.example.pdb") as f:
data = swanlab.Molecule(f.read(), caption="data")

# upload
swanlab.log({"file": file, "data": data, "example": cco})
Loading
0