Instructions to use SaeedLab/TITAN-BBB with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SaeedLab/TITAN-BBB with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("SaeedLab/TITAN-BBB", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedTokenizer | |
| from transformers.tokenization_utils_base import BatchEncoding | |
| from transformers import AutoTokenizer, AutoModel | |
| from rdkit import Chem | |
| from rdkit.Chem import Descriptors, AllChem, MACCSkeys | |
| from rdkit.ML.Descriptors import MoleculeDescriptors | |
| from rdkit import RDLogger | |
| from rdkit.Chem import Draw | |
| import joblib | |
| import numpy as np | |
| import os | |
| from huggingface_hub import snapshot_download | |
| import warnings | |
| from sklearn.exceptions import InconsistentVersionWarning | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| warnings.filterwarnings("ignore", category=InconsistentVersionWarning) | |
| RDLogger.DisableLog('rdApp.*') | |
| class BBBTokenizer(PreTrainedTokenizer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList]) | |
| self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM') | |
| self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval() | |
| self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1") | |
| self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval() | |
| self.img_preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ) | |
| ]) | |
| self.feature_transformer_tab = None | |
| self.feature_transformer_img = None | |
| self.feature_transformer_txt = None | |
| self.task = None | |
| def generate_tab_features(self, smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32) | |
| rdkit_2d = np.array(self.calc.CalcDescriptors(mol)) | |
| rdkit_2d[np.isinf(rdkit_2d)] = np.nan | |
| rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0) | |
| maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int) | |
| tab_input = np.concatenate([rdkit_2d, maccs]) | |
| tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0] | |
| tab_input = np.clip(tab_input, -1e5, 1e5) | |
| return torch.tensor(tab_input, dtype=torch.float32) | |
| def generate_img_features(self, smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| img = Image.new("RGB", (300,300), color=(0,0,0)) | |
| else: | |
| img = Draw.MolToImage(mol, size=(300, 300)) | |
| img = self.img_preprocess(img) | |
| with torch.no_grad(): | |
| img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1) | |
| img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0] | |
| return torch.tensor(img_input, dtype=torch.float32) | |
| def generate_txt_features(self, smiles): | |
| encoded = self.tokenizer(smiles, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.chemberta(**encoded) | |
| hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy() | |
| txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0] | |
| return torch.tensor(txt_input, dtype=torch.float32) | |
| def _batch_encode_plus( | |
| self, | |
| batch_smiles: list[str], | |
| task: str = 'classification', | |
| return_tensors: str = "pt", | |
| **kwargs | |
| ): | |
| if self.task is None or self.task != task: | |
| if task == 'classification': | |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_tabular.joblib"]) | |
| transformer_tab_path = os.path.join(model_dir, "normalize_cls_tabular.joblib") | |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_image.joblib"]) | |
| transformer_img_path = os.path.join(model_dir, "normalize_cls_image.joblib") | |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_text.joblib"]) | |
| transformer_txt_path = os.path.join(model_dir, "normalize_cls_text.joblib") | |
| self.task = task | |
| elif task == 'regression': | |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_tabular.joblib"]) | |
| transformer_tab_path = os.path.join(model_dir, "normalize_reg_tabular.joblib") | |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_image.joblib"]) | |
| transformer_img_path = os.path.join(model_dir, "normalize_reg_image.joblib") | |
| model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_text.joblib"]) | |
| transformer_txt_path = os.path.join(model_dir, "normalize_reg_text.joblib") | |
| self.task = task | |
| else: | |
| raise ValueError('task not defined') | |
| return | |
| self.feature_transformer_tab = joblib.load(transformer_tab_path) | |
| self.feature_transformer_img = joblib.load(transformer_img_path) | |
| self.feature_transformer_txt = joblib.load(transformer_txt_path) | |
| data_list = [] | |
| tab, img, txt = [], [], [] | |
| for smiles in batch_smiles: | |
| tab.append(self.generate_tab_features(smiles)) | |
| img.append(self.generate_img_features(smiles)) | |
| txt.append(self.generate_txt_features(smiles)) | |
| tab = torch.stack(tab) | |
| img = torch.stack(img) | |
| txt = torch.stack(txt) | |
| output = {} | |
| output["tab"] = tab | |
| output["img"] = img | |
| output["txt"] = txt | |
| return BatchEncoding(output, tensor_type=return_tensors) | |
| def encode(self, | |
| batch_smiles: list[str], | |
| task: str = 'classification', | |
| return_tensors: str = "pt", | |
| **kwargs): | |
| return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) | |
| def __call__(self, | |
| batch_smiles: list[str], | |
| task: str = 'classification', | |
| return_tensors: str = "pt", | |
| **kwargs): | |
| return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) | |
| def _tokenize(self, text, **kwargs): | |
| return [] | |
| def save_vocabulary(self, save_directory, filename_prefix=None): | |
| return () | |
| def get_vocab(self): | |
| return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4} | |
| def vocab_size(self): | |
| return len(self.get_vocab()) |