Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c50ac1f
chore: implemented dynamic adjustment of num_workers. updated gitignore
alexey-krasnov Oct 14, 2024
266d31c
chore: bumped dependency of rdkit to 2024.3.5
alexey-krasnov Oct 25, 2024
d8346d0
chore: updated verson rdkit in requirements.txt, set num_workers equa…
alexey-krasnov Oct 30, 2024
2c0aa8d
fix: fixed version of transformers>=4.5.1,<=4.47.0
alexey-krasnov Mar 12, 2025
b77f086
fix: fixed versions of opencv-python>=4.10.0.84 and numpy>=1.19.5,<2.…
alexey-krasnov Mar 12, 2025
fcd40c7
chore: merged with remote version
alexey-krasnov Mar 12, 2025
7f457c6
Merge pull request #3 from alexey-krasnov/build
alexey-krasnov Mar 12, 2025
02f8014
feat: updated requirements. Tested updated version of dependencies. T…
alexey-krasnov Mar 17, 2025
bcae485
feat: updated requirements. Tested updated version of dependencies. T…
alexey-krasnov Mar 17, 2025
77fe487
feat: updated requirements. Tested updated versions of dependencies.
alexey-krasnov Mar 17, 2025
1fabf64
fix: fixed requirements for transformers>=4.5.1,<=4.47.0
alexey-krasnov Mar 17, 2025
8f8fc06
fix: fixed requirements for transformers>=4.5.1,<=4.47.0
alexey-krasnov Mar 17, 2025
25fb1a9
style: removed redundand commented code
alexey-krasnov Mar 17, 2025
f69a5e4
fix: fixed version of albumentations, used proper import of class fro…
alexey-krasnov Mar 18, 2025
94b07c9
chore: implemented batch_size-32 for Apple MPS. Adjusted and tested t…
alexey-krasnov Jun 20, 2025
31e4dd1
chore: bumped version to 1.2.0. Implemented fix a bug about case like…
alexey-krasnov Jul 15, 2025
65e9218
updated and tested versions of transformers>=4.47.0,<=4.52.4
alexey-krasnov Jul 15, 2025
f3b9825
chore: made changed according to CrystalEye42 fork
alexey-krasnov Feb 3, 2026
89f2fc8
chore: added R-groups from MolNexTR
alexey-krasnov Feb 4, 2026
2e98d43
chore: bumped version from 1.2.0 to 1.2.1
alexey-krasnov Feb 4, 2026
c8d33bb
chore: Removed dynamically adjusted num_workers based on the number o…
alexey-krasnov Feb 18, 2026
be9cf7e
chore: Removed dynamically adjusted num_workers based on the number o…
alexey-krasnov Feb 18, 2026
0ce05a2
chore: Removed dynamically adjusted num_workers based on the number o…
alexey-krasnov Feb 18, 2026
1c6278f
chore: Removed unnessesary AllChem.Compute2DCoords(mol) from _expand_…
alexey-krasnov Feb 18, 2026
a5dcd5b
chore: Removed unnessesary AllChem.Compute2DCoords(mol) from _expand_…
alexey-krasnov Feb 18, 2026
c7b5010
fix: added accidentally removed self.optimal_batch_size param
alexey-krasnov Feb 18, 2026
564973d
fix: added param num_workers in docstring of MolScribe class
alexey-krasnov Feb 18, 2026
784450e
fix: added param num_workers in docstring of MolScribe class
alexey-krasnov Feb 18, 2026
568b285
Merge pull request #4 from alexey-krasnov/build
alexey-krasnov Apr 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ tmp/
**/checkpoints/
*.png
**/.DS_Store
*.egg-info
3 changes: 1 addition & 2 deletions molscribe/augment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import albumentations as A
from albumentations.augmentations.geometric.functional import safe_rotate_enlarged_img_size, _maybe_process_in_chunks, \
keypoint_rotate
keypoint_rotate
import cv2
import math
import random
import numpy as np


def safe_rotate(
img: np.ndarray,
angle: int = 0,
Expand Down
109 changes: 69 additions & 40 deletions molscribe/chemistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import multiprocessing
import itertools

import rdkit
import rdkit.RDLogger as RDLogger
import rdkit.Chem as Chem
import rdkit.Chem.AllChem as AllChem

rdkit.RDLogger.DisableLog('rdApp.*')
RDLogger.DisableLog('rdApp.*')

from SmilesPE.pretokenizer import atomwise_tokenizer

Expand Down Expand Up @@ -366,7 +367,11 @@ def get_smiles_from_symbol(symbol, mol, atom, bonds):

total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))
formula_list = _expand_carbon(_parse_formula(symbol))
smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
# smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
if len(bonds) != 2:
smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
else:
smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, 1, None)
if success:
return smiles
return None
Expand Down Expand Up @@ -448,39 +453,61 @@ def _need_expand(mol, mappings):
atom.SetIsotope(0)
continue

