| | """ |
| | Module: preprocess.py |
| | |
| | This module provides a preprocessing pipeline for single-cell RNA sequencing (scRNA-seq) data |
| | stored in AnnData format. It includes functions for loading data, filtering cells and genes, |
| | normalizing and scaling data, and saving processed results. The pipeline is designed to be |
| | configurable via hyperparameters and supports various preprocessing steps such as mitochondrial |
| | gene filtering, highly variable gene selection, and log transformation. |
| | |
| | Main Features: |
| | - Load and preprocess scRNA-seq data in AnnData format. |
| | - Filter cells and genes based on various criteria. |
| | - Normalize, scale, and log-transform data. |
| | - Save processed data and metadata to disk. |
| | - Configurable via JSON-based hyperparameters. |
| | |
| | Dependencies: |
| | - anndata, numpy, pandas, scanpy, scipy, sklearn |
| | |
| | Usage: |
| | - Run this script as a standalone program with a configuration file specifying the hyperparameters. |
| | - Import the `preprocess` function and call it with the data path, metadata path, and hyperparameters. |
| | """ |
| |
|
| | import gc |
| | import json |
| | import os |
| | import warnings |
| | from argparse import ArgumentParser |
| | from typing import Sequence, Optional, Union |
| | from pathlib import Path |
| |
|
| | import anndata as ad |
| | import numpy as np |
| | import pandas as pd |
| | import scanpy as sc |
| | from anndata import ImplicitModificationWarning |
| | import scipy.sparse as sp |
| | from scipy.sparse import csr_matrix, issparse |
| | from sklearn.utils import sparsefuncs, sparsefuncs_fast |
| |
|
| | from teddy.data_processing.utils.gene_mapping.gene_mapper import ( |
| | map_mouse_human, |
| | map_mouse_human2, |
| | ) |
| |
|
| | |
| | _HUMAN_MITO_ENSEMBL= { |
| | "ENSG00000211459", "ENSG00000210082", |
| | |
| | "ENSG00000210049", "ENSG00000210077", "ENSG00000209082", |
| | "ENSG00000210100", "ENSG00000210107", "ENSG00000210112", |
| | "ENSG00000210119", "ENSG00000210122", "ENSG00000210116", |
| | "ENSG00000210117", "ENSG00000210118", "ENSG00000210124", |
| | "ENSG00000210126", "ENSG00000210134", "ENSG00000210135", |
| | "ENSG00000210142", "ENSG00000210144", "ENSG00000210148", |
| | "ENSG00000210150", "ENSG00000210155", "ENSG00000210196", |
| | "ENSG00000210151", |
| | |
| | "ENSG00000198888", "ENSG00000198763", "ENSG00000198840", |
| | "ENSG00000198886", "ENSG00000212907", "ENSG00000198786", |
| | "ENSG00000198695", "ENSG00000198804", "ENSG00000198712", |
| | "ENSG00000198938", "ENSG00000198899", "ENSG00000228253", |
| | "ENSG00000198727", |
| | } |
| |
|
| | _HUMAN_MITO_SYMBOLS = { |
| | "MT-RNR1", "MT-RNR2", "MT-TF", "MT-TV", "MT-TL1", "MT-TI", "MT-TQ", |
| | "MT-TM", "MT-TW", "MT-TA", "MT-TN", "MT-TC", "MT-TY", "MT-TD", "MT-TK", |
| | "MT-TG", "MT-TR", "MT-TH", "MT-TS2", "MT-TL2", "MT-TT", "MT-TE", "MT-TP", |
| | "MT-TS1", "MT-ND1", "MT-ND2", "MT-ND3", "MT-ND4", "MT-ND4L", "MT-ND5", |
| | "MT-ND6", "MT-CO1", "MT-CO2", "MT-CO3", "MT-ATP6", "MT-ATP8", "MT-CYB", |
| | } |
| |
|
| |
|
| | def load_data_and_metadata(data_path: str, metadata_path: str): |
| | """ |
| | Load an AnnData h5ad file (data_processing) and a JSON file (metadata). |
| | """ |
| | data = ad.read_h5ad(data_path) |
| | with open(metadata_path, "r") as f: |
| | metadata = json.load(f) |
| | return data, metadata |
| |
|
| |
|
| | def set_raw_if_necessary(data: ad.AnnData): |
| | """ |
| | If data_processing.raw is None, checks if data_processing.X is integer for ~64 cells. |
| | If so, set data_processing.raw = data_processing. Otherwise return None (skip). |
| | """ |
| | if data.raw is not None: |
| | return data |
| | |
| | if 'counts' in data.layers: |
| | X = data.layers['counts'] |
| | |
| | if isinstance(X, np.ndarray): |
| | X_sample = X[:64] |
| | elif issparse(X): |
| | X_sample = X[:64].toarray() |
| | |
| | if np.all(np.equal(np.mod(X_sample, 1), 0)): |
| | data.raw = ad.AnnData(X = data.layers['counts'], var = data.var.copy()) |
| | return data |
| | |
| | X = data.X |
| | |
| | if isinstance(X, np.ndarray): |
| | X_sample = X[:64] |
| | elif issparse(X): |
| | X_sample = X[:64].toarray() |
| | |
| | if np.all(np.equal(np.mod(X_sample, 1), 0)): |
| | data.raw = data |
| | return data |
| | else: |
| | print("No integer-valued matrix found") |
| | return None |
| |
|
| |
|
| |
|
| |
|
| | def initialize_processed_layer(data: ad.AnnData): |
| | """ |
| | If 'processed' layer is missing, copy from data_processing.raw.X |
| | """ |
| | if "processed" not in data.layers: |
| | data.layers["processed"] = data.raw.X.astype("float32") |
| | return data |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | def filter_reference_id(data: ad.AnnData, hyperparameters: dict): |
| | human_map = pd.read_csv("teddy/data_processing/utils/gene_mapping/data/human_mapping.txt", sep="\t") |
| | mouse_map = pd.read_csv("teddy/data_processing/utils/gene_mapping/data/2407_mouse_gene_mapping.txt", sep="\t") |
| | orthologs = pd.read_csv( |
| | "teddy/data_processing/utils/gene_mapping/data/mouse_to_human_orthologs.one2one.txt", sep="\t" |
| | ) |
| |
|
| | if hyperparameters.get("mouse_nonorthologs", False): |
| | reference_id = map_mouse_human2( |
| | data_frame=data.var, |
| | query_column=None, |
| | human_map_db=human_map, |
| | mouse_map_db=mouse_map, |
| | orthology_db=orthologs, |
| | )["reference_id"] |
| | else: |
| | reference_id = map_mouse_human( |
| | data_frame=data.var, |
| | query_column=None, |
| | human_map_db=human_map, |
| | mouse_map_db=mouse_map, |
| | orthology_db=orthologs, |
| | )["reference_id"] |
| |
|
| | valid_mask = reference_id != "" |
| | data = data[:, valid_mask].copy() |
| | reference_id = reference_id[valid_mask].reset_index(drop=True) |
| |
|
| | if not isinstance(data.layers["processed"], np.ndarray): |
| | corrected = data.layers["processed"].toarray() |
| | else: |
| | corrected = data.layers["processed"] |
| |
|
| | unique_ids = reference_id.unique() |
| | vars_to_keep = [] |
| | for rid in unique_ids: |
| | repeated_idx = np.where(reference_id == rid)[0] |
| | vars_to_keep.append(repeated_idx[0]) |
| | if len(repeated_idx) > 1: |
| | corrected[:, repeated_idx[0]] = corrected[:, repeated_idx].max(axis=1) |
| |
|
| | vars_to_keep = sorted(vars_to_keep) |
| | corrected = corrected[:, vars_to_keep] |
| | data = data[:, vars_to_keep] |
| |
|
| | with warnings.catch_warnings(): |
| | warnings.filterwarnings("ignore", category=ImplicitModificationWarning) |
| | data.layers["processed"] = csr_matrix(corrected) |
| | data.var["reference_id"] = list(reference_id[vars_to_keep]) |
| |
|
| | gc.collect() |
| | return data |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | def remove_assays(data: ad.AnnData, assays_to_remove: list): |
| | """ |
| | Removes observations from specified 'assay' categories if 'assay' is in data_processing.obs. |
| | """ |
| | data = data[~data.obs.assay.isin(assays_to_remove)].copy() |
| | gc.collect() |
| | return data |
| |
|
| |
|
| | def filter_cells_by_gene_counts(data: ad.AnnData, min_count: int): |
| | """ |
| | Removes cells (observations) whose total gene counts < min_count. |
| | """ |
| | mask = sc.pp.filter_cells(data.layers["processed"], min_counts=min_count)[0] |
| | data = data[np.where(mask)].copy() |
| | del mask |
| | gc.collect() |
| | return data |
| |
|
| |
|
| | def filter_cells_by_mitochondrial_fraction(data: ad.AnnData, max_mito_prop: float): |
| | """ |
| | Remove low-quality cells whose mitochondrial read fraction exceeds *max_fraction*. |
| | DO NOT RUN THIS IN ANY PREPROCESSING PIPELINE UNTIL YOU HAVE SET RAW COUNTS |
| | Parameters |
| | ---------- |
| | data |
| | `AnnData` object containing counts. Works with dense or sparse matrices. |
| | max_mito_prop |
| | Threshold above which cells are discarded. |
| | Returns |
| | ------- |
| | AnnData |
| | A **copy** of `data` with poor-quality cells removed and two new |
| | columns added to ``.obs``: |
| | - **mito_prop** – per-cell mitochondrial fraction |
| | - **poor_quality_mito** – boolean flag marking dropped cells |
| | """ |
| | |
| | |
| | counts = data.X |
| | var_index = data.var_names |
| | if var_index[0].startswith("ENSG"): |
| | ref = _HUMAN_MITO_ENSEMBL |
| | else: |
| | ref = _HUMAN_MITO_SYMBOLS |
| | mito_idx = np.flatnonzero(var_index.isin(ref)) |
| | if mito_idx.size == 0: |
| | _logger.info("No mitochondrial genes found, returning data") |
| | return data |
| | if sp.issparse(counts): |
| | total = counts.sum(axis=1).A1 |
| | mito = counts[:, mito_idx].sum(axis=1).A1 |
| | else: |
| | total = counts.sum(axis=1) |
| | mito = counts[:, mito_idx].sum(axis=1) |
| | mito_prop = mito / np.maximum(total, 1) |
| | data.obs["mito_prop"] = mito_prop |
| | data.obs["poor_quality_mito"] = mito_prop > max_mito_prop |
| | filtered = data[~data.obs["poor_quality_mito"]].copy() |
| | gc.collect() |
| | return filtered |
| |
|
| |
|
| | def filter_highly_variable_genes(data: ad.AnnData, method: str): |
| | """ |
| | Filter genes to those that are highly variable using scanpy. |
| | method must be "seurat_v3" or "cell_ranger". |
| | """ |
| | if "highly_variable" in data.var: |
| | data = data[:, data.var["highly_variable"]] |
| | else: |
| | sc.pp.highly_variable_genes(data, flavor=method, n_top_genes=10000) |
| | gc.collect() |
| | return data |
| |
|
| |
|
| | def normalize_data_inplace(matrix_csr: csr_matrix, norm_value: float): |
| | """ |
| | In-place row normalization + scale. matrix_csr must be a CSR matrix. |
| | """ |
| | |
| | sparsefuncs_fast.inplace_csr_row_normalize_l1(matrix_csr) |
| | |
| | scale_factors = np.array([norm_value] * matrix_csr.shape[0]) |
| | sparsefuncs.inplace_row_scale(matrix_csr, scale_factors) |
| | gc.collect() |
| |
|
| |
|
| | def scale_columns_by_median_dict(layer: csr_matrix, data: ad.AnnData, median_dict_path: str, median_column: str): |
| | """ |
| | Read a JSON median_dict, scale columns by 1/median. The lookup key is either |
| | data_processing.var.index or data_processing.var[median_column]. |
| | """ |
| | with open(median_dict_path) as f: |
| | median_dict = json.load(f) |
| |
|
| | if median_column == "index": |
| | median_var = data.var.index |
| | else: |
| | median_var = data.var[median_column] |
| |
|
| | factors = [] |
| | for g in median_var: |
| | if g in median_dict: |
| | factors.append(1.0 / median_dict[g]) |
| | else: |
| | factors.append(1.0) |
| | factors = np.array(factors) |
| |
|
| | |
| | sparsefuncs.inplace_csr_column_scale(layer, factors) |
| |
|
| |
|
| | def log_transform_layer(data: ad.AnnData, layer_name: str = "processed"): |
| | """ |
| | Apply sc.pp.log1p in place to data_processing.layers[layer_name]. |
| | """ |
| | sc.pp.log1p(data, layer=layer_name, copy=False) |
| |
|
| |
|
| | def compute_and_save_medians(data: ad.AnnData, data_path: str, hyperparameters: dict): |
| | """ |
| | Convert zeros to NaN, compute column medians ignoring NaN, and save results as JSON. |
| | """ |
| | with warnings.catch_warnings(): |
| | warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") |
| |
|
| | mat = data.layers["processed"].toarray() |
| | mat[mat == 0] = np.nan |
| | medians = np.nanmedian(mat, axis=0) |
| |
|
| | if hyperparameters["median_column"] == "index": |
| | median_var = data.var.index.copy() |
| | if not isinstance(median_var, pd.Series): |
| | median_var = pd.Series(median_var) |
| | else: |
| | median_var = data.var[hyperparameters["median_column"]].copy() |
| |
|
| | valid_idxs = np.where(~np.isnan(medians))[0] |
| | median_values = {median_var.iloc[k]: medians[k].item() for k in valid_idxs} |
| |
|
| | save_path = data_path.replace(hyperparameters["load_dir"], hyperparameters["save_dir"]) |
| | save_path = save_path.replace(".h5ad", "_medians.json") |
| | with open(save_path, "w") as f: |
| | json.dump(median_values, f, indent=4) |
| |
|
| |
|
| | def update_metadata(metadata: dict, data: ad.AnnData, hyperparameters: dict): |
| | """ |
| | Update metadata with cell_count and track processing arguments. |
| | """ |
| | metadata["cell_count"] = data.n_obs |
| | if "processing_args" in metadata: |
| | metadata["processing_args"] = [metadata["processing_args"]] + [hyperparameters] |
| | else: |
| | |
| | metadata["processings_args"] = [hyperparameters] |
| | return metadata |
| |
|
| |
|
| | def save_and_cleanup(data: ad.AnnData, metadata: dict, data_path: str, metadata_path: str, hyperparameters: dict): |
| | """ |
| | Write processed data_processing and metadata to disk, then GC cleanup. |
| | """ |
| | load_dir = hyperparameters["load_dir"] |
| | save_dir = hyperparameters["save_dir"] |
| | data_filename = os.path.basename(data_path) |
| | metadata_filename = os.path.basename(metadata_path) |
| |
|
| | save_processed_path = os.path.join(save_dir, data_filename) |
| | save_metadata_path = os.path.join(save_dir, metadata_filename) |
| |
|
| | |
| | os.makedirs(os.path.dirname(save_processed_path), exist_ok=True) |
| | os.makedirs(os.path.dirname(save_metadata_path), exist_ok=True) |
| |
|
| | if data.n_obs == 0: |
| | return None, None |
| |
|
| | |
| | if not isinstance(data.raw.X, csr_matrix): |
| | data.raw.X = csr_matrix(data.raw.X) |
| | if not isinstance(data.X, csr_matrix): |
| | data.X = csr_matrix(data.X) |
| | if "processed" in data.layers and not isinstance(data.layers["processed"], csr_matrix): |
| | data.layers["processed"] = csr_matrix(data.layers["processed"]) |
| |
|
| | try: |
| | data.write_h5ad(save_processed_path, compression="gzip") |
| | except Exception: |
| | |
| | if data.obs.index.name in data.obs.columns: |
| | del data.obs[data.obs.index.name] |
| | data.write_h5ad(save_processed_path, compression="gzip") |
| |
|
| | del data |
| | gc.collect() |
| |
|
| | with open(save_metadata_path, "w") as f: |
| | json.dump(metadata, f, indent=4) |
| |
|
| | return True, True |
| |
|
| |
|
| | def preprocess(data_path: str, metadata_path: str, hyperparameters: dict): |
| | """ |
| | Original pipeline steps: |
| | 1. Load data_processing & metadata |
| | 2. Ensure data_processing.raw if counts are integer |
| | 3. Initialize 'processed' layer |
| | 4. Filter genes by reference_id |
| | 5. Remove assays |
| | 6. Filter cells (min gene counts) |
| | 7. Filter cells (max mito fraction) |
| | 8. HVG filtering |
| | 9. Normalize total |
| | 10. Median-based column scaling |
| | 11. Log transform |
| | 12. Compute medians (optional) |
| | 13. Update metadata and save |
| | """ |
| | |
| | data, metadata = load_data_and_metadata(data_path, metadata_path) |
| |
|
| | |
| | data = set_raw_if_necessary(data) |
| | if data is None: |
| | return None, None |
| |
|
| | |
| | data = initialize_processed_layer(data) |
| | |
| |
|
| | |
| | if hyperparameters["reference_id_only"]: |
| | data = filter_reference_id(data, hyperparameters) |
| |
|
| | |
| | if "assay" in data.obs and hyperparameters["remove_assays"]: |
| | data = remove_assays(data, hyperparameters["remove_assays"]) |
| |
|
| | |
| | if hyperparameters["min_gene_counts"]: |
| | data = filter_cells_by_gene_counts(data, hyperparameters["min_gene_counts"]) |
| |
|
| | |
| | if hyperparameters["max_mitochondrial_prop"]: |
| | |
| | data = filter_cells_by_mitochondrial_fraction( |
| | data, hyperparameters["max_mitochondrial_prop"]) |
| |
|
| | |
| | if hyperparameters["hvg_method"] in ["seurat_v3", "cell_ranger"]: |
| | data = filter_highly_variable_genes(data, hyperparameters["hvg_method"]) |
| |
|
| | |
| | if hyperparameters["normalized_total"]: |
| | if not isinstance(data.layers["processed"], csr_matrix): |
| | data.layers["processed"] = csr_matrix(data.layers["processed"]) |
| | normalize_data_inplace(data.layers["processed"], hyperparameters["normalized_total"]) |
| |
|
| | |
| | if hyperparameters["median_dict"]: |
| | scale_columns_by_median_dict( |
| | data.layers["processed"], data, hyperparameters["median_dict"], hyperparameters["median_column"] |
| | ) |
| |
|
| | |
| | if hyperparameters["log1p"]: |
| | log_transform_layer(data, "processed") |
| |
|
| | |
| | if hyperparameters["compute_medians"]: |
| | compute_and_save_medians(data, data_path, hyperparameters) |
| |
|
| | |
| | metadata = update_metadata(metadata, data, hyperparameters) |
| | return save_and_cleanup(data, metadata, data_path, metadata_path, hyperparameters) |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | parser = ArgumentParser(description="Preprocess scRNA-seq data stored in AnnData format.") |
| | parser.add_argument( |
| | "--data_path", |
| | type=str, |
| | required=True, |
| | help="Path to the input .h5ad file." |
| | ) |
| | parser.add_argument( |
| | "--metadata_path", |
| | type=str, |
| | required=True, |
| | help="Path to the input metadata JSON file." |
| | ) |
| | parser.add_argument( |
| | "--config_path", |
| | type=str, |
| | required=True, |
| | help="Path to the JSON configuration file containing hyperparameters." |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | with open(args.config_path, "r") as f: |
| | hyperparameters = json.load(f) |
| |
|
| | |
| | success, _ = preprocess( |
| | data_path=args.data_path, |
| | metadata_path=args.metadata_path, |
| | hyperparameters=hyperparameters |
| | ) |
| |
|
| | if success: |
| | print("Preprocessing completed successfully.") |
| | else: |
| | print("Preprocessing returned no data (0 cells), no file saved.") |
| |
|