Source code for spateo.external.CAST.visualize

import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.cluster import KMeans, MiniBatchKMeans


[docs]def kmeans_plot_multiple( embed_dict_t, graph_list, coords, taskname_t, output_path_t, k=20, dot_size=10, scale_bar_t=None, minibatch=True, plot_strategy="sep", axis_off=False, ): num_plot = len(graph_list) plot_row = int(np.floor(num_plot / 2) + 1) embed_stack = embed_dict_t[graph_list[0]].cpu().detach().numpy() for i in range(1, num_plot): embed_stack = np.row_stack((embed_stack, embed_dict_t[graph_list[i]].cpu().detach().numpy())) print(f"Perform KMeans clustering on {embed_stack.shape[0]} cells...") kmeans = ( KMeans(n_clusters=k, random_state=0).fit(embed_stack) if minibatch == False else MiniBatchKMeans(n_clusters=k, random_state=0).fit(embed_stack) ) cell_label = kmeans.labels_ cluster_pl = sns.color_palette("tab20", len(np.unique(cell_label))) print(f"Plotting the KMeans clustering results...") cell_label_idx = 0 if plot_strategy == "sep": plt.figure(figsize=((20, 10 * plot_row))) for j in range(num_plot): plt.subplot(plot_row, 2, j + 1) coords0 = coords[graph_list[j]] col = coords0[:, 0].tolist() row = coords0[:, 1].tolist() cell_type_t = cell_label[cell_label_idx : (cell_label_idx + coords0.shape[0])] cell_label_idx += coords0.shape[0] for i in set(cell_type_t): plt.scatter( np.array(col)[cell_type_t == i], np.array(row)[cell_type_t == i], s=dot_size, edgecolors="none", c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]], label=str(i), rasterized=True, ) plt.title(graph_list[j] + " (KMeans, k = " + str(k) + ")", fontsize=20) plt.xticks(fontsize=20) plt.yticks(fontsize=20) plt.axis("equal") if axis_off: plt.xticks([]) plt.yticks([]) if type(scale_bar_t) != type(None): add_scale_bar(scale_bar_t[0], scale_bar_t[1]) else: plt.figure(figsize=[10, 12]) plt.rcParams.update({"font.size": 10, "axes.titlesize": 20, "pdf.fonttype": 42}) for j in range(num_plot): coords0 = coords[graph_list[j]] col = coords0[:, 0].tolist() row = coords0[:, 1].tolist() cell_type_t = cell_label[cell_label_idx : (cell_label_idx + coords0.shape[0])] cell_label_idx += coords0.shape[0] for i in set(cell_type_t): plt.scatter( np.array(col)[cell_type_t == i], np.array(row)[cell_type_t == i], s=dot_size, edgecolors="none", alpha=0.5, c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]], label=str(i), rasterized=True, ) plt.xticks(fontsize=20) plt.yticks(fontsize=20) plt.axis("equal") if axis_off: plt.xticks([]) plt.yticks([]) plt.title("K means (k = " + str(k) + ")", fontsize=30) if type(scale_bar_t) != type(None): add_scale_bar(scale_bar_t[0], scale_bar_t[1]) plt.savefig(f"{output_path_t}/{taskname_t}_trained_k{str(k)}.pdf", dpi=100) return cell_label
[docs]def add_scale_bar(length_t, label_t): import matplotlib.font_manager as fm from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar fontprops = fm.FontProperties(size=20, family="Arial") bar = AnchoredSizeBar( plt.gca().transData, length_t, label_t, 4, pad=0.1, sep=5, borderpad=0.5, frameon=False, size_vertical=0.1, color="black", fontproperties=fontprops, ) plt.gca().add_artist(bar)
[docs]def plot_mid_v2( coords_q, coords_r=None, output_path="", filename=None, title_t=["ref", "query"], s_t=8, scale_bar_t=None ): plt.rcParams.update({"font.size": 30, "axes.titlesize": 30, "pdf.fonttype": 42, "legend.markerscale": 5}) plt.figure(figsize=[10, 12]) if coords_r is not None: plt.scatter( np.array(coords_r)[:, 0].tolist(), np.array(coords_r)[:, 1].tolist(), s=s_t, edgecolors="none", alpha=0.5, rasterized=True, c="#9295CA", label=title_t[0], ) plt.scatter( np.array(coords_q)[:, 0].tolist(), np.array(coords_q)[:, 1].tolist(), s=s_t, edgecolors="none", alpha=0.5, rasterized=True, c="#E66665", label=title_t[1], ) plt.legend(fontsize=15) plt.axis("equal") if type(scale_bar_t) != type(None): add_scale_bar(scale_bar_t[0], scale_bar_t[1]) if filename != None: plt.savefig(os.path.join(output_path, filename + ".pdf"), dpi=300)
[docs]def plot_mid( coords_q, coords_r, output_path="", filename=None, title_t=["ref", "query"], s_t=8, scale_bar_t=None, axis_off=False ): plt.rcParams.update({"font.size": 30, "axes.titlesize": 30, "pdf.fonttype": 42, "legend.markerscale": 5}) plt.figure(figsize=[10, 12]) plt.scatter( np.array(coords_r)[:, 0].tolist(), np.array(coords_r)[:, 1].tolist(), s=s_t, edgecolors="none", alpha=0.5, rasterized=True, c="#9295CA", label=title_t[0], ) plt.scatter( np.array(coords_q)[:, 0].tolist(), np.array(coords_q)[:, 1].tolist(), s=s_t, edgecolors="none", alpha=0.5, rasterized=True, c="#E66665", label=title_t[1], ) plt.legend(fontsize=15) plt.axis("equal") if axis_off: plt.xticks([]) plt.yticks([]) if type(scale_bar_t) != type(None): add_scale_bar(scale_bar_t[0], scale_bar_t[1]) if filename != None: plt.savefig(os.path.join(output_path, filename + ".pdf"), dpi=300)
[docs]def dsplot( coords0, coords_plaque_t, s_cell=10, s_plaque=40, col_cell="#999999", col_plaque="red", cmap_t="vlag", alpha=1, vmax_t=None, title=None, scale_bar_200=None, output_path_t=None, coords0_mask=None, ): if coords0_mask is not None: coords_other = coords0[~coords0_mask].copy() coords0 = coords0[coords0_mask].copy() else: coords_other = None if type(vmax_t) == type(None): vmax_t = np.abs(col_cell).max() plt.figure(figsize=(13, 13)) if title is not None: plt.title(title, fontsize=30) plt.xticks(fontsize=20) plt.yticks(fontsize=20) plt.axis("equal") if type(col_cell) != str: col_cell_i = np.array(col_cell)[coords0_mask] if coords0_mask is not None else col_cell else: col_cell_i = col_cell if coords_other is not None: plt.scatter( coords_other[:, 0], coords_other[:, 1], s=s_cell, edgecolors="none", alpha=0.2, c="#aaaaaa", cmap=cmap_t, rasterized=True, ) col = coords0[:, 0].tolist() row = coords0[:, 1].tolist() plt.scatter( np.array(col), np.array(row), s=s_cell, edgecolors="none", alpha=alpha, vmax=vmax_t, vmin=-vmax_t, c=col_cell_i, cmap=cmap_t, rasterized=True, ) plt.colorbar(ticks=[-vmax_t, 0, vmax_t]) if type(coords_plaque_t) != type(None): coords1 = coords_plaque_t if type(s_plaque) != int: s_plaque_i = np.array(s_plaque) else: s_plaque_i = s_plaque if type(col_plaque) != str: col_plaque_i = np.array(col_plaque) else: col_plaque_i = col_plaque col = coords1[:, 0].tolist() row = coords1[:, 1].tolist() plt.scatter(np.array(col), np.array(row), s=s_plaque_i, edgecolors="none", c=col_plaque_i, rasterized=True) if scale_bar_200 is not None: add_scale_bar(scale_bar_200, "200 µm") if output_path_t is not None: plt.savefig(output_path_t, dpi=300) plt.close("all")