# remove bonds connected to abbreviation/condensed formula
adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds]
for adjacent_idx in adjacent_indices:
mol_w.RemoveBond(i, adjacent_idx)

adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices]
for adjacent_atom, bond in zip(adjacent_atoms, bonds):
adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble()))

# get indices of atoms of main body that connect to substituent
bonding_atoms_w = adjacent_indices
# assume indices are concated after combine mol_w and mol_r
bonding_atoms_r = [mol_w.GetNumAtoms()]
for atm in mol_r.GetAtoms():
if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0:
bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx())

# combine main body and substituent into a single molecule object
combo = Chem.CombineMols(mol_w, mol_r)

# connect substituent to main body with bonds
mol_w = Chem.RWMol(combo)
# if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body
for atm in bonding_atoms_w:
bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons()
mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order])

# reset radical electrons
for atm in bonding_atoms_w:
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
for atm in bonding_atoms_r:
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
atoms_to_remove.append(i)
if "(" in symbol and len(bonds) == 2:
# Get connection information for the current atom
connected_info = [(neighbor.GetIdx(), mol_w.GetBondBetweenAtoms(i, neighbor.GetIdx()).GetBondType())
for neighbor in atom.GetNeighbors()]

# Create a new molecule by combining the current molecule with the expanded fragment
combined = Chem.RWMol(AllChem.CombineMols(mol_w, mol_r))

# Connect the first atom of the expanded fragment to the first neighbor
combined.AddBond(connected_info[0][0], mol_w.GetNumAtoms(), connected_info[0][1])

# Connect the last atom of the expanded fragment to the second neighbor
combined.AddBond(connected_info[1][0], mol_w.GetNumAtoms() + mol_r.GetNumAtoms() - 1,
connected_info[1][1])

# Update the working molecule
mol_w = combined

# Mark the original atom for removal
atoms_to_remove.append(i)

else:
# remove bonds connected to abbreviation/condensed formula
adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds]
for adjacent_idx in adjacent_indices:
mol_w.RemoveBond(i, adjacent_idx)

adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices]
for adjacent_atom, bond in zip(adjacent_atoms, bonds):
adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble()))

# get indices of atoms of main body that connect to substituent
bonding_atoms_w = adjacent_indices
# assume indices are concated after combine mol_w and mol_r
bonding_atoms_r = [mol_w.GetNumAtoms()]
for atm in mol_r.GetAtoms():
if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0:
bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx())

# combine main body and substituent into a single molecule object
combo = Chem.CombineMols(mol_w, mol_r)

# connect substituent to main body with bonds
mol_w = Chem.RWMol(combo)
# if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body
for atm in bonding_atoms_w:
bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons()
mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order])

# reset radical electrons
for atm in bonding_atoms_w:
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
for atm in bonding_atoms_r:
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
atoms_to_remove.append(i)

# Remove atom in the end, otherwise the id will change
# Reverse the order and remove atoms with larger id first
Expand Down Expand Up @@ -551,10 +578,12 @@ def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False):
ratio = width / height
coords = [[x * ratio * 10, y * 10] for x, y in coords]
mol = _verify_chirality(mol, coords, symbols, edges, debug)
# molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates.
# TODO: make sure molblock has the abbreviation information
pred_molblock = Chem.MolToMolBlock(mol)
pred_smiles, mol = _expand_functional_group(mol, {}, debug)

# First expand functional groups
pred_smiles, expanded_mol = _expand_functional_group(mol, {}, debug)

# Generate molblock from the expanded molecule
pred_molblock = Chem.MolToMolBlock(expanded_mol)
success = True
except Exception as e:
if debug:
Expand Down
6 changes: 4 additions & 2 deletions molscribe/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
ORGANIC_SET = {'B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I'}

RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12',
'Ra', 'Rb', 'Rc', 'Rd', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar']
'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', 'Ar3', 'Ar4','Ar5','Ar6','Ar7',"R'",
'1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]',"EWG",'Nu']

PLACEHOLDER_ATOMS = ["Lv", "Lu", "Nd", "Yb", "At", "Fm", "Er"]

Expand All @@ -21,8 +22,9 @@ def __init__(self, abbrvs, smarts, smiles, probability):

SUBSTITUTIONS: List[Substitution] = [
Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5),
Substitution(['OCOCH3'], '[#8]-[#6](=[#8])-[#6]', "[O]C(=O)C]", 0.5),
Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5),
Substitution(['CO2Et', 'COOEt'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5),
Substitution(['CO2Et', 'COOEt', 'EtO2C'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5),

Substitution(['OAc'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7),
Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7),
Expand Down
18 changes: 15 additions & 3 deletions molscribe/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
from typing import List

import cv2
Expand Down Expand Up @@ -30,7 +31,7 @@ def __init__(self, model_path, device=None, num_workers=1):
MolScribe Interface
:param model_path: path of the model checkpoint.
:param device: torch device, defaults to be CPU.
:param multiprocessing_enabled: uses multiprocessing to parallelize parts of the inference when enabled, defaults to False.
:param num_workers: number of workers for parallel processing, defaults to 1.
"""
model_states = torch.load(model_path, map_location=torch.device('cpu'))
args = self._get_args(model_states['args'])
Expand All @@ -41,6 +42,9 @@ def __init__(self, model_path, device=None, num_workers=1):
self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states)
self.transform = get_transforms(args.input_size, augment=False)
self.num_workers = num_workers
# MPS-specific optimizations
self.is_mps = str(device).startswith('mps')
self.optimal_batch_size = 32 if self.is_mps else 4

def _get_args(self, args_states=None):
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -90,7 +94,13 @@ def _get_model(self, args, tokenizer, device, states):
decoder.eval()
return encoder, decoder

def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16):
def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=None):
"""
Optimized version of predict_images with better MPS performance
"""
if batch_size is None:
batch_size = self.optimal_batch_size

device = self.device
predictions = []
self.decoder.compute_confidence = return_confidence
Expand All @@ -104,7 +114,9 @@ def predict_images(self, input_images: List, return_atoms_bonds=False, return_co
batch_predictions = self.decoder.decode(features, hiddens)
predictions += batch_predictions

smiles = [pred['chartok_coords']['smiles'] for pred in predictions]
return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds)

def convert_graph_to_output(self, predictions, input_images, return_confidence=False, return_atoms_bonds=False):
node_coords = [pred['chartok_coords']['coords'] for pred in predictions]
node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions]
edges = [pred['edges'] for pred in predictions]
Expand Down
15 changes: 8 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
torch
torch>=2.5.0,<=2.7.1
torchvision
numpy>=1.19.5,<2.0
pandas>=1.2.4
pandas>=2.2.3
matplotlib>=3.5.3
opencv-python==4.5.5.64
transformers>=4.5.1
opencv-python>=4.10.0.84
huggingface-hub>=0.11.0
tensorboardX
SmilesPE==0.0.3
OpenNMT-py==2.2.0
rdkit>=2022.3.3
albumentations @ git+https://github.com/albumentations-team/albumentations@37e714fd2e326f6f88778e425f98c2de8c8d5372
timm @ git+https://github.com/rwightman/pytorch-image-models.git@54a6cca27a9a3e092a07457f5d56709da56e3cf5
rdkit~=2025.3.6
albumentations==1.1.0
timm>=0.4.12,<=0.5.4
# Only this version of transformers are compatible with timm
transformers>=4.47.0,<=4.52.4
22 changes: 8 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from distutils.core import setup
from pathlib import Path

def read_requirements():
"""Read the requirements.txt file and return a list of dependencies."""
with open("requirements.txt", "r", encoding="utf-8") as fh:
return fh.read().splitlines()

setup(name='MolScribe',
version='1.1.1',
version='1.2.1',
description='MolScribe',
author='Yujie Qian',
author_email='yujieq@csail.mit.edu',
Expand All @@ -12,15 +16,5 @@
package_data={'molscribe': ['vocab/*']},
python_requires='>=3.7',
setup_requires=['numpy'],
install_requires=[
"numpy>=1.19.5,<2.0",
"torch>=1.11.0",
"pandas",
"matplotlib",
"opencv-python>=4.5.5.64",
"SmilesPE==0.0.3",
"OpenNMT-py==2.2.0",
"rdkit>=2022.3.3",
"albumentations==1.1.0",
"timm==0.4.12"
])
install_requires=read_requirements(),
)