diff --git a/.gitignore b/.gitignore index b440a5e..d71eaad 100755 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ tmp/ **/checkpoints/ *.png **/.DS_Store +*.egg-info \ No newline at end of file diff --git a/molscribe/augment.py b/molscribe/augment.py index a80ebc5..22c92b2 100755 --- a/molscribe/augment.py +++ b/molscribe/augment.py @@ -1,12 +1,11 @@ import albumentations as A from albumentations.augmentations.geometric.functional import safe_rotate_enlarged_img_size, _maybe_process_in_chunks, \ - keypoint_rotate + keypoint_rotate import cv2 import math import random import numpy as np - def safe_rotate( img: np.ndarray, angle: int = 0, diff --git a/molscribe/chemistry.py b/molscribe/chemistry.py index 5c541b8..602c5aa 100644 --- a/molscribe/chemistry.py +++ b/molscribe/chemistry.py @@ -4,10 +4,11 @@ import multiprocessing import itertools -import rdkit +import rdkit.RDLogger as RDLogger import rdkit.Chem as Chem +import rdkit.Chem.AllChem as AllChem -rdkit.RDLogger.DisableLog('rdApp.*') +RDLogger.DisableLog('rdApp.*') from SmilesPE.pretokenizer import atomwise_tokenizer @@ -366,7 +367,11 @@ def get_smiles_from_symbol(symbol, mol, atom, bonds): total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds])) formula_list = _expand_carbon(_parse_formula(symbol)) - smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) + # smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) + if len(bonds) != 2: + smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) + else: + smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, 1, None) if success: return smiles return None @@ -448,39 +453,61 @@ def _need_expand(mol, mappings): atom.SetIsotope(0) continue - # remove bonds connected to abbreviation/condensed formula - adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds] - for adjacent_idx in adjacent_indices: - mol_w.RemoveBond(i, adjacent_idx) - - adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices] - for adjacent_atom, bond in zip(adjacent_atoms, bonds): - adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble())) - - # get indices of atoms of main body that connect to substituent - bonding_atoms_w = adjacent_indices - # assume indices are concated after combine mol_w and mol_r - bonding_atoms_r = [mol_w.GetNumAtoms()] - for atm in mol_r.GetAtoms(): - if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0: - bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx()) - - # combine main body and substituent into a single molecule object - combo = Chem.CombineMols(mol_w, mol_r) - - # connect substituent to main body with bonds - mol_w = Chem.RWMol(combo) - # if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body - for atm in bonding_atoms_w: - bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons() - mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order]) - - # reset radical electrons - for atm in bonding_atoms_w: - mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) - for atm in bonding_atoms_r: - mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) - atoms_to_remove.append(i) + if "(" in symbol and len(bonds) == 2: + # Get connection information for the current atom + connected_info = [(neighbor.GetIdx(), mol_w.GetBondBetweenAtoms(i, neighbor.GetIdx()).GetBondType()) + for neighbor in atom.GetNeighbors()] + + # Create a new molecule by combining the current molecule with the expanded fragment + combined = Chem.RWMol(AllChem.CombineMols(mol_w, mol_r)) + + # Connect the first atom of the expanded fragment to the first neighbor + combined.AddBond(connected_info[0][0], mol_w.GetNumAtoms(), connected_info[0][1]) + + # Connect the last atom of the expanded fragment to the second neighbor + combined.AddBond(connected_info[1][0], mol_w.GetNumAtoms() + mol_r.GetNumAtoms() - 1, + connected_info[1][1]) + + # Update the working molecule + mol_w = combined + + # Mark the original atom for removal + atoms_to_remove.append(i) + + else: + # remove bonds connected to abbreviation/condensed formula + adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds] + for adjacent_idx in adjacent_indices: + mol_w.RemoveBond(i, adjacent_idx) + + adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices] + for adjacent_atom, bond in zip(adjacent_atoms, bonds): + adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble())) + + # get indices of atoms of main body that connect to substituent + bonding_atoms_w = adjacent_indices + # assume indices are concated after combine mol_w and mol_r + bonding_atoms_r = [mol_w.GetNumAtoms()] + for atm in mol_r.GetAtoms(): + if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0: + bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx()) + + # combine main body and substituent into a single molecule object + combo = Chem.CombineMols(mol_w, mol_r) + + # connect substituent to main body with bonds + mol_w = Chem.RWMol(combo) + # if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body + for atm in bonding_atoms_w: + bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons() + mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order]) + + # reset radical electrons + for atm in bonding_atoms_w: + mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) + for atm in bonding_atoms_r: + mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) + atoms_to_remove.append(i) # Remove atom in the end, otherwise the id will change # Reverse the order and remove atoms with larger id first @@ -551,10 +578,12 @@ def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False): ratio = width / height coords = [[x * ratio * 10, y * 10] for x, y in coords] mol = _verify_chirality(mol, coords, symbols, edges, debug) - # molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates. - # TODO: make sure molblock has the abbreviation information - pred_molblock = Chem.MolToMolBlock(mol) - pred_smiles, mol = _expand_functional_group(mol, {}, debug) + + # First expand functional groups + pred_smiles, expanded_mol = _expand_functional_group(mol, {}, debug) + + # Generate molblock from the expanded molecule + pred_molblock = Chem.MolToMolBlock(expanded_mol) success = True except Exception as e: if debug: diff --git a/molscribe/constants.py b/molscribe/constants.py index 6e470c8..d0cde7b 100644 --- a/molscribe/constants.py +++ b/molscribe/constants.py @@ -4,7 +4,8 @@ ORGANIC_SET = {'B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I'} RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', - 'Ra', 'Rb', 'Rc', 'Rd', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar'] + 'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', 'Ar3', 'Ar4','Ar5','Ar6','Ar7',"R'", + '1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]',"EWG",'Nu'] PLACEHOLDER_ATOMS = ["Lv", "Lu", "Nd", "Yb", "At", "Fm", "Er"] @@ -21,8 +22,9 @@ def __init__(self, abbrvs, smarts, smiles, probability): SUBSTITUTIONS: List[Substitution] = [ Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5), + Substitution(['OCOCH3'], '[#8]-[#6](=[#8])-[#6]', "[O]C(=O)C]", 0.5), Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5), - Substitution(['CO2Et', 'COOEt'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5), + Substitution(['CO2Et', 'COOEt', 'EtO2C'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5), Substitution(['OAc'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7), Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7), diff --git a/molscribe/interface.py b/molscribe/interface.py index 54a3a2e..055f862 100644 --- a/molscribe/interface.py +++ b/molscribe/interface.py @@ -1,4 +1,5 @@ import argparse +import os from typing import List import cv2 @@ -30,7 +31,7 @@ def __init__(self, model_path, device=None, num_workers=1): MolScribe Interface :param model_path: path of the model checkpoint. :param device: torch device, defaults to be CPU. - :param multiprocessing_enabled: uses multiprocessing to parallelize parts of the inference when enabled, defaults to False. + :param num_workers: number of workers for parallel processing, defaults to 1. """ model_states = torch.load(model_path, map_location=torch.device('cpu')) args = self._get_args(model_states['args']) @@ -41,6 +42,9 @@ def __init__(self, model_path, device=None, num_workers=1): self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states) self.transform = get_transforms(args.input_size, augment=False) self.num_workers = num_workers + # MPS-specific optimizations + self.is_mps = str(device).startswith('mps') + self.optimal_batch_size = 32 if self.is_mps else 4 def _get_args(self, args_states=None): parser = argparse.ArgumentParser() @@ -90,7 +94,13 @@ def _get_model(self, args, tokenizer, device, states): decoder.eval() return encoder, decoder - def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16): + def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=None): + """ + Optimized version of predict_images with better MPS performance + """ + if batch_size is None: + batch_size = self.optimal_batch_size + device = self.device predictions = [] self.decoder.compute_confidence = return_confidence @@ -104,7 +114,9 @@ def predict_images(self, input_images: List, return_atoms_bonds=False, return_co batch_predictions = self.decoder.decode(features, hiddens) predictions += batch_predictions - smiles = [pred['chartok_coords']['smiles'] for pred in predictions] + return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds) + + def convert_graph_to_output(self, predictions, input_images, return_confidence=False, return_atoms_bonds=False): node_coords = [pred['chartok_coords']['coords'] for pred in predictions] node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions] edges = [pred['edges'] for pred in predictions] diff --git a/requirements.txt b/requirements.txt index fdb0fe1..d34d695 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,15 @@ -torch +torch>=2.5.0,<=2.7.1 torchvision numpy>=1.19.5,<2.0 -pandas>=1.2.4 +pandas>=2.2.3 matplotlib>=3.5.3 -opencv-python==4.5.5.64 -transformers>=4.5.1 +opencv-python>=4.10.0.84 huggingface-hub>=0.11.0 tensorboardX SmilesPE==0.0.3 OpenNMT-py==2.2.0 -rdkit>=2022.3.3 -albumentations @ git+https://github.com/albumentations-team/albumentations@37e714fd2e326f6f88778e425f98c2de8c8d5372 -timm @ git+https://github.com/rwightman/pytorch-image-models.git@54a6cca27a9a3e092a07457f5d56709da56e3cf5 \ No newline at end of file +rdkit~=2025.3.6 +albumentations==1.1.0 +timm>=0.4.12,<=0.5.4 +# Only this version of transformers are compatible with timm +transformers>=4.47.0,<=4.52.4 \ No newline at end of file diff --git a/setup.py b/setup.py index b054c6a..daa675f 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,12 @@ from distutils.core import setup -from pathlib import Path + +def read_requirements(): + """Read the requirements.txt file and return a list of dependencies.""" + with open("requirements.txt", "r", encoding="utf-8") as fh: + return fh.read().splitlines() setup(name='MolScribe', - version='1.1.1', + version='1.2.1', description='MolScribe', author='Yujie Qian', author_email='yujieq@csail.mit.edu', @@ -12,15 +16,5 @@ package_data={'molscribe': ['vocab/*']}, python_requires='>=3.7', setup_requires=['numpy'], - install_requires=[ - "numpy>=1.19.5,<2.0", - "torch>=1.11.0", - "pandas", - "matplotlib", - "opencv-python>=4.5.5.64", - "SmilesPE==0.0.3", - "OpenNMT-py==2.2.0", - "rdkit>=2022.3.3", - "albumentations==1.1.0", - "timm==0.4.12" - ]) + install_requires=read_requirements(), + )