| | from typing import Dict, Optional |
| | import os |
| |
|
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from scipy.stats import gaussian_kde |
| | from scipy import interpolate |
| |
|
| |
|
| | |
| | FONTSIZE = 18 |
| | |
| | |
| | |
| | def scatterplot_2d( |
| | data_dict: Dict, |
| | save_to: str, |
| | ref_key: str = 'target', |
| | xlabel: str = 'tIC1', |
| | ylabel: str = 'tIC2', |
| | n_max_point: int = 1000, |
| | pop_ref: bool = False, |
| | xylim_key: bool = 'PDB_clusters', |
| | plot_kde: bool = False, |
| | density_mapping: Optional[Dict] = None |
| | ): |
| | |
| | if xylim_key and xylim_key in data_dict: |
| | xylim = data_dict.pop(xylim_key) |
| | |
| | x_max = max(xylim[:,0]) |
| | x_min = min(xylim[:,0]) |
| | y_max = max(xylim[:,1]) |
| | y_min = min(xylim[:,1]) |
| | else: |
| | xylim = None |
| | x_max = max(data_dict[ref_key][:,0]) |
| | x_min = min(data_dict[ref_key][:,0]) |
| | y_max = max(data_dict[ref_key][:,1]) |
| | y_min = min(data_dict[ref_key][:,1]) |
| | |
| | |
| | x_min -= (x_max - x_min)/5.0 |
| | x_max += (x_max - x_min)/5.0 |
| | y_min -= (y_max - y_min)/5.0 |
| | y_max += (y_max - y_min)/5.0 |
| | |
| | |
| | if pop_ref: |
| | data_dict.pop(ref_key) |
| | |
| | |
| | print(f">>> Plotting scatter in 2D space. Image save to {save_to}") |
| | |
| | |
| | plot_n_row = len(data_dict) // 5 if len(data_dict) > 5 else 1 |
| | plot_n_columns = len(data_dict) // plot_n_row if len(data_dict) > 5 else len(data_dict) |
| | plt.figure(figsize=(6 * plot_n_columns , plot_n_row * 6)) |
| |
|
| | i = 0 |
| | for k, v in data_dict.items(): |
| | i += 1 |
| | plt.subplot(plot_n_row, plot_n_columns, i) |
| | |
| | if k != ref_key and v.shape[0] > n_max_point: |
| | idx = np.random.choice(v.shape[0], n_max_point, replace=False) |
| | v = v[idx] |
| | |
| | if v.shape[0] < v.shape[1]: |
| | print(f"Warning: {k} has more dimensions than samples, using uniform density.") |
| | density = np.ones_like(v[:,0]) |
| | density /= density.sum() |
| | else: |
| | cov = np.transpose(v) |
| | density = gaussian_kde(cov)(cov) |
| | |
| | |
| | if density_mapping and k in density_mapping: |
| | density = density_mapping[k] |
| |
|
| | plt.scatter(v[:, 0], v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40) |
| | |
| | |
| | if plot_kde: |
| | sns.kdeplot(x=data_dict[ref_key][:, 0], y=data_dict[ref_key][:,1]) |
| | |
| | if xylim is not None: |
| | plt.scatter(xylim[:,0], xylim[:,1], s=40, marker="o", c="none", edgecolors="tab:red") |
| | |
| | plt.xlabel(xlabel, fontsize=FONTSIZE, fontfamily="sans-serif") |
| | if (i-1) % plot_n_columns == 0: |
| | plt.ylabel(ylabel, fontsize=FONTSIZE, fontfamily="sans-serif") |
| | |
| | plt.xlim(x_min, x_max) |
| | plt.ylim(y_min, y_max) |
| | plt.title(k, fontsize=FONTSIZE, fontfamily="sans-serif") |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_to, dpi=500) |