import math
import random
import anndata as ad
import numba
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from scipy.sparse import issparse
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from ...logging import logger_manager as lm
[docs]def calculate_adj_matrix(x, y, x_pixel=None, y_pixel=None, image=None, beta=49, alpha=1, histology=True):
"""(Part of spagcn algorithm) Function to calculate adjacent matrix according to spatial coordinate and image pixels.
Args:
x (list): a list which contains corresponding x-coordinates for the spots, spatialy.
y (list): a list which contains corresponding y-coordinates for the spots, spatialy.
x_pixel (list, optional): a list which contains corresponding x-pixels for the spots, in histology image. Defaults to None.
y_pixel (list, optional): a list which contains corresponding y-pixels for the spots, in histology image. Defaults to None.
image (class: `numpy.ndarray`, optional): the image(typically histology image) in `numpy.ndarray` format(can be obtained by cv2.imread). Defaults to None.
beta (int, optional): to control the range of neighbourhood when calculate grey value for one spot. Defaults to 49.
alpha (int, optional): to control the color scale. Defaults to 1.
histology (bool, optional): if the image is histological. Defaults to True.
Returns:
class: `numpy.ndarray`: the calculated adjacent matrix.
"""
if histology:
assert (x_pixel is not None) & (x_pixel is not None) & (image is not None)
assert (len(x) == len(x_pixel)) & (len(y) == len(y_pixel))
lm.main_info("Calculateing adj matrix using histology image...")
beta_half = round(beta / 2)
g = []
max_x = image.shape[0]
max_y = image.shape[1]
for i in range(len(x_pixel)):
nbs = image[
max(0, x_pixel[i] - beta_half) : min(max_x, x_pixel[i] + beta_half + 1),
max(0, y_pixel[i] - beta_half) : min(max_y, y_pixel[i] + beta_half + 1),
]
g.append(np.mean(np.mean(nbs, axis=0), axis=0))
c0, c1, c2 = [], [], []
for i in g:
c0.append(i[0])
c1.append(i[1])
c2.append(i[2])
c0 = np.array(c0)
c1 = np.array(c1)
c2 = np.array(c2)
lm.main_info(f"Var of c0,c1,c2 = {np.var(c0)}, {np.var(c1)}, {np.var(c2)}")
c3 = (c0 * np.var(c0) + c1 * np.var(c1) + c2 * np.var(c2)) / (np.var(c0) + np.var(c1) + np.var(c2))
c4 = (c3 - np.mean(c3)) / np.std(c3)
z_scale = np.max([np.std(x), np.std(y)]) * alpha
z = c4 * z_scale
z = z.tolist()
lm.main_info(f"Var of x,y,z = {np.var(x)}, {np.var(y)}, {np.var(z)}")
X = np.array([x, y, z]).T.astype(np.float32)
else:
lm.main_info("Calculateing adj matrix using xy only...")
X = np.array([x, y]).T.astype(np.float32)
n = X.shape[0]
adj = np.empty((n, n), dtype=np.float32)
for i in numba.prange(n):
for j in numba.prange(n):
adj[i][j] = np.sqrt(np.sum((X[i] - X[j]) ** 2))
return adj
[docs]def calculate_p(adj, l):
adj_exp = np.exp(-1 * (adj**2) / (2 * (l**2)))
return np.mean(np.sum(adj_exp, 1)) - 1
[docs]def search_l(p, adj, start=0.01, end=1000, tol=0.01, max_run=100):
"""Function to search proper `l` value for spagcn algorithm.
Args:
p (float, optional): parameter `p` in spagcn algorithm. See `SpaGCN` for details.
adj (class: `numpy.ndarray`): the calculated adjacent matrix in spagcn algorithm.
start (float, optional): lower boundary of search. Defaults to 0.01.
end (int, optional): upper boundary of search. Defaults to 1000.
tol (float, optional): step length for search. Defaults to 0.01.
max_run (int, optional): maximum number of searching iteration. Defaults to 100.
Returns:
float: the `l` value
"""
run = 0
p_low = calculate_p(adj, start)
p_high = calculate_p(adj, end)
if p_low > p + tol:
lm.main_info("l not found, try smaller start point.")
return None
elif p_high < p - tol:
lm.main_info("l not found, try bigger end point.")
return None
elif np.abs(p_low - p) <= tol:
lm.main_info(f"recommended l = {str(start)}.")
return start
elif np.abs(p_high - p) <= tol:
lm.main_info(f"recommended l = {str(end)}.")
return end
while (p_low + tol) < p < (p_high - tol):
run += 1
lm.main_info(
"Run "
+ str(run)
+ ": l ["
+ str(start)
+ ", "
+ str(end)
+ "], p ["
+ str(p_low)
+ ", "
+ str(p_high)
+ "]"
)
if run > max_run:
lm.main_info(
"Exact l not found, closest values are:\n"
+ "l="
+ str(start)
+ ": "
+ "p="
+ str(p_low)
+ "\nl="
+ str(end)
+ ": "
+ "p="
+ str(p_high)
)
return None
mid = (start + end) / 2
p_mid = calculate_p(adj, mid)
if np.abs(p_mid - p) <= tol:
lm.main_info(f"recommended l = {str(mid)}")
return mid
if p_mid <= p:
start = mid
p_low = p_mid
else:
end = mid
p_high = p_mid
[docs]def get_cluster_num(
adata,
adj,
res,
tol,
lr,
max_epochs,
l,
r_seed=100,
t_seed=100,
n_seed=100,
):
"""get the initial number of clusters corresponding to given louvain resolution.
Args:
adata, adj, res, tol, lr, max_epochs: further passed to SpaGCN.train(), see `SpaGCN.train`.
l (float): parameter `l` in spagcn algorithm, see `SpaGCN` for details.
r_seed, t_seed, n_seed (int, optional): Global seed for `random`, `torch`, `numpy`. Defaults to 100.
Returns:
int: number of clusters
"""
random.seed(r_seed)
torch.manual_seed(t_seed)
np.random.seed(n_seed)
clf = SpaGCN()
clf.set_l(l)
clf.train(
adata,
adj,
init_spa=True,
init="louvain",
res=res,
tol=tol,
lr=lr,
max_epochs=max_epochs,
)
y_pred, _ = clf.predict()
return len(set(y_pred))
[docs]def search_res(
adata,
adj,
l,
target_num,
start=0.4,
step=0.1,
tol=5e-3,
lr=0.05,
max_epochs=10,
r_seed=100,
t_seed=100,
n_seed=100,
max_run=10,
):
"""Function to search a proper initial louvain resolution to get desired number of clusters in spagcn algorithm.
Args:
adata (class:`~anndata.AnnData`): an Annadata object.
adj (class: `numpy.ndarray`): the calculated adjacent matrix in spagcn algorithm.
l (float): parameter `l` in spagcn algorithm, see `SpaGCN` for details.
target_num (int): desired number of clusters.
start (float, optional): the lower boundary of search for resolution. Defaults to 0.4.
step (float, optional): search step length. Defaults to 0.1.
tol, lr, max_epochs: further passed to SpaGCN.train(), see `SpaGCN.train`.
r_seed, t_seed, n_seed (int, optional): Global seed for `random`, `torch`, `numpy`. Defaults to 100.
max_run (int, optional): max number of iteration. Defaults to 10.
Returns:
float: calculated initial louvain resolution.
"""
res = start
lm.main_info(f"Start at res = {res} step = {step}")
old_num = get_cluster_num(adata, adj, res, tol, lr, max_epochs, l, r_seed, t_seed, n_seed)
lm.main_info(f"Res = {res} Num of clusters = {old_num}")
run = 0
while old_num != target_num:
old_sign = -1 if (old_num < target_num) else 1
new_num = get_cluster_num(
adata,
adj,
res + step * old_sign,
tol,
lr,
max_epochs,
l,
r_seed,
t_seed,
n_seed,
)
lm.main_info(f"Res = {res + step * old_sign} Num of clusters = {new_num}")
if new_num == target_num:
res = res + step * old_sign
lm.main_info(f"recommended res = {res}")
return res
new_sign = -1 if (new_num < target_num) else 1
if new_sign == old_sign:
res = res + step * old_sign
lm.main_info(f"Res changed to res")
old_num = new_num
else:
step = step / 2
lm.main_info(f"Step changed to {step}")
if run > max_run:
lm.main_info("Exact resolution not found")
lm.main_info(f"Recommended res = {res}")
return res
run += 1
lm.main_info(f"recommended res = {res}")
return res
[docs]def refine(sample_id, pred, dis, shape="square"):
"""To refine(smooth) the boundary of spatial domains(clusters).
Args:
sample_id (list): list of sample(cell, spot or bin) names.
pred (list): list of spatial domains corresponding to the sample_id list.
dis (class: `numpy.ndarray`): the calculated adjacent matrix in spagcn algorithm.
shape (str, optional): Smooth the spatial domains with given spatial topology, "hexagon" for Visium data, "square" for ST data. Defaults to "square".
Returns:
[list]: list of refined spatial domains corresponding to the sample_id list.
"""
refined_pred = []
pred = pd.DataFrame({"pred": pred}, index=sample_id)
dis_df = pd.DataFrame(dis, index=sample_id, columns=sample_id)
if shape == "hexagon":
num_nbs = 6
elif shape == "square":
num_nbs = 4
else:
lm.main_info("Shape not recongized, shape='hexagon' for Visium data, 'square' for ST data.")
for i in range(len(sample_id)):
index = sample_id[i]
dis_tmp = dis_df.loc[index, :].sort_values()
nbs = dis_tmp[0 : num_nbs + 1]
nbs_pred = pred.loc[nbs.index, "pred"]
self_pred = pred.loc[index, "pred"]
v_c = nbs_pred.value_counts()
if (v_c.loc[self_pred] < num_nbs / 2) and (np.max(v_c) > num_nbs / 2):
refined_pred.append(v_c.idxmax())
else:
refined_pred.append(self_pred)
return refined_pred
[docs]class GraphConvolution(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.parameter.Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = nn.parameter.Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
[docs] def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output
[docs] def __repr__(self):
return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
[docs]class simple_GC_DEC(nn.Module):
"""
Simple NN model constructed with a GraphConvolution layer followed by a DeepEmbeddingClustering layer.
For DEC, see https://arxiv.org/abs/1511.06335v2
"""
def __init__(self, nfeat, nhid, alpha=0.2):
super(simple_GC_DEC, self).__init__()
self.gc = GraphConvolution(nfeat, nhid)
self.nhid = nhid
# self.mu determined by the init method
self.alpha = alpha
[docs] def forward(self, x, adj):
x = self.gc(x, adj)
q = 1.0 / ((1.0 + torch.sum((x.unsqueeze(1) - self.mu) ** 2, dim=2) / self.alpha) + 1e-8)
q = q ** (self.alpha + 1.0) / 2.0
q = q / torch.sum(q, dim=1, keepdim=True)
return x, q
[docs] def loss_function(self, p, q):
def kld(target, pred):
return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=1))
loss = kld(p, q)
return loss
[docs] def target_distribution(self, q):
# weight = q ** 2 / q.sum(0)
# return torch.transpose((torch.transpose(weight,0,1) / weight.sum(1)),0,1)e
p = q**2 / torch.sum(q, dim=0)
p = p / torch.sum(p, dim=1, keepdim=True)
return p
[docs] def fit(
self,
X,
adj,
lr=0.001,
max_epochs=5000,
update_interval=3,
trajectory_interval=50,
weight_decay=5e-4,
opt="sgd",
init="louvain",
n_neighbors=10,
res=0.4,
n_clusters=10,
init_spa=True,
tol=1e-3,
):
self.trajectory = []
if opt == "sgd":
optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
elif opt == "adam":
optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
features = self.gc(torch.FloatTensor(X), torch.FloatTensor(adj))
# ----------------------------------------------------------------
if init == "kmeans":
lm.main_info("Initializing cluster centers with kmeans, n_clusters known")
self.n_clusters = n_clusters
kmeans = KMeans(self.n_clusters, n_init=20)
if init_spa:
# ------Kmeans use exp and spatial
y_pred = kmeans.fit_predict(features.detach().numpy())
else:
# ------Kmeans only use exp info, no spatial
y_pred = kmeans.fit_predict(X) # Here we use X as numpy
elif init == "louvain":
lm.main_info(f"Initializing cluster centers with louvain, resolution = {res}")
if init_spa:
adata = ad.AnnData(features.detach().numpy())
else:
adata = ad.AnnData(X)
import dynamo as dyn
dyn.tl.neighbors(adata, n_neighbors=n_neighbors, X_data=adata.X)
dyn.tl.louvain(adata, resolution=res)
y_pred = adata.obs["louvain"].astype(int).to_numpy()
self.n_clusters = len(np.unique(y_pred))
# ----------------------------------------------------------------
y_pred_last = y_pred
self.mu = nn.parameter.Parameter(torch.Tensor(self.n_clusters, self.nhid))
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
self.trajectory.append(y_pred)
features = pd.DataFrame(features.detach().numpy(), index=np.arange(0, features.shape[0]))
Group = pd.Series(y_pred, index=np.arange(0, features.shape[0]), name="Group")
Mergefeature = pd.concat([features, Group], axis=1)
cluster_centers = np.asarray(Mergefeature.groupby("Group").mean())
self.mu.data.copy_(torch.Tensor(cluster_centers))
self.train()
for epoch in range(max_epochs):
if epoch % update_interval == 0:
_, q = self.forward(X, adj)
p = self.target_distribution(q).data
if epoch % 10 == 0:
lm.main_info(f"Epoch {epoch}")
optimizer.zero_grad()
z, q = self(X, adj)
loss = self.loss_function(p, q)
loss.backward()
optimizer.step()
if epoch % trajectory_interval == 0:
self.trajectory.append(torch.argmax(q, dim=1).data.cpu().numpy())
# Check stop criterion
y_pred = torch.argmax(q, dim=1).data.cpu().numpy()
delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / X.shape[0]
y_pred_last = y_pred
if epoch > 0 and (epoch - 1) % update_interval == 0 and delta_label < tol:
lm.main_info(f"delta_label {delta_label} < tol {tol}")
lm.main_info("Reach tolerance threshold. Stopping training.")
lm.main_info(f"Total epoch: {epoch}")
break
[docs] def predict(self, X, adj):
z, q = self(torch.FloatTensor(X), torch.FloatTensor(adj))
return z, q
[docs]class SpaGCN(object):
"""
Implementation for spagcn algorithm, see https://doi.org/10.1038/s41592-021-01255-8
"""
def __init__(self):
super(SpaGCN, self).__init__()
self.l = None
[docs] def set_l(self, l):
self.l = l
[docs] def train(
self,
adata,
adj,
num_pcs=50,
lr=0.005,
max_epochs=2000,
weight_decay=0,
opt="adam",
init_spa=True,
init="louvain", # louvain or kmeans
n_neighbors=10, # for louvain
n_clusters=None, # for kmeans
res=0.4, # for louvain
tol=1e-3,
):
"""train model for spagcn
Args:
adata (class:`~anndata.AnnData`): an Annadata object.
adj (class: `numpy.ndarray`): the calculated adjacent matrix in spagcn algorithm.
num_pcs (int, optional): number of pcs(out dimension of PCA) to use. Defaults to 50.
lr (float, optional): learning rate in neural network. Defaults to 0.005.
max_epochs (int, optional): max epochs to train in neural network. Defaults to 2000.
weight_decay (int, optional): make learning rate decay while training. Defaults to 0.
opt (str, optional): the optimizer to use. Defaults to "adam".
init_spa (bool, optional): make initial clusters with louvain or kmeans. Defaults to True.
init (str, optional): algorithm to use in inital clustering. Supports "louvain", "kmeans". Defaults to "louvain".
"""
self.num_pcs = num_pcs
self.res = res
self.lr = lr
self.max_epochs = max_epochs
self.weight_decay = weight_decay
self.opt = opt
self.init_spa = init_spa
self.init = init
self.n_neighbors = n_neighbors
self.n_clusters = n_clusters
self.res = res
self.tol = tol
assert adata.shape[0] == adj.shape[0] == adj.shape[1]
pca = PCA(n_components=self.num_pcs)
if issparse(adata.X):
pca.fit(adata.X.A)
embed = pca.transform(adata.X.A)
else:
pca.fit(adata.X)
embed = pca.transform(adata.X)
###------------------------------------------###
if self.l is None:
raise ValueError("l should be set before fitting the model!")
adj_exp = np.exp(-1 * (adj**2) / (2 * (self.l**2)))
# ----------Train model----------
self.model = simple_GC_DEC(embed.shape[1], embed.shape[1])
self.model.fit(
embed,
adj_exp,
lr=self.lr,
max_epochs=self.max_epochs,
weight_decay=self.weight_decay,
opt=self.opt,
init_spa=self.init_spa,
init=self.init,
n_neighbors=self.n_neighbors,
n_clusters=self.n_clusters,
res=self.res,
tol=self.tol,
)
self.embed = embed
self.adj_exp = adj_exp
[docs] def predict(self):
z, q = self.model.predict(self.embed, self.adj_exp)
y_pred = torch.argmax(q, dim=1).data.cpu().numpy()
# Max probability plot
prob = q.detach().numpy()
return y_pred, prob