| """
|
| Copyright (c) All Rights Reserved
|
| by bowen
|
| """
|
|
|
| import json
|
| import math
|
| import os
|
| import sys
|
| import pathlib
|
| from typing import Iterable, List
|
| import random
|
| import itertools
|
|
|
| import numpy as np
|
| import pandas as pd
|
| import tqdm
|
| import torch
|
| import torch.amp
|
| from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from skimage.measure import label
|
|
|
| import rdkit
|
| from rdkit import Chem
|
| from rdkit.Chem import Draw, AllChem
|
| from rdkit.Chem import rdchem, RWMol, CombineMols
|
| from rdkit import Chem
|
| from rdkit.Chem import rdFMCS
|
| import copy
|
| from paddleocr import PaddleOCR
|
| import re
|
| from rdkit import DataStructs
|
| import matplotlib.pyplot as plt
|
| from matplotlib.patches import Rectangle, Circle
|
| from scipy.spatial import cKDTree, KDTree
|
| from rdkit.Geometry import Point3D
|
| import multiprocessing
|
|
|
|
|
|
|
| def select_longest_smiles(smiles):
|
|
|
| components = smiles.split('.')
|
|
|
| longest_component = max(components, key=len)
|
| return longest_component
|
|
|
| def MCS_mol(mcs):
|
|
|
| mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
|
| AllChem.Compute2DCoords(mcs_mol)
|
| return mcs_mol
|
|
|
| def g_atompair_matches(pair,mcs):
|
| mcs_mol = MCS_mol(mcs)
|
| matches0 = pair[0].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False)
|
| matches1 = pair[1].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False)
|
| if len(matches0) != len(matches1):
|
| matches0=list(matches0)
|
| matches1=list(matches1)
|
| print( " g_atompair_matches noted: matcher not equal !!")
|
| if len(matches0)>len(matches1) and len(matches1) !=0:
|
| for i in range(0,len(matches0)):
|
| if i < len(matches1):
|
| pass
|
| else:
|
| ii=i % len(matches1)
|
| matches1.append(matches1[ii])
|
| else:
|
| for i in range(0,len(matches1)):
|
| if i < len(matches0) and len(matches0):
|
| pass
|
| else:
|
| ii=i % len(matches0)
|
| matches0.append(matches0[ii])
|
|
|
| if len(matches0) != len(matches1):
|
| atommaping_pairs=[[]]
|
| else:atommaping_pairs=[list(zip(matches0[i],matches1[i])) for i in range(0,len(matches0))]
|
| return atommaping_pairs
|
|
|
|
|
| class CustomError(Exception):
|
| """A custom exception for specific errors."""
|
| pass
|
|
|
| bond_dirs = {'NONE': Chem.rdchem.BondDir.NONE,
|
| 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT,
|
| 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE,
|
| 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH,
|
| 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,}
|
|
|
| BONDTYPE = {'SINGLE': Chem.rdchem.BondType.SINGLE,
|
| 'DOUBLE': Chem.rdchem.BondType.DOUBLE,
|
| 'TRIPLE': Chem.rdchem.BondType.TRIPLE,
|
| 'AROMATIC': Chem.rdchem.BondType.AROMATIC}
|
| BOND_DIRS = {'NONE': Chem.rdchem.BondDir.NONE,
|
| 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT,
|
| 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE,
|
| 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH,
|
| 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,}
|
| BONDDIRECT=['ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']
|
|
|
|
|
| BONDTYPE2ORD={
|
| 'wdge':1,
|
| 'dash':1,
|
| Chem.rdchem.BondType.SINGLE: 1,
|
| Chem.rdchem.BondType.DOUBLE: 2,
|
| Chem.rdchem.BondType.TRIPLE: 3,
|
| Chem.rdchem.BondType.AROMATIC: 1.5,
|
| }
|
|
|
| BONDTYPE={'SINGLE': Chem.BondType.SINGLE,
|
| 'DOUBLE': Chem.BondType.DOUBLE,
|
| 'TRIPLE': Chem.BondType.TRIPLE,
|
| 'AROMATIC': Chem.BondType.AROMATIC}
|
|
|
| VALENCES = {
|
| "H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1],
|
| "Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2],
|
| "Br": [1], "I": [1], "*":[3,4,5,6],
|
| }
|
|
|
| ELEMENTS = [
|
| "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
|
| "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
|
| "Sc", "Ti", "Ru", "Rh","Rn","Rf", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
|
| "Ga", "Ge", "As", "Se", "Br", "Kr", "Sr", "Zr",
|
| "Nb", "Mo", "Tc", "Pd", "Ag", "Cd", "In", "Sn",
|
| "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
|
| "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
|
| "Lu", "Hf", "Ta", "W", "Os", "Ir", "Pt", "Au", "Hg",
|
| "Tl", "Pb", "Bi", "Po", "At", "Fr", "Ac", "Th",
|
| "Pa", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm",
|
| "Md", "No", "Lr", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
|
| "Cn", "Nh", "Fl", "Mc", "Lv", "Og"
|
| ]
|
|
|
|
|
|
|
| RGROUP_SYMBOLS = ['R',"R'" 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12',
|
| 'Ra', 'Rb', 'Rc', 'Rd','Re','Rg', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar',
|
| "V", "Y","U",'M', 'G','L',
|
| 'Nr','Tt','Uu','Vv','Ww',
|
| 'D',
|
| ]
|
|
|
| COLORS = {
|
| u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
|
| u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
|
| }
|
|
|
| class Substitution(object):
|
| '''Define common substitutions for chemical shorthand'''
|
| def __init__(self, abbrvs, smarts, smiles, probability):
|
| assert type(abbrvs) is list
|
| self.abbrvs = abbrvs
|
| self.smarts = smarts
|
| self.smiles = smiles
|
| self.probability = probability
|
|
|
| SUBSTITUTIONS: List[Substitution] = [
|
|
|
|
|
| Substitution(['CH2CH2NSO2CH3'], '[CH2][CH]', '[CH2]CNS(=O)(C)=O', 0.5),
|
| Substitution(['NHNHCOCF3'], 'NHNHCOCF3', '[NH]NC(=O)C(F)(F)(F)', 0.5),
|
| Substitution(['CO2CysPr'], 'CO2CysPr', '[C](=O)ON[C@H](C(CCC)=O)CS', 0.5),
|
| Substitution(['OCH2CHOHCH2'], 'OCH2CHOHCH2', '[O]CC(O)C', 0.5),
|
| Substitution(['OCH2CHOHCH2OH'], 'OCH2CHOHCH2', '[O]CC(O)CO', 0.5),
|
|
|
| Substitution(['SO2(CH2)3SO2NHCH2CHCH2OH'], 'OCH2CHOHCH2', '[S](=O)(=O)CCCS(=O)(=O)NC[C]CO', 0.5),
|
|
|
|
|
|
|
|
|
| Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5),
|
|
|
| Substitution(['CO2Et', 'COOEt'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5),
|
|
|
| Substitution(['OAc','AcO'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7),
|
| Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7),
|
| Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1),
|
|
|
| Substitution(['OBz','BzO'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7),
|
| Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2),
|
|
|
| Substitution(['COOBn','BnO2C'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)OCc1ccccc1", 0.7),
|
| Substitution(['OBn','BnO'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7),
|
| Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2),
|
| Substitution(['NHBn'], '[NH]Cc1ccccc1', "[NH]Cc1ccccc1", 0.2),
|
| Substitution(['NBn2'], '[NH]Cc1ccccc1', "[N](Cc1ccccc1)Cc1ccccc1", 0.2),
|
|
|
| Substitution(['NHBoc','BocHN',"BOCHN"], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH]C(=O)OC(C)(C)C", 0.6),
|
| Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
|
| Substitution(['Boc','BOc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2),
|
|
|
| Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2),
|
| Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4),
|
| Substitution(['NHCbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[NH]C(=O)OCc1ccccc1", 0.4),
|
| Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3),
|
| Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3',
|
| "[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6),
|
| Substitution(['FmocHN','FmOcHN', 'NHFmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3',
|
| "[NH]C(=O)OCC1c(cccc2)c2c3c1cccc3", 0.6),
|
| Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5),
|
| Substitution(['OMs','MsO'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7),
|
| Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2),
|
| Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5),
|
|
|
| Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2),
|
| Substitution(['PMBN'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[N]Cc1ccc(OC)cc1", 0.2),
|
| Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1),
|
|
|
| Substitution(['SEM','MES'], '[CH2;D2][O][CH2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]OCC[Si](C)(C)C", 0.2),
|
|
|
| Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2),
|
| Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5),
|
| Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2),
|
| Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7),
|
| Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2),
|
| Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3),
|
| Substitution(['TFAH2N'], 'C(=O)C(F)(F)F', "[NH]C(=O)C(F)(F)F", 0.3),
|
| Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5),
|
| Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6),
|
| Substitution(['TsO','OTs'], '[O]S(C1=CC=C(C=C1)C)(=O)=O', "[O]S(C1=CC=C(C=C1)C)(=O)=O", 0.6),
|
|
|
| Substitution(['COCH3'], '[OH0;D2][CH3;D1]', "[C](=O)C", 0.3),
|
|
|
| Substitution(['OMe', 'MeO','H;CO', 'CH3O','OCH3', 'H3CO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3),
|
| Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3),
|
| Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[N]C", 0.3),
|
| Substitution(['NMe2', 'Me2N'], '[N;X3](C)[CH3;D1]', "[N](C)C", 0.3),
|
|
|
| Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1),
|
| Substitution(['OEt', 'EtO','C2H5O','OC2H5'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5),
|
| Substitution(['MeOH2C','CH2OMe'], '[CH2;D2]O[CH3]', "[CH2]OC", 0.5),
|
| Substitution(['Et', 'CH2CH3','CH3CH2'], '[CH2;D2][CH3]', "[CH2]C", 0.3),
|
|
|
|
|
| Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3),
|
| Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3),
|
|
|
|
|
|
|
| Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2),
|
| Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2),
|
| Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2),
|
| Substitution(['OtBu','tBuO'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6),
|
| Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3),
|
|
|
|
|
| Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5),
|
| Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5),
|
| Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5),
|
| Substitution(['OCCl3', 'Cl3CO'], '[OH0;X2][CH0;D4](Cl)(Cl)Cl', "[O]C(Cl)(Cl)Cl", 0.5),
|
| Substitution(['SCF3', 'F3CS'], '[SH0;X2][CH0;D4](F)(F)F', "[S]C(F)(F)F", 0.5),
|
| Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5),
|
| Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5),
|
| Substitution(['CO2NH4','COONH4','H4NOOC','H4NO2C'], 'C(=O)[OH]', "[C](=O)ON", 0.5),
|
| Substitution([ 'COO-','CO2-'], 'C(=O)[OH]', "[C](=O)[O-]", 0.5),
|
|
|
| Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5),
|
|
|
|
|
| Substitution(['N3'], '[N]=[N+]=[N-]', "[N]=[N+]=[N-]", 0.4),
|
|
|
| Substitution(['N2+Cl-','Cl-N2+'], '[N+]#[N].[Cl-]', "[N+]#[N].[Cl-]", 0.4),
|
| Substitution(['N2'], '[N]=[N-]', "[N]=[N-]", 0.4),
|
| Substitution(['N2H'], '[N]=[N-]', "[N]=[NH]", 0.4),
|
| Substitution(['NO','N=O','O=N','ON'], '[N]=[O]', "[N]=O", 0.4),
|
| Substitution(['NCH3'], '[N]C', "[NH]C", 0.4),
|
| Substitution(['NOMe'], '[N]OC', "[N]OC", 0.4),
|
| Substitution(['OCH2'], '[O]C', "[O]C", 0.4),
|
| Substitution(['C=O','O=C'], '[C]=[O]', "[C]=O", 0.4),
|
| Substitution(['NPh','PhN'], 'NC1=CC=CC=C1', "[N]C1=CC=CC=C1", 0.4),
|
| Substitution(['NHPh','PhNH','PhHN'], 'NC1=CC=CC=C1', "[NH]C1=CC=CC=C1", 0.4),
|
| Substitution(['TMSO','OSMT'], 'O[Si](C)(C)C', "[O][Si](C)(C)C", 0.5),
|
| Substitution(['SPh','PhS'], 'SC1=CC=CC=C1', "[S]C1=CC=CC=C1", 0.4),
|
| Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4),
|
| Substitution(['SO3NH2','SO3NH4','H4NO3S'], 'S(=O)(=O)[OH]', "[S](=O)(=O)ON", 0.4),
|
| Substitution(['SO3'], 'S(=O)(=O)[OH]', "[S](=O)(=O)[O-]", 0.4),
|
| Substitution(['SO2CF3'], '[S](=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.5),
|
| Substitution(['SO2Cl'], '[S](=O)(=O)Cl', "[S](=O)(=O)Cl", 0.5),
|
| Substitution(['SO2F'], '[S](=O)(=O)F', "[S](=O)(=O)F", 0.5),
|
| Substitution(['SO2'], '[S](=O)(=O)', "[S](=O)(=O)", 0.5),
|
| Substitution(['SO2NH'], '[S](=O)(=O)[N]', "[S](=O)(=O)[N]", 0.5),
|
| Substitution(['SO2NH2'], '[S](=O)(=O)[NH2]', "[S](=O)(=O)[NH2]", 0.5),
|
| Substitution(['SO2Me','SO2CH3'], '[S](=O)(=O)C', "[S](=O)(=O)C", 0.5),
|
| Substitution(['NHO2S'], '[S](=O)(=O)[N]', "[N][S](=O)(=O)", 0.5),
|
| Substitution(['OSO2Me'], '[O]S(=O)(=O)C', "[O]S(=O)(=O)C", 0.5),
|
| Substitution(['NHSO2Me'], '[NH]S(=O)(=O)C', "[NH]S(=O)(=O)C", 0.5),
|
| Substitution(['SOCH3','SOMe'], '[S](=O)(=O)', "[S](=O)C", 0.5),
|
|
|
| Substitution(['P+Ph3Br-'], '[P+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', "[P+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3", 0.5),
|
| Substitution(['N+Ph3Br-'], '[N+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', "[N+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3", 0.5),
|
| Substitution(['PPh2'], "[P](C1=CC=CC=C1)C2=CC=CC=C2", "[P](C1=CC=CC=C1)C2=CC=CC=C2", 0.5),
|
|
|
| Substitution(['CO2Me', 'COOMe'], 'C(=O)[OH0;D2][CH3]', "[C](=O)OC", 0.5),
|
| Substitution(['ONa', 'NaO'], '[O][Na]', "[O][Na]", 0.5),
|
| Substitution(['OTBDMS', 'TBDMSO'], "[O][Si](C)(C)C(C)(C)C", "[O][Si](C)(C)C(C)(C)C", 0.5),
|
| Substitution(['CONH2'], '[C](O)(N)', "[C](=O)[NH2]", 0.5),
|
| Substitution(['NHNH2'], '[NH2;D1]', "[NH]N", 0.1),
|
| Substitution(['CONH'], 'CONH', '[C](=O)N', 0.5),
|
| Substitution(['CH3CONH'], '[NH]C(=O)C', '[NH]C(=O)C', 0.5),
|
| Substitution(['NH3Cl'], '[NH]Cl', '[NH]Cl', 0.5),
|
|
|
| Substitution(['SAc','AcS'], '[S]C(C)=O', "[S]C(C)=O", 0.5),
|
| Substitution(['OAll'], '[O]CC=C', '[O]CC=C', 0.5),
|
|
|
| Substitution(['Tos','TOs'], '[Si](C)(C)C', '[S](=O)(=O)C(C=C1)=CC=C1C', 0.5),
|
| Substitution(['OTos','OTOs','soTO'], '[Si](C)(C)C', '[O]S(=O)(=O)C(C=C1)=CC=C1C', 0.5),
|
| Substitution(['TsN'], '[N]S(C1=CC=C(C=C1)C)(=O)=O', '[N]S(C1=CC=C(C=C1)C)(=O)=O', 0.5),
|
| Substitution(['Ts'], '[S](C1=CC=C(C=C1)C)(=O)=O', '[S](C1=CC=C(C=C1)C)(=O)=O', 0.5),
|
| Substitution(['COCF3'], '[C](=O)C(F)(F)(F)', '[C](=O)C(F)(F)(F)', 0.5),
|
| Substitution(['CF2', 'F2C'], '[C;D4](F)(F)', "[C](F)(F)", 0.5),
|
| Substitution(['PMB'], '[CH2]C1=CC=C(C=C1)OC', '[CH2]C1=CC=C(C=C1)OC', 0.5),
|
| Substitution(['NHCOtBu'], '[NH]C(C(C)(C)C)=O','[NH]C(C(C)(C)C)=O', 0.5),
|
| Substitution(['OCN'], '[N]=C=O', "[N]=C=O", 0.5),
|
| Substitution(['Me3Si'], '[Si](C)(C)(C)', "[Si](C)(C)(C)", 0.5),
|
| Substitution(['PhO','OPh'], '[O]C1=CC=CC=C1', "[O]C1=CC=CC=C1", 0.5),
|
| Substitution(['Allyl'], '[CH2]C=C', '[CH2]C=C', 0.5),
|
| Substitution(['C7H3'], '[C]#CC#CC#CC', '[C]#CC#CC#CC', 0.5),
|
| Substitution(['C5H11'], '[CH2]CCCC', '[CH2]CCCC', 0.5),
|
| Substitution(['R1R2N'], "[N]([*])[*]", "[N]([*])[*]", 0.5),
|
| Substitution(['CO2R'], '[C](=O)O*', '[C](=O)O*', 0.5),
|
| Substitution(['CCl3CH2O2C'], '[C](=O)OCC(Cl)(Cl)Cl', '[C](=O)OCC(Cl)(Cl)Cl', 0.5),
|
| Substitution(['NHOH'], '[NH]O', '[NH]O', 0.5),
|
| Substitution(['CO2'], '[C](=O)[O]', '[C](=O)[O]', 0.5),
|
| Substitution(['O2C'], '[C](=O)[O]', '[O][C](=O)', 0.5),
|
|
|
| Substitution(['PPh3'], '[P](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', '[P](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', 0.5),
|
| Substitution(['TfO'], '[C](=O)[O]', '[O]S(=O)(C(F)(F)F)=O', 0.5),
|
| Substitution(['OCH2Ph'], '[O]CC1=CC=CC=C1', '[O]CC1=CC=CC=C1', 0.5),
|
| Substitution(['OCH2CF3'], '[O]CC(F)(F)(F)', '[O]CC(F)(F)(F)', 0.5),
|
| Substitution(['COOCH2Ph'], '[C](=O)OCC1=CC=CC=C1', '[C](=O)OCC1=CC=CC=C1', 0.5),
|
| Substitution(['OCH2OC2H5'], '[C](=O)C(C)(C)C', '[O]COCC', 0.5),
|
|
|
| Substitution(['Trt'], '[C](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', '[C](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', 0.5),
|
| Substitution(['SF5'], '[S](F)(F)(F)(F)F', '[S](F)(F)(F)(F)F', 0.5),
|
|
|
|
|
|
|
|
|
|
|
| Substitution(['S*'], '[S]*', '[S]*', 0.5),
|
| Substitution(['N*, NH*'], '[NH]*', '[NH]*', 0.5),
|
| Substitution(['C*','CH2*'], '[C]*', '[CH2]*', 0.5),
|
| Substitution(['P*',"PH*"], '[P]*', '[PH]*', 0.5),
|
| Substitution(['O*'], '[O]*', '[O]*', 0.5),
|
|
|
| Substitution(['N(CH3)2'], '[N](C)(C)', "[N](C)(C)", 0.5),
|
| Substitution(['(C2H5)2N','Et2N'], '[N](C)(C)', "[N](CC)(CC)", 0.5),
|
| Substitution(['B(OH)2'], '[B](O)O', "[B](O)O", 0.5),
|
| Substitution(['CO2C(CH3)3'], '[C](=O)C(C)(C)C', '[C](=O)C(C)(C)C', 0.5),
|
| Substitution(['P(O)(OEt)2', 'P(OEt)2(O)'], "[P](OCC)(=O)CCO", "[P](OCC)(=O)OCC", 0.5),
|
| Substitution(['(CH2)16Me'], '[CH2]CCCCCCCCCCCCCCCC', "[CH2]CCCCCCCCCCCCCCCC", 0.3),
|
| Substitution(['(CH2)11Me'], '[CH2]CCCCCCCCCCC', "[CH2]CCCCCCCCCCC", 0.3),
|
| Substitution(['N(H)Et','Et(H)N'], '[NH]CC', '[NH]CC', 0.5),
|
| Substitution(['N(H)Me','Me(H)N'], '[NH]C', '[NH]C', 0.5),
|
|
|
|
|
|
|
| ]
|
| ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs}
|
|
|
|
|
| def extract_abbreviation_key(item):
|
| if isinstance(item, list):
|
| while isinstance(item, list):
|
| item = item[0]
|
| return item
|
| return item
|
|
|
|
|
| def clean_unpaired_brackets(text):
|
|
|
| result = []
|
| stack = []
|
| bracket_pairs = {')': '(', ']': '['}
|
| opening_brackets = {'(', '['}
|
|
|
| for char in text:
|
| if char in opening_brackets:
|
| stack.append(char)
|
| result.append(char)
|
| elif char in bracket_pairs:
|
| if stack and stack[-1] == bracket_pairs[char]:
|
| stack.pop()
|
| result.append(char)
|
| else:
|
|
|
| continue
|
| else:
|
| result.append(char)
|
| return ''.join(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def replace_c1(text):
|
|
|
| return re.sub(r'C1(?!\d)', 'Cl', text)
|
| def transform_formula(formula):
|
|
|
| match = re.match(r'C(\d+)(.*?)Hg(.*)', formula)
|
| if not match:
|
| return formula
|
|
|
| n = int(match.group(1))
|
| prefix = match.group(2)
|
| suffix = match.group(3)
|
| g_new = n * 2 + 1
|
| return f"C{n}{prefix}H{g_new}{suffix}"
|
| def Cg_transform_formula(formula):
|
|
|
| match = re.match(r'CgH(\d+)(.*?)', formula)
|
| if not match:
|
| return formula
|
|
|
| n = int(match.group(1))
|
| suffix = match.group(2)
|
| g_new = (n-1)// 2
|
| return f"C{g_new}H{n}{suffix}"
|
|
|
| def normalize_ocr_text(text, replacement_map):
|
| """Normalize OCR text using the predefined mapping rules"""
|
| if 'C1'in text:
|
| text=replace_c1(text)
|
| if 'Hg' in text:
|
| text= transform_formula(text)
|
| if 'Cg' in text:
|
| text= Cg_transform_formula(text)
|
| if 'Q' in text:
|
| pattern = r'Q([A-Z])(\w+)'
|
| replacement = r'O\1\2'
|
| text = re.sub(pattern, replacement, text)
|
| if text in ELEMENTS:
|
| return text
|
|
|
| if ' ' in text:
|
| text = text.replace(" ", "")
|
| if any(c in text for c in '0oO'):
|
|
|
|
|
| text = re.sub(r'(?<![CF,CH]\d)[oO](?=[a-zA-GI-Z]|$)', '0', text)
|
| if '00' in text: text = re.sub(r'00', 'OO', text)
|
|
|
|
|
|
|
| if text in ['R20']: return text
|
|
|
| text = re.sub(r'(?<=[a-zA-Z])0(?=([a-zA-Z]|$))', 'O', text)
|
| text = re.sub( r'^(0)|(?<=[a-zA-Z][?\d])0(?=[a-zA-Z0-9]*$|[a-zA-Z])', 'O', text)
|
|
|
|
|
|
|
|
|
| text=clean_unpaired_brackets(text)
|
| pattern_n1 = r'^NHR[0-9a-z]$'
|
|
|
|
|
| if text in ['OzN','O2N', 'O,N', 'NOz','NO2', 'NO,', '0;N','02N','N20']: text = 'NO2'
|
|
|
| elif text in ['CHzCH','CH,CH',]:text='CH3CH'
|
| elif text in ["NHCHzCOOH","NHCH2COOH",]:text='NHCH2COOH'
|
| elif text in ['CIOC','COCE','ClOC','COCI']:text='COCl'
|
|
|
| elif text in ['CHCOOCHs','CH2COOCH5']:text='CH2COOC2H5'
|
|
|
| elif text in ['(t-Bu)','t-Bu']:text='t-Bu'
|
|
|
|
|
| elif text in ['SiMe2','Me2Si']:text='SiMe2'
|
| elif text in ['ArzP(O)','Ar2P(O)']:text='Ar2P(O)'
|
| elif text in ['P(O)(0Et)2','P(O)(OEt)2']:text='P(O)(OEt)2'
|
| elif text in ['PhOzS','PhO2S']:text='PhO2S'
|
| elif text in ['CH3O','CHzO']:text='CH3O'
|
| elif text in ['NH.HCI','NH,.Hcl']:text='NH2.HCl'
|
|
|
|
|
| elif text in ['2','Z']:text='Z'
|
| elif text in ['(CH2)m','(CH2)q','(CH2)s']:text='CH2'
|
| elif text in ['Arl','Ari','Ar2','Ar1',]:text='Ar'
|
| elif text in [ '"0ls','"ols','S[0]a']:text='S[O]a'
|
| elif text in ['NHR%','NHR*']:text='NHR8'
|
| elif text in ['Vv','Vy']:text='Vv'
|
|
|
|
|
| elif text in ['N3','NY','Ny']:text='N3'
|
| elif text in ['C2H52N','N(CH,CH3)2','C;H52N','(C;H5)2N','N(C;Hs)2','N(C;H5)2','(CHzCH2)2N','N(CHCH3)2','(CH3CH2)2N','(C2H52N', '(CHzCH)2N','(C2H5)2N','Et2N']:text='(C2H5)2N'
|
| elif text in ['(CH3)2N','Me2NH','Me,N','Me2N']:text='Me2N'
|
| elif text in ['(C;H4O)H','(C2H4O)H']:text='(C2H4O)H'
|
| elif text in ['(C;H4O)4CH3','(C2H4O)4CH3' ]:text='(C2H4O)4CH3'
|
| elif text in ['(CH2)16Me' ]:text='(CH2)16Me'
|
| elif text in ['(CH2)11Me']:text='(CH2)11Me'
|
| elif text in ['CO2CH2Ph','COOCH2Ph','COOCH,Ph']:text='COOCH2Ph'
|
| elif text in ['CO2C(CH3)3','(CH3)3CO2C',]:text='CO2C(CH3)3'
|
| elif text in ['OCH2Ph','OCH,Ph','OCHAPH','OCH;Ph']:text='OCH2Ph'
|
| elif text in ['(CF2)8H','(CF2)gH','(CF2)sH','CF2sH', 'CF:H','CF2)sH','CF):H' ]: text = '(CF2)8H'
|
| elif text in ['NHSO,Bu','NHSO2Bu',]: text = 'NHSO2Bu'
|
| elif text in ['NHSO,CH3','NHSO2CH3','NHSO2Me']: text = 'NHSO2CH3'
|
| elif text in ['1231','1231','23T', 'l23I']: text = 'l23I'
|
|
|
|
|
|
|
| elif text in ['CF3','CFs', 'CF,', '13','CF 3','F;C', 'F:C', 'F sC', 'CF', 'CF;', 'CFa', 'FzC', 'CFz']: text = 'F3C'
|
| elif text in ['OCCl3','Cl3CO',]: text = 'OCCl3'
|
| elif text in ['CCl3','Cl3C',]: text = 'CCl3'
|
| elif text in ['F;CN', 'NCF;']: text = 'F3CN'
|
| elif text in ['NCH3','NHCH3', 'NCH;','CH3N','MeN','MeNH']: text = 'NCH3'
|
| elif text in ['NOMe']:text='NOMe'
|
| elif text in ['R,R,N']: text = 'R1R2N'
|
| elif text in ['HzC','HyC','CHy','CHE','H3C.','1;.C', '1;C', 'M e','Mé', 'CH 3', 'CH:', 'HsC', 'HaC', 'H3C', 'CH3', 'CHa', 'H;C', 'CH,', 'CHs', 'CH;']: text = 'Me'
|
|
|
| elif text in ['PhzBr']: text = 'Ph3Br'
|
| elif text in ['PPh3', 'PPha']: text = 'PPh3'
|
| elif text in ['Et', 'CH,CH3','Catls','Cafls','CH2CH3','H3CH2C','C:H5','HzCH2C','H3CH2C', 'C,H5', 'CzH5','C2H5','C2Hs']: text = 'CH2CH3'
|
| elif text in ['Ovle', 'HzCO','OCH', 'OCH:','H2CO', 'CH3O', 'CH,O', 'HsCO','OMe','AME', 'AMe','H3CO', 'MeO']: text = 'OMe'
|
| elif text in ['OCHa','HgCO', 'OCH','HaCO', 'OCH:','H2CO', 'CH3O', 'CH,O', 'OMe','AME', 'AMe', 'MeO']: text = 'OMe'
|
| elif text in ['SO2Cl', 'SOzCl']: text = 'SO2Cl'
|
| elif text in ['SO2F', 'SOzF']: text = 'SO2F'
|
| elif text in ['SONH', 'HNOS','SON', 'SO2NH']: text = 'SO2NH'
|
| elif text in ['HNO2S','NHO2S']: text = 'NHO2S'
|
| elif text in ['SO2Cl', 'SOzCl']: text = 'SO2Cl'
|
| elif text in ['SO2F', 'SOzF']: text = 'SO2F'
|
| elif text in ['SONH', 'HNOS','SON', 'SO2NH']: text = 'SO2NH'
|
| elif text in ['SO2NH2', 'SO,NH', 'SO:NH2', 'SONH2']: text = 'SO2NH2'
|
| elif text in ['SOzCF3', 'SO2CF3', 'CF3SO2']: text = 'SO2CF3'
|
| elif text in ['SOz','O2S', '$02', 'S02','SO,', '62','O:S','SO2']: text = 'SO2'
|
| elif text in ['H3CO2S','SO2CH3']: text='SO2CH3'
|
| elif text in ['SO3H','SOsH','SOaH', 'HO3S','SOzH','HOzS']: text = 'SO3H'
|
| elif text in ['MeO2SO','OSO2CH3','OSO2Me']:text='OSO2Me'
|
| elif text in ['MeO2SHN','NHSO2Me']:text='NHSO2Me'
|
|
|
|
|
| elif text in ['PIME', 'PMB']: text = 'PMB'
|
| elif text in ['1-BU', '-BU', '-Bu', 't-BU','t-Bu']: text = 't-Bu'
|
| elif text in ['NTS', 'NTs', 'TsN']: text = 'TsN'
|
| elif text in ['TsO', 'OTs']: text = 'OTs'
|
| elif text in ['Nz* Cl', "N2+Cl-"]: text = 'N2+Cl-'
|
| elif text in ['NH3Cl', 'NHzCl','NH;Cl']: text = 'NH3Cl'
|
| elif text in ['B(OH)2']: text = 'B(OH)2'
|
| elif text in ['NHAC', 'NHAc']: text = 'NHAc'
|
| elif text in ['1CO', 'NCO', 'OCN', 'OON']: text = 'OCN'
|
| elif text in ['COCFs','COCF3', 'COCF s']: text = 'COCF3'
|
| elif text in ['OCF3', 'OCF 3','OCE', 'OCE:','OCEE', 'F3CO', 'OCF', 'OCF:']: text = 'OCF3'
|
| elif text in ['SCF3', 'SCE', 'SCEE', 'F3CS', 'SCF', 'SCF:']: text = 'SCF3'
|
| elif text in ['HzCS', 'SCH3', 'SMe','MeS','H3SC' ]: text = 'SMe'
|
| elif text in ['CHzCHzO', 'CH3CH2O','H5C2O','OC2H5']: text = 'OEt'
|
| elif text in ['CO,Et','COzEt', 'CO2Et','H3CH2COOC','CO2C2H5']:text = 'CO2Et'
|
| elif text in ['OTBS', 'TBSO', 'OTBDMS']: text = 'OTBDMS'
|
| elif text in ['PhO', 'Pho']: text = 'PhO'
|
| elif text in ['CI', 'C1']: text = 'Cl'
|
| elif text in ['P h', 'Ph']: text = 'Ph'
|
| elif text in ['FAHN', 'TFAH,N','TFAH2N',]: text = 'TFAH2N'
|
| elif text in ['MeaSi', 'Me3Si']: text = 'Me3Si'
|
|
|
| elif text in ['PHzC','PH;C', 'PH3C']: text = 'PH3C'
|
| elif text in ['COOH','OOOH','1OOC', 'HOOO','HOOC', 'DOOH', 'CO:H','HO,C','CO,H','CO2H']: text = 'CO2H'
|
|
|
| elif text in ['CO2R','RO2C', 'RO,C','CO2*', "COzR'"]: text = 'CO2R'
|
| elif text in ['CO2', 'COO','OOC', "COz"]: text = 'CO2'
|
|
|
| elif text in ['O2C', '02C']: text = 'O2C'
|
| elif text in ['CaH;', 'CHS', 'C2H5']: text = 'C2H5'
|
| elif text in ['NHBoc','NHBOc', 'BocHN','BOcHN', "BOCHN"]: text = "NHBoc"
|
| elif text in ['C7H', 'C7H3']: text = 'C7H3'
|
| elif text in ['CsH11', 'C5H11']: text = 'C5H11'
|
| elif text in ['CC3CH2O2C', 'CCl3CH2O2C']: text = 'CCl3CH2O2C'
|
| elif text in ['CH2OMe','MeOH,C','CH,0Me', 'CH,OMe','MeOH2C']: text = 'CH2OMe'
|
| elif text in ['R', "R'"]: text = '*'
|
| elif text in ['U', 'U.']: text = 'U'
|
| elif text in ['RO']: text = 'O*'
|
| elif text in ['OAc', 'OAC']: text = 'OAc'
|
| elif text in ['Rg', 'R9']: text = 'R9'
|
| elif text in ['OQ', '00', '0Q','OCH3']: text = 'OMe'
|
|
|
| elif text in ['NH', 'HN', "NH2", 'H2N', 'H,N']: text = 'N'
|
| elif text in ['OH', 'HO', 'OH2', '0']: text = 'O'
|
| elif text in ['N(H)Et','Et(H)N']: text = 'N(H)Et'
|
| elif text in ['N(H)Me','Me(H)N']: text = 'N(H)Me'
|
| elif text in ['HNOC','CONH']: text='CONH'
|
| elif text in ['HNOCCH3','CH,CONH','CH3CONH']: text='CH3CONH'
|
| elif text in ['PPh2','Ph,P','Ph2P']: text='PPh2'
|
| elif text in ['SF5','F5S']: text = 'SF5'
|
| elif text in ['OCH2CF3','F3CH2CO']: text = 'OCH2CF3'
|
| elif text in ['NHCbz','CbzHN']: text = 'NHCbz'
|
| elif text in ['NHNH2','H2NHN']: text = 'NHNH2'
|
| elif text in ['CHzCH22N','N2(CH2CH3)','(CH3CH2)2N']: text = '(CH3CH2)2N'
|
|
|
| elif text in ['CHCHCH2CH-3','CH2CH2CH2CH']: text = 'CH2CH2CH2CH'
|
| elif text in ['HCH2CH2CH2C','HCH2CH2CH2C' ]: text = 'HCH2CH2CH2C'
|
|
|
| elif text in ['(HzC)2HC','(H3C)2HC']: text = '(H3C)2HC'
|
| elif text in ['13CO2SHNH2CH2C','H3CO2SHNH2CH2C','CH2CH2NSO2CH3']: text = 'CH2CH2NSO2CH3'
|
| elif text in ['CgH19','C9H19']: text = 'C9H19'
|
| elif text in ['(CF2):H','(CF2)8H']: text = '(CF2)8H'
|
|
|
| elif text in ['COOCH3','HzCO2C', 'CO,Me','H3CO2C','CO2CH3','MeOOC','CO2Me','COzMe','MeO2C','MeO,C']: text = 'CO2Me'
|
| elif text in ['(CHCHO)','CH2CH2O']: text = 'CH2CH2O'
|
| elif text in ['CO,CysPr','CO2CysPr']: text = 'CO2CysPr'
|
| elif text in ['CH2CH2C(O)OCHCH3','CH;CH2C(O)OCHCH3']:text='CH2CH2C(O)OCH2CH3'
|
| elif text in ['H4NOzS','H4NO3S']: text = 'H4NO3S'
|
| elif text in ['C1OH21','C1oH21','CloH21', 'C10H21']: text = 'C10H21'
|
|
|
| elif text in ['']: text = 'CF2'
|
|
|
| elif text in replacement_map:
|
| text = replacement_map[text]
|
|
|
|
|
|
|
|
|
| return text
|
|
|
|
|
|
|
|
|
|
|
| def C_H_affixExpand(group):
|
| """
|
| Expands CnHm or HmCn chemical group notation into SMILES format.
|
| Supports formats like C6H11, NHC6H11, H11C6, H11C6HN where H = 2C - 1.
|
| Returns SMILES string or False if invalid.
|
| """
|
|
|
| p_cn_hm = r'^C(\d+)H(\d+)$'
|
| p_hm_cn = r'^H(\d+)C(\d+)$'
|
| p_prefix = r'^([A-Za-z]+)(C(\d+)H(\d+))$'
|
| p_suffix = r'^(C(\d+)H(\d+))([A-Za-z]+)$'
|
| p_hm_cn_prefix = r'^([A-Za-z]+)(H(\d+)C(\d+))$'
|
| p_hm_cn_suffix = r'^(H(\d+)C(\d+))([A-Za-z]+)$'
|
|
|
|
|
| patterns = [
|
|
|
| (p_prefix, p_cn_hm, 1, 2, 3, 4, 'prefix'),
|
| (p_suffix, p_cn_hm, 4, 1, 2, 3, 'suffix'),
|
| (p_hm_cn_prefix, p_hm_cn,1, 2, 4, 3, 'prefix'),
|
| (p_hm_cn_suffix, p_hm_cn, 4, 1, 3, 2, 'suffix')
|
| ]
|
|
|
|
|
| ABBREVIATIONS2 = {
|
| 'NH': '[NH]', 'HNOC': '[C](=O)[NH]',
|
| 'CONH': '[C](=O)[NH]', 'HN': '[NH]', 'HNO': '[NH]O', 'NO': '[N]=O',
|
| 'COO':'[C](=O)O',
|
| 'CO2':'[C](=O)O',
|
|
|
| }
|
|
|
| def validate_and_expand(c_count, h_count, prefix=None, suffix=None):
|
| """Helper to validate CnHm/HmCn and generate SMILES."""
|
| if h_count != 2 * c_count + 1:
|
| return False
|
|
|
| smiles = '[CH2]C' if c_count == 2 else '[CH2]'+'C' * int(c_count - 1)
|
| print([c_count, h_count, prefix, suffix],'[c_count, h_count, prefix, suffix]')
|
| if prefix:
|
| prefix = ABBREVIATIONS2.get(prefix, prefix)
|
| smiles = prefix + smiles
|
| if suffix:
|
| suffix = ABBREVIATIONS2.get(suffix, suffix)
|
| smiles = suffix + smiles
|
| return smiles
|
|
|
|
|
| match_cn_hm = re.match(p_cn_hm, group)
|
| if match_cn_hm:
|
| c_count, h_count = int(match_cn_hm.group(1)), int(match_cn_hm.group(2))
|
| return validate_and_expand(c_count, h_count)
|
|
|
| match_hm_cn = re.match(p_hm_cn, group)
|
| if match_hm_cn:
|
| h_count, c_count = int(match_hm_cn.group(1)), int(match_hm_cn.group(2))
|
| return validate_and_expand(c_count, h_count)
|
|
|
| for pattern, sub_pattern,aff_idx, group_idx, c_idx, h_idx, aff_type in patterns:
|
| match = re.match(pattern, group)
|
| if match:
|
| cn_hm = match.group(group_idx)
|
| affix = match.group(aff_idx)
|
| c_count = int(match.group(c_idx))
|
| h_count = int(match.group(h_idx))
|
| print(cn_hm,affix,c_count,h_count,'cn_hm,affix,c_count,h_count')
|
| return validate_and_expand(
|
| c_count, h_count,
|
| prefix=affix if aff_type == 'prefix' else None,
|
| suffix=affix if aff_type == 'suffix' else None
|
| )
|
|
|
| return False
|
|
|
| def N_C_H_expand(group):
|
|
|
| match = re.match(r'NHC(\d+)H(\d+)', group)
|
| match1 = re.match(r'NC(\d+)H(\d+)', group)
|
| if not match and not match1:
|
| return False
|
|
|
| if match:
|
| C_count = int(match.group(1))
|
| H_count = int(match.group(2))
|
| if match1:
|
| C_count = int(match1.group(1))
|
| H_count = int(match1.group(2))
|
| if H_count== C_count*2 +1 :
|
|
|
| smiles = '[N]' + 'C' * C_count
|
| return smiles
|
|
|
| def C_F_expand(group):
|
|
|
| match_cnfm = re.match(r'C(\d+)F(\d+)', group)
|
| match_cnfm_2 = re.match(r'F(\d+)C(\d+)', group)
|
| if match_cnfm:
|
| C_count = int(match_cnfm.group(1))
|
| F_count = int(match_cnfm.group(2))
|
|
|
| if F_count != 2 * C_count + 1:
|
| return False
|
| else:
|
|
|
|
|
| match_cfx = re.match(r'(CF2)*CF3$', group)
|
| if not match_cfx:
|
| return False
|
|
|
| cf2_count = group.count('CF2')
|
| C_count = cf2_count + 1
|
| F_count = cf2_count * 2 + 3
|
|
|
| if F_count != 2 * C_count + 1:
|
| return False
|
|
|
| smiles = []
|
| for i in range(C_count):
|
| if i < C_count - 1:
|
|
|
| if len(smiles)==0:
|
| smiles.append('[C](F)(F)')
|
| else:
|
| smiles.append('C(F)(F)')
|
| else:
|
|
|
| smiles.append('C(F)(F)(F)')
|
|
|
|
|
| return ''.join(smiles)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import re
|
|
|
| def C_H_expand(group):
|
| """
|
| Expands CnHm or HmCn chemical group notation into SMILES format.
|
| Supports formats like C18H37HNOC, CONHC3H7, C3H7, H23C11, and (H7C3)2N.
|
| Returns SMILES string or False if invalid.
|
| """
|
|
|
| p_cn_hm = r'^C(\d+)H(\d+)$'
|
| p_hm_cn = r'^H(\d+)C(\d+)$'
|
| p_prefix = r'^([A-Za-z]+)(C\d+H\d+)$'
|
| p_suffix = r'^(C\d+H\d+)([A-Za-z]+)$'
|
| p_hm_cn_prefix = r'^([A-Za-z]+)(H\d+C\d+)$'
|
| p_hm_cn_suffix = r'^(H\d+C\d+)([A-Za-z]+)$'
|
|
|
|
|
| p_bracketed_group = r'^\((H(\d+)C(\d+))\)(\d+)([A-Za-z]+)$'
|
| p_reverse_bracketed_group = r'^([A-Za-z]+)\((C(\d+)H(\d+))\)(\d+)$'
|
|
|
|
|
| elements = ['S', 'N', 'P', 'C', 'O']
|
| keys = [f"{e}{suffix}" for e in elements for suffix in ['R"', "R'", "R", "*"]]
|
| replacement_map = {key: f'{key[0]}*' for key in keys}
|
|
|
| def validate_and_expand(c_count, h_count, prefix=None, suffix=None):
|
| """Helper to validate CnHm/HmCn and generate SMILES."""
|
| if h_count != 2 * c_count + 1:
|
| return False
|
| smiles = '[CH2]' + 'C' * (c_count - 1)
|
| if prefix:
|
| prefix = normalize_ocr_text(prefix, replacement_map)
|
| smiles = ABBREVIATIONS.get(prefix, prefix) + 'C' * c_count
|
| elif suffix:
|
| suffix = normalize_ocr_text(suffix, replacement_map)
|
| smiles = ABBREVIATIONS.get(suffix, suffix) + 'C' * c_count
|
| return smiles
|
|
|
|
|
| match_cn_hm = re.match(p_cn_hm, group)
|
| if match_cn_hm:
|
| c_count, h_count = int(match_cn_hm.group(1)), int(match_cn_hm.group(2))
|
| return validate_and_expand(c_count, h_count)
|
|
|
| match_hm_cn = re.match(p_hm_cn, group)
|
| if match_hm_cn:
|
| h_count, c_count = int(match_hm_cn.group(1)), int(match_hm_cn.group(2))
|
| return validate_and_expand(c_count, h_count)
|
|
|
|
|
| patterns = [
|
| (p_prefix, p_cn_hm, 1, 2, 'suffix'),
|
| (p_suffix, p_cn_hm, 2, 1, 'prefix'),
|
| (p_hm_cn_prefix, p_hm_cn, 1, 2, 'suffix'),
|
| (p_hm_cn_suffix, p_hm_cn, 2, 1, 'prefix')
|
| ]
|
|
|
| for pattern, sub_pattern, c_idx, h_idx, aff_type in patterns:
|
| match = re.match(pattern, group)
|
| if match:
|
| cn_hm = match.group(1 if aff_type == 'suffix' else 2)
|
| affix = match.group(2 if aff_type == 'suffix' else 1)
|
| sub_match = re.match(sub_pattern, cn_hm)
|
| if sub_match:
|
| c_count = int(sub_match.group(c_idx))
|
| h_count = int(sub_match.group(h_idx))
|
| return validate_and_expand(
|
| c_count, h_count,
|
| prefix=affix if aff_type == 'prefix' else None,
|
| suffix=affix if aff_type == 'suffix' else None
|
| )
|
|
|
| base_smiles=False
|
|
|
| match_bracketed_group = re.match(p_bracketed_group, group)
|
| if match_bracketed_group:
|
| h_count, c_count = int(match_bracketed_group.group(2)), int(match_bracketed_group.group(3))
|
| prefix = match_bracketed_group.group(5)
|
| prefix_n = int(match_bracketed_group.group(4))
|
| print("h_count, c_count,prefix",[h_count, c_count,prefix])
|
| unit_smi='C'*c_count
|
| BACKET_SM=f"({unit_smi})"* prefix_n
|
| base_smiles=f"[{prefix}]{BACKET_SM}"
|
|
|
|
|
| match_reverse_bracketed_group = re.match(p_reverse_bracketed_group, group)
|
| if match_reverse_bracketed_group:
|
| c_count, h_count = int(match_reverse_bracketed_group.group(3)), int(match_reverse_bracketed_group.group(4))
|
| prefix = match_reverse_bracketed_group.group(1)
|
| prefix_n = int(match_reverse_bracketed_group.group(5))
|
| print("h_count, c_count,prefix",[h_count, c_count,prefix])
|
| unit_smi='C'*c_count
|
| BACKET_SM=f"({unit_smi})"* prefix_n
|
| base_smiles=f"[{prefix}]{BACKET_SM}"
|
|
|
| if base_smiles:
|
|
|
| return f"{base_smiles}"
|
|
|
|
|
| return False
|
|
|
|
|
| def C_H_expand2(group):
|
| """
|
| Expands CnHm or HmCn chemical group notation into SMILES format.
|
| Supports formats like C6H11, NHC6H11, H11C6, H11C6HN where H = 2C - 1.
|
| Returns SMILES string or False if invalid.
|
| """
|
|
|
| p_cn_hm = r'^C(\d+)H(\d+)$'
|
| p_hm_cn = r'^H(\d+)C(\d+)$'
|
| p_prefix = r'^([A-Za-z]+)(C\d+H\d+)$'
|
| p_suffix = r'^(C\d+H\d+)([A-Za-z]+)$'
|
| p_hm_cn_prefix = r'^([A-Za-z]+)(H\d+C\d+)$'
|
| p_hm_cn_suffix = r'^(H\d+C\d+)([A-Za-z]+)$'
|
|
|
|
|
| ABBREVIATIONS2 = {
|
| 'NH': '[NH]', 'CONH': '[C](=O)[NH]', 'HN': '[NH]', 'HNO': '[NH]O', 'NO': '[N]=O'
|
| }
|
|
|
| def validate_and_expand(c_count, h_count, prefix=None, suffix=None):
|
| """Helper to validate CnHm/HmCn and generate SMILES."""
|
| if h_count != 2 * c_count - 1:
|
| return False
|
| if c_count % 2 != 0:
|
| print(f"C#C , c_count have to be 2n!!!")
|
| return False
|
|
|
| smiles = '[C]#C unit repeat' if c_count == 2 else '[C]#C'+'C#C' * int(c_count/2 - 1)
|
| if prefix:
|
| prefix = ABBREVIATIONS2.get(prefix, prefix)
|
| smiles = prefix + smiles
|
| if suffix:
|
| suffix = ABBREVIATIONS2.get(suffix, suffix)
|
| smiles += suffix
|
| return smiles
|
|
|
|
|
| match_cn_hm = re.match(p_cn_hm, group)
|
| if match_cn_hm:
|
| c_count, h_count = int(match_cn_hm.group(1)), int(match_cn_hm.group(2))
|
| return validate_and_expand(c_count, h_count)
|
|
|
| match_hm_cn = re.match(p_hm_cn, group)
|
| if match_hm_cn:
|
| h_count, c_count = int(match_hm_cn.group(1)), int(match_hm_cn.group(2))
|
| return validate_and_expand(c_count, h_count)
|
|
|
|
|
| patterns = [
|
| (p_prefix, p_cn_hm, 2, 1, 2, 'prefix'),
|
| (p_suffix, p_cn_hm, 1, 1, 2, 'suffix'),
|
| (p_hm_cn_prefix, p_hm_cn, 2, 2, 1, 'prefix'),
|
| (p_hm_cn_suffix, p_hm_cn, 1, 2, 1, 'suffix')
|
| ]
|
|
|
| for pattern, sub_pattern, group_idx, c_idx, h_idx, aff_type in patterns:
|
| match = re.match(pattern, group)
|
| if match:
|
| cn_hm = match.group(group_idx)
|
| affix = match.group(3 - group_idx)
|
| sub_match = re.match(sub_pattern, cn_hm)
|
| if sub_match:
|
| c_count = int(sub_match.group(c_idx))
|
| h_count = int(sub_match.group(h_idx))
|
| return validate_and_expand(
|
| c_count, h_count,
|
| prefix=affix if aff_type == 'prefix' else None,
|
| suffix=affix if aff_type == 'suffix' else None
|
| )
|
|
|
| return False
|
|
|
|
|
| def H_C_expand(group):
|
|
|
| match_cn_hm_prefix = re.match(r'(H\d+C\d+)(.+)', group)
|
| elements = ['S', 'N', 'P', 'C', 'O']
|
| keys = [f"{e}{suffix}" for e in elements for suffix in ['R"', "R'", "R", "*"]]
|
| replacement_map = {key: f'{key[0]}*' for key in keys}
|
|
|
| if match_cn_hm_prefix:
|
| cn_hm = match_cn_hm_prefix.group(1)
|
| suffix = match_cn_hm_prefix.group(2)
|
|
|
| match_cn_hm = re.match(r'H(\d+)C(\d+)', cn_hm)
|
| if match_cn_hm:
|
| C_count = int(match_cn_hm.group(1))
|
| H_count = int(match_cn_hm.group(2))
|
| if H_count != 2 * C_count + 1:
|
| return False
|
| else:
|
| smiles = '[C]' + 'C' * (C_count - 1)
|
| if suffix:
|
| suffix = normalize_ocr_text(suffix, replacement_map)
|
| suffix_smi=ABBREVIATIONS[suffix].smiles if suffix in ABBREVIATIONS else suffix
|
| sub_smic=sub_smic=suffix_smi + 'C' * (C_count )
|
| return sub_smic
|
| else:
|
| return smiles
|
| return False
|
|
|
| match_cn_hm_suffix = re.match(r'(.+)(H\d+C\d+)$', group)
|
| if match_cn_hm_suffix:
|
| prefix = match_cn_hm_suffix.group(1)
|
| cn_hm = match_cn_hm_suffix.group(2)
|
|
|
| match_cn_hm = re.match(r'H(\d+)C(\d+)', cn_hm)
|
| if match_cn_hm:
|
| C_count = int(match_cn_hm.group(1))
|
| H_count = int(match_cn_hm.group(2))
|
|
|
| if H_count != 2 * C_count + 1:
|
| return False
|
| else:
|
| smiles = '[C]' + 'C' * (C_count - 1)
|
| if prefix:
|
| prefix = normalize_ocr_text(prefix, replacement_map)
|
| prefix_smi=ABBREVIATIONS[prefix].smiles if prefix in ABBREVIATIONS else prefix
|
| sub_smic=sub_smic=prefix_smi + 'C' * (C_count )
|
| return sub_smic
|
| else:
|
| return smiles
|
| return False
|
|
|
|
|
| match_cnfm = re.match(r'H(\d+)C(\d+)', group)
|
| if match_cnfm:
|
| C_count = int(match_cnfm.group(1))
|
| F_count = int(match_cnfm.group(2))
|
|
|
| if F_count != 2 * C_count + 1:
|
| return False
|
| smiles = '[C]' + 'C' * (C_count - 1)
|
| return smiles
|
|
|
| def C_F_expand(group):
|
|
|
| match_cnfm = re.match(r'C(\d+)F(\d+)', group)
|
| if match_cnfm:
|
| C_count = int(match_cnfm.group(1))
|
| F_count = int(match_cnfm.group(2))
|
|
|
| if F_count != 2 * C_count + 1:
|
| return False
|
| else:
|
|
|
|
|
| match_cfx = re.match(r'(CF2)*CF3$', group)
|
| if not match_cfx:
|
| return False
|
|
|
| cf2_count = group.count('CF2')
|
| C_count = cf2_count + 1
|
| F_count = cf2_count * 2 + 3
|
|
|
| if F_count != 2 * C_count + 1:
|
| return False
|
|
|
| smiles = []
|
| for i in range(C_count):
|
| if i < C_count - 1:
|
|
|
| if len(smiles)==0:
|
| smiles.append('[C](F)(F)')
|
| else:
|
| smiles.append('C(F)(F)')
|
| else:
|
|
|
| smiles.append('[C](F)(F)(F)')
|
|
|
|
|
| return ''.join(smiles)
|
|
|
|
|
|
|
| original_str ='|'.join(list(ABBREVIATIONS.keys()))
|
| escaped_str = original_str.replace('*', r'\*').replace('(', r'\(').replace(')', r'\)')
|
|
|
| FORMULA_REGEX_str='(' + escaped_str + '|R[0-9]*|[A-Z][a-z]+|[A-Z]|[0-9]+|\(|\))'
|
|
|
|
|
| FORMULA_REGEX = re.compile(FORMULA_REGEX_str)
|
|
|
| def _parse_tokens(tokens: list):
|
| """
|
| Parse tokens of condensed formula into list of pairs `(elt, num)`
|
| where `num` is the multiplicity of the atom (or nested condensed formula) `elt`
|
| Used by `_parse_formula`, which does the same thing but takes a formula in string form as input
|
| """
|
| elements = []
|
| i = 0
|
| j = 0
|
| while i < len(tokens):
|
| if tokens[i] == '(':
|
| while j < len(tokens) and tokens[j] != ')':
|
| j += 1
|
| elt = _parse_tokens(tokens[i + 1:j])
|
| else:
|
| elt = tokens[i]
|
| j += 1
|
| if j < len(tokens) and tokens[j].isnumeric():
|
| num = int(tokens[j])
|
| j += 1
|
| else:
|
| num = 1
|
| elements.append((elt, num))
|
| i = j
|
| return elements
|
|
|
|
|
| def _parse_formula(formula: str):
|
| """
|
| Parse condensed formula into list of pairs `(elt, num)`
|
| where `num` is the subscript to the atom (or nested condensed formula) `elt`
|
| Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)]
|
| """
|
| tokens = FORMULA_REGEX.findall(formula)
|
|
|
|
|
| return _parse_tokens(tokens)
|
|
|
|
|
| def _expand_carbon(elements: list):
|
| """
|
| Given list of pairs `(elt, num)`, output single list of all atoms in order,
|
| expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary
|
| Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O'])
|
| """
|
| expanded = []
|
| i = 0
|
| while i < len(elements):
|
| elt, num = elements[i]
|
|
|
| if num > 100000:
|
| i += 1; continue
|
|
|
| if elt == 'C' and num > 1 and i + 1 < len(elements):
|
| next_elt, next_num = elements[i + 1]
|
| if next_num > 100000:
|
| i += 1; continue
|
| quotient, remainder = next_num // num, next_num % num
|
| for _ in range(num):
|
| expanded.append('C')
|
| for _ in range(quotient):
|
| expanded.append(next_elt)
|
| for _ in range(remainder):
|
| expanded.append(next_elt)
|
| i += 2
|
|
|
| elif isinstance(elt, list):
|
| new_elt = _expand_carbon(elt)
|
| for _ in range(num):
|
| expanded.append(new_elt)
|
| i += 1
|
|
|
| else:
|
| for _ in range(num):
|
| expanded.append(elt)
|
| i += 1
|
| if expanded==[]:
|
| return False
|
| else:
|
| return expanded
|
|
|
| def replace_bracket(match):
|
| content = match.group(1)
|
|
|
| if re.search(r'\d|\+|-', content):
|
| return f'[{content}]'
|
|
|
| elif content == 'H':
|
| return '[H]'
|
|
|
| elif len(content) >= 2 and 'H' in content:
|
| return ''.join([ch for ch in content if ch != 'H'])
|
|
|
| else:
|
| return content
|
|
|
|
|
|
|
| def formula_regex(abbrev):
|
| tokens = FORMULA_REGEX.findall(abbrev)
|
|
|
| abbrev_exp=_expand_carbon(_parse_tokens(tokens))
|
| if abbrev_exp==[]:
|
| return False
|
| else:
|
| return abbrev_exp
|
|
|
| def _expand_abbreviationMS(abbrev):
|
| """
|
| Expand abbreviation into its SMILES; also converts [Rn] to [n*]
|
| Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula
|
| """
|
| if abbrev in ABBREVIATIONS:
|
| return ABBREVIATIONS[abbrev].smiles
|
|
|
| if abbrev in RGROUP_SYMBOLS or (abbrev[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()):
|
| if abbrev[1:].isdigit():
|
| return f'[{abbrev[1:]}*]'
|
| return '*'
|
| return f'[{abbrev}]'
|
|
|
|
|
| def _get_bond_symb(bond_num):
|
| """
|
| Get SMILES symbol for a bond given bond order
|
| Used in `_condensed_formula_list_to_smiles` while writing the SMILES string
|
| """
|
| if bond_num == 0:
|
| return '.'
|
| elif bond_num == 1:
|
| return ''
|
| elif bond_num == 2:
|
| return '='
|
| elif bond_num == 3:
|
| return '#'
|
| else:
|
| print(f"check this val {bond_num} !!!" )
|
|
|
| return ''
|
| def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None):
|
| """
|
| Converts condensed formula (in the form of a list of symbols) to smiles
|
| Input:
|
| `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2
|
| `start_bond`: # bonds attached to beginning of formula
|
| `end_bond`: # bonds attached to end of formula (deduce automatically if None)
|
| `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically)
|
| Returns:
|
| `smiles`: smiles corresponding to input condensed formula
|
| `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified
|
| `num_trials`: number of trials
|
| `success` (bool): whether conversion was successful
|
| """
|
|
|
| if direction is None:
|
| num_trials = 1
|
| for dir_choice in [1, -1]:
|
| smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice)
|
| num_trials += trials
|
| if success:
|
| return smiles, bonds_left, num_trials, success
|
| return None, None, num_trials, False
|
| assert direction == 1 or direction == -1
|
|
|
| def dfs(smiles, bonds_left, cur_idx, add_idx):
|
| """
|
| `smiles`: SMILES string so far
|
| `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached)
|
| `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far
|
| `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to
|
| `add_idx`: index (in list `formula`) of atom to be attached to current atom
|
| `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far
|
| Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2)
|
| """
|
| num_trials = 1
|
|
|
| if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1):
|
| if end_bond is not None and end_bond != bonds_left:
|
| return smiles, bonds_left, num_trials, False
|
| return smiles, bonds_left, num_trials, True
|
|
|
|
|
| if bonds_left <= 0:
|
| return smiles, bonds_left, num_trials, False
|
| to_add = formula_list[add_idx]
|
| if not isinstance(to_add, str):
|
| return smiles, bonds_left, num_trials, False
|
| if isinstance(to_add, list):
|
| if bonds_left > 1:
|
|
|
|
|
| add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
|
| if val > 0:
|
| add_str = _get_bond_symb(val + 1) + add_str
|
| num_trials += trials
|
| if not success:
|
| return smiles, bonds_left, num_trials, False
|
|
|
| result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction)
|
| else:
|
|
|
|
|
| add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
|
| num_trials += trials
|
| if not success:
|
| return smiles, bonds_left, num_trials, False
|
|
|
| result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction)
|
| smiles, bonds_left, trials, success = result
|
| num_trials += trials
|
| return smiles, bonds_left, num_trials, success
|
|
|
| for val in VALENCES.get(to_add, [1]):
|
| add_str = _expand_abbreviationMS(to_add)
|
| if bonds_left > val:
|
| if cur_idx >= 0:
|
| add_str = _get_bond_symb(val) + add_str
|
| result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction)
|
| else:
|
| if cur_idx >= 0:
|
| add_str = _get_bond_symb(bonds_left) + add_str
|
| result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction)
|
| trials, success = result[2:]
|
| num_trials += trials
|
| if success:
|
| return result[0], result[1], num_trials, success
|
| if num_trials > 10000:
|
| break
|
| return smiles, bonds_left, num_trials, False
|
|
|
| cur_idx = -1 if direction == 1 else len(formula_list)
|
| add_idx = 0 if direction == 1 else len(formula_list) - 1
|
| return dfs('', start_bond, cur_idx, add_idx)
|
|
|
| def swap_paren_bracket(text):
|
|
|
| if not text.startswith('('):
|
| return text
|
|
|
| pattern = r'^\((.*?)\)\[(.*?)\]'
|
|
|
| match = re.match(pattern, text)
|
| if match:
|
|
|
| return f'[{match.group(2)}]({match.group(1)})'
|
|
|
| return text
|
|
|
| def convert_ch2_string(s):
|
|
|
| pattern = r'\(CH2\)(\d+|[a-zA-Z]+)'
|
| match = re.fullmatch(pattern, s)
|
| if not match:
|
| return s
|
|
|
| suffix = match.group(1)
|
|
|
| if suffix.isdigit():
|
| n = int(suffix)
|
| if n == 1:
|
| return '[CH2]'
|
| else:
|
| return '[CH2]' + 'C' * (n - 1)
|
| else:
|
|
|
| var = suffix
|
| print(var,s)
|
| return s
|
|
|
|
|
| def process_string_joinused(s):
|
|
|
| match = re.match(r'^\[([^\]]*)\](.*)$', s)
|
| if not match:
|
| return s
|
|
|
| content, rest = match.groups()
|
|
|
| char_count = len(content)
|
|
|
|
|
| if char_count > 1 and 'H' in content:
|
|
|
| new_content = re.sub(r'H\d*', '', content)
|
| return f'[{new_content}]{rest}'
|
| return s
|
|
|
| def all_elements_in_dict(lst, dictionary):
|
| """
|
| 递归检查列表(可能嵌套)中的所有元素是否都存在于字典的键中
|
|
|
| :param lst: 要检查的列表(可能包含嵌套列表)
|
| :param dictionary: 要检查的字典
|
| :return: 如果所有元素都在字典键中返回True,否则返回False
|
| """
|
| for element in lst:
|
| if isinstance(element, list):
|
|
|
| if not all_elements_in_dict(element, dictionary):
|
| return False
|
| else:
|
|
|
| if element not in dictionary:
|
| return False
|
| return True
|
|
|
| def expand_cf2_to_smiles(input_string):
|
|
|
| pattern = r'\(CF2\)(\d+)([A-Za-z0-9]+)'
|
| match = re.match(pattern, input_string)
|
| if not match:
|
| return input_string
|
|
|
| n = int(match.group(1))
|
| tail_group = f"[{match.group(2)}]"
|
|
|
|
|
| cf2_unit = 'C(F)(F)'
|
| smiles = '[C](F)(F)' + cf2_unit * (n-1) + tail_group if n > 0 else tail_group
|
| return smiles
|
|
|
| def find_repeating_unit_and_smiles(s):
|
| match = re.fullmatch(r'(.+?)(?:\1)+', s)
|
| if match:
|
| unit = match.group(1)
|
| repeat_count = len(s) // len(unit)
|
|
|
| if unit == "CH2":
|
| smiles_unit = "C"
|
| smi_init="[CH2]"
|
| elif unit == "CF2":
|
| smiles_unit = "C(F)(F)"
|
| smi_init="[C](F)(F)"
|
| elif unit == "SO2":
|
| smiles_unit = "S(=O)(=O)"
|
| smi_init="[S](=O)(=O)"
|
| else:
|
| smiles_unit,smi_init='',''
|
| print(f'please add the repateat patter here !!! for: {s}')
|
|
|
|
|
| smiles = smi_init + smiles_unit * (repeat_count - 1 )
|
|
|
| return smiles, repeat_count, unit
|
| else:
|
| return None, 0, None
|
|
|
| def get_smiles_from_symbol(symbol, mol, bonds):
|
| """
|
| Convert symbol (abbrev. or condensed formula) to smiles
|
| If condensed formula, determine parsing direction and num. bonds on each side using coordinates
|
| """
|
| if symbol in ABBREVIATIONS:
|
| return ABBREVIATIONS[symbol].smiles
|
| if symbol in RGROUP_SYMBOLS or (symbol[0] in RGROUP_SYMBOLS and symbol[1:].isdigit()):
|
| if symbol[1:].isdigit():
|
| return f'[{symbol[1:]}*]'
|
| return '*'
|
|
|
| if len(symbol) > 20:
|
| return None
|
| smiles=convert_ch2_string(symbol)
|
| if smiles !=symbol:
|
| return smiles
|
| if '(CF2)' in symbol:
|
| smiles=expand_cf2_to_smiles(symbol)
|
| return smiles
|
| smiles, repeat_count, unit = find_repeating_unit_and_smiles(symbol)
|
| if repeat_count>0:
|
| return smiles
|
|
|
|
|
|
|
| if symbol in ['CH2CH','CHCH2','CH2CH2', 'CH2CH2CH','CH2CH2CH','H2CH2CHC','CHCH2CH2','(CH2)10', 'H2C','CH2',
|
| 'OCH2CHOHCH2NH','OCH2CHOHCH2','CF2O','OF2C','EtO2CHN','EtO2C',
|
| 'CH2CH2C(O)0CH2CH3','CH2CH2C(O)OCH2CH3','l23I',
|
| 'OCH2CH2OH','OCH2CHCH2CCH3','CH2O',
|
| '(H4NO)2','SO2NHCH2CH','OCH2CH','OCF2H','COCOOCH2CH3','CH2CH2CH2CH','HCH2CH2CH2C','CF3CF2CF2CF2SO3',
|
|
|
| '(CF2)8H','PH3C','CO','OC',
|
| 'CF2CF2H','NHSO2CH3','CH2CH2C','CH;CH2C(O)0CHCH3','CH2CH2C(O)OCHCH3',
|
| 'NH2','H2N', 'CHO', 'OHC', 'N(SO2CH3)2','CH2CH2O','CH2CH2C(O)OCH2CH3',
|
|
|
| 'Ar2P(O)','PhO2S','NHP(O)Ph2','P*Ph3','P+Ph3','NH2.HCl',
|
|
|
| 'S[O]a',
|
|
|
| '(C3H6O)7CH3','HC','(HC','(CH2CH2CH2CH-)','3(CHCHCHCH272',
|
|
|
| 'NHzBrH','NH2BrH',
|
|
|
| '(co)','(CO)',
|
|
|
| 'CH3CH','CH3CCH3','CH3CO','CH3OCH2','CO2C','CH2CO2CH3',"COCl",
|
| ]:
|
|
|
|
|
|
|
|
|
| if symbol in ['CH2CH','CHCH2']:smiles='[CH2][CH]'
|
| elif symbol in ['PH3C']:smiles='[CH2]P'
|
| elif symbol in ['l23I']:smiles='[I]'
|
| elif symbol in ['HC','(HC']:smiles='[CH]'
|
| elif symbol in ['NHzBrH','NH2BrH']:smiles='[NH2].Br'
|
| elif symbol in ['(C3H6O)7CH3']:smiles="[O]CCC"+"OCCC"*6+'C'
|
| elif symbol in ['NH2.HCl']:smiles="[NH2].Cl"
|
| elif symbol in ['CH2CH2CH2CH','(CH2CH2CH2CH-)']:smiles='[CH2]CC[CH]'
|
| elif symbol in ['3(CHCHCHCH272', 'CHCHCHCH2']:smiles='[CH]CC[CH2]'
|
|
|
| elif symbol in [ 'CH3CH']:smiles='[CH]C'
|
| elif symbol in [ 'CH2CO2CH3']:smiles='[CH2]C(=O)OC'
|
| elif symbol in [ 'CO2C']:smiles='[C](=O)O[C]'
|
| elif symbol in [ 'CH3CCH3']:smiles='[C](C)(C)'
|
| elif symbol in [ 'CH3CO']:smiles='[C](=O)C'
|
| elif symbol in [ 'CH3OCH2']:smiles='[CH2]OC'
|
|
|
| elif symbol in [ '(co)','(CO)']:smiles='[C](=O)'
|
| elif symbol in ['Ar2P(O)']:smiles='[P](*)(*)(=O)'
|
| elif symbol in ['PhO2S']:smiles='[S](=O)(=O)c1ccccc1'
|
| elif symbol in ['CO','OC']:smiles='[C](=O)'
|
| elif symbol in ['CH2O']:smiles='[CH2][O]'
|
| elif symbol in ['P*Ph3','P+Ph3',]:smiles='[P+](c1ccccc1)(c1ccccc1)(c1ccccc1)'
|
| elif symbol in ['NHP(O)Ph2']:smiles='[NH]P(=O)(c1ccccc1)c1ccccc1'
|
| elif symbol in ['CH;CH2C(O)0CHCH3','CH2CH2C(O)OCHCH3']:smiles='[CH2]CC(=O)OCC'
|
| elif symbol in ['CH2CH2CH','H2CH2CHC','CHCH2CH2']:smiles='[CH2][CH2][CH]'
|
| elif symbol in ['CH2CH2CH2CH']:smiles='[CH2]CC[CH]'
|
| elif symbol in ['HCH2CH2CH2C']:smiles='[CH]CC[CH2]'
|
| elif symbol in ['H2C','CH2']:smiles='[CH2]'
|
| elif symbol in ['H2CH2C','CH2CH2']:smiles='[CH2][CH2]'
|
| elif symbol in ['CHO', 'OHC']:smiles="[CH](=O)"
|
| elif symbol in ['NH2','H2N']:smiles="[NH2]"
|
| elif symbol in ['(CF2)8H',]:smiles="[C](F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)"
|
| elif symbol in ['CH2CH2C(O)OCH2CH3','CH2CH2C(O)0CH2CH3']:smiles='[CH2]CC(=O)OCC'
|
| elif symbol in ['CF3CF2CF2CF2SO3']:smiles='[S](=O)(=O)([O-])C(F)(F)C(F)(F)C(F)(F)C(F)(F)(F)'
|
| elif symbol in ['S[O]a']:smiles='[S](=O)'
|
| elif symbol in ['COCl']:smiles='[C](=O)Cl'
|
|
|
|
|
|
|
| elif symbol in ['OCF2H']:smiles="[O]C(F)(F)"
|
| elif symbol in ['CF2O']:smiles="[C](F)(F)[O]"
|
| elif symbol in ['OF2C']:smiles="[O][C](F)(F)"
|
| elif symbol in ['CF2CF2H']:smiles="[C](F)(F)C(F)(F)"
|
|
|
| elif symbol in ['CH2CH2O']:smiles='[CH2]CO'
|
| elif symbol in ['OCH2CH2OH']:smiles='[O]CCO'
|
| elif symbol in ['EtO2CHN']:smiles='[N]C(=O)OCC'
|
| elif symbol in ['OCH2CHOHCH2NH']:smiles='[O]CC(O)CN'
|
| elif symbol in ['OCH2CHCH2CCH3']:smiles='[O]C[CH]C[C]C'
|
| elif symbol in ['(H4NO)2']:smiles='[O]NON'
|
| elif symbol in ['SO2NHCH2CH']:smiles='[S](=O)(=O)NC[CH]'
|
| elif symbol in ['N(SO2CH3)2']:smiles='[N](S(=O)(=O)C)(S(=O)(=O)C)'
|
| elif symbol in ['CH2CH2C(O)OCH2CH3']:smiles='[CH2]CC(=O)OCC'
|
| elif symbol in ['OCH2CH']:smiles='[O]C[CH]'
|
| elif symbol in ['EtO2C']:smiles='C(=O)OCC'
|
| elif symbol in ['CH2CH2C']:smiles='[CH2]C[C]'
|
| elif symbol in ['NHSO2CH3']:smiles='[NH]S(=O)(=O)C'
|
| elif symbol in ['COCOOCH2CH3']:smiles='C(=O)C(=O)OCC'
|
|
|
|
|
|
|
| else:smiles=None
|
| return smiles
|
|
|
| total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))
|
| formula_list = _expand_carbon(_parse_formula(symbol))
|
|
|
| all_in_dict=all_elements_in_dict(formula_list,ABBREVIATIONS)
|
|
|
| smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
|
|
|
| print(f'{[formula_list, total_bonds]} use _condensed_formula_list_to_smiles {success} <<-------\n {smiles}')
|
| if success:
|
| smiles=swap_paren_bracket(smiles)
|
| return smiles
|
| elif all_in_dict :
|
|
|
| key = extract_abbreviation_key(formula_list[0])
|
| if key in ABBREVIATIONS:
|
| smiles = ABBREVIATIONS[key].smiles
|
| else:
|
|
|
| print(f"Abbreviation {key} not found in ABBREVIATIONS.")
|
| smiles=''
|
| for fl_i in range(1,len(formula_list)):
|
| cur_smi=process_string_joinused(ABBREVIATIONS[formula_list[fl_i]].smiles)
|
| smiles += cur_smi
|
| return smiles
|
|
|
| return None
|
|
|
| def abbrev2smile(abbrev,abbrev_exp,mol,idx):
|
|
|
| atom_gost = mol.GetAtomWithIdx(idx)
|
| bonds_gost = atom_gost.GetBonds()
|
| sub_smi = get_smiles_from_symbol(abbrev, mol, bonds_gost)
|
|
|
| if sub_smi:
|
|
|
| return sub_smi
|
| else:
|
| print(f"failed expanding {abbrev},{abbrev_exp}\n{sub_smi}\t{idx}")
|
| return '[*]'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def replace_cg_notation(astr):
|
| def replacer(match):
|
| h_count = int(match.group(1))
|
| c_count = (h_count - 1) // 2
|
| return f'C{c_count}H{h_count}'
|
|
|
| return re.sub(r'CgH(\d+)', replacer, astr)
|
|
|
|
|
| def _expand_abbreviation(abbrev, mol,idx):
|
| """
|
| Expand abbreviation into its SMILES; also converts [Rn] to [n*].
|
| """
|
|
|
| if abbrev in ABBREVIATIONS:
|
| return ABBREVIATIONS[abbrev].smiles
|
|
|
| elif N_C_H_expand(abbrev):return N_C_H_expand(abbrev)
|
| elif C_F_expand(abbrev):return C_F_expand(abbrev)
|
| elif C_H_expand2(abbrev):return C_H_expand2(abbrev)
|
| elif C_H_expand(abbrev):return C_H_expand(abbrev)
|
| elif C_H_affixExpand(abbrev):return C_H_affixExpand(abbrev)
|
|
|
| elif abbrev in RGROUP_SYMBOLS or (abbrev[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()):
|
| if abbrev[1:].isdigit():
|
| return f'[{abbrev[1:]}*]'
|
| elif abbrev in ELEMENTS:
|
| return f'[{abbrev}]'
|
|
|
| elif formula_regex(abbrev):
|
| abbrev_exp= formula_regex(abbrev)
|
| return abbrev2smile(abbrev,abbrev_exp,mol,idx)
|
|
|
| match = re.match(r'^(\d+)?(.*)', abbrev)
|
| if match:
|
| numeric_part, remaining_part = match.groups()
|
| if remaining_part in ELEMENTS:
|
| return f'[{abbrev}]'
|
| elif numeric_part:
|
| return f'[{numeric_part}*]'
|
|
|
| else:
|
| print(f"fixme !!!@@@@: {abbrev}")
|
|
|
| return '[*]'
|
|
|
| def count_current_bonds(mol, atom_idx):
|
| """Count current bonds (including bond order) for an atom."""
|
| atom = mol.GetAtomWithIdx(atom_idx)
|
| return sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds())
|
|
|
| debug_not=True
|
|
|
| def expandABB(mol, ABBREVIATIONS, placeholder_atoms):
|
| mols = [mol]
|
|
|
| for idx in sorted(placeholder_atoms.keys(), reverse=True) :
|
| group = placeholder_atoms[idx]
|
| group_smiles = _expand_abbreviation(group,mol,idx)
|
| submol = Chem.MolFromSmiles(group_smiles)
|
| try:
|
| submol_rw = Chem.RWMol(submol)
|
| except Exception as e:
|
| print(f"abbver: {group}")
|
| print(f'try to convert {group_smiles} to sub_mol')
|
| print(e)
|
| if debug_not:
|
| print(f"Failed to convert {group_smiles} to sub_mol, using placeholder [*] instead.")
|
| submol = Chem.MolFromSmiles('[*]')
|
| submol_rw = Chem.RWMol(submol)
|
| else:
|
| raise e
|
|
|
|
|
| anchor_atoms = [0]
|
| for atom in submol_rw.GetAtoms():
|
|
|
| if atom.GetNumRadicalElectrons() > 0 and atom.GetIdx() not in anchor_atoms:
|
| anchor_atoms.append(atom.GetIdx())
|
|
|
| new_mol = Chem.RWMol(mol)
|
| placeholder_idx = idx
|
|
|
| bonds_info = []
|
| for bond in new_mol.GetBonds():
|
| if bond.GetBeginAtomIdx() == placeholder_idx:
|
| bonds_info.append({
|
| "neighbor": bond.GetEndAtomIdx(),
|
| "bond_type": bond.GetBondType()
|
| })
|
| elif bond.GetEndAtomIdx() == placeholder_idx:
|
| bonds_info.append({
|
| "neighbor": bond.GetBeginAtomIdx(),
|
| "bond_type": bond.GetBondType()
|
| })
|
|
|
|
|
| for bond_info in bonds_info:
|
| new_mol.RemoveBond(placeholder_idx, bond_info["neighbor"])
|
|
|
|
|
| new_mol.RemoveAtom(placeholder_idx)
|
|
|
|
|
| adjusted_bonds_info = []
|
| for bond_info in bonds_info:
|
| neighbor = bond_info["neighbor"]
|
| if neighbor < placeholder_idx:
|
| adjusted_neighbor = neighbor
|
| else:
|
| adjusted_neighbor = neighbor - 1
|
| adjusted_bonds_info.append({
|
| "neighbor": adjusted_neighbor,
|
| "bond_type": bond_info["bond_type"]
|
| })
|
|
|
|
|
| new_mol = Chem.RWMol(Chem.CombineMols(new_mol, submol_rw))
|
|
|
|
|
| submol_atom_offset = new_mol.GetNumAtoms() - submol_rw.GetNumAtoms()
|
| new_anchor_indices = [submol_atom_offset + anchor_idx for anchor_idx in anchor_atoms]
|
|
|
|
|
| if len(new_anchor_indices) == 1:
|
|
|
| anchor_idx = new_anchor_indices[0]
|
| for bond_info in adjusted_bonds_info:
|
| neighbor = bond_info["neighbor"]
|
| bond_type = bond_info["bond_type"]
|
| new_mol.AddBond(neighbor, anchor_idx, bond_type)
|
|
|
| a1 = new_mol.GetAtomWithIdx(neighbor)
|
| a2 = new_mol.GetAtomWithIdx(anchor_idx)
|
| a1.SetNumRadicalElectrons(0)
|
| a2.SetNumRadicalElectrons(0)
|
| else:
|
|
|
|
|
| if len(adjusted_bonds_info) > len(new_anchor_indices):
|
| print(adjusted_bonds_info,' <---adjusted_bonds_info')
|
| print(new_anchor_indices,'<---new_anchor_indices')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| anchor_bond_counts = {idx: new_mol.GetAtomWithIdx(idx).GetTotalValence() for idx in new_anchor_indices}
|
| print(anchor_bond_counts,'<---anchor_bond_counts')
|
|
|
| adjusted_bonds_info = sorted(adjusted_bonds_info, key=lambda x: x['neighbor'])
|
| if mol.GetNumConformers() > 0:
|
| pos_0 = mol.GetConformer().GetAtomPosition(adjusted_bonds_info[0]['neighbor'])
|
| pos_1 = mol.GetConformer().GetAtomPosition(adjusted_bonds_info[-1]['neighbor'])
|
| print(pos_0.x,pos_1.x,"xxx",adjusted_bonds_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| for bond_info in adjusted_bonds_info:
|
| neighbor = bond_info["neighbor"]
|
| bond_type = bond_info["bond_type"]
|
| bond_valence = {Chem.BondType.SINGLE: 1, Chem.BondType.DOUBLE: 2, Chem.BondType.TRIPLE: 3}.get(bond_type, 1)
|
|
|
| selected_anchor_idx = None
|
| for anchor_idx in new_anchor_indices:
|
| atom = new_mol.GetAtomWithIdx(anchor_idx)
|
| atomic_num = atom.GetAtomicNum()
|
| current_valence = anchor_bond_counts[anchor_idx]
|
| max_allowed = max(VALENCES.get( atom.GetSymbol(), [1]))
|
| if current_valence + bond_valence <= max_allowed:
|
| selected_anchor_idx = anchor_idx
|
| break
|
| if selected_anchor_idx is None:
|
| continue
|
|
|
| new_mol.AddBond(neighbor, selected_anchor_idx, bond_type)
|
|
|
| anchor_bond_counts[selected_anchor_idx] += bond_valence
|
|
|
| a1 = new_mol.GetAtomWithIdx(neighbor)
|
| a2 = new_mol.GetAtomWithIdx(selected_anchor_idx)
|
| a1.SetNumRadicalElectrons(0)
|
| a2.SetNumRadicalElectrons(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| mol = new_mol
|
| mols.append(mol)
|
|
|
|
|
| modified_smiles = Chem.MolToSmiles(mols[-1])
|
| return mols[-1], modified_smiles
|
|
|
|
|
| def is_valid_chem_text(text):
|
| """检查化学表达式是否只包含大小写字母、数字和成对括号,且括号成对"""
|
| if not text:
|
| return False
|
| if text.isdigit():
|
| return False
|
|
|
| if not re.match(r'^[A-Za-z0-9()]+$', text):
|
| return False
|
|
|
| stack = []
|
| for char in text:
|
| if char == '(':
|
| stack.append(char)
|
| elif char == ')':
|
| if not stack or stack[-1] != '(':
|
| return False
|
| stack.pop()
|
| return len(stack) == 0
|
|
|
| def select_chem_expression(orig_text, orig_score, scaled_text, scaled_score, cropped_img_orig, cropped_img_scaled):
|
| """选择更合理的化学表达式"""
|
|
|
| score_diff = abs(orig_score - scaled_score)
|
| if scaled_text in orig_text and orig_text in ABBREVIATIONS:
|
| print(f'use orig_text as include the sacled and in ABBREVIATIONS {orig_text}')
|
| return orig_text, orig_score, cropped_img_orig
|
| elif orig_text in scaled_text and scaled_text in ABBREVIATIONS:
|
| print(f'use scaled_text as include the orig_text and in ABBREVIATIONS {scaled_text}')
|
| return scaled_text, scaled_score, cropped_img_scaled
|
|
|
|
|
| orig_valid = is_valid_chem_text(orig_text)
|
| scaled_valid = is_valid_chem_text(scaled_text)
|
|
|
|
|
|
|
| if score_diff > 0.1:
|
| if orig_valid and scaled_valid:
|
| if orig_score >= scaled_score and orig_text:
|
| return orig_text, orig_score, cropped_img_orig
|
| elif scaled_text:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
| elif orig_valid and not scaled_valid:
|
| return orig_text, orig_score, cropped_img_orig
|
| elif scaled_valid and not orig_valid:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
| else:
|
| print(f"Both texts are invalid: orig_text='{orig_text}', scaled_text='{scaled_text}'")
|
| if orig_score >= scaled_score:
|
| return orig_text, orig_score, cropped_img_orig
|
| else:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
|
|
| else:
|
|
|
| if orig_valid and not scaled_valid:
|
| return orig_text, orig_score, cropped_img_orig
|
| elif scaled_valid and not orig_valid:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
|
|
| elif orig_valid and scaled_valid:
|
| if orig_text in ABBREVIATIONS and scaled_text not in ABBREVIATIONS:
|
| if N_C_H_expand(scaled_text) or C_F_expand(scaled_text) or C_H_expand2(scaled_text) or C_H_expand(scaled_text):
|
| if len(scaled_text)> len(orig_text):
|
| return scaled_text, scaled_score, cropped_img_scaled
|
| return orig_text, orig_score, cropped_img_orig
|
| elif orig_text not in ABBREVIATIONS and scaled_text in ABBREVIATIONS:
|
| if N_C_H_expand(orig_text) or C_F_expand(orig_text) or C_H_expand2(orig_text) or C_H_expand(orig_text):
|
| if len(orig_text)> len(scaled_text):
|
| return orig_text, orig_score, cropped_img_orig
|
| return scaled_text, scaled_score, cropped_img_scaled
|
| elif orig_text not in ABBREVIATIONS and scaled_text not in ABBREVIATIONS:
|
| if len(orig_text) > len(scaled_text):
|
| return orig_text, orig_score, cropped_img_orig
|
| else:
|
| if len(orig_text) == len(scaled_text):
|
| if orig_score >= scaled_score :
|
| return orig_text, orig_score, cropped_img_orig
|
| else:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
| return scaled_text, scaled_score, cropped_img_scaled
|
|
|
| elif orig_text in ABBREVIATIONS and scaled_text in ABBREVIATIONS:
|
| if len(orig_text) >= len(scaled_text):
|
| return orig_text, orig_score, cropped_img_orig
|
| else:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
|
|
| elif orig_text:
|
| return orig_text, orig_score, cropped_img_orig
|
| elif scaled_text:
|
| return scaled_text, scaled_score, cropped_img_scaled
|
|
|
|
|
| return scaled_text, scaled_score, cropped_img_scaled if scaled_text else (None, None, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def boxes_overlap(box1, box2):
|
| x1, y1, x2, y2 = box1
|
| bx1, by1, bx2, by2 = box2
|
| return not (x2 < bx1 or x1 > bx2 or y2 < by1 or y1 > by2)
|
|
|
| def boxes_overlap2(atombonx, bondbox):
|
| """
|
| 检查两个矩形框是否重叠,并返回 bondbox 中不重叠一端到中心 10% 位置的坐标。
|
|
|
| 参数:
|
| atombonx: tuple (x1, y1, x2, y2) 表示原子框的坐标
|
| bondbox: tuple (bx1, by1, bx2, by2) 表示键框的坐标
|
|
|
| 返回:
|
| tuple (x, y) 表示 bondbox 不重叠一端到中心 80% 位置的坐标,如果完全包含返回 (None, None)
|
| """
|
| x1, y1, x2, y2 = atombonx
|
| bx1, by1, bx2, by2 = bondbox
|
|
|
|
|
| bond_center_x = (bx1 + bx2) / 2
|
| bond_center_y = (by1 + by2) / 2
|
|
|
|
|
| def distance_to_center(x, y):
|
| center_x = (x1 + x2) / 2
|
| center_y = (y1 + y2) / 2
|
| return ((x - center_x) ** 2 + (y - center_y) ** 2) ** 0.5
|
|
|
|
|
| def get_80_percent_point(far_x, far_y):
|
|
|
| dx = far_x - bond_center_x
|
| dy = far_y - bond_center_y
|
| new_x = bond_center_x + 0.7 * dx
|
| new_y = bond_center_y + 0.7 * dy
|
| return new_x, new_y
|
|
|
|
|
| if (bx2 < x1 or bx1 > x2 or by2 < y1 or by1 > y2):
|
|
|
| dist1 = distance_to_center(bx1, by1)
|
| dist2 = distance_to_center(bx2, by2)
|
| far_x, far_y = (bx2, by2) if dist2 > dist1 else (bx1, by1)
|
| return get_80_percent_point(far_x, far_y)
|
|
|
|
|
| if (bx1 >= x1 and bx2 <= x2 and by1 >= y1 and by2 <= y2):
|
|
|
|
|
| return bond_center_x, bond_center_y
|
|
|
|
|
| if (bx1 >= x1 and bx1 <= x2 and by1 >= y1 and by1 <= y2):
|
|
|
| return get_80_percent_point(bx2, by2)
|
| elif (bx2 >= x1 and bx2 <= x2 and by2 >= y1 and by2 <= y2):
|
|
|
| return get_80_percent_point(bx1, by1)
|
|
|
|
|
|
|
| dist1 = distance_to_center(bx1, by1)
|
| dist2 = distance_to_center(bx2, by2)
|
| far_x, far_y = (bx2, by2) if dist2 > dist1 else (bx1, by1)
|
| return get_80_percent_point(far_x, far_y)
|
|
|
|
|
| charge_labels = [19,20,21,22,23]
|
| def outputbox_update(output,charge_labels,bond_labels,lab2idx):
|
| bonds_mask = np.array([True if ins in bond_labels else False for ins in output['pred_classes']])
|
| bond_bbox=output['bbox'][bonds_mask]
|
| atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']])
|
| atom_bbox=output['bbox'][atoms_mask]
|
| new_atoms=[]
|
| b_len=3
|
| single_odd_b2a=dict()
|
| for bi,bb in enumerate(bond_bbox):
|
| overlapped_atoms = []
|
| overlapped_abox=[]
|
| for ai,aa in enumerate(atom_bbox):
|
| overlap_flag=boxes_overlap(bb, aa)
|
| if overlap_flag:
|
|
|
| overlapped_atoms.append(ai)
|
| overlapped_abox.append(aa)
|
| if len(overlapped_atoms) == 1:
|
| single_odd_b2a[bi]=overlapped_atoms
|
|
|
| non_overlapping_x,non_overlapping_y=boxes_overlap2(overlapped_abox[0], bb)
|
| new_atom_out={'bbox': np.array([non_overlapping_x - b_len,
|
| non_overlapping_y - b_len,
|
| non_overlapping_x + b_len,
|
| non_overlapping_y + b_len]).reshape(-1,4),
|
| 'bbox_centers': np.array([non_overlapping_x,non_overlapping_y]).reshape(-1,2),
|
| 'scores': np.array([1.0]),
|
| 'pred_classes': np.array([lab2idx['H']])}
|
| new_atoms.append(new_atom_out)
|
|
|
| output2_=copy.deepcopy(output)
|
| for boxout in new_atoms:
|
| for k,arr in boxout.items():
|
| value_or_row=output2_[k]
|
| if arr.ndim == 1:
|
| output2_[k]=np.append(value_or_row, arr)
|
| elif arr.ndim >= 2:
|
| output2_[k] = np.concatenate([value_or_row, arr], axis=0)
|
| else:
|
| print('errprs, unkown conditions !!!@')
|
| return output2_, single_odd_b2a
|
|
|
|
|
| def remove_unconnected_hydrogens(mol):
|
| """
|
| 移除分子中不与重原子相连的氢原子(包括孤立 H 和只连到其他 H 的 H)。
|
|
|
| 参数:
|
| mol: RDKit Mol 对象
|
|
|
| 返回:
|
| 移除氢原子后的 RWMol 对象
|
| """
|
|
|
| molexp = Chem.RWMol(mol)
|
| to_remove = []
|
|
|
|
|
| for atom in molexp.GetAtoms():
|
| if atom.GetSymbol() == 'H':
|
| neighbors = atom.GetNeighbors()
|
|
|
| has_heavy_atom = False
|
| for neighbor in neighbors:
|
| if neighbor.GetSymbol() != 'H':
|
| has_heavy_atom = True
|
| break
|
|
|
| if not has_heavy_atom:
|
| to_remove.append(atom.GetIdx())
|
|
|
| to_remove.sort(reverse=True)
|
|
|
|
|
| for ai in to_remove:
|
| molexp.RemoveAtom(ai)
|
| return molexp
|
|
|
| from rdkit import Chem
|
| from rdkit.Chem import AllChem
|
|
|
| def remove_unconnected_hydrogens2(mol):
|
| """
|
| 移除分子中不与重原子相连的氢原子(包括孤立 H 和只连到其他 H 的 H),并返回移除的氢原子坐标。
|
|
|
| 参数:
|
| mol: RDKit Mol 对象
|
|
|
| 返回:
|
| rw_mol: 移除氢原子后的 RWMol 对象
|
| removed_h_coords: 移除的氢原子的坐标列表 [(x1, y1, z1), (x2, y2, z2), ...]
|
| """
|
|
|
| rw_mol = Chem.RWMol(mol)
|
| to_remove = []
|
|
|
|
|
| conformer = rw_mol.GetConformer()
|
|
|
|
|
| removed_h_coords = []
|
|
|
|
|
| for atom in rw_mol.GetAtoms():
|
| if atom.GetSymbol() == 'H':
|
| neighbors = atom.GetNeighbors()
|
|
|
| has_heavy_atom = False
|
| for neighbor in neighbors:
|
| if neighbor.GetSymbol() != 'H':
|
| has_heavy_atom = True
|
| break
|
|
|
| if not has_heavy_atom:
|
| to_remove.append(atom.GetIdx())
|
| pos = conformer.GetAtomPosition(atom.GetIdx())
|
| removed_h_coords.append((pos.x, pos.y, pos.z))
|
|
|
| to_remove.sort(reverse=True)
|
|
|
| for ai in to_remove:
|
| rw_mol.RemoveAtom(ai)
|
|
|
| return rw_mol, removed_h_coords
|
|
|
| def detect_unconnected_hydrogens(mol):
|
| rw_mol = Chem.RWMol(mol)
|
| to_remove = []
|
|
|
| conformer = rw_mol.GetConformer()
|
|
|
| removed_h_coords = []
|
|
|
| for atom in rw_mol.GetAtoms():
|
| if atom.GetSymbol() == 'H':
|
| neighbors = atom.GetNeighbors()
|
|
|
| has_heavy_atom = False
|
| for neighbor in neighbors:
|
| if neighbor.GetSymbol() != 'H':
|
| has_heavy_atom = True
|
| break
|
|
|
| if not has_heavy_atom:
|
| to_remove.append(atom.GetIdx())
|
| pos = conformer.GetAtomPosition(atom.GetIdx())
|
| removed_h_coords.append((pos.x, pos.y, pos.z))
|
|
|
| to_remove.sort(reverse=True)
|
| return to_remove
|
|
|
| def view_box_center2(bond_bbox, bond_centers, bond_scores, bond_classes,overlap_dist_thresh=5.0,
|
| max_centers_per_box=5,
|
| plot_view=False,
|
| ):
|
| """
|
| 筛选和可视化 bond_bbox 和 bond_centers,处理重叠圆和过多中心的框。
|
|
|
| 参数:
|
| bond_bbox: numpy array, [x1, y1, x2, y2] 格式的框坐标
|
| bond_centers: numpy array, [x, y] 格式的中心坐标
|
| bond_scores: numpy array, 得分
|
| overlap_dist_thresh: float,判断圆重叠的距离阈值(默认为 5 个单位)
|
| max_centers_per_box: int,一个框内允许的最大中心数(超过则移除)
|
|
|
| 返回:
|
| tuple: (筛选后的 bond_bbox, bond_centers, bond_scores)
|
| """
|
|
|
| assert len(bond_bbox) == len(bond_centers) == len(bond_scores), "Input arrays must have equal length"
|
| n = len(bond_bbox)
|
|
|
| keep_centers = np.ones(n, dtype=bool)
|
| for i in range(n):
|
| if not keep_centers[i]:
|
| continue
|
| for j in range(i + 1, n):
|
| if not keep_centers[j]:
|
| continue
|
|
|
| dist = np.sqrt(np.sum((bond_centers[i] - bond_centers[j]) ** 2))
|
| if dist < overlap_dist_thresh:
|
|
|
| if bond_scores[i] > bond_scores[j]:
|
| keep_centers[j] = False
|
| else:
|
| keep_centers[i] = False
|
|
|
| bond_bbox = bond_bbox[keep_centers]
|
| bond_centers = bond_centers[keep_centers]
|
| bond_scores = bond_scores[keep_centers]
|
| bond_classes= bond_classes[keep_centers]
|
| n = len(bond_bbox)
|
|
|
| keep_boxes = np.ones(n, dtype=bool)
|
| for i in range(n):
|
|
|
| x1, y1, x2, y2 = bond_bbox[i]
|
| centers_in_box = np.sum((bond_centers[:, 0] >= x1) & (bond_centers[:, 0] <= x2) &
|
| (bond_centers[:, 1] >= y1) & (bond_centers[:, 1] <= y2))
|
| if centers_in_box > max_centers_per_box:
|
| keep_boxes[i] = False
|
|
|
| final_bond_bbox = bond_bbox[keep_boxes]
|
| final_bond_centers = bond_centers[keep_boxes]
|
| final_bond_scores = bond_scores[keep_boxes]
|
| final_bond_classes= bond_classes[keep_boxes]
|
| if plot_view:
|
|
|
| fig, ax = plt.subplots(figsize=(10, 10))
|
| for box in final_bond_bbox:
|
| x1, y1, x2, y2 = box
|
| width = x2 - x1
|
| height = y2 - y1
|
| rect = Rectangle((x1, y1), width, height, linewidth=1, edgecolor='blue', facecolor='none')
|
| ax.add_patch(rect)
|
| for center in final_bond_centers:
|
| circle = Circle(center, radius=5, edgecolor='red', facecolor='none', linewidth=1)
|
| ax.add_patch(circle)
|
|
|
|
|
| x_min = min(final_bond_bbox[:, 0].min(), final_bond_centers[:, 0].min()) - 10
|
| x_max = max(final_bond_bbox[:, 2].max(), final_bond_centers[:, 0].max()) + 10
|
| y_min = min(final_bond_bbox[:, 1].min(), final_bond_centers[:, 1].min()) - 10
|
| y_max = max(final_bond_bbox[:, 3].max(), final_bond_centers[:, 1].max()) + 10
|
| ax.set_xlim(x_min, x_max)
|
| ax.set_ylim(y_min, y_max)
|
|
|
| ax.set_title("Filtered Boxes and Centers")
|
| ax.set_xlabel("X")
|
| ax.set_ylabel("Y")
|
| plt.gca().set_aspect('equal', adjustable='box')
|
| plt.grid(True, linestyle='--', alpha=0.7)
|
|
|
| else:
|
| fig=None
|
| return final_bond_bbox, final_bond_centers, final_bond_scores,final_bond_classes,fig
|
|
|
| def calculate_iou(box1, box2):
|
| """
|
| 计算两个框的 IoU(Intersection over Union)。
|
|
|
| 参数:
|
| box1, box2: [x1, y1, x2, y2] 格式的框坐标
|
|
|
| 返回:
|
| float: IoU 值
|
| """
|
| x1 = max(box1[0], box2[0])
|
| y1 = max(box1[1], box2[1])
|
| x2 = min(box1[2], box2[2])
|
| y2 = min(box1[3], box2[3])
|
|
|
| intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| union = area1 + area2 - intersection
|
|
|
| return intersection / union if union > 0 else 0
|
|
|
| def nms_per_class(labels, boxes, scores, iou_thresh=0.5):
|
| """
|
| 对每个类别应用 NMS,保留得分最高的框。
|
| 参数:
|
| labels: numpy array,类别标签
|
| boxes: numpy array,框坐标 [x1, y1, x2, y2]
|
| scores: numpy array,得分
|
| iou_thresh: float,IoU 阈值
|
| 返回:
|
| dict: 筛选后的输出
|
| """
|
|
|
| unique_labels = np.unique(labels)
|
| kept_indices = []
|
| for label in unique_labels:
|
|
|
| class_mask = labels == label
|
| class_indices = np.where(class_mask)[0]
|
| class_boxes = boxes[class_mask]
|
| class_scores = scores[class_mask]
|
|
|
|
|
| order = np.argsort(class_scores)[::-1]
|
| class_boxes = class_boxes[order]
|
| class_scores = class_scores[order]
|
| class_indices = class_indices[order]
|
|
|
|
|
| keep = []
|
| while len(class_scores) > 0:
|
|
|
| keep.append(class_indices[0])
|
| if len(class_scores) == 1:
|
| break
|
|
|
|
|
| ious = np.array([calculate_iou(class_boxes[0], box) for box in class_boxes[1:]])
|
|
|
| keep_mask = ious < iou_thresh
|
| class_boxes = class_boxes[1:][keep_mask]
|
| class_scores = class_scores[1:][keep_mask]
|
| class_indices = class_indices[1:][keep_mask]
|
|
|
| kept_indices.extend(keep)
|
|
|
|
|
| kept_indices = np.array(kept_indices)
|
| return {
|
| 'labels': labels[kept_indices],
|
| 'boxes': boxes[kept_indices],
|
| 'scores': scores[kept_indices]
|
| }
|
|
|
|
|
|
|
|
|
| import numpy as np
|
| def get_overlap_region(box1, box2):
|
| """
|
| Get the overlapping region of two boxes.
|
|
|
| Args:
|
| box1, box2: [x_min, y_min, x_max, y_max]
|
|
|
| Returns:
|
| tuple: (x_min, y_min, x_max, y_max) of overlap region, or None if no overlap
|
| """
|
| x1 = max(box1[0], box2[0])
|
| y1 = max(box1[1], box2[1])
|
| x2 = min(box1[2], box2[2])
|
| y2 = min(box1[3], box2[3])
|
|
|
| if x2 <= x1 or y2 <= y1:
|
| return None
|
| return (x1, y1, x2, y2)
|
|
|
| def are_bond_connected(box1, box2, bond_bboxes, bond_iou_threshold=0.1):
|
| """
|
| Check if two atom boxes are connected by a bond box, with bond center in overlap region.
|
|
|
| Args:
|
| box1, box2: atom boxes to check
|
| bond_bboxes: array of bond boxes
|
| bond_iou_threshold: IoU threshold for initial bond overlap
|
|
|
| Returns:
|
| bool: True if connected by a bond with center in overlap region
|
| """
|
|
|
| overlap_region = get_overlap_region(box1, box2)
|
| if overlap_region is None:
|
| return False
|
|
|
| ox_min, oy_min, ox_max, oy_max = overlap_region
|
|
|
| for bond_box in bond_bboxes:
|
|
|
| iou1 = calculate_iou(box1, bond_box)
|
| iou2 = calculate_iou(box2, bond_box)
|
| if iou1 > bond_iou_threshold and iou2 > bond_iou_threshold:
|
|
|
| bond_center_x = (bond_box[0] + bond_box[2]) / 2
|
| bond_center_y = (bond_box[1] + bond_box[3]) / 2
|
|
|
|
|
| if (ox_min <= bond_center_x <= ox_max and
|
| oy_min <= bond_center_y <= oy_max):
|
| return True
|
| return False
|
|
|
| def calculate_iou(box1, box2):
|
| """
|
| 计算两个边界框的 IoU
|
| box1, box2: [x_min, y_min, x_max, y_max]
|
| """
|
| x1 = max(box1[0], box2[0])
|
| y1 = max(box1[1], box2[1])
|
| x2 = min(box1[2], box2[2])
|
| y2 = min(box1[3], box2[3])
|
|
|
| intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| union = area1 + area2 - intersection
|
|
|
| return intersection / union if union > 0 else 0
|
|
|
| def nms(atom_bboxes, atom_scores, atom_classes, iou_threshold=0.5):
|
| """
|
| 应用非极大值抑制 (NMS)
|
| atom_bboxes: 列表,包含所有边界框 [x_min, y_min, x_max, y_max]
|
| atom_scores: 列表,包含每个边界框的置信度
|
| atom_classes: 列表,包含每个边界框的类别
|
| iou_threshold: IoU 阈值,用于判断是否抑制
|
| 返回: 保留的边界框、类别和置信度的索引
|
| """
|
|
|
| indices = np.argsort(atom_scores)[::-1]
|
|
|
| keep_indices = []
|
| while len(indices) > 0:
|
|
|
| current_idx = indices[0]
|
| keep_indices.append(current_idx)
|
|
|
|
|
| ious = np.array([calculate_iou(atom_bboxes[current_idx], atom_bboxes[idx]) for idx in indices[1:]])
|
|
|
| suppress_indices = indices[1:][ious > iou_threshold]
|
|
|
| indices = np.setdiff1d(indices, np.concatenate(([current_idx], suppress_indices)))
|
|
|
|
|
|
|
|
|
|
|
| kept_bboxes = np.array([atom_bboxes[i] for i in keep_indices])
|
| kept_classes = np.array([atom_classes[i] for i in keep_indices])
|
| kept_scores = np.array([atom_scores[i] for i in keep_indices])
|
|
|
| return kept_bboxes, kept_classes, kept_scores
|
|
|
| def count_bond_overlaps(box, bond_bboxes, bond_iou_threshold=0.1):
|
| """
|
| Count how many bond boxes overlap with an atom box.
|
|
|
| Args:
|
| box: atom box [x_min, y_min, x_max, y_max]
|
| bond_bboxes: array of bond boxes
|
| bond_iou_threshold: IoU threshold for overlap
|
|
|
| Returns:
|
| int: number of overlapping bond boxes
|
| """
|
| return sum(1 for bond_box in bond_bboxes if calculate_iou(box, bond_box) > bond_iou_threshold)
|
|
|
|
|
| def count_bond_overlaps(box, bond_bboxes, bond_iou_threshold=0.01):
|
| """Count how many bond boxes overlap with an atom box."""
|
| return sum(1 for bond_box in bond_bboxes if calculate_iou(box, bond_box) > bond_iou_threshold)
|
|
|
| def count_atom_overlaps(box, all_bboxes, exclude_idx, min_iou=0.01):
|
| """Count how many other atom boxes overlap with this box."""
|
| return sum(1 for i, other_box in enumerate(all_bboxes)
|
| if i != exclude_idx and calculate_iou(box, other_box) > min_iou)
|
|
|
| def merge_low_iou_boxes(kept_bboxes, kept_classes, kept_scores, bond_bboxes,
|
| merge_threshold=0.5, score_threshold=0.7, bond_iou_threshold=0.01,
|
| high_iou_threshold=0.8, large_score_threshold=0.5):
|
| """
|
| Merge or filter boxes with IoU conditions, removing large low-score boxes first.
|
|
|
| Args:
|
| kept_bboxes: array, atom bounding boxes [x_min, y_min, x_max, y_max]
|
| kept_classes: array, class labels (e.g., 0 for 'C')
|
| kept_scores: array, confidence scores
|
| bond_bboxes: array, bond bounding boxes
|
| merge_threshold: float, upper IoU threshold for merging
|
| score_threshold: float, score threshold to preserve boxes
|
| bond_iou_threshold: float, IoU threshold for bond connectivity
|
| high_iou_threshold: float, IoU threshold for high-IoU merging
|
| large_score_threshold: float, score threshold for large box removal (default 0.5)
|
|
|
| Returns:
|
| tuple: (merged_bboxes, merged_classes, merged_scores)
|
| """
|
| if len(kept_bboxes) <= 1:
|
| return kept_bboxes, kept_classes, kept_scores
|
|
|
| kept_bboxes = np.array(kept_bboxes)
|
| kept_classes = np.array(kept_classes)
|
| kept_scores = np.array(kept_scores)
|
| bond_bboxes = np.array(bond_bboxes)
|
|
|
|
|
| areas = (kept_bboxes[:, 2] - kept_bboxes[:, 0]) * (kept_bboxes[:, 3] - kept_bboxes[:, 1])
|
| median_area = np.median(areas)
|
| keep_mask = np.ones(len(kept_bboxes), dtype=bool)
|
|
|
| for i in range(len(kept_bboxes)):
|
| if kept_scores[i] < large_score_threshold:
|
| atom_overlaps = count_atom_overlaps(kept_bboxes[i], kept_bboxes, i)
|
| bond_overlaps = count_bond_overlaps(kept_bboxes[i], bond_bboxes, bond_iou_threshold)
|
| is_large = areas[i] > median_area
|
| if is_large and atom_overlaps >= 2 and bond_overlaps >= 3:
|
| keep_mask[i] = False
|
| print(f"Removed large low-score box idx {i}: score {kept_scores[i]}, "
|
| f"area {areas[i]}, atom overlaps {atom_overlaps}, bond overlaps {bond_overlaps}")
|
|
|
|
|
| kept_bboxes = kept_bboxes[keep_mask]
|
| print(f"afterRemoved large low-score atom box::{len(kept_bboxes)} ")
|
| kept_classes = kept_classes[keep_mask]
|
| kept_scores = kept_scores[keep_mask]
|
| if len(kept_bboxes) == 0:
|
| return np.array([]), np.array([]), np.array([])
|
|
|
| merged_bboxes = []
|
| merged_classes = []
|
| merged_scores = []
|
| used_indices = set()
|
|
|
|
|
| i = 0
|
| while i < len(kept_bboxes):
|
| if i in used_indices:
|
| i += 1
|
| continue
|
|
|
| high_iou_group = [i]
|
| for j in range(len(kept_bboxes)):
|
| if j in used_indices or j == i:
|
| continue
|
| iou = calculate_iou(kept_bboxes[i], kept_bboxes[j])
|
| if iou > high_iou_threshold:
|
| high_iou_group.append(j)
|
|
|
| if len(high_iou_group) > 1:
|
| group_scores = kept_scores[high_iou_group]
|
| max_score_idx = high_iou_group[np.argmax(group_scores)]
|
| merged_bboxes.append(kept_bboxes[max_score_idx])
|
| merged_classes.append(kept_classes[max_score_idx])
|
| merged_scores.append(kept_scores[max_score_idx])
|
| used_indices.update(high_iou_group)
|
| print(f"Merged high-IoU (> {high_iou_threshold}) boxes: {high_iou_group}, "
|
| f"kept index: {max_score_idx}")
|
| i += 1
|
|
|
|
|
| i = 0
|
| while i < len(kept_bboxes):
|
| if i in used_indices:
|
| i += 1
|
| continue
|
|
|
| current_indices = [i]
|
| for j in range(len(kept_bboxes)):
|
| if j in used_indices or j == i:
|
| continue
|
| iou = calculate_iou(kept_bboxes[i], kept_bboxes[j])
|
| if 0.05 <= iou < merge_threshold:
|
|
|
| if kept_scores[j]<0.7:
|
| current_indices.append(j)
|
|
|
| group_indices = current_indices
|
| group_scores = kept_scores[group_indices]
|
| group_classes = kept_classes[group_indices]
|
| group_bboxes = kept_bboxes[group_indices]
|
|
|
| max_score = np.max(group_scores)
|
| max_score_idx = group_indices[np.argmax(group_scores)]
|
|
|
| if max_score >= score_threshold:
|
| bond_connected = False
|
| if len(group_indices) > 1:
|
| for idx1, idx2 in zip(group_indices[:-1], group_indices[1:]):
|
| if are_bond_connected(kept_bboxes[idx1], kept_bboxes[idx2],
|
| bond_bboxes, bond_iou_threshold):
|
| bond_connected = True
|
| break
|
| if bond_connected:
|
| for idx in group_indices:
|
| merged_bboxes.append(kept_bboxes[idx])
|
| merged_classes.append(kept_classes[idx])
|
| merged_scores.append(kept_scores[idx])
|
| print(f"Kept all bond-connected boxes: {group_indices}")
|
| else:
|
| bond_overlap_counts = [count_bond_overlaps(kept_bboxes[idx], bond_bboxes,
|
| bond_iou_threshold) for idx in group_indices]
|
| max_overlaps = max(bond_overlap_counts)
|
| candidates = [idx for idx, count in zip(group_indices, bond_overlap_counts)
|
| if count == max_overlaps]
|
| best_idx = max(candidates, key=lambda idx: kept_scores[idx])
|
| merged_bboxes.append(kept_bboxes[best_idx])
|
| merged_classes.append(kept_classes[best_idx])
|
| merged_scores.append(kept_scores[best_idx])
|
|
|
|
|
| else:
|
| if len(group_indices) == 1:
|
| merged_bboxes.append(kept_bboxes[i])
|
| merged_classes.append(kept_classes[i])
|
| merged_scores.append(kept_scores[i])
|
| print(f"Merged lower IOU @@ ONLY ONE box {i}")
|
| else:
|
| new_bbox = [
|
| np.min(group_bboxes[:, 0]),
|
| np.min(group_bboxes[:, 1]),
|
| np.max(group_bboxes[:, 2]),
|
| np.max(group_bboxes[:, 3])
|
| ]
|
| merged_bboxes.append(new_bbox)
|
| merged_classes.append(group_classes[np.argmax(group_scores)])
|
| merged_scores.append(max_score)
|
| print(f"Merged low-score boxes: {group_indices}")
|
| used_indices.update(group_indices)
|
| i += 1
|
|
|
| print(f"after processs low IOU atom box::{len(merged_bboxes)} ")
|
| return (np.array(merged_bboxes), np.array(merged_classes), np.array(merged_scores))
|
|
|
|
|
| def refine_boxes(atom_bboxes, atom_scores, atom_classes, bond_bboxes,
|
| nms_iou_threshold=0.5, merge_threshold=0.5, score_threshold=0.5,
|
| bond_iou_threshold=0.01, high_iou_threshold=0.8):
|
| """
|
| Iteratively apply NMS and merge until the number of boxes stabilizes.
|
|
|
| Args:
|
| atom_bboxes, atom_scores, atom_classes: Initial atom box data
|
| bond_bboxes: Bond box data
|
| nms_iou_threshold, merge_threshold, score_threshold, bond_iou_threshold, high_iou_threshold: Parameters
|
|
|
| Returns:
|
| tuple: (final_bboxes, final_classes, final_scores)
|
| """
|
| current_bboxes = np.array(atom_bboxes)
|
| current_classes = np.array(atom_classes)
|
| current_scores = np.array(atom_scores)
|
| prev_count = len(current_bboxes) + 1
|
|
|
| iteration = 0
|
| while len(current_bboxes) < prev_count:
|
| print(f"\nIteration {iteration}: Starting with {len(current_bboxes)} boxes")
|
| prev_count = len(current_bboxes)
|
|
|
|
|
| kept_bboxes, kept_classes, kept_scores = nms(
|
| current_bboxes, current_scores, current_classes, iou_threshold=nms_iou_threshold
|
| )
|
| print(f"After NMS: {len(kept_bboxes)} boxes")
|
|
|
|
|
| merged_bboxes, merged_classes, merged_scores = merge_low_iou_boxes(
|
| kept_bboxes, kept_classes, kept_scores, bond_bboxes,
|
| merge_threshold=merge_threshold, score_threshold=score_threshold,
|
| bond_iou_threshold=bond_iou_threshold, high_iou_threshold=high_iou_threshold
|
| )
|
| print(f"After merge: {len(merged_bboxes)} boxes")
|
|
|
|
|
| current_bboxes = merged_bboxes
|
| current_classes = merged_classes
|
| current_scores = merged_scores
|
| iteration += 1
|
|
|
| print(f"Converged after {iteration} iterations with {len(current_bboxes)} boxes")
|
| return current_bboxes, current_scores, current_classes
|
|
|
| def merge_low_iou_boxes_old(kept_bboxes, kept_classes, kept_scores, merge_threshold=0.3):
|
| """
|
| 合并 IoU < merge_threshold 的边界框,使用较高 score 的 class
|
| """
|
| if len(kept_bboxes) <= 1:
|
| return kept_bboxes, kept_classes, kept_scores
|
|
|
| merged_bboxes = []
|
| merged_classes = []
|
| merged_scores = []
|
| used_indices = set()
|
|
|
| for i in range(len(kept_bboxes)):
|
| if i in used_indices:
|
| continue
|
|
|
|
|
| current_indices = [i]
|
| for j in range(i + 1, len(kept_bboxes)):
|
| if j in used_indices:
|
| continue
|
| iou = calculate_iou(kept_bboxes[i], kept_bboxes[j])
|
| if iou < merge_threshold and iou >0.01:
|
| current_indices.append(j)
|
|
|
|
|
| scores = kept_scores[current_indices]
|
| classes = kept_classes[current_indices]
|
| bboxes = kept_bboxes[current_indices]
|
|
|
| max_score = np.max(scores)
|
| max_score_idx = current_indices[np.argmax(scores)]
|
|
|
| if max_score > 0.5:
|
|
|
| merged_bboxes.append(kept_bboxes[max_score_idx])
|
| merged_classes.append(kept_classes[max_score_idx])
|
| merged_scores.append(kept_scores[max_score_idx])
|
| else:
|
|
|
| new_bbox = [
|
| np.min(bboxes[:, 0]),
|
| np.min(bboxes[:, 1]),
|
| np.max(bboxes[:, 2]),
|
| np.max(bboxes[:, 3])
|
| ]
|
| merged_bboxes.append(new_bbox)
|
| merged_classes.append(0)
|
| merged_scores.append(max_score)
|
|
|
|
|
| used_indices.update(current_indices)
|
|
|
|
|
| merged_bboxes = np.array(merged_bboxes)
|
| merged_classes = np.array(merged_classes)
|
| merged_scores = np.array(merged_scores)
|
|
|
| return merged_bboxes, merged_classes, merged_scores
|
|
|
|
|
|
|
| from SmilesPE.pretokenizer import atomwise_tokenizer
|
|
|
| def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True):
|
| if type(smiles) is not str or smiles == '':
|
| return '', False
|
| if ignore_cistrans:
|
| smiles = smiles.replace('/', '').replace('\\', '')
|
| if replace_rgroup:
|
| tokens = atomwise_tokenizer(smiles)
|
| for j, token in enumerate(tokens):
|
| if token[0] == '[' and token[-1] == ']':
|
| symbol = token[1:-1]
|
| if symbol[0] == 'R' and symbol[1:].isdigit():
|
| tokens[j] = f'[{symbol[1:]}*]'
|
| elif Chem.AtomFromSmiles(token) is None:
|
| tokens[j] = '*'
|
| smiles = ''.join(tokens)
|
| try:
|
| canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral))
|
| success = True
|
| except:
|
| canon_smiles = smiles
|
| success = False
|
| return canon_smiles, success
|
|
|
| def convert_smiles_to_canonsmiles(
|
| smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16):
|
| with multiprocessing.Pool(num_workers) as p:
|
| results = p.starmap(canonicalize_smiles,
|
| [(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list],
|
| chunksize=128)
|
| canon_smiles, success = zip(*results)
|
| return list(canon_smiles), np.mean(success)
|
|
|
| def tanimoto_similarity(smiles1, smiles2):
|
| try:
|
| mol1 = Chem.MolFromSmiles(smiles1)
|
| mol2 = Chem.MolFromSmiles(smiles2)
|
| fp1 = Chem.RDKFingerprint(mol1)
|
| fp2 = Chem.RDKFingerprint(mol2)
|
| tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2)
|
| return tanimoto
|
| except:
|
| return 0
|
|
|
|
|
| def compute_tanimoto_similarities(gold_smiles, pred_smiles, num_workers=16):
|
| with multiprocessing.Pool(num_workers) as p:
|
| similarities = p.starmap(tanimoto_similarity, [(gs, ps) for gs, ps in zip(gold_smiles, pred_smiles)])
|
| return similarities
|
|
|
| class SmilesEvaluator(object):
|
| def __init__(self, gold_smiles, num_workers=16, tanimoto=False):
|
| self.gold_smiles = gold_smiles
|
| self.num_workers = num_workers
|
| self.tanimoto = tanimoto
|
| self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles,
|
| ignore_cistrans=True,
|
| num_workers=num_workers)
|
| self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles,
|
| ignore_chiral=True, ignore_cistrans=True,
|
| num_workers=num_workers)
|
| self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans)
|
| self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral)
|
|
|
| def _replace_empty(self, smiles_list):
|
| """Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty."""
|
| return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>"
|
| for smiles in smiles_list]
|
|
|
| def evaluate(self, pred_smiles, include_details=False):
|
| results = {}
|
| if self.tanimoto:
|
| results['tanimoto'] = np.mean(compute_tanimoto_similarities(self.gold_smiles, pred_smiles))
|
|
|
| pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles,
|
| ignore_cistrans=True,
|
| num_workers=self.num_workers)
|
| results['canon_smiles'] = np.mean(np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans))
|
| if include_details:
|
| results['canon_smiles_details'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans))
|
|
|
| pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles,
|
| ignore_chiral=True, ignore_cistrans=True,
|
| num_workers=self.num_workers)
|
| results['graph'] = np.mean(np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral))
|
|
|
| chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g])
|
| results['chiral'] = np.mean(chiral[:, 0] == chiral[:, 1]) if len(chiral) > 0 else -1
|
| return results
|
|
|
|
|
|
|
|
|
| def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| device: torch.device, epoch: int, max_norm: float = 0, **kwargs):
|
| model.train()
|
| criterion.train()
|
| metric_logger = MetricLogger(delimiter=" ")
|
| metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
|
|
| header = 'Epoch: [{}]'.format(epoch)
|
| print_freq = kwargs.get('print_freq', 10)
|
|
|
| ema = kwargs.get('ema', None)
|
| scaler = kwargs.get('scaler', None)
|
|
|
| for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| samples = samples.to(device)
|
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
|
| if scaler is not None:
|
| with torch.autocast(device_type=str(device), cache_enabled=True):
|
| outputs = model(samples, targets)
|
|
|
| with torch.autocast(device_type=str(device), enabled=False):
|
| loss_dict = criterion(outputs, targets)
|
|
|
| loss = sum(loss_dict.values())
|
| scaler.scale(loss).backward()
|
|
|
| if max_norm > 0:
|
| scaler.unscale_(optimizer)
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
|
|
| scaler.step(optimizer)
|
| scaler.update()
|
| optimizer.zero_grad()
|
|
|
| else:
|
| outputs = model(samples, targets)
|
| loss_dict = criterion(outputs, targets)
|
|
|
| loss = sum(loss_dict.values())
|
| optimizer.zero_grad()
|
| loss.backward()
|
|
|
| if max_norm > 0:
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
|
|
| optimizer.step()
|
|
|
|
|
| if ema is not None:
|
| ema.update(model)
|
|
|
| loss_dict_reduced = reduce_dict(loss_dict)
|
| loss_value = sum(loss_dict_reduced.values())
|
|
|
| if not math.isfinite(loss_value):
|
| print("Loss is {}, stopping training".format(loss_value))
|
| print(loss_dict_reduced)
|
| sys.exit(1)
|
|
|
| metric_logger.update(loss=loss_value, **loss_dict_reduced)
|
| metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
|
|
|
|
| metric_logger.synchronize_between_processes()
|
| print("Averaged stats:", metric_logger)
|
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def remove_bond_directions_if_no_chiral(mol):
|
|
|
| if mol is None:
|
| return None
|
|
|
| chiral_centers = Chem.FindMolChiralCenters(mol, includeUnassigned=True)
|
|
|
| if not chiral_centers:
|
| for bond in mol.GetBonds():
|
|
|
| if bond.GetBondType() == Chem.BondType.SINGLE:
|
|
|
| bond.SetBondDir(Chem.BondDir.NONE)
|
| return mol
|
|
|
| def molExpanding(mol_rebuit,placeholder_atoms,wdbs,bond_dirs,alignmol=False):
|
| cm=copy.deepcopy(mol_rebuit)
|
|
|
| expand_mol, expand_smiles= expandABB(cm,ABBREVIATIONS, placeholder_atoms)
|
| rdm=copy.deepcopy(expand_mol)
|
| AllChem.Compute2DCoords(rdm)
|
| target_mol, ref_mol=rdm, cm
|
|
|
| if alignmol:
|
| mcs=rdFMCS.FindMCS([target_mol, ref_mol],
|
| atomCompare=rdFMCS.AtomCompare.CompareAny,
|
|
|
| ringCompare=rdFMCS.RingCompare.IgnoreRingFusion,
|
| matchChiralTag=False,
|
| )
|
| atommaping_pairs=g_atompair_matches([target_mol, ref_mol],mcs)
|
| atomMap=atommaping_pairs[0]
|
| try:
|
| rmsd2=rdkit.Chem.rdMolAlign.AlignMol(prbMol=target_mol, refMol=ref_mol, atomMap=atomMap,maxIters=2000000)
|
| except Exception as e:
|
| print(atomMap,"@@@@")
|
| print(e)
|
|
|
| c2p={cur:pre for cur, pre in atomMap}
|
| p2c={pre:cur for cur, pre in atomMap}
|
| for b in wdbs:
|
| p0,p1=int(b[0]), int(b[1])
|
| if p0 in p2c.keys() and p1 in p2c.keys():
|
| c0,c1=p2c[p0],p2c[p1]
|
|
|
| b_=target_mol.GetBondBetweenAtoms(c0,c1)
|
| if b_:
|
| b_.SetBondDir(bond_dirs[b[3]])
|
| expandStero_smi=Chem.MolToSmiles(target_mol)
|
| else:
|
| expandStero_smi =expand_smiles
|
|
|
| m=target_mol.GetMol()
|
|
|
| Chem.DetectBondStereochemistry(m)
|
| Chem.AssignChiralTypesFromBondDirs(m)
|
| Chem.AssignStereochemistry(m)
|
|
|
| return expandStero_smi, m
|
|
|
|
|
| def remove_backslash_and_slash(input_string):
|
| if "\\" in input_string:
|
| input_string = input_string.replace("\\", "")
|
| if "/" in input_string:
|
| input_string = input_string.replace("/", "")
|
|
|
| return input_string
|
|
|
|
|
| def remove_number_before_star(input_string):
|
| result = list(input_string)
|
|
|
| i = 0
|
| while i < len(result):
|
| if result[i] == '*' and i!= len(result) -1:
|
|
|
| j = i - 1
|
| if result[j-1].isalpha():
|
| continue
|
| while j >= 0 and result[j].isdigit():
|
| result[j] = ''
|
| j -= 1
|
| i += 1
|
|
|
| return ''.join(result)
|
|
|
| def remove_SP(input_string):
|
| pattern = r'\[([^@]*)@?[A-Z0-9]*\]'
|
|
|
|
|
|
|
|
|
|
|
|
|
| input_string = re.sub(r'@SP[1-3]', '', input_string)
|
| if '@TB' in input_string:
|
| result = re.sub(pattern, r'[\1]', input_string)
|
| input_string=result
|
| return input_string
|
|
|
| def rdkit_canonicalize_smiles(smiles):
|
| Aad_string = r'([A-Z][a-z]*)([0-9]+)'
|
| tokens = atomwise_tokenizer(smiles)
|
| for j, token in enumerate(tokens):
|
| if token[0] == '[' and token[-1] == ']':
|
| symbol = token[1:-1]
|
|
|
| matches = re.match(Aad_string, symbol)
|
| if matches:
|
| letters, numbers = matches.groups()
|
| print(f"{letters} {numbers}")
|
|
|
| tokens[j] = '*'
|
| elif symbol in RGROUP_SYMBOLS:
|
| tokens[j] = '*'
|
| elif Chem.AtomFromSmiles(token) is None:
|
| tokens[j] = '*'
|
|
|
| smiles = ''.join(tokens)
|
| try:
|
| canon_smiles = Chem.CanonSmiles(smiles, useChiral=False)
|
| success = True
|
| except:
|
| canon_smiles = smiles
|
| success = False
|
| return canon_smiles, success
|
|
|
| def NoRadical_Smi(smi):
|
| aa=Chem.MolFromSmiles(smi)
|
| for atom in aa.GetAtoms():
|
| if atom.GetNumRadicalElectrons() > 0:
|
|
|
|
|
| atom.SetNumRadicalElectrons(0)
|
|
|
| atom.SetNumExplicitHs(atom.GetTotalValence() - atom.GetExplicitValence())
|
| san_before=Chem.MolToSmiles(aa)
|
|
|
| return san_before
|
|
|
| import logging
|
|
|
| def check_and_fix_valence(smiles_or_list):
|
| """
|
| Check atom valences in a SMILES string or a list [smiles, suffix/prefix].
|
| Fix unusual valences (e.g., N(2)) by adding/removing hydrogens to maintain neutrality.
|
| Returns: (corrected_smiles_or_list, warnings)
|
| """
|
|
|
| logging.basicConfig(level=logging.WARNING)
|
| warnings = []
|
|
|
|
|
| standard_valences = {
|
| 'C': [4],
|
| 'N': [3],
|
| 'O': [2],
|
| 'H': [1],
|
| 'F': [1]
|
| }
|
|
|
|
|
| if isinstance(smiles_or_list, list):
|
| smiles, other_part = smiles_or_list
|
| else:
|
| smiles, other_part = smiles_or_list, None
|
|
|
|
|
| mol = Chem.MolFromSmiles(smiles, sanitize=False) if smiles else None
|
| if mol is None:
|
| warnings.append(f"Invalid SMILES: {smiles}")
|
| return smiles_or_list, warnings
|
|
|
|
|
| other_part_mol = None
|
| if other_part:
|
| try:
|
| other_part_mol = Chem.MolFromSmiles(other_part, sanitize=False)
|
| except:
|
| pass
|
|
|
|
|
| def process_molecule(mol, is_other_part=False):
|
| nonlocal warnings
|
| corrected = False
|
| prefix = "other_part" if is_other_part else "SMILES"
|
|
|
|
|
| mol.UpdatePropertyCache(strict=False)
|
|
|
|
|
| for atom in mol.GetAtoms():
|
| symbol = atom.GetSymbol()
|
| valence = atom.GetTotalValence()
|
| expected_valences = standard_valences.get(symbol, [valence])
|
| if valence not in expected_valences:
|
| warnings.append(f"Unusual valence in {prefix} for {symbol}: {valence} (expected {expected_valences})")
|
|
|
|
|
| if any('N' in w for w in warnings if prefix in w):
|
| rw_mol = Chem.RWMol(mol)
|
| for atom in rw_mol.GetAtoms():
|
| if atom.GetSymbol() != 'N':
|
| continue
|
| valence = atom.GetTotalValence()
|
| if valence < 3:
|
|
|
| hydrogens_needed = 3 - valence
|
| atom.SetNumExplicitHs(atom.GetNumExplicitHs() + hydrogens_needed)
|
| corrected = True
|
| elif valence > 3:
|
|
|
| hydrogens_to_remove = valence - 3
|
| current_hydrogens = atom.GetNumExplicitHs()
|
| if current_hydrogens >= hydrogens_to_remove:
|
| atom.SetNumExplicitHs(current_hydrogens - hydrogens_to_remove)
|
| corrected = True
|
| else:
|
| warnings.append(f"Cannot reduce N valence in {prefix} to 3 without removing non-H bonds")
|
| if corrected:
|
| mol = rw_mol.GetMol()
|
|
|
|
|
| if corrected:
|
| try:
|
| Chem.SanitizeMol(mol, catchErrors=True)
|
| return mol, True
|
| except Exception as e:
|
| warnings.append(f"Failed to sanitize {prefix} after correction: {str(e)}")
|
| return mol, False
|
| return mol, False
|
|
|
|
|
| mol, mol_corrected = process_molecule(mol)
|
|
|
|
|
| corrected_smiles = Chem.MolToSmiles(mol) if mol_corrected else smiles
|
|
|
|
|
| corrected_other_part = other_part
|
| if other_part_mol:
|
| other_part_mol, other_corrected = process_molecule(other_part_mol, is_other_part=True)
|
| corrected_other_part = Chem.MolToSmiles(other_part_mol) if other_corrected else other_part
|
|
|
|
|
| if other_part:
|
| return [corrected_smiles, corrected_other_part], warnings
|
| return corrected_smiles, warnings
|
|
|
| def molfpsim(original_smiles,test_smiles):
|
|
|
| test_smiles= select_longest_smiles(test_smiles)
|
| original_smiles= select_longest_smiles(original_smiles)
|
| test_smiles, warnings=check_and_fix_valence(test_smiles)
|
|
|
| original_smiles = remove_backslash_and_slash(original_smiles)
|
| test_smiles = remove_backslash_and_slash(test_smiles)
|
| original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)
|
| test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles)
|
| original_smiles = remove_SP(original_smiles)
|
| test_smiles = remove_SP(test_smiles)
|
|
|
| rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)
|
| if "S" in rd_smi_ori and success1:
|
| rd_smi_ori=NoRadical_Smi(rd_smi_ori)
|
| rd_smi, success2=rdkit_canonicalize_smiles(test_smiles)
|
| original_smiles,test_smiles=rd_smi_ori,rd_smi
|
|
|
| mol1 = Chem.MolFromSmiles(original_smiles)
|
| mol2 = Chem.MolFromSmiles(test_smiles)
|
|
|
| morganfps1 = AllChem.GetMorganFingerprint(mol1, useChirality=False)
|
| morganfps2 = AllChem.GetMorganFingerprint(mol2, useChirality=False)
|
| morgan_tani = DataStructs.DiceSimilarity(morganfps1, morganfps2)
|
| fp1 = Chem.RDKFingerprint(mol1)
|
| fp2 = Chem.RDKFingerprint(mol2)
|
| tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2)
|
| return morgan_tani, tanimoto
|
|
|
|
|
|
|
|
|
| def comparing_smiles2(original_smiles,test_smiles):
|
| original_smiles = remove_backslash_and_slash(original_smiles)
|
| test_smiles = remove_backslash_and_slash(test_smiles)
|
| original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)
|
| test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles)
|
| original_smiles = remove_SP(original_smiles)
|
| test_smiles = remove_SP(test_smiles)
|
|
|
| rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)
|
| if "S" in rd_smi_ori and success1:
|
| rd_smi_ori=NoRadical_Smi(rd_smi_ori)
|
|
|
| rd_smi, success2=rdkit_canonicalize_smiles(test_smiles)
|
| original_smiles,test_smiles=rd_smi_ori,rd_smi
|
|
|
| try:
|
| original_mol = Chem.MolFromSmiles(original_smiles)
|
| test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)
|
| if success2 and success1:
|
|
|
|
|
| RDarom_smi=Chem.MolToSmiles(original_mol)
|
| RDarom_smi_test=Chem.MolToSmiles(test_mol)
|
| if RDarom_smi==RDarom_smi_test:
|
| return True
|
| else:
|
| print(f'smiles ori,pred after Chem.CanonSmiles:\n{RDarom_smi}\n{RDarom_smi_test}\n')
|
|
|
| if original_mol:
|
| Chem.SanitizeMol(original_mol)
|
| keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True)
|
| else:
|
| keku_smi_ori=original_smiles
|
|
|
| if test_mol:
|
| Chem.SanitizeMol(test_mol)
|
| keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True)
|
| else:
|
| keku_smi=test_smiles
|
|
|
| if '*' not in keku_smi:
|
| keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori))
|
| keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi))
|
| else:
|
| keku_inch_ori= 1
|
| keku_inch_test= 2
|
|
|
| rd_smi=Chem.MolToSmiles(test_mol)
|
| rd_smi_ori=Chem.MolToSmiles(original_mol)
|
| except Exception as e:
|
| print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")
|
| print(e,"!!!!!!!\n")
|
| keku_inch_ori= 1
|
| keku_inch_test= 2
|
| keku_smi=1
|
| keku_smi_ori=2
|
|
|
| if not success1:
|
| rd_smi_ori = rd_smi
|
|
|
|
|
|
|
|
|
| if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :
|
| return True
|
| else:return False
|
|
|
| def smiles12_comparing(original_smiles,test_smiles):
|
| original_smiles = remove_backslash_and_slash(original_smiles)
|
| test_smiles = remove_backslash_and_slash(test_smiles)
|
| original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)
|
| test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles)
|
| original_smiles = remove_SP(original_smiles)
|
| test_smiles = remove_SP(test_smiles)
|
|
|
| rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)
|
| rd_smi, success2=rdkit_canonicalize_smiles(test_smiles)
|
| original_smiles,test_smiles=rd_smi_ori,rd_smi
|
| try:
|
| original_mol = Chem.MolFromSmiles(original_smiles)
|
| test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)
|
| if original_mol:
|
| Chem.SanitizeMol(original_mol)
|
| keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True)
|
| else:
|
| keku_smi_ori=original_smiles
|
|
|
| if test_mol:
|
| Chem.SanitizeMol(test_mol)
|
| keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True)
|
| else:
|
| keku_smi=test_smiles
|
|
|
| if '*' not in keku_smi:
|
| keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori))
|
| keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi))
|
| else:
|
| keku_inch_ori= 1
|
| keku_inch_test= 2
|
|
|
| rd_smi=Chem.MolToSmiles(test_mol)
|
| rd_smi_ori=Chem.MolToSmiles(original_mol)
|
| except Exception as e:
|
| print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")
|
| print(e,"!!!!!!!\n")
|
| keku_inch_ori= 1
|
| keku_inch_test= 2
|
| keku_smi=1
|
| keku_smi_ori=2
|
|
|
| if not success1:
|
| rd_smi_ori = rd_smi
|
|
|
|
|
|
|
|
|
| if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :
|
| return True
|
| else:return False
|
|
|
|
|
| def comparing_smiles(new_row,test_smiles):
|
| original_smiles=new_row['SMILESori']
|
| original_smiles = remove_backslash_and_slash(original_smiles)
|
| test_smiles = remove_backslash_and_slash(test_smiles)
|
| original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)
|
| test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles)
|
| original_smiles = remove_SP(original_smiles)
|
| test_smiles = remove_SP(test_smiles)
|
|
|
| rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)
|
| rd_smi, success2=rdkit_canonicalize_smiles(test_smiles)
|
| original_smiles,test_smiles=rd_smi_ori,rd_smi
|
| try:
|
| original_mol = Chem.MolFromSmiles(original_smiles)
|
| test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)
|
| if original_mol:
|
| Chem.SanitizeMol(original_mol)
|
| keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True)
|
| else:
|
| keku_smi_ori=original_smiles
|
|
|
| if test_mol:
|
| Chem.SanitizeMol(test_mol)
|
| keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True)
|
| else:
|
| keku_smi=test_smiles
|
|
|
| if '*' not in keku_smi:
|
| keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori))
|
| keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi))
|
| else:
|
| keku_inch_ori= 1
|
| keku_inch_test= 2
|
|
|
| rd_smi=Chem.MolToSmiles(test_mol)
|
| rd_smi_ori=Chem.MolToSmiles(original_mol)
|
| except Exception as e:
|
| print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")
|
| print(new_row)
|
| print(e,"!!!!!!!\n")
|
| keku_inch_ori= 1
|
| keku_inch_test= 2
|
| keku_smi=1
|
| keku_smi_ori=2
|
|
|
| if not success1:
|
| rd_smi_ori = rd_smi
|
|
|
|
|
|
|
|
|
| if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :
|
| return True
|
| else:return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def bbox2center(bbox):
|
| x_center = (bbox[:, 0] + bbox[:, 2]) / 2
|
| y_center = (bbox[:, 1] + bbox[:, 3]) / 2
|
|
|
| centers = np.stack((x_center, y_center), axis=1)
|
| return centers
|
|
|
| import cv2
|
| BONDDIRECT=['ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']
|
|
|
|
|
| def reorder_bond_bbox(bond_bbox, single_atom_bond):
|
|
|
| normal_indices = []
|
| special_indices = []
|
|
|
| keys_to_move = set(single_atom_bond.keys())
|
|
|
| for i in range(len(bond_bbox)):
|
| if i in keys_to_move:
|
| special_indices.append(i)
|
| else:
|
| normal_indices.append(i)
|
|
|
| new_order = normal_indices + special_indices
|
|
|
| reordered_bbox = [bond_bbox[i] for i in new_order]
|
| return reordered_bbox
|
|
|
| def boxes_overlap(box1, box2):
|
| """
|
| 检查两个边界框是否重叠
|
| box1, box2: [x1, y1, x2, y2]
|
| """
|
| return not (box1[2] < box2[0] or box1[0] > box2[2] or
|
| box1[3] < box2[1] or box1[1] > box2[3])
|
| def calculate_center(box):
|
| """
|
| 计算边界框的中心点
|
| """
|
| return np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2])
|
| def merge_boxes(box1, box2):
|
| """
|
| 合并两个边界框,返回新边界框 [x1, y1, x2, y2]
|
| """
|
| return [
|
| min(box1[0], box2[0]),
|
| min(box1[1], box2[1]),
|
| max(box1[2], box2[2]),
|
| max(box1[3], box2[3])
|
| ]
|
|
|
|
|
| def get_merged_box(boxes):
|
| """Calculate the smallest box encompassing all given boxes."""
|
| x_mins = [box[0] for box in boxes]
|
| y_mins = [box[1] for box in boxes]
|
| x_maxs = [box[2] for box in boxes]
|
| y_maxs = [box[3] for box in boxes]
|
| return [min(x_mins), min(y_mins), max(x_maxs), max(y_maxs)]
|
|
|
| def box_area(box):
|
| """Calculate the area of a box."""
|
| return (box[2] - box[0]) * (box[3] - box[1])
|
|
|
| def Newbox_(atom_bbox,bond_bbox, lab2idx):
|
|
|
| new_atoms=[]
|
| b_len=3
|
| single_odd_b2a=dict()
|
| for bi,bb in enumerate(bond_bbox):
|
| overlapped_atoms = []
|
| overlapped_abox=[]
|
| for ai,aa in enumerate(atom_bbox):
|
| overlap_flag=boxes_overlap(bb, aa)
|
| if overlap_flag:
|
|
|
| overlapped_atoms.append(ai)
|
| overlapped_abox.append(aa)
|
| if len(overlapped_atoms) == 1:
|
| single_odd_b2a[bi]=overlapped_atoms
|
|
|
| non_overlapping_x,non_overlapping_y=boxes_overlap2(overlapped_abox[0], bb)
|
| new_atom_out={'bbox': np.array([non_overlapping_x - b_len,
|
| non_overlapping_y - b_len,
|
| non_overlapping_x + b_len,
|
| non_overlapping_y + b_len]).reshape(-1,4),
|
| 'bbox_centers': np.array([non_overlapping_x,non_overlapping_y]).reshape(-1,2),
|
| 'scores': np.array([1.0]),
|
| 'pred_classes': np.array([lab2idx['H']])}
|
| new_atoms.append(new_atom_out)
|
| return new_atoms, single_odd_b2a
|
|
|
|
|
| def has_boxes(data):
|
|
|
| return isinstance(data, list) and len(data) > 0 and all(
|
| isinstance(item, list) and len(item) == 2 and
|
| isinstance(item[0], list) and len(item[0]) == 4
|
| for item in data
|
| )
|
|
|
| def AtomBox2bondBox(atom_box,bond_bbox):
|
| b_nei=[]
|
| overlap=True
|
| for bi,bb in enumerate(bond_bbox):
|
| overlap_flag=boxes_overlap(bb, atom_box)
|
| if overlap_flag:
|
| b_nei.append(bi)
|
| if len(b_nei)==0:
|
|
|
| overlap=False
|
| return overlap, b_nei
|
|
|
|
|
| import torchvision.transforms.v2 as T
|
|
|
| def image_to_tensor(image_path,debug=True):
|
| image = Image.open(image_path)
|
| w, h = image.size
|
|
|
|
|
| if image.mode == "L":
|
| if debug: print("检测到灰度图像 (1 通道),转换为 RGB...")
|
| image = image.convert("RGB")
|
| elif image.mode != "RGB":
|
| if debug: print(f"检测到 {image.mode} 模式,转换为 RGB...")
|
| image = image.convert("RGB")
|
|
|
| transform = T.Compose([
|
| T.Resize((640, 640)),
|
|
|
| T.ToTensor(),
|
| lambda x: x.to(torch.float32),
|
| ])
|
|
|
|
|
| tensor = transform(image)
|
|
|
| return tensor,w,h
|
|
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def evaluate_x(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors,
|
| data_loader, device,
|
| outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_CLEF.csv',
|
| visual_check=False,
|
| other2ppsocr=True,
|
| getacc=False,
|
| ):
|
|
|
| postprocessor2=RTDETRPostProcessor(num_classes=30, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False)
|
| output_directory = os.path.dirname(outcsv_filename)
|
| prefix_f = os.path.basename(outcsv_filename).split('.')[0]
|
| if other2ppsocr:
|
| ocr = PaddleOCR(
|
| use_angle_cls=True,
|
| lang='latin',use_space_char=True,use_debug=False,
|
| use_gpu=True if cv2.cuda.getCudaEnabledDeviceCount() > 0 else False)
|
|
|
| ocr2 = ocr2 = PaddleOCR(use_angle_cls=True,use_gpu =False,use_debug=False,
|
| rec_algorithm='SVTR_LCNet', rec_model_dir='/nfs_home/bowen/.paddleocr/whl/rec/en/en_PP-OCRv4_rec_infer',
|
| lang="en")
|
| outcsv_filename=f"{output_directory}/{prefix_f}_withOCR.csv"
|
|
|
|
|
| if visual_check:
|
| output_directory = os.path.dirname(outcsv_filename)
|
| prefix_f = os.path.basename(outcsv_filename).split('.')[0]
|
| ima_checkdir=f"{output_directory}/{prefix_f}_Boxed"
|
| os.makedirs(ima_checkdir, exist_ok=True)
|
|
|
| if getacc:
|
| acc_summary=f"{outcsv_filename}.I2Msummary.txt"
|
| flogout = open(f'{acc_summary}' , 'w')
|
| failed=[]
|
| mydiff=[]
|
| simRD=0
|
| sim=0
|
| mysum=0
|
|
|
| model.eval()
|
| criterion.eval()
|
| metric_logger = MetricLogger(delimiter=" ")
|
| header = 'Infering:'
|
| res_smiles = []
|
| idx_to_labels23={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
|
| 9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
|
| 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'2',}
|
| idx_to_labels30 = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
|
| 9:'I',10:'P',11:'H',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
|
| 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'2',
|
| 23:'CF3',
|
| 24:'CN',
|
| 25:'Me',
|
| 26:'CO2Et',
|
| 27:'R',
|
| 28:'Ph',
|
| 29:'*',
|
| }
|
| bond_labels = [13,14,15,16,17]
|
|
|
| if postprocessors.num_classes==23:
|
|
|
| print(f'usage idx_to_labels23',idx_to_labels23)
|
| idx_to_labels = idx_to_labels23
|
| elif postprocessors.num_classes==30:
|
|
|
| print(f'usage idx_to_labels30',idx_to_labels30)
|
| idx_to_labels = idx_to_labels30
|
| else:
|
| print(f"error unkown ways@@@@@@@@@@@!!!!!!!!!!idx_to_labels::{len(idx_to_labels)}\n{idx_to_labels}")
|
| abrevie={"[23*]":'CF3',
|
| "[24*]":'CN',
|
| "[25*]":'Me',
|
| "[26*]":'CO2Et',
|
| "[27*]":'R',
|
| "[28*]":'Ph',
|
| "[29*]":'3~7UP',
|
| }
|
|
|
| lab2idx={ v:k for k,v in idx_to_labels.items() }
|
|
|
| smiles_data = pd.DataFrame({'file_name': [],
|
| 'SMILESori':[],
|
| 'SMILESpre':[],
|
| 'SMILESexp':[],
|
| }
|
| )
|
| output_dict = {}
|
| output_ori={}
|
| filtered_output_dict = {}
|
| box_thresh=0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| for samples, targets in metric_logger.log_every(data_loader, 10, header):
|
|
|
| for i_, ti in enumerate(targets['image_id']):
|
| output_dict[ti.item()] = [
|
| targets['img_path'][i_],
|
| targets['SMILES'][i_],
|
| ]
|
|
|
|
|
| for key, value in output_dict.items():
|
|
|
| image_path = value[0]
|
| SMILESori = value[1]
|
|
|
|
|
|
|
|
|
|
|
| image_path='/cadd_data/samba_share/from_docker/data/work_space/ori/real/acs/ol020229e-Scheme-c3-10.png'
|
|
|
| tensor,w,h = image_to_tensor(image_path)
|
| tensor=tensor.unsqueeze(0).to(device)
|
| print(tensor.size())
|
| ori_size=torch.Tensor([w,h]).long().unsqueeze(0).to(device)
|
| outputs = model(tensor)
|
| result_ = postprocessor2(outputs, ori_size)
|
|
|
| score_=result_[0]['scores']
|
| boxe_=result_[0]['boxes']
|
| label_=result_[0]['labels']
|
|
|
| selected_indices =score_ > box_thresh
|
| true_count = selected_indices.sum().item()
|
| output={
|
| 'labels': label_[selected_indices].to("cpu").numpy(),
|
| 'boxes': boxe_[selected_indices].to("cpu").numpy(),
|
| 'scores': score_[selected_indices].to("cpu").numpy()
|
| }
|
|
|
| img_ori = Image.open(image_path).convert('RGB')
|
| w_ori, h_ori = img_ori.size
|
| print(w_ori, h_ori, "orignianl vs 1000,1000")
|
|
|
| print(f"selected_indices 中 True 的数量: {true_count}")
|
| print(f"before nms_per_class, :: box 的数量:{len(output['labels'])}")
|
| output = nms_per_class(output['labels'], output['boxes'], output['scores'], iou_thresh=0.5)
|
| print(f"after nms_per_class, :: box 的数量:{len(output['labels'])}")
|
|
|
|
|
|
|
| x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
|
| y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
|
|
|
| center_coords = np.stack((x_center, y_center), axis=1)
|
|
|
|
|
| output = {'bbox': output["boxes"],
|
| 'bbox_centers': center_coords,
|
| 'scores': output["scores"],
|
| 'pred_classes': output["labels"],
|
| }
|
|
|
| img_ori = Image.open(image_path).convert('RGB')
|
| w_ori, h_ori = img_ori.size
|
| print(w_ori, h_ori, "orignianl vs 1000,1000")
|
|
|
| scale_x = 1000 / w_ori
|
| scale_y = 1000 / h_ori
|
| img_ori_1k = img_ori.resize((1000,1000))
|
| img = Image.open(image_path).convert('RGB')
|
| img = img.resize((1000,1000))
|
|
|
|
|
| print(f"from oupt socore > {box_thresh},get box {len(output['bbox'])} after nms_per_class ")
|
|
|
| charge_mask = np.array([True if ins in charge_labels and output['scores'][i]>0.1 else False for i, ins in enumerate(output['pred_classes'])])
|
| charges_bbox=output['bbox'][charge_mask]
|
| charges_centers= output['bbox_centers'][charge_mask]
|
| charges_classes= output['pred_classes'][charge_mask]
|
| charges_scores= output['scores'][charge_mask]
|
| charges_bbox, charges_centers, charges_scores,charges_classes,figc =view_box_center2(charges_bbox, charges_centers, charges_scores,charges_classes, overlap_dist_thresh=5.0, max_centers_per_box=5)
|
|
|
|
|
|
|
| bonds_mask = np.array([True if ins in bond_labels and output['scores'][i]>0.2 else False for i, ins in enumerate(output['pred_classes'])])
|
| bond_bbox=output['bbox'][bonds_mask]
|
| bond_centers= output['bbox_centers'][bonds_mask]
|
| bond_classes= output['pred_classes'][bonds_mask]
|
| bond_scores= output['scores'][bonds_mask]
|
|
|
| bond_bbox, bond_centers, bond_scores,bond_classes,fig =view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=3)
|
| bond_bbox, bond_classes, bond_scores = nms(bond_bbox, bond_scores,bond_classes, iou_threshold=0.5)
|
|
|
| heavy_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins != lab2idx['H'] else False for ins in output['pred_classes']])
|
| h_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins == lab2idx['H'] else False for ins in output['pred_classes']])
|
|
|
|
|
| heavy_bbox = output['bbox'][heavy_mask]
|
| heavy_classes = output['pred_classes'][heavy_mask]
|
| heavy_centers= output['bbox_centers'][heavy_mask]
|
| heavy_scores= output['scores'][heavy_mask]
|
| heavy_bbox, heavy_centers, heavy_scores,heavy_classes,fighv =view_box_center2(heavy_bbox, heavy_centers, heavy_scores,heavy_classes, overlap_dist_thresh=5.0, max_centers_per_box=5)
|
|
|
|
|
| delt_hei=[]
|
| for hei,hebox in enumerate(heavy_bbox):
|
| he_class=idx_to_labels[heavy_classes[hei]]
|
| b_nei=[]
|
| if he_class in ['C']:
|
| for bi,bb in enumerate(bond_bbox):
|
| overlap_flag=boxes_overlap(bb, hebox)
|
| if overlap_flag:
|
| b_nei.append(bi)
|
| if len(b_nei)==0:
|
| delt_hei.append(hei)
|
| n = len(heavy_scores)
|
| keep_boxes = np.ones(n, dtype=bool)
|
| keep_boxes[delt_hei]=False
|
| heavy_bbox, heavy_centers, heavy_scores,heavy_classes=heavy_bbox[keep_boxes], heavy_centers[keep_boxes], heavy_scores[keep_boxes],heavy_classes[keep_boxes]
|
|
|
| h_bbox = output['bbox'][h_mask]
|
| h_centers= output['bbox_centers'][h_mask]
|
| h_classes= output['pred_classes'][h_mask]
|
| h_scores= output['scores'][h_mask]
|
| h_bbox, h_centers, h_scores,h_classes,figh =view_box_center2(h_bbox, h_centers, h_scores,h_classes, overlap_dist_thresh=5.0, max_centers_per_box=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
| atom_bbox = np.concatenate([heavy_bbox, h_bbox], axis=0)
|
| atom_classes = np.concatenate([heavy_classes, h_classes], axis=0)
|
|
|
| atom_scores = np.concatenate([heavy_scores, h_scores], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| atom_bbox, atom_scores, atom_classes = refine_boxes(atom_bbox, atom_scores, atom_classes, bond_bbox)
|
|
|
|
|
| x_center = (atom_bbox[:, 0] + atom_bbox[:, 2]) / 2
|
| y_center = (atom_bbox[:, 1] + atom_bbox[:, 3]) / 2
|
|
|
| center_coords = np.stack((x_center, y_center), axis=1)
|
| atom_centers=center_coords
|
|
|
| print(f"before NMS :: heavy box {len(heavy_bbox)} ---- H box {len(h_bbox)}---bond box{len(bond_bbox)}")
|
| print(f"after NMS+view_box_center2 :: atom box {len(atom_bbox)} bond box {len(bond_bbox)} charge box {len(charges_bbox)} ")
|
|
|
| print(f"atom box afte NMS and merge_low_iou_boxes")
|
| print(f"get box {len(output['bbox'])} with NMS")
|
| print(f"atom score >0.1 bond score >0.2, then folllowed with NMS")
|
| print(f"bond_bbox nums::",bond_bbox.shape,len(bond_bbox))
|
| print(f" OCR will start involved ")
|
|
|
| new_atoms, single_odd_b2a= Newbox_(atom_bbox,bond_bbox, lab2idx )
|
| print(f"new_atoms number {len(new_atoms)}\n{new_atoms}")
|
| if len(new_atoms)>0:
|
| for boxout in new_atoms:
|
| for k,arr in boxout.items():
|
| value_or_row=output[k]
|
| if arr.ndim == 1:
|
| output[k]=np.append(value_or_row, arr)
|
| elif arr.ndim >= 2:
|
| output[k] = np.concatenate([value_or_row, arr], axis=0)
|
| else:
|
| print('errprs, unkown conditions !!!@')
|
|
|
|
|
| image = cv2.imread(image_path)
|
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
|
|
|
| kernel = np.ones((2, 2), np.uint8)
|
| dilated = cv2.dilate(thresh, kernel, iterations=1)
|
|
|
|
|
|
|
| result = ocr.ocr(dilated, cls=True)
|
|
|
| text_boxes = []
|
| text_contents = []
|
| confidences = []
|
| for line in result:
|
| print(line)
|
| if line:
|
| for box_info in line:
|
| box = box_info[0]
|
| x_coords = [point[0] for point in box]
|
| y_coords = [point[1] for point in box]
|
| text_box = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
|
| text = box_info[1][0]
|
| text_boxes.append(text_box)
|
| text_contents.append(text)
|
| confidences.append(box_info[1][1])
|
| print("Detected text boxes:", text_boxes)
|
| print("Detected text contents:", text_contents)
|
| print("Confidences:", confidences)
|
|
|
|
|
| ai2text = {}
|
| ai2relplace = {}
|
| ai2rdkitlab_unknown = {}
|
| non_overlapping_texts = []
|
|
|
| tree = cKDTree(atom_centers)
|
|
|
| indices_to_delete = set()
|
|
|
| for ti, text_box in enumerate(text_boxes):
|
| text_center = calculate_center(text_box)
|
| ocr_text = text_contents[ti]
|
|
|
|
|
| if ocr_text in ['OH', 'HO']:
|
| ocr_text = 'O'
|
| elif ocr_text in ['SH', 'HS']:
|
| ocr_text = 'S'
|
| elif ocr_text in ['NH', 'HN']:
|
| ocr_text = 'N'
|
| elif ocr_text in ['CH', 'HC']:
|
| ocr_text = 'C'
|
| elif ocr_text == '0':
|
| ocr_text = 'O'
|
| elif ocr_text == 'L':
|
| ocr_text = 'Li'
|
| elif ocr_text[-1]=='-':
|
| if ocr_text[:-1] in ABBREVIATIONS:
|
| ocr_text=ocr_text[:-1]
|
|
|
|
|
| overlapping_indices = []
|
| for idx in range(len(atom_bbox)):
|
| if idx not in indices_to_delete and boxes_overlap(atom_bbox[idx], text_box):
|
| overlapping_indices.append(idx)
|
|
|
| if overlapping_indices:
|
|
|
| if len(overlapping_indices) > 1:
|
|
|
| overlapping_boxes = [atom_bbox[idx] for idx in overlapping_indices]
|
| merged_box = get_merged_box(overlapping_boxes)
|
| overlapping_indices_atomboxclass=[idx_to_labels[atom_classes[i]] for i in overlapping_indices]
|
| print(f"Merging {len(overlapping_indices)} atom boxes overlapping with OCR text: {ocr_text}")
|
| print(f" {overlapping_indices} boxes type{overlapping_indices_atomboxclass} merged as OCR text: {ocr_text}")
|
| merged_area = box_area(merged_box)
|
| text_area = box_area(text_box)
|
| final_box = merged_box if merged_area >= text_area else text_box
|
| else:
|
|
|
| final_box = text_box
|
|
|
| primary_idx = overlapping_indices[0]
|
|
|
|
|
|
|
| atom_bbox[primary_idx] = final_box
|
|
|
| if ocr_text in ABBREVIATIONS:
|
| ai2relplace[primary_idx] = ocr_text
|
| atom_classes[primary_idx] = 0
|
| if ocr_text in lab2idx:
|
| atom_classes[primary_idx] = lab2idx[ocr_text]
|
| elif ocr_text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:
|
| atom_classes[primary_idx] = lab2idx[ocr_text]
|
| elif ocr_text in RGROUP_SYMBOLS or (ocr_text[0] == 'R' and ocr_text[1:].isdigit()):
|
| atom_classes[primary_idx] = 0
|
| else:
|
| ai2rdkitlab_unknown[primary_idx] = ocr_text
|
| atom_classes[primary_idx] = 0
|
|
|
| ai2text[primary_idx] = ocr_text
|
|
|
|
|
| indices_to_delete.update(overlapping_indices[1:])
|
|
|
| else:
|
|
|
| distance, nearest_idx = tree.query(text_center)
|
| if nearest_idx not in indices_to_delete:
|
| print(f"No overlap for OCR text '{ocr_text}', nearest atom box index: {nearest_idx}")
|
| non_overlapping_texts.append({
|
| 'text': ocr_text,
|
| 'text_box': text_box,
|
| 'nearest_atom_idx': nearest_idx,
|
| 'distance': distance
|
| })
|
|
|
|
|
| atom_ocr=[]
|
| for i,ai in enumerate(atom_classes):
|
| if i in ai2text:
|
| atom_ocr.append(ai2text[i])
|
|
|
|
|
| else:
|
| atom_ocr.append(idx_to_labels[ai])
|
| print(f"atom class + ocr presented as symbols::\n{atom_ocr}")
|
| atom_ocr=np.array(atom_ocr)
|
|
|
| if indices_to_delete:
|
| indices_to_keep = np.setdiff1d(np.arange(len(atom_bbox)), list(indices_to_delete))
|
| atom_bbox = atom_bbox[indices_to_keep]
|
| atom_classes = atom_classes[indices_to_keep]
|
| atom_centers = atom_centers[indices_to_keep]
|
| atom_scores = atom_scores[indices_to_keep]
|
| atom_ocr= atom_ocr[indices_to_keep]
|
|
|
|
|
| for d in [ai2text, ai2relplace, ai2rdkitlab_unknown]:
|
| d_new = {}
|
| for old_idx, value in d.items():
|
| new_idx = np.where(indices_to_keep == old_idx)[0][0] if old_idx in indices_to_keep else None
|
| if new_idx is not None:
|
| d_new[new_idx] = value
|
| d.clear()
|
| d.update(d_new)
|
|
|
|
|
| for entry in non_overlapping_texts:
|
| old_idx = entry['nearest_atom_idx']
|
| if old_idx in indices_to_keep:
|
| entry['nearest_atom_idx'] = np.where(indices_to_keep == old_idx)[0][0]
|
| else:
|
| entry['nearest_atom_idx'] = -1
|
|
|
|
|
| tree = cKDTree(atom_centers)
|
|
|
|
|
| print("Whole img with OCR :: ai2relplace, ai2rdkitlab_unknown:", [ai2relplace, ai2rdkitlab_unknown])
|
| print(f"Adjusted ai ocr_text: {ai2text}")
|
| print(f"Atom box num: {len(atom_bbox)}:: {[idx_to_labels[i] for i in atom_classes]}")
|
| print("Non-overlapping OCR text boxes:", non_overlapping_texts)
|
|
|
|
|
| atomcorp_img = Image.open(image_path).convert('RGB')
|
| atomcorp_img1k=atomcorp_img.resize([1000,1000])
|
| text_contents_star=[]
|
| text_confidences_star=[]
|
| text_boxes_star=[]
|
| boxid2del=dict()
|
| ocr_discrepancies = {}
|
| print(f"has atom_bbox number {len(atom_bbox)}")
|
| for i,box in enumerate(atom_bbox):
|
|
|
| abox =box* [scale_x, scale_y, scale_x, scale_y]
|
| cropped_img=atomcorp_img1k.crop(abox)
|
| image_npocr = np.array(cropped_img)
|
| result_ocr= ocr2.ocr(image_npocr, det=False)
|
|
|
| if result_ocr:
|
| for line in result_ocr:
|
|
|
| if line:
|
| box_flag=has_boxes(line)
|
| for box_info in line:
|
|
|
| if not box_flag:
|
| text=box_info[0]
|
|
|
| text=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', text)
|
| score_=box_info[1]
|
| text_contents_star.append(text)
|
| text_confidences_star.append(score_)
|
| else:
|
| box = box_info[0]
|
| x_coords = [point[0] for point in box]
|
| y_coords = [point[1] for point in box]
|
| text_box = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
|
| text = box_info[1][0]
|
| text=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', text)
|
| text_boxes_star.append(text_box)
|
| text_contents_star.append(text)
|
| score_=box_info[1][1]
|
| text_confidences_star.append(score_)
|
| if i in ai2text:
|
|
|
| if ai2text[i] != text:
|
| text=ai2text[i] if len(ai2text[i])>=len(text) else text
|
| print(f"Atom box {i}@@ OCR text: {text}, score: {score_}, AI class: {idx_to_labels[atom_classes[i]]}, AI score: {atom_scores[i]}")
|
|
|
| if text in ['OH', 'HO']:
|
| text = 'O'
|
| elif text in ['SH', 'HS']:
|
| text = 'S'
|
| elif text in ['NH', 'HN']:
|
| text = 'N'
|
| elif text in ['CH', 'HC']:
|
| text = 'C'
|
| elif text == '0':
|
| text = 'O'
|
| elif text == 'L':
|
| text = 'Li'
|
| elif '-' in text:
|
| if text[:-1] in ABBREVIATIONS:
|
| text=text[:-1]
|
|
|
|
|
| is_single_char = len(text) == 1
|
| ai_pred = idx_to_labels[atom_classes[i]]
|
|
|
| if text=='0':
|
| atom_classes[i]=lab2idx['O']
|
| elif text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:
|
| atom_classes[i]=lab2idx[text]
|
|
|
| elif is_single_char and text not in ELEMENTS and ai_pred == 'C':
|
|
|
| ocr_discrepancies[i] = {
|
| 'ocr_text': text,
|
| 'ocr_score': score_,
|
| 'ai_class': ai_pred,
|
| 'ai_score': atom_scores[i]
|
| }
|
| else:
|
| overlap, b_nei=AtomBox2bondBox(atom_bbox[i],bond_bbox)
|
| if not overlap:
|
| if text not in ELEMENTS and text not in ABBREVIATIONS:
|
|
|
|
|
| if text != idx_to_labels[atom_classes[i]]:
|
| boxid2del[i]= [text,idx_to_labels[atom_classes[i]]]
|
| else:
|
| if text != idx_to_labels[atom_classes[i]]:
|
| if atom_scores[i]<=score_:
|
| if text in RGROUP_SYMBOLS or text in ABBREVIATIONS:
|
| ai2relplace[i]=text
|
| atom_classes[i]=0
|
| if text in lab2idx and lab2idx[text] in list(range(23,29)):atom_classes[i]=lab2idx[text]
|
| elif text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:
|
| atom_classes[i]=lab2idx[text]
|
| else:
|
| ai2relplace[i]=text
|
| atom_classes[i]=0
|
|
|
|
|
| boxid2del = dict(sorted(boxid2del.items(), key=lambda item: item[0], reverse=True))
|
| print(f"considering del box",boxid2del)
|
| print("after split img OCR:: ai2relplace,ai2rdkitlab_unknown",[ai2relplace,ai2rdkitlab_unknown])
|
| print(f"considering delet atomb box :{boxid2del}")
|
| syms=[]
|
| for i in range(len(atom_classes)):
|
| if i in ai2relplace: syms.append(ai2relplace[i])
|
| elif i in ai2rdkitlab_unknown:syms.append(ai2rdkitlab_unknown[i])
|
| else:
|
| syms.append(idx_to_labels[atom_classes[i]])
|
| print(f"atombox {atom_classes}:: number {len(atom_classes)}\n",[idx_to_labels[i] for i in atom_classes])
|
| print(f" {syms}")
|
|
|
| isolated_ais = []
|
|
|
| bond_distances = []
|
| singleAtomBond=dict()
|
| for bi, bb in enumerate(bond_bbox):
|
| overlapped_atoms = []
|
| overlapped_abox = []
|
| for ai, aa in enumerate(atom_bbox):
|
| overlap_flag = boxes_overlap(bb, aa)
|
| if overlap_flag:
|
| overlapped_atoms.append(ai)
|
| overlapped_abox.append(aa)
|
|
|
|
|
|
|
|
|
| if len(overlapped_atoms) == 2:
|
| center1 = calculate_center(atom_bbox[overlapped_atoms[0]])
|
| center2 = calculate_center(atom_bbox[overlapped_atoms[1]])
|
| distance = np.linalg.norm(center1 - center2)
|
| bond_distances.append(distance)
|
|
|
| elif len(overlapped_atoms) == 1:
|
| print(f"single bond - atom still exists for bond {bi}, need porcess this !!")
|
| if bi not in singleAtomBond:
|
| singleAtomBond[bi]=overlapped_atoms
|
|
|
|
|
| distance_threshold = max(bond_distances) if bond_distances else 100.0
|
| distance_threshold_min = min(bond_distances) if bond_distances else 100.0
|
| print(f"Calculated distance_threshold center based: {distance_threshold:.2f}")
|
|
|
|
|
| a2b=dict()
|
| for ai, aa in enumerate(atom_bbox):
|
| b_nei = []
|
| for bi, bb in enumerate(bond_bbox):
|
| overlap_flag = boxes_overlap(bb, aa)
|
| if overlap_flag:
|
| b_nei.append(bi)
|
| a2b[ai] = b_nei
|
| if a2b[ai] ==[]:
|
| if ai not in isolated_ais:
|
| isolated_ais.append(ai)
|
|
|
| isolated_ais=sorted(isolated_ais,reverse=True)
|
| print(f"isolated_ais atom box {isolated_ais}\n ", [idx_to_labels[i] for i in atom_classes[isolated_ais]])
|
|
|
|
|
| updated_atom_bbox = atom_bbox.copy()
|
| updated_atom_classes = atom_classes.copy()
|
| updated_atom_scores = atom_scores.copy()
|
| print(f"atom bbox num {len(atom_bbox)}")
|
| new_bond_bbox=[]
|
| deleted_ais=[]
|
| del4boxid2del=set()
|
| for isolated_ai in isolated_ais:
|
| isolated_box = atom_bbox[isolated_ai]
|
| isolated_center = calculate_center(isolated_box)
|
| nearest_distance = float('inf')
|
| nearest_ai = -1
|
|
|
| for ai, aa in enumerate(atom_bbox):
|
| if ai not in isolated_ais and ai != isolated_ai:
|
| center = calculate_center(aa)
|
| distance = np.linalg.norm(isolated_center - center)
|
| if distance < nearest_distance:
|
| nearest_distance = distance
|
| nearest_ai = ai
|
|
|
| if nearest_ai != -1:
|
| if nearest_distance<=distance_threshold_min or (nearest_distance <=distance_threshold and nearest_distance>=distance_threshold_min):
|
| nearest_box = atom_bbox[nearest_ai]
|
| nearest_class = atom_classes[nearest_ai]
|
| nearest_center = calculate_center(nearest_box)
|
| if isolated_ai in boxid2del:
|
| textocr2del=boxid2del[isolated_ai][0]
|
| else:
|
| textocr2del=None
|
|
|
| overlap1,bondnei=AtomBox2bondBox(nearest_box,bond_bbox)
|
| if len(bondnei)==1:
|
|
|
| if textocr2del is not None and not any(c.isupper() for c in textocr2del):
|
|
|
| deleted_ais.append(isolated_ai)
|
| pass
|
| else:
|
| new_bc = (isolated_center + nearest_center)*0.5
|
| new_bondbox=np.array([new_bc[0] - nearest_distance*0.5,
|
| new_bc[1] - nearest_distance*0.5,
|
| new_bc[0] + nearest_distance*0.5,
|
| new_bc[1] + nearest_distance*0.5]
|
| )
|
| new_bond_bbox.append(new_bondbox.reshape(-1,4))
|
| print(f'add a new bond box new_bc for two atom boxes {isolated_ai} ---- {nearest_ai}::\n {idx_to_labels[atom_classes[isolated_ai]]} --- {idx_to_labels[atom_classes[nearest_ai]]}')
|
| else:
|
| try:
|
| new_box = merge_boxes(isolated_box, nearest_box)
|
| updated_atom_bbox[nearest_ai] = new_box
|
| chosed_score_ = max(atom_scores[isolated_ai], atom_scores[nearest_ai])
|
| updated_atom_scores[nearest_ai] = chosed_score_
|
| except Exception as e:
|
| print(f"file_name@: {image_path}\n SMILES in csv:\n{SMILESori}")
|
| print(e)
|
| print('nearest_ai ', nearest_ai)
|
| check2=True
|
| if check2:
|
| padding=5
|
|
|
| atombox_img=draw_objs(copy.deepcopy(img),
|
| atom_bbox* [scale_x, scale_y, scale_x, scale_y],
|
| atom_classes, atom_scores ,
|
| category_index=idx_to_labels,
|
| box_thresh=box_thresh,
|
| line_thickness=3,
|
| font='arial.ttf',
|
| font_size=10)
|
| bonbox_img=draw_objs(copy.deepcopy(img),
|
| bond_bbox* [scale_x, scale_y, scale_x, scale_y],
|
| bond_classes, bond_scores ,
|
| category_index=idx_to_labels,
|
| box_thresh=0.01,
|
| line_thickness=3,
|
| font='arial.ttf',
|
| font_size=10)
|
|
|
| atom_width, atom_height = atombox_img.size
|
| bon_width, bon_height = bonbox_img.size
|
| combined_width = atom_width + bon_width + padding * 3
|
| combined_height = max(atom_height, bon_height) + padding * 2
|
| combined_img = Image.new('RGB', (combined_width, combined_height), color=(255, 255, 255))
|
|
|
| combined_img.paste(atombox_img, (padding, padding))
|
| combined_img.paste(bonbox_img, (atom_width + padding * 2, padding))
|
| print(f"atom box afte NMS and merge_low_iou_boxes")
|
| combined_img.save(f"tttttttttttttttttttttttBoxed.png"
|
| )
|
| raise Exception("@debug this!!\n")
|
|
|
| if chosed_score_>=0.5:
|
| if chosed_score_==atom_scores[isolated_ai]:
|
| updated_atom_classes[nearest_ai] = 0
|
|
|
|
|
| updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0)
|
|
|
| updated_atom_classes = np.delete(updated_atom_classes, isolated_ai)
|
| updated_atom_scores = np.delete(updated_atom_scores, isolated_ai)
|
| print(f"Merged atom {isolated_ai} into {nearest_ai}, new box: {new_box}")
|
| isolated_ais.remove(isolated_ai)
|
| deleted_ais.append(isolated_ai)
|
|
|
| elif atom_scores[isolated_ai] < 0.5:
|
|
|
| updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0)
|
| updated_atom_classes = np.delete(updated_atom_classes, isolated_ai)
|
| updated_atom_scores = np.delete(updated_atom_scores, isolated_ai)
|
| print(f"DELET isolated atom {isolated_ai} with score {atom_scores[isolated_ai]}")
|
| deleted_ais.append(isolated_ai)
|
|
|
| isolated_ais = [i if i < isolated_ai else i - 1 for i in isolated_ais if i != isolated_ai]
|
| else:
|
| print(f"KEEP isolated atom {isolated_ai} with score {atom_scores[isolated_ai]} >= 0.5")
|
|
|
|
|
| else:
|
| if atom_scores[isolated_ai] < 0.5:
|
| updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0)
|
| updated_atom_classes = np.delete(updated_atom_classes, isolated_ai)
|
| updated_atom_scores = np.delete(updated_atom_scores, isolated_ai)
|
| print(f"DELET isolated atom {isolated_ai} with score {atom_scores[isolated_ai]}")
|
| deleted_ais.append(isolated_ai)
|
| isolated_ais = [i if i < isolated_ai else i - 1 for i in isolated_ais if i != isolated_ai]
|
| else:
|
| print(f"KEEP isolated atom {isolated_ai} with score {atom_scores[isolated_ai]} >= 0.5")
|
|
|
| if len(new_bond_bbox)>0:
|
| for i,bond_box in enumerate(new_bond_bbox):
|
| bond_bbox= np.concatenate([bond_bbox,bond_box],axis=0)
|
| bond_scores= np.concatenate((bond_scores,np.array([0.9])),axis=0)
|
| bond_classes= np.concatenate([bond_classes,np.array([13])],axis=0)
|
|
|
| x_center = (bond_bbox[:, 0] + bond_bbox[:, 2]) / 2
|
| y_center = (bond_bbox[:, 1] + bond_bbox[:, 3]) / 2
|
|
|
| center_coords = np.stack((x_center, y_center), axis=1)
|
| bond_centers=center_coords
|
|
|
|
|
| if len(deleted_ais) > 0:
|
| print(f"will delete atom box with idx :: {deleted_ais}")
|
|
|
| atom_classes = np.delete(atom_classes, deleted_ais, axis=0)
|
| atom_scores = np.delete(atom_scores, deleted_ais, axis=0)
|
| atom_bbox = np.delete(atom_bbox, deleted_ais, axis=0)
|
| atom_ocr = np.delete(atom_ocr, deleted_ais, axis=0)
|
|
|
|
|
|
|
| cur_atomSymbols=[idx_to_labels[i] for i in atom_classes]
|
| ocr_wholeImg=[]
|
| for i in atom_classes:
|
| if i in ai2relplace:
|
| ocr_wholeImg.append(ai2relplace[i])
|
| elif i in ai2rdkitlab_unknown:
|
| ocr_wholeImg.append(ai2rdkitlab_unknown[i])
|
| else:
|
| ocr_wholeImg.append(idx_to_labels[i])
|
| print("ai2relplace,ai2rdkitlab_unknown",ai2relplace,ai2rdkitlab_unknown)
|
| print("cur_atomSymbols:",cur_atomSymbols)
|
| print(" atomSymbolsOCR:",ocr_wholeImg)
|
|
|
|
|
| h_indices = np.where(atom_classes == lab2idx['H'])[0]
|
| non_h_indices = np.where(atom_classes != lab2idx['H'])[0]
|
|
|
|
|
| new_order = np.concatenate((non_h_indices, h_indices)).astype(np.int64)
|
|
|
|
|
| atom_classes = atom_classes[new_order]
|
| atom_bbox = atom_bbox[new_order]
|
| atom_scores = atom_scores[new_order]
|
| x_center = (atom_bbox[:, 0] + atom_bbox[:, 2]) / 2
|
| y_center = (atom_bbox[:, 1] + atom_bbox[:, 3]) / 2
|
|
|
| center_coords = np.stack((x_center, y_center), axis=1)
|
| atom_centers=center_coords
|
|
|
| bond_bbox = reorder_bond_bbox(bond_bbox, singleAtomBond)
|
| bond_classes = reorder_bond_bbox(bond_classes, singleAtomBond)
|
| bond_scores = reorder_bond_bbox(bond_scores, singleAtomBond)
|
| bond_centers = reorder_bond_bbox(bond_centers, singleAtomBond)
|
|
|
|
|
| a2b=dict()
|
| for ai, aa in enumerate(atom_bbox):
|
| b_nei = []
|
| for bi, bb in enumerate(bond_bbox):
|
| overlap_flag = boxes_overlap(bb, aa)
|
| if overlap_flag:
|
| b_nei.append(bi)
|
| a2b[ai] = b_nei
|
| if a2b[ai] ==[]:
|
| if ai not in isolated_ais:
|
| isolated_ais.append(ai)
|
|
|
| b2a=dict()
|
| for bi,bb in enumerate(bond_bbox):
|
| overlapped_atoms = []
|
| overlapped_abox=[]
|
| for ai,aa in enumerate(atom_bbox):
|
| overlap_flag=boxes_overlap(bb, aa)
|
| if overlap_flag:
|
|
|
| overlapped_atoms.append(ai)
|
| overlapped_abox.append(aa)
|
| if bi not in b2a.keys():
|
| b2a[bi]=[ai]
|
| else:
|
|
|
| b2a[bi].append(ai)
|
| if len(overlapped_atoms) == 1:
|
| print(f"single bond -atom still exists {overlapped_atoms}")
|
|
|
|
|
|
|
| if len(charges_classes) > 0:
|
|
|
| kdt = cKDTree(atom_centers)
|
| atid_list=list(range(len(atom_centers)))
|
| used_charge_indices=set()
|
| c2a=dict()
|
| for i, (x,y) in enumerate(charges_centers):
|
| overlapped_abox=[]
|
| cc=charges_bbox[i]
|
| for ai, aa in enumerate(atom_bbox):
|
| overlap_flag=boxes_overlap(cc, aa)
|
| ac_iou=calculate_iou(cc, aa)
|
| charge_=charges_classes[i]
|
| charge_score=charges_scores[i]
|
| if overlap_flag:
|
| if i in c2a:
|
| c2a[i].append(ai)
|
| else:
|
| c2a[i]=[ai]
|
| if ai not in atid_list:
|
| print(f"Warning: ai {ai} is out of range for atom_list.")
|
| continue
|
|
|
| a2c=dict()
|
| for ci,v in c2a.items():
|
| charge_=idx_to_labels[charges_classes[ci]]
|
| if len(v)==1:
|
| a2c[v[0]]=ci
|
| else:
|
| for ai in v:
|
| ats=idx_to_labels[atom_classes[ai]]
|
| if ats=='other':
|
| ats='*'
|
| if ats in ['F','Cl','I','Br','O'] and int(charge_)<0:
|
| a2c[ai]=ci
|
| elif ats in ['N','H','P'] and int(charge_)>0:
|
| a2c[ai]=ci
|
| else:
|
| print(f'unusuaal case charge {charge_} with atom {ats}!!')
|
|
|
| print(f"all a2b b2a a2c c2a done, start mol built")
|
|
|
| output={
|
| 'bbox': np.concatenate([atom_bbox, bond_bbox,charges_bbox], axis=0),
|
| 'bbox_centers': np.concatenate([atom_centers, bond_centers,charges_centers],axis=0),
|
| 'scores': np.concatenate([atom_scores, bond_scores, charges_scores],axis=0),
|
| 'pred_classes': np.concatenate([atom_classes, bond_classes, charges_classes],axis=0),
|
| 'image_path': image_path
|
| }
|
|
|
| boxinfor={
|
| 'bbox': output['bbox'],
|
| 'scores': output['scores'],
|
| 'pred_classes': output['pred_classes'],
|
| 'image_path': image_path
|
| }
|
|
|
| charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']])
|
| charges_bbox=output['bbox'][charge_mask]
|
| charges_centers=bbox2center(charges_bbox)
|
|
|
| charges_classes= output['pred_classes'][charge_mask]
|
| charges_scores= output['scores'][charge_mask]
|
| charges_bbox, charges_centers, charges_scores,charges_classes,figc =view_box_center2(charges_bbox, charges_centers, charges_scores,charges_classes, overlap_dist_thresh=5.0, max_centers_per_box=5)
|
|
|
|
|
|
|
| bonds_mask = np.array([True if ins in bond_labels and output['scores'][i]>0.2 else False for i, ins in enumerate(output['pred_classes'])])
|
| bond_bbox=output['bbox'][bonds_mask]
|
| bond_centers=bbox2center(bond_bbox)
|
|
|
| bond_classes= output['pred_classes'][bonds_mask]
|
| bond_scores= output['scores'][bonds_mask]
|
| print(f"before view_box_center2 bond nums {len(bond_scores)}")
|
|
|
| bond_bbox, bond_centers, bond_scores,bond_classes,fig =view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=3)
|
| print(f"after view_box_center2 bond nums {len(bond_scores)}")
|
|
|
| heavy_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins != lab2idx['H'] else False for ins in output['pred_classes']])
|
| h_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins == lab2idx['H'] else False for ins in output['pred_classes']])
|
|
|
|
|
| heavy_bbox = output['bbox'][heavy_mask]
|
|
|
| heavy_centers=bbox2center(heavy_bbox)
|
|
|
| heavy_scores= output['scores'][heavy_mask]
|
| heavy_classes = output['pred_classes'][heavy_mask]
|
| heavy_bbox, heavy_centers, heavy_scores,heavy_classes,fighv =view_box_center2(heavy_bbox, heavy_centers, heavy_scores,heavy_classes, overlap_dist_thresh=5.0, max_centers_per_box=5)
|
|
|
| rwmol_ = Chem.RWMol()
|
| boxi2ai = {}
|
| placeholder_atoms=dict()
|
| J=0
|
| for i, (bbox, a) in enumerate(zip(atom_bboxes, atom_classes)):
|
| a2labl=False
|
| a=replace_cg_notation(a)
|
|
|
| if a in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:
|
|
|
| ad = Chem.Atom(a)
|
|
|
| elif a in ELEMENTS:
|
| ad = Chem.Atom(a)
|
| elif a in ABBREVIATIONS :
|
| ad = Chem.Atom("*")
|
| placeholder_atoms[i] = a
|
| a2labl=True
|
|
|
| else:
|
| if N_C_H_expand(a):
|
| ad = Chem.Atom("*")
|
| placeholder_atoms[i] = a
|
| a2labl=True
|
| elif C_H_expand(a):
|
| ad = Chem.Atom("*")
|
| placeholder_atoms[i] = a
|
| a2labl=True
|
| elif C_H_expand2(a):
|
| ad = Chem.Atom("*")
|
| placeholder_atoms[i] = a
|
| a2labl=True
|
| elif formula_regex(a):
|
| ad = Chem.Atom("*")
|
| placeholder_atoms[i] = a
|
| a2labl=True
|
| else:
|
| ad = Chem.Atom("*")
|
| if a not in ['*',"other"]:
|
| a2labl=True
|
|
|
|
|
| rwmol_.AddAtom(ad)
|
| boxi2ai[J] = rwmol_.GetNumAtoms() - 1
|
| if a2labl: rwmol_.GetAtomWithIdx(J).SetProp("atomLabel", f"{a}")
|
| J+=1
|
|
|
|
|
| if len(charges_classes) > 0:
|
| for k,v in a2c.items():
|
| fc=int(idx_to_labels[charges_classes[v]])
|
| rwmol_.GetAtomWithIdx(k).SetFormalCharge(fc)
|
|
|
| print(f"mol with atoms number {i+1}, max atom id {i}")
|
| print(f"mol with bond box number {len(bond_classes)}")
|
| print(f"placeholder_atoms@@ {placeholder_atoms}")
|
|
|
|
|
| bonds=dict()
|
| existing_bonds = set()
|
| b2aa=dict()
|
| singleAtomBond=[]
|
| bondWithdirct=[]
|
|
|
|
|
| tree_atom = KDTree(atom_centers)
|
| if len(idx_to_labels)==30:
|
| _margin=0
|
| for bi, (bbox, idx_) in enumerate(zip(bond_bbox, bond_classes)):
|
| bond_type = idx_to_labels[idx_]
|
| if len(idx_to_labels)==23:
|
| if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| _margin = 5
|
| else:
|
| _margin = 8
|
| anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| oposite_anchor_positions = anchor_positions.copy()
|
| oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
|
|
|
|
| anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
|
|
| dists, neighbours = tree_atom.query(anchor_positions, k=1)
|
| if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
|
|
|
| begin_idx, end_idx = neighbours[:2]
|
| else:
|
|
|
| begin_idx, end_idx = neighbours[2:]
|
| atom1_idx = boxi2ai[begin_idx]
|
| atom2_idx = boxi2ai[end_idx]
|
| if atom1_idx == atom2_idx:
|
| print(f"attempt to add self-bond:{bi} atomIdx1 == atomIdx2 ::{[atom1_idx, atom2_idx]}")
|
| print(f"for bond bi {bi} H atom may involbed dists:",dists)
|
| print(neighbours)
|
| print("anchor_positions",anchor_positions)
|
| else:
|
| if bond_type in ['-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| if bond_type in BONDDIRECT:
|
| bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', bond_type)
|
| bondWithdirct.append(bi)
|
| else:
|
| bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', None)
|
| bond_type=BONDTYPE['SINGLE']
|
| elif bond_type == '=':
|
| bonds[bi] = (atom1_idx, atom2_idx, 'DOUBLE', None)
|
| bond_type=BONDTYPE['DOUBLE']
|
| elif bond_type == '#':
|
| bonds[bi] = (atom1_idx, atom2_idx, 'TRIPLE', None)
|
| bond_type=BONDTYPE['TRIPLE']
|
| else:
|
| print(f'unkown bond type relaced with single@@ {bond_type}')
|
| bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', None)
|
| bond_type=BONDTYPE['SINGLE']
|
|
|
| atom1 = rwmol_.GetAtomWithIdx(atom1_idx)
|
| atom2 = rwmol_.GetAtomWithIdx(atom2_idx)
|
| val1 = sum(b.GetBondTypeAsDouble() for b in atom1.GetBonds())
|
| val2 = sum(b.GetBondTypeAsDouble() for b in atom2.GetBonds())
|
| max_val1 = max(VALENCES[atom1.GetSymbol()])
|
| max_val2 = max(VALENCES[atom2.GetSymbol()])
|
|
|
| bond_order=BONDTYPE2ORD[bond_type]
|
| if val1 + bond_order <= max_val1 and val2 + bond_order <= max_val2:
|
| bond1 = rwmol_.GetBondBetweenAtoms(atom1_idx, atom2_idx)
|
| bond2 = rwmol_.GetBondBetweenAtoms(atom2_idx, atom1_idx)
|
| if bond1 or bond2:
|
|
|
| pass
|
|
|
| else:
|
|
|
| rwmol_.AddBond(atom1_idx, atom2_idx, bond_type)
|
| else:
|
| print(f"Skipping bond {bi}: Exceeds valence.")
|
| existing_bonds.add((atom1_idx, atom2_idx))
|
| b2aa[bi]=sorted([atom1_idx, atom2_idx])
|
|
|
| if len(bond_bbox)==1 and len(atom_bbox)==2:
|
| ca1='[*:0][C:2]#[C:3][*:1]'
|
| rwmol_ = Chem.RWMol()
|
| ats= ['*','*','C','C']
|
| for ia in ats:
|
| a=Chem.Atom(ia)
|
| id_=rwmol_.AddAtom(a)
|
|
|
| rwmol_.AddBond(2, 3, Chem.BondType.TRIPLE)
|
| rwmol_.AddBond(0, 2, Chem.BondType.SINGLE)
|
| rwmol_.AddBond(1, 3, Chem.BondType.SINGLE)
|
|
|
|
|
| for i in range(len(atom_classes)):
|
| atom_classes[i]=lab2idx['*']
|
| AllChem.Compute2DCoords(rwmol_)
|
| else:
|
| rwmol_=copy.deepcopy(rwmol_)
|
| print(f"placeholder_atoms {placeholder_atoms}")
|
|
|
|
|
| mol = rwmol_.GetMol()
|
| mol.RemoveAllConformers()
|
| conf = Chem.Conformer(mol.GetNumAtoms())
|
|
|
|
|
| for i, (x, y) in enumerate(atom_centers):
|
| x, y=float(x),float(y)
|
| conf.SetAtomPosition(i, (x, y, 0))
|
| mol.AddConformer(conf)
|
|
|
| Chem.AssignStereochemistryFrom3D(mol)
|
| rwmol_=Chem.RWMol(mol)
|
|
|
| skeleton_mol=copy.deepcopy(rwmol_)
|
| print(skeleton_mol.GetNumBonds())
|
| chiral_centers_aids = Chem.FindMolChiralCenters(mol, includeUnassigned=True)
|
|
|
|
|
| heavyNumber=len(heavy_centers)
|
| print(f'mol with heavy number atoms {heavyNumber}, max id {heavyNumber-1}')
|
| onlyHeayMol=copy.deepcopy(rwmol_)
|
| chiral_centers = Chem.FindMolChiralCenters(
|
| rwmol_, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False)
|
| chiral_center_ids = [idx for idx, _ in chiral_centers]
|
| Hais=[]
|
| Hais_bt=[]
|
| Hbd=[]
|
|
|
| for bi, ais in b2a.items():
|
| bt=bond_classes[bi]
|
| for ai in ais:
|
| if ai>heavyNumber-1:
|
| if bt in [14,15]:
|
| Hais.append(ais)
|
| print(f"within H bond box id {bi} bond direction {idx_to_labels[bt]} atoms box id {ais} ")
|
| Hais_bt.append(idx_to_labels[bt])
|
| Hbd.append(bi)
|
|
|
|
|
| H_existing_bonds = set()
|
| ha2boxa=dict()
|
| for ais, bt in zip(Hais,Hais_bt):
|
| idx_2=ais[-1]
|
| idx_1=ais[0]
|
| hbond=rwmol_.GetBondBetweenAtoms(idx_1,idx_2)
|
| if hbond is not None:
|
| if idx_1 in chiral_center_ids:
|
| hbond.SetBondDir(BOND_DIRS[bt])
|
| else:
|
| had = Chem.Atom("H")
|
| addHatom_idx = rwmol_.AddAtom(had)
|
| ha2boxa[addHatom_idx]=idx_2
|
|
|
| atom= rwmol_.GetAtomWithIdx(idx_1)
|
| max_val=max(VALENCES[atom.GetSymbol()])
|
| val = sum(b.GetBondTypeAsDouble() for b in atom.GetBonds())
|
| if (idx_1, addHatom_idx) not in H_existing_bonds and (addHatom_idx, idx_1) not in H_existing_bonds:
|
| if val<=max_val-1:
|
|
|
| print(idx_1, addHatom_idx)
|
| rwmol_.AddBond(idx_1,addHatom_idx, Chem.BondType.SINGLE)
|
| b=rwmol_.GetBondBetweenAtoms(idx_1,addHatom_idx)
|
| if idx_1 in chiral_center_ids:
|
| b.SetBondDir(BOND_DIRS[bt])
|
| H_existing_bonds.add((idx_1,addHatom_idx))
|
| i
|
| if len(ha2boxa)>0:
|
|
|
| rwmol_.RemoveAllConformers()
|
| conf = Chem.Conformer(rwmol_.GetNumAtoms())
|
| conf.Set3D(True)
|
| coords2d=[]
|
| for i, (x, y) in enumerate(heavy_centers):
|
| position = Point3D(float(x), float(y), 0.)
|
| conf.SetAtomPosition(i, position)
|
| coords2d.append([x,y])
|
| for k,v in ha2boxa.items():
|
| x,y=atom_centers[v]
|
| position = Point3D(float(x), float(y), 0.)
|
| conf.SetAtomPosition(k, position)
|
| coords2d.append([x,y])
|
| rwmol_.AddConformer(conf)
|
|
|
| additonalH=detect_unconnected_hydrogens(rwmol_)
|
| if len(additonalH)>0:
|
| rwmol_,rmovedAtomcoords=remove_unconnected_hydrogens2(rwmol_)
|
|
|
| if len(rmovedAtomcoords)>0:
|
| delbb=[]
|
| kdt = cKDTree(atom_centers)
|
| for i, (x,y,z) in enumerate(rmovedAtomcoords):
|
| dist, idx_=kdt.query([x,y], k=1)
|
| delbb.append(idx_)
|
| mask = np.ones(len(atom_classes), dtype=bool)
|
| mask[delbb] = False
|
| atom_bbox = atom_bbox[mask]
|
| atom_classes = atom_classes[mask]
|
| atom_centers = atom_centers[mask]
|
|
|
|
|
| mol=copy.deepcopy(rwmol_)
|
| conf=mol.GetConformers()[0]
|
| mola2xy=dict()
|
| mola2d=[]
|
| for i,a in enumerate(mol.GetAtoms()):
|
| x,y,z=conf.GetAtomPosition(i)
|
| mola2xy[i]=[x,y]
|
| mola2d.append([x,y])
|
|
|
| kdt = cKDTree(mola2d)
|
| chiral_centers = Chem.FindMolChiralCenters(
|
| mol, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False)
|
| chiral_center_ids = [idx for idx, _ in chiral_centers]
|
|
|
| for bi,bcent in enumerate(bond_centers):
|
| if bi in bondWithdirct :
|
| dists, a1a2 = kdt.query(bcent, k=2)
|
| a1,a2=sorted(a1a2)
|
| a1,a2=int(a1),int(a2)
|
| bt= mol.GetBondBetweenAtoms(a1, a2)
|
| if bt:
|
|
|
| current_begin = bt.GetBeginAtomIdx()
|
| current_end = bt.GetEndAtomIdx()
|
| bond_dir=bond_dirs[idx_to_labels[bond_classes[bi]]]
|
| if bond_dir == rdchem.BondDir.BEGINWEDGE:
|
| reverse_dir = rdchem.BondDir.BEGINDASH
|
| elif bond_dir == rdchem.BondDir.BEGINDASH:
|
| reverse_dir = rdchem.BondDir.BEGINWEDGE
|
|
|
|
|
| if a1 in chiral_center_ids:
|
| if current_begin == a1:
|
| bt.SetBondDir(bond_dir)
|
| print(f'a1 dir')
|
| else:
|
|
|
| bt.SetBondDir(reverse_dir)
|
| print(f'a1 reverse_dir')
|
|
|
|
|
| elif a2 in chiral_center_ids:
|
| if current_begin == a2:
|
| bt.SetBondDir(bond_dir)
|
| print(f'a2 dir {bond_dir} {reverse_dir}')
|
| else:
|
|
|
| mol.RemoveBond(current_begin, current_end)
|
| mol.AddBond(current_end, current_begin, bt.GetBondType())
|
| bond = mol.GetBondBetweenAtoms(current_end, current_begin)
|
| bond.SetBondDir(bond_dir)
|
| print(f'a2 reverse_dir {bond_dir} {reverse_dir}')
|
|
|
|
|
|
|
| else:
|
| print('bond stro not with chiral atom???, will ignore this stero bond infors')
|
| print(f"{[bi, bond_dir, current_begin,current_end]}")
|
|
|
|
|
|
|
|
|
| try:
|
| mol_rebuit=mol.GetMol()
|
| conf = mol_rebuit.GetConformer()
|
| Chem.WedgeMolBonds(mol_rebuit,conf)
|
| Chem.DetectBondStereochemistry(mol_rebuit)
|
| Chem.AssignChiralTypesFromBondDirs(mol_rebuit)
|
| Chem.AssignStereochemistry(mol_rebuit)
|
|
|
| smiH=Chem.MolToSmiles(mol_rebuit)
|
| print(F"smiH\n",smiH)
|
|
|
|
|
|
|
|
|
|
|
| mol = rdkit.Chem.RWMol(mol_rebuit)
|
| other2ppsocr=True
|
| if other2ppsocr:
|
| print()
|
| need_cut=[]
|
| ppstr=[]
|
| ppstr_score=[]
|
| crops=[]
|
| index_token=dict()
|
| expan=0
|
| for i_,(heav_c,heav_box) in enumerate(zip(atom_classes,atom_bbox)):
|
| if lab2idx['*']==heav_c or lab2idx['other']==heav_c or lab2idx['Cl']==heav_c:
|
| need_cut.append(i_)
|
| a=heav_box+np.array([-expan,-expan,expan,expan])
|
|
|
| box=a * [scale_x, scale_y, scale_x, scale_y]
|
|
|
| cropped_img = img_ori_1k.crop(box)
|
| crops.append(cropped_img)
|
| image_npocr = np.array(cropped_img)
|
| result_ocr= ocr2.ocr(image_npocr, det=False)
|
| s_, score_ =result_ocr[0][0]
|
| s_previos=atom_ocr[i_]
|
| if s_previos != "other" :
|
| s_=s_previos if len(s_previos)>=len(s_) else s_
|
| print(f'ocr::idx:{i_}',s_, score_ )
|
| if score_<=0.1:
|
|
|
| s_='*'
|
| if s_=='+' or s_=='-':
|
| s_="*"
|
| if len(s_)>1:
|
| s_=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', s_)
|
| if re.match(r'^\d+$', s_):
|
| s_=f'{s_}*'
|
|
|
| if s_=='L':s_='Li'
|
| elif s_=='0':s_='O'
|
| elif s_ in ['N,+ CI','N,+ Cl' ,'N,+Cl','N,+CI','N+CI']:s_='N2+Cl-'
|
| elif s_ in ['NO,','O,N' ]:s_='NO2'
|
|
|
|
|
| match = re.match(r'^(\d+)?(.*)', s_)
|
|
|
| if match:
|
| numeric_part, remaining_part = match.groups()
|
| fc_=mol.GetAtomWithIdx(i_).GetFormalCharge()
|
| if remaining_part in ELEMENTS:
|
| new_atom = Chem.Atom(remaining_part)
|
| mol.ReplaceAtom(i_, new_atom)
|
| print(i_, remaining_part,"@@@")
|
| elif remaining_part in ABBREVIATIONS:
|
| placeholder_atoms[i_]=s_
|
| elif remaining_part=='OH':
|
| new_atom = Chem.Atom("O")
|
| mol.ReplaceAtom(i_, new_atom)
|
| elif remaining_part=='SH':
|
| new_atom = Chem.Atom("S")
|
| mol.ReplaceAtom(i_, new_atom)
|
| elif remaining_part=='NH':
|
| new_atom = Chem.Atom("N")
|
| mol.ReplaceAtom(i_, new_atom)
|
| mol.GetAtomWithIdx(i_).SetFormalCharge(fc_)
|
| index_token[i_]=f'{s_}:{i_}'
|
| print(f"idx:{i_}, atm: <{idx_to_labels[heav_c]}> --- [{s_}:{i_}] with score:{score_} ||previousOCR:: {atom_ocr[i_]}")
|
| if s_ in ELEMENTS :
|
| new_atom = Chem.Atom(s_)
|
| mol.ReplaceAtom(i_, new_atom)
|
| mol.GetAtomWithIdx(i_).SetProp("atomLabel", f"{s_}")
|
| ppstr.append(s_)
|
| ppstr_score.append(score_)
|
| if s_ in ABBREVIATIONS.keys():
|
| placeholder_atoms[i_]=s_
|
|
|
| bond_dirs_rev={v:k for k,v in bond_dirs.items()}
|
| wdbs=[]
|
| for b in mol.GetBonds():
|
| bd=b.GetBondDir()
|
| bt=b.GetBondType()
|
|
|
| if bd ==bond_dirs['BEGINDASH'] or bd==bond_dirs['BEGINWEDGE']:
|
| a1,a2=b.GetBeginAtomIdx(), b.GetEndAtomIdx()
|
| wdbs.append([a1,a2,bt,bond_dirs_rev[bd]])
|
|
|
|
|
|
|
| cm=copy.deepcopy(mol)
|
|
|
| expand_mol, expand_smiles= expandABB(cm,ABBREVIATIONS, placeholder_atoms)
|
| SMILESpre=expand_smiles
|
| rdm=copy.deepcopy(expand_mol)
|
| target_mol, ref_mol=rdm, cm
|
| AllChem.Compute2DCoords(target_mol)
|
| pair=[target_mol, ref_mol]
|
| mcs=rdFMCS.FindMCS([target_mol, ref_mol],
|
|
|
| bondCompare=rdFMCS.BondCompare.CompareAny,
|
| ringCompare=rdFMCS.RingCompare.IgnoreRingFusion,
|
| matchChiralTag=False,
|
| )
|
| mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
|
| AllChem.Compute2DCoords(mcs_mol)
|
|
|
| matches0 = pair[0].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False)
|
| matches1 = pair[1].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False)
|
| if len(matches0) != len(matches1):
|
| matches0=list(matches0)
|
| matches1=list(matches1)
|
|
|
| if len(matches0)>len(matches1):
|
| for i in range(0,len(matches0)):
|
| if i < len(matches1):
|
| pass
|
| else:
|
| ii=i % len(matches1)
|
| matches1.append(matches1[ii])
|
| else:
|
| for i in range(0,len(matches1)):
|
| if i < len(matches0):
|
| pass
|
| else:
|
| ii=i % len(matches0)
|
| matches0.append(matches0[ii])
|
| assert len(matches0) == len(matches1), "matcher not equal break!!"
|
| atommaping_pairs=[list(zip(matches0[i],matches1[i])) for i in range(0,len(matches0))]
|
| atomMap=atommaping_pairs[0]
|
| rmsd2=rdkit.Chem.rdMolAlign.AlignMol(prbMol=target_mol, refMol=ref_mol, atomMap=atomMap,maxIters=2000000)
|
| print(f"rmsd {rmsd2}")
|
|
|
| ocr_mol = copy.deepcopy(target_mol)
|
| AllChem.Compute2DCoords(ocr_mol)
|
| ocr_smi = Chem.MolToSmiles(ocr_mol)
|
| molexp=ocr_mol
|
| expandStero_smi, success= rdkit_canonicalize_smiles(ocr_smi)
|
|
|
|
|
|
|
| if visual_check:
|
| boxed_img = draw_objs(img,
|
| atom_bbox,
|
| atom_classes,
|
| atom_scores,
|
| category_index=idx_to_labels,
|
| box_thresh=0.5,
|
| line_thickness=3,
|
| font='arial.ttf',
|
| font_size=10)
|
| opts = Draw.MolDrawOptions()
|
| opts.addAtomIndices = False
|
| opts.addStereoAnnotation = False
|
| img_ori = Image.open(image_path).convert('RGB')
|
| img_ori_1k = img_ori.resize((1000,1000))
|
| if other2ppsocr:
|
| img_rebuit = Draw.MolToImage(ocr_mol, options=opts,size=(1000, 1000))
|
| else:
|
| img_rebuit = Draw.MolToImage(ocr_mol, options=opts,size=(1000, 1000))
|
| combined_img = Image.new('RGB', (img_ori_1k.width + boxed_img.width + img_rebuit.width, img_ori_1k.height))
|
| combined_img.paste(img_ori_1k, (0, 0))
|
| combined_img.paste(boxed_img, (img_ori_1k.width, 0))
|
| combined_img.paste(img_rebuit, (img_ori_1k.width + boxed_img.width, 0))
|
| imprefix=os.path.basename(image_path).split('.')[0]
|
| combined_img.save(f"{ima_checkdir}/{imprefix}Boxed.png")
|
|
|
| new_row = {'file_name':image_path, "SMILESori":SMILESori,
|
| 'SMILESpre':SMILESpre,
|
| 'SMILESexp':expandStero_smi,
|
| }
|
| smiles_data = smiles_data._append(new_row, ignore_index=True)
|
|
|
|
|
| if getacc:
|
| sameWithOutStero=comparing_smiles(new_row,SMILESpre)
|
| sameWithOutStero_exp=comparing_smiles(new_row,expandStero_smi)
|
|
|
| if (type(SMILESori)!=type('a')) or (type(SMILESpre)!=type('a')):
|
| if sameWithOutStero or sameWithOutStero_exp:
|
| mysum += 1
|
| else:
|
| print(f"smiles problems\n{SMILESori}\n{SMILESpre}\n{image_path}")
|
| failed.append([SMILESori,SMILESpre,image_path])
|
| mydiff.append([SMILESori,SMILESpre,image_path])
|
| continue
|
| mol1 = Chem.MolFromSmiles(SMILESori)
|
| if mol1 is None:
|
| rd_smi_ori, success1_=rdkit_canonicalize_smiles(SMILESori)
|
| mol1=Chem.MolFromSmiles(rd_smi_ori)
|
| if (mol_rebuit is None) or (mol1 is None):
|
| if sameWithOutStero or sameWithOutStero_exp:
|
| mysum += 1
|
| else:
|
| print(f'get rdkit mol None\n{SMILESori}\n{SMILESpre}\n{image_path}')
|
| failed.append([SMILESori,SMILESpre,image_path])
|
| mydiff.append([SMILESori,SMILESpre,image_path])
|
| continue
|
| if mol1:
|
| rdk_smi1=Chem.MolToSmiles(mol1)
|
| else:
|
| rdk_smi1=SMILESori
|
| if mol_rebuit:
|
| rdk_smi2=Chem.MolToSmiles(mol_rebuit)
|
| else:
|
| rdk_smi2=''
|
|
|
| if rdk_smi1==rdk_smi2 or rdk_smi1==expandStero_smi:
|
| mysum += 1
|
| else:
|
| if sameWithOutStero or sameWithOutStero_exp:
|
| mysum += 1
|
| else:
|
| mydiff.append([SMILESori,SMILESpre,image_path])
|
| if visual_check:
|
| combined_img.save(f"{ima_checkdir}/{imprefix}Boxed_diff{len(mydiff)}.png")
|
| try:
|
| morganfps1 = AllChem.GetMorganFingerprint(mol1, 3,useChirality=True)
|
| morganfps2 = AllChem.GetMorganFingerprint(mol_rebuit, 3,useChirality=True)
|
| morgan_tani = DataStructs.DiceSimilarity(morganfps1, morganfps2)
|
| fp1 = Chem.RDKFingerprint(mol1)
|
| fp2 = Chem.RDKFingerprint(mol_rebuit)
|
| tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2)
|
| if expandStero_smi!= '':
|
| fp3 = Chem.RDKFingerprint(molexp)
|
| morganfps3 = AllChem.GetMorganFingerprint(molexp, 3,useChirality=True)
|
| morgan_tani3 = DataStructs.DiceSimilarity(morganfps1, morganfps3)
|
| tanimoto3 = DataStructs.FingerprintSimilarity(fp1, fp3)
|
| if morgan_tani3> morgan_tani or tanimoto3> tanimoto :
|
| sim+=morgan_tani3
|
| simRD+=tanimoto3
|
| else:
|
| simRD+=tanimoto
|
| sim+=morgan_tani
|
| except Exception as e:
|
| print(f"mol to fingerprint erros")
|
| simRD+=0
|
| sim+=0
|
| continue
|
| except Exception as e:
|
| print(f"file_name@: {image_path}\n SMILES in csv:\n{SMILESori}")
|
| raise Exception("@debug this!!\n")
|
|
|
| if getacc:
|
| sim_100 = 100*sim/len(smiles_data)
|
| simrd100 = 100*simRD/len(smiles_data)
|
| flogout.write(f"rdkit concanlized==smiles:{100*mysum/len(smiles_data)}%\n")
|
| flogout.write(f"failed:{len(failed)}\n totoal saved in csv : {len(smiles_data)}\n")
|
| flogout.write(f"avarage similarity morgan tanimoto: RDKFp tanimoto:: {sim_100}%, {simrd100}% \n")
|
| flogout.write(f'I2M@@:: match--{mysum},unmatch--{len(mydiff)},failed--{len(failed)},correct %{100*mysum/len(smiles_data)} \n')
|
|
|
| from src.solver.evaluate import SmilesEvaluator
|
| evaluator = SmilesEvaluator(smiles_data['SMILESori'], tanimoto=False)
|
| res_pre=evaluator.evaluate(smiles_data['SMILESpre'])
|
| res_exp=evaluator.evaluate(smiles_data['SMILESexp'])
|
| flogout.write(f'MolScribe style evaluation@SMILESpre:: {str(res_pre)} \n')
|
| flogout.write(f'MolScribe style evaluation@SMILESexp:: {str(res_exp)} \n')
|
| flogout.close()
|
| print(f"will save {len(smiles_data)} dataframe into csv")
|
| smiles_data.to_csv(outcsv_filename, index=False)
|
|
|
|
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import torchvision
|
|
|
|
|
| class RTDETRPostProcessor(nn.Module):
|
| __share__ = ['num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category']
|
|
|
| def __init__(self, classes_dict=None, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) -> None:
|
| super().__init__()
|
| self.use_focal_loss = use_focal_loss
|
| if classes_dict is None:
|
| classes_dict = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
|
| 9:'I',10:'P',11:'H',12:'Si',
|
|
|
| 13:'single',14:'wdge',15:'dash',
|
| 16:'=',17:'#',18:':',
|
|
|
| 19:'-4',20:'-2',
|
| 21:'-1',
|
| 22:'+1',
|
| 23:'+2',
|
| }
|
| num_classes=len(classes_dict)
|
| self.num_top_queries = num_top_queries
|
| self.num_classes = num_classes
|
| self.remap_mscoco_category = remap_mscoco_category
|
| self.deploy_mode = False
|
|
|
| mscoco_category2label = {k: i for i, k in enumerate(classes_dict.keys())}
|
| mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
|
| self.mscoco_label2category=mscoco_label2category
|
|
|
| def extra_repr(self) -> str:
|
| return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}'
|
|
|
|
|
| def forward(self, outputs, orig_target_sizes):
|
|
|
| logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
|
|
|
|
|
| bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
|
| bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
|
|
|
| if self.use_focal_loss:
|
| scores = F.sigmoid(logits)
|
| scores, index = torch.topk(scores.flatten(1), self.num_top_queries, axis=-1)
|
| labels = index % self.num_classes
|
| index = index // self.num_classes
|
| boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]))
|
|
|
| else:
|
| scores = F.softmax(logits)[:, :, :-1]
|
| scores, labels = scores.max(dim=-1)
|
| boxes = bbox_pred
|
| if scores.shape[1] > self.num_top_queries:
|
| scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
|
| labels = torch.gather(labels, dim=1, index=index)
|
| boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
|
|
|
|
| if self.deploy_mode:
|
| return labels, boxes, scores
|
|
|
|
|
| if self.remap_mscoco_category:
|
|
|
| labels = torch.tensor([self.mscoco_label2category[int(x.item())] for x in labels.flatten()])\
|
| .to(boxes.device).reshape(labels.shape)
|
|
|
| results = []
|
| for lab, box, sco in zip(labels, boxes, scores):
|
| result = dict(labels=lab, boxes=box, scores=sco)
|
| results.append(result)
|
|
|
| return results
|
|
|
|
|
| def deploy(self, ):
|
| self.eval()
|
| self.deploy_mode = True
|
| return self
|
|
|
| @property
|
| def iou_types(self, ):
|
| return ('bbox', ) |