File size: 3,379 Bytes
ca7299e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Dict, Optional
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import gaussian_kde
from scipy import interpolate


##################################################
FONTSIZE = 18
##################################################
    
    
def scatterplot_2d(
    data_dict: Dict, 
    save_to: str,  
    ref_key: str = 'target',
    xlabel: str = 'tIC1',
    ylabel: str = 'tIC2',
    n_max_point: int = 1000,
    pop_ref: bool = False,
    xylim_key: bool = 'PDB_clusters', 
    plot_kde: bool = False,
    density_mapping: Optional[Dict] = None
):
    # configure min max
    if xylim_key and xylim_key in data_dict:
        xylim = data_dict.pop(xylim_key)
        # plot
        x_max = max(xylim[:,0]) 
        x_min = min(xylim[:,0]) 
        y_max = max(xylim[:,1]) 
        y_min = min(xylim[:,1])
    else:
        xylim = None
        x_max = max(data_dict[ref_key][:,0]) 
        x_min = min(data_dict[ref_key][:,0]) 
        y_max = max(data_dict[ref_key][:,1]) 
        y_min = min(data_dict[ref_key][:,1])
        
    # Add margin.
    x_min -= (x_max - x_min)/5.0
    x_max += (x_max - x_min)/5.0
    y_min -= (y_max - y_min)/5.0
    y_max += (y_max - y_min)/5.0
    
    # Remove reference data to save time.
    if pop_ref:
        data_dict.pop(ref_key)
    
    # plot tica
    print(f">>> Plotting scatter in 2D space. Image save to {save_to}")
    
    # Configure subplots.
    plot_n_row = len(data_dict) // 5 if len(data_dict) > 5 else 1  # at most 6 columns
    plot_n_columns = len(data_dict) // plot_n_row if len(data_dict) > 5 else len(data_dict)
    plt.figure(figsize=(6 * plot_n_columns , plot_n_row * 6))

    i = 0
    for k, v in data_dict.items():
        i += 1 
        plt.subplot(plot_n_row, plot_n_columns, i)
        
        if k != ref_key and v.shape[0] > n_max_point:    # subsample for visualize
            idx = np.random.choice(v.shape[0], n_max_point, replace=False)
            v = v[idx]
        
        if v.shape[0] < v.shape[1]:
            print(f"Warning: {k} has more dimensions than samples, using uniform density.")
            density = np.ones_like(v[:,0])
            density /= density.sum()
        else:
            cov = np.transpose(v)
            density = gaussian_kde(cov)(cov)
        
        # Optional precomputed density mapping.
        if density_mapping and k in density_mapping:
            density = density_mapping[k]

        plt.scatter(v[:, 0], v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40)
        # sns.scatterplot(x=v[:, 0], y=v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40)
        
        if plot_kde:
            sns.kdeplot(x=data_dict[ref_key][:, 0], y=data_dict[ref_key][:,1])    # landscape
        
        if xylim is not None:
            plt.scatter(xylim[:,0], xylim[:,1], s=40, marker="o", c="none", edgecolors="tab:red")   # cluster centers
        
        plt.xlabel(xlabel, fontsize=FONTSIZE, fontfamily="sans-serif")
        if (i-1) % plot_n_columns == 0:
            plt.ylabel(ylabel, fontsize=FONTSIZE, fontfamily="sans-serif")
        
        plt.xlim(x_min, x_max)
        plt.ylim(y_min, y_max)
        plt.title(k, fontsize=FONTSIZE, fontfamily="sans-serif")
                
    plt.tight_layout()
    plt.savefig(save_to, dpi=500)