import copy
import os
import random
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from sklearn.cluster import KMeans
from tqdm import trange
from .visualize import add_scale_bar
#################### Registration ####################
# Parameters class
@dataclass
[docs]class reg_params:
### affine
[docs] d_list: list[float] = field(default_factory=list)
[docs] translation_params: list[float] = None
[docs] mirror_t: list[float] = None
[docs] alpha_basis: list[float] = field(default_factory=list)
[docs] dist_penalty1: float = 0
[docs] attention_params: list[float] = field(default_factory=list)
### BS
[docs] mesh_trans_list: list[float] = field(default_factory=list)
[docs] attention_region: list[float] = field(default_factory=list)
[docs] attention_params_bs: list[float] = field(default_factory=list)
[docs] mesh_weight: list[float] = field(default_factory=list)
[docs] iterations_bs: list[float] = field(default_factory=list)
[docs] alpha_basis_bs: list[float] = field(default_factory=list)
[docs] meshsize: list[float] = field(default_factory=list)
[docs] img_size_bs: list[float] = field(default_factory=list) # max_xy
[docs] dist_penalty2: list[float] = field(default_factory=list)
[docs] PaddingRate_bs: float = 0
### common
[docs] device: str = field(init=False)
[docs] def __post_init__(self):
if self.gpu != -1 and torch.cuda.is_available():
self.device = "cuda:{}".format(self.gpu)
else:
self.device = "cpu"
[docs]def get_range(sp_coords):
yrng = max(sp_coords, key=lambda x: x[1])[1] - min(sp_coords, key=lambda x: x[1])[1]
xrng = max(sp_coords, key=lambda x: x[0])[0] - min(sp_coords, key=lambda x: x[0])[0]
return xrng, yrng
[docs]def prelocate(
coords_q,
coords_r,
cov_anchor_it,
bleeding,
output_path,
d_list=[1, 2, 3],
prefix="test",
ifplot=True,
index_list=None,
translation_params=None,
mirror_t=None,
):
idx_q = np.ones(coords_q.shape[0], dtype=bool) if index_list is None else index_list[0]
idx_r = np.ones(coords_r.shape[0], dtype=bool) if index_list is None else index_list[1]
mirror_t = [1, -1] if mirror_t is None else mirror_t
theta_t = []
J_t = []
if translation_params is None:
translation_x = [0]
translation_y = [0]
else:
xrng, yrng = get_range(coords_r.detach().cpu())
dx_ratio_max, dy_ratio_max, xy_steps = translation_params
dx_max = dx_ratio_max * xrng
dy_max = dy_ratio_max * yrng
translation_x = np.linspace(-dx_max, dx_max, num=int(xy_steps)) # dx
translation_y = np.linspace(-dy_max, dy_max, num=int(xy_steps)) # dy
for mirror in mirror_t:
for dx in translation_x:
for dy in translation_y:
for d in d_list:
for phi in [0, 90, 180, 270]:
a = d
d = d * mirror
theta = torch.Tensor([a, d, phi, dx, dy]).reshape(5, 1).to(coords_q.device)
coords_query_it = affine_trans_t(theta, coords_q)
try:
J_t.append(
J_cal(coords_query_it[idx_q], coords_r[idx_r], cov_anchor_it, bleeding).sum().item()
)
except:
continue
theta_t.append(theta)
if ifplot:
prelocate_loss_plot(J_t, output_path, prefix)
return theta_t[np.argmin(J_t)]
[docs]def Affine_GD(
coords_query_it_raw,
coords_ref_it,
cov_anchor_it,
output_path,
bleeding=500,
dist_penalty=0,
diff_step=50,
alpha_basis=np.reshape(np.array([0, 0, 1 / 5, 2, 2]), [5, 1]),
iterations=50,
prefix="test",
attention_params=[None, 3, 1, 0],
scale_t=1,
coords_log=False,
index_list=None,
mid_visual=False,
early_stop_thres=1,
ifrigid=False,
):
idx_q = np.ones(coords_query_it_raw.shape[0], dtype=bool) if index_list is None else index_list[0]
idx_r = np.ones(coords_ref_it.shape[0], dtype=bool) if index_list is None else index_list[1]
dev = coords_query_it_raw.device
theta = torch.Tensor([1, 1, 0, 0, 0]).reshape(5, 1).to(dev) # initial theta, [a,d,phi,t1,t2]
coords_query_it = coords_query_it_raw.clone()
plot_mid(
coords_query_it.cpu() * scale_t, coords_ref_it.cpu() * scale_t, output_path, prefix + "_init", scale_bar_t=None
) if mid_visual else None
similarity_score = [
J_cal(coords_query_it[idx_q], coords_ref_it[idx_r], cov_anchor_it, bleeding, dist_penalty, attention_params)
.sum()
.cpu()
.item()
]
it_J = []
it_theta = []
coords_q_log = []
delta_similarity_score = [np.inf] * 5
t = trange(iterations, desc="", leave=True)
for it in t:
alpha = alpha_init(alpha_basis, it, dev)
## de_sscore
dJ_dxy_mat = dJ_dt_cal(
coords_query_it[idx_q],
coords_ref_it[idx_r],
diff_step,
dev,
cov_anchor_it,
bleeding,
dist_penalty,
attention_params,
)
dJ_dtheta = dJ_dtheta_cal(
coords_query_it[idx_q, 0], coords_query_it[idx_q, 1], dJ_dxy_mat, theta, dev, ifrigid=ifrigid
)
theta = theta_renew(theta, dJ_dtheta, alpha, ifrigid=ifrigid)
coords_query_it = affine_trans_t(theta, coords_query_it_raw)
it_J.append(dJ_dtheta)
it_theta.append(theta)
if coords_log:
coords_q_log.append(coords_query_it.detach().cpu().numpy())
sscore_t = (
J_cal(coords_query_it[idx_q], coords_ref_it[idx_r], cov_anchor_it, bleeding, dist_penalty, attention_params)
.sum()
.cpu()
.item()
)
# print(f'Loss: {sscore_t}')
t.set_description(f"Loss: {sscore_t:.3f}")
t.refresh()
similarity_score.append(sscore_t)
if mid_visual:
if (it % 20 == 0) | (it == 0):
plot_mid(
coords_query_it.cpu() * scale_t,
coords_ref_it.cpu() * scale_t,
output_path,
prefix + str(int(it / 10 + 0.5)),
scale_bar_t=None,
)
if early_stop_thres is not None and it > 200:
delta_similarity_score.append(similarity_score[-2] - similarity_score[-1])
if np.all(np.array(delta_similarity_score[-5:]) < early_stop_thres):
print(f"Early stop at {it}th iteration.")
break
return [similarity_score, it_J, it_theta, coords_q_log]
[docs]def BSpline_GD(
coords_q,
coords_r,
cov_anchor_it,
iterations,
output_path,
bleeding,
dist_penalty=0,
alpha_basis=1000,
diff_step=50,
mesh_size=5,
prefix="test",
mesh_weight=None,
attention_params=[None, 3, 1, 0],
scale_t=1,
coords_log=False,
index_list=None,
mid_visual=False,
max_xy=None,
renew_mesh_trans=True,
restriction_t=0.5,
):
idx_q = np.ones(coords_q.shape[0], dtype=bool) if index_list is None else index_list[0]
idx_r = np.ones(coords_r.shape[0], dtype=bool) if index_list is None else index_list[1]
dev = coords_q.device
plot_mid(
coords_q.cpu() * scale_t,
coords_r.cpu() * scale_t,
output_path,
prefix + "_FFD_initial_" + str(iterations),
scale_bar_t=None,
) if mid_visual else None
max_xy = coords_q.max(0)[0].cpu() if max_xy is None else max_xy
mesh, mesh_weight, kls, dxy_ffd_all, delta = BSpline_GD_preparation(max_xy, mesh_size, dev, mesh_weight)
coords_query_it = coords_q.clone()
similarity_score = [
J_cal(coords_query_it[idx_q], coords_r[idx_r], cov_anchor_it, bleeding, dist_penalty, attention_params)
.sum()
.cpu()
.item()
]
mesh_trans_list = []
coords_q_log = []
mesh_trans = mesh.clone()
max_movement = (max_xy / (mesh_size - 1.0) * restriction_t).to(mesh.device).unsqueeze(-1).unsqueeze(-1)
t = trange(iterations, desc="", leave=True)
for it in t:
dJ_dxy_mat = dJ_dt_cal(
coords_query_it[idx_q],
coords_r[idx_r],
diff_step,
dev,
cov_anchor_it,
bleeding,
dist_penalty,
attention_params,
)
if renew_mesh_trans or it == 0:
uv_raw, ij_raw = BSpline_GD_uv_ij_calculate(coords_query_it, delta, dev)
uv = uv_raw[:, idx_q] # 2 * N[idx]
ij = ij_raw[:, idx_q] # 2 * N[idx]
result_B_t = B_matrix(uv, kls) ## 16 * N[idx]
dxy_ffd = get_dxy_ffd(ij, result_B_t, mesh, dJ_dxy_mat, mesh_weight, alpha_basis)
if renew_mesh_trans:
mesh_trans = mesh + dxy_ffd
else:
mesh_trans = mesh + torch.clamp(mesh_trans + dxy_ffd - mesh, min=-max_movement, max=max_movement)
mesh_trans_list.append(mesh_trans)
coords_query_it = BSpline_renew_coords(uv_raw, kls, ij_raw, mesh_trans)
if coords_log:
coords_q_log.append(coords_query_it.detach().cpu().numpy())
sscore_t = (
J_cal(coords_query_it[idx_q], coords_r[idx_r], cov_anchor_it, bleeding, dist_penalty, attention_params)
.sum()
.cpu()
.item()
)
# print(f'Loss: {sscore_t}')
t.set_description(f"Loss: {sscore_t:.3f}")
t.refresh()
similarity_score.append(sscore_t)
if mid_visual:
if (it % 20 == 0) | (it == 0):
plot_mid(
coords_query_it.cpu() * scale_t,
coords_r.cpu() * scale_t,
output_path,
prefix + "_FFD_it_" + str(it),
scale_bar_t=None,
)
mesh_plot(mesh.cpu(), coords_q_t=coords_query_it.cpu(), mesh_trans_t=mesh_trans.cpu())
plt.savefig(f"{output_path}/{prefix}_mesh_plot_it_{it}.pdf")
plt.clf()
### visualization
plt.figure(figsize=[20, 10])
plt.subplot(1, 2, 1)
plt.scatter(
np.array(coords_q.cpu()[:, 0].tolist()) * scale_t,
np.array(coords_q.cpu()[:, 1].tolist()) * scale_t,
s=2,
edgecolors="none",
alpha=0.5,
rasterized=True,
c="blue",
label="Before",
)
plt.scatter(
np.array(coords_query_it.cpu()[:, 0].tolist()) * scale_t,
np.array(coords_query_it.cpu()[:, 1].tolist()) * scale_t,
s=2,
edgecolors="none",
alpha=0.7,
rasterized=True,
c="#ef233c",
label="After",
)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=15)
plt.axis("equal")
plt.subplot(1, 2, 2)
titles = "loss = " + format(similarity_score[-1], ".1f")
plt.scatter(list(range(0, len(similarity_score))), similarity_score, s=5)
plt.title(titles, fontsize=20)
plt.savefig(os.path.join(output_path, prefix + "_after_Bspline_" + str(iterations) + ".pdf"))
return [coords_query_it, mesh_trans_list, dxy_ffd_all, similarity_score, coords_q_log]
[docs]def J_cal(coords_q, coords_r, cov_mat, bleeding=10, dist_penalty=0, attention_params=[None, 3, 1, 0]):
attention_region, double_penalty, penalty_inc_all, penalty_inc_both = attention_params
bleeding_x = coords_q[:, 0].min() - bleeding, coords_q[:, 0].max() + bleeding
bleeding_y = coords_q[:, 1].min() - bleeding, coords_q[:, 1].max() + bleeding
sub_ind = (
(coords_r[:, 0] > bleeding_x[0])
& (coords_r[:, 0] < bleeding_x[1])
& (coords_r[:, 1] > bleeding_y[0])
& (coords_r[:, 1] < bleeding_y[1])
)
cov_mat_t = cov_mat[:, sub_ind]
dist = torch.cdist(coords_q, coords_r[sub_ind, :])
min_dist_values, close_idx = torch.min(dist, dim=1)
tmp1 = torch.stack((torch.arange(coords_q.shape[0], device=coords_q.device), close_idx)).T
s_score_mat = cov_mat_t[tmp1[:, 0], tmp1[:, 1]]
if dist_penalty != 0:
penalty_tres = torch.sqrt(
(coords_r[:, 0].max() - coords_r[:, 0].min())
* (coords_r[:, 1].max() - coords_r[:, 1].min())
/ coords_r.shape[0]
)
dist_d = min_dist_values / penalty_tres
if type(attention_region) is np.ndarray:
attention_region = torch.tensor(attention_region, device=coords_q.device)
dist_d[attention_region] = min_dist_values[attention_region] / (penalty_tres / double_penalty)
dist_d[dist_d < 1] = 1
dist_d[dist_d > 1] *= dist_penalty
dist_d[attention_region] *= penalty_inc_all
dist_d[(dist_d > 1) & attention_region] *= penalty_inc_both / dist_penalty + 1
else:
dist_d[dist_d < 1] = 1
dist_d[dist_d > 1] *= dist_penalty
return s_score_mat * dist_d
return s_score_mat
[docs]def alpha_init(alpha_basis, it, dev):
return 5 / torch.pow(torch.Tensor([it / 40 + 1]).to(dev), 0.6) * alpha_basis
[docs]def dJ_dt_cal(coords_q, coords_r, diff_step, dev, cov_anchor_it, bleeding, dist_penalty, attention_params):
dJ_dy = (
J_cal(
coords_q + torch.tensor([0, diff_step], device=dev),
coords_r,
cov_anchor_it,
bleeding,
dist_penalty,
attention_params,
)
- J_cal(
coords_q + torch.tensor([0, -diff_step], device=dev),
coords_r,
cov_anchor_it,
bleeding,
dist_penalty,
attention_params,
)
) / (2 * diff_step)
dJ_dx = (
J_cal(
coords_q + torch.tensor([diff_step, 0], device=dev),
coords_r,
cov_anchor_it,
bleeding,
dist_penalty,
attention_params,
)
- J_cal(
coords_q + torch.tensor([-diff_step, 0], device=dev),
coords_r,
cov_anchor_it,
bleeding,
dist_penalty,
attention_params,
)
) / (2 * diff_step)
dJ_dxy_mat = torch.vstack((dJ_dx, dJ_dy)) # [dJ_{i}/dx_{i},dJ_{i}/dy_{i}] (2 * N)
return dJ_dxy_mat
[docs]def dJ_dtheta_cal(xi, yi, dJ_dxy_mat, theta, dev, ifrigid=False):
"""
#dxy_da:
#{x * cos(rad_phi), x * sin(rad_phi)}
#dxy_dd:
#{-y * sin(rad_phi), y * cos(rad_phi)}
#dxy_dphi:
#{-d * y * cos(rad_phi) - a * x * sin(rad_phi), a * x * cos(rad_phi) - d * y * sin(rad_phi)}
#dxy_dt1:
#{1, 0}
#dxy_dt2:
#{0, 1}
# when we set d = a (rigid):
#dxy_da
#{x * cos(rad_phi) - y * sin(rad_phi), y * cos(rad_phi) + x * sin(rad_phi)}
#dxy_dd - set as the same value as dxy_da
#{x * cos(rad_phi) - y * sin(rad_phi), y * cos(rad_phi) + x * sin(rad_phi)}
#dxy_dphi
#{-a * y * cos(rad_phi) - a * x * sin(rad_phi), a * x * cos(rad_phi) - a * y * sin(rad_phi)}
"""
N = xi.shape[0]
rad_phi = theta[2, 0].deg2rad()
cos_rad_phi = rad_phi.cos()
sin_rad_phi = rad_phi.sin()
ones = torch.ones(N, device=dev)
zeros = torch.zeros(N, device=dev)
if ifrigid:
#### let d = a, only allow scaling, rotation and translation (Similarity transformation)
#### If we want to use pure rigid transformation, just set `alpha_basis` as `[0,0,x,x,x]`, then the theta[0] will be always 1.
dxy_dtheta = torch.stack(
[
torch.stack(
[
xi * cos_rad_phi - yi * sin_rad_phi, # dxy_da (rigid)
xi * cos_rad_phi - yi * sin_rad_phi, # dxy_dd - won't use (rigid)
-theta[0] * cos_rad_phi * yi - theta[0] * xi * sin_rad_phi, # dxy_dphi
ones, # dxy_dt1
zeros,
]
), # dxy_dt2
torch.stack(
[
yi * cos_rad_phi + xi * sin_rad_phi, # dxy_da (rigid)
yi * cos_rad_phi + xi * sin_rad_phi, # dxy_dd - won't use (rigid)
theta[0] * xi * cos_rad_phi - theta[0] * yi * sin_rad_phi, # dxy_dphi
zeros, # dxy_dt1
ones,
]
),
]
) # dxy_dt2
else:
dxy_dtheta = torch.stack(
[
torch.stack(
[
xi * cos_rad_phi, # dxy_da
-yi * sin_rad_phi, # dxy_dd
-theta[1] * cos_rad_phi * yi - theta[0] * xi * sin_rad_phi, # dxy_dphi
ones, # dxy_dt1
zeros,
]
), # dxy_dt2
torch.stack(
[
xi * sin_rad_phi, # dxy_da
yi * cos_rad_phi, # dxy_dd
theta[0] * xi * cos_rad_phi - theta[1] * yi * sin_rad_phi, # dxy_dphi
zeros, # dxy_dt1
ones,
]
),
]
) # dxy_dt2
dJ_dtheta = torch.bmm(
dxy_dtheta.permute(2, 1, 0), dJ_dxy_mat.transpose(0, 1).unsqueeze(-1) ### [N,5,2] ### [N,2,1]
).squeeze(
2
) # [dJ_{i}/dtheta_{k}] (N * 5)
dJ_dtheta = dJ_dtheta.sum(0)
return dJ_dtheta
[docs]def theta_renew(theta, dJ_dtheta, alpha, ifrigid=False):
alpha_dJ = alpha * dJ_dtheta.reshape(5, 1)
alpha_dJ[0:3] = alpha_dJ[0:3] / 1000 # avoid dtheta_{abcd} change a lot of x and y
if ifrigid & (theta[0] == -theta[1]):
# only when the rigid transformation is allowed, we should check the value of d and a if they are mirrored.
# if d and a are mirrored (setting in the prelocate `d = d * mirror``), we should set alpha_dJ[1] as the `-alpha_dJ[1]`.
alpha_dJ[1] = -alpha_dJ[1]
theta_new = theta - alpha_dJ
return theta_new
[docs]def affine_trans_t(theta, coords_t):
rad_phi = theta[2, 0].deg2rad()
cos_rad_phi = rad_phi.cos()
sin_rad_phi = rad_phi.sin()
A = torch.Tensor(
[
[theta[0, 0] * cos_rad_phi, -theta[1, 0] * sin_rad_phi],
[theta[0, 0] * sin_rad_phi, theta[1, 0] * cos_rad_phi],
]
).to(theta.device)
t_vec = theta[3:5, :]
coords_t1 = torch.mm(A, coords_t.T) + t_vec
coords_t1 = coords_t1.T
return coords_t1
[docs]def torch_Bspline(uv, kl):
return torch.where(
kl == 0,
(1 - uv) ** 3 / 6,
torch.where(
kl == 1,
uv**3 / 2 - uv**2 + 2 / 3,
torch.where(
kl == 2,
(-3 * uv**3 + 3 * uv**2 + 3 * uv + 1) / 6,
torch.where(kl == 3, uv**3 / 6, torch.zeros_like(uv)),
),
),
)
[docs]def BSpline_GD_preparation(max_xy, mesh_size, dev, mesh_weight):
delta = max_xy / (mesh_size - 1.0)
mesh = np.ones((2, mesh_size + 3, mesh_size + 3)) ## 2 * (mesh_size + 3) * (mesh_size + 3)
for i in range(mesh_size + 3):
for j in range(mesh_size + 3):
mesh[:, i, j] = [
(i - 1) * delta[0],
(j - 1) * delta[1],
] ## 0 - -delta, 1 - 0, 2 - delta, ..., 6 - delta * 5, 7 - delta * 6 (last row)
mesh = torch.tensor(mesh).to(dev)
mesh_weight = torch.tensor(mesh_weight).to(dev) if type(mesh_weight) is np.ndarray else 1
kls = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(4))).flatten(1).to(dev) ## 2 * 16
dxy_ffd_all = torch.zeros(mesh.shape, device=dev) ## 2 * (mesh_size + 3) * (mesh_size + 3)
return mesh, mesh_weight, kls, dxy_ffd_all, delta
[docs]def BSpline_GD_uv_ij_calculate(coords_query_it, delta, dev):
pos_reg = coords_query_it.T / delta.reshape(2, 1).to(dev) # 2 * N
pos_floor = pos_reg.floor().long() # 2 * N
uv_raw = pos_reg - pos_floor # 2 * N
ij_raw = pos_floor - 1 # 2 * N
return uv_raw, ij_raw
[docs]def B_matrix(uv_t, kls_t):
result_B_list = []
for kl in kls_t.T:
B = torch_Bspline(uv_t, kl.view(2, 1)) # 2 * N[idx]
result_B_list.append(
B.prod(0, keepdim=True)
) # 1 * N[idx] ; .prod() - product of all elements in the tensor along a given dimension (0 - reduce along rows, 1 - reduce along columns)
return torch.cat(result_B_list, 0) # 16 * N[idx]
[docs]def get_dxy_ffd(ij, result_B_t, mesh, dJ_dxy_mat, mesh_weight, alpha_basis):
dxy_ffd_t = torch.zeros(mesh.shape, device=result_B_t.device)
ij_0 = ij[0] + 1
ij_1 = ij[1] + 1
for k in range(dxy_ffd_t.shape[1]):
for l in range(dxy_ffd_t.shape[2]):
mask = (ij_0 <= k) & (k <= ij_0 + 3) & (ij_1 <= l) & (l <= ij_1 + 3)
if mask.any(): # check if there is any True in the mask
idx_kl = mask.nonzero().flatten()
ij_t = torch.tensor([k, l], device=ij.device) - (ij[:, idx_kl].T + 1)
keys = ij_t[:, 0] * 4 + ij_t[:, 1]
t33 = result_B_t[keys, idx_kl]
dxy_ffd_t[:, k, l] -= torch.matmul(dJ_dxy_mat[:, idx_kl], t33.unsqueeze(1).float()).squeeze(1)
dxy_ffd_t *= mesh_weight
dxy_ffd_t = dxy_ffd_t * alpha_basis
return dxy_ffd_t
[docs]def BSpline_renew_coords(uv_t, kls_t, ij_t, mesh_trans):
result_tt = torch.zeros_like(uv_t, dtype=torch.float32)
for kl in kls_t.T:
B = torch_Bspline(uv_t, kl.view(2, 1))
pivots = (ij_t + 1 + kl.view(2, 1)).clamp(0, mesh_trans.size(-1) - 1)
mesh_t = mesh_trans[:, pivots[0], pivots[1]]
result_tt += B.prod(0, keepdim=True) * mesh_t
return result_tt.T
[docs]def reg_total_t(coords_q, coords_r, params_dist):
dev = params_dist.device
mean_q = coords_q.mean(0)
mean_r = coords_r.mean(0)
coords_q_t = torch.tensor(np.array(coords_q) - mean_q).float().to(dev) ## Initial location
coords_q_r1 = affine_trans_t(params_dist.theta_r1, coords_q_t) ## Prelocation 1st Affine
coords_q_r2 = affine_trans_t(params_dist.theta_r2, coords_q_r1) ## Affine transformation 2st Affine
if params_dist.mesh_trans_list != [] and params_dist.mesh_trans_list != [[]]:
coords_q_r3 = coords_q_r2.clone()
for round_t in range(len(params_dist.mesh_trans_list)):
coords_q_r3 = coords_q_r3.clone() - params_dist.min_qr2[round_t]
coords_q_r3 = FFD_Bspline_apply_t(coords_q_r3.clone(), params_dist, round_t)
coords_q_r3 = coords_q_r3.clone() + params_dist.min_qr2[round_t]
coords_q_f = coords_q_r3.clone()
else:
coords_q_f = coords_q_r2
coords_q_reconstruct = coords_q_f + torch.tensor(mean_r).to(dev)
coords_q_reconstruct = coords_q_reconstruct.float()
return coords_q_f, coords_q_reconstruct
[docs]def FFD_Bspline_apply_t(coords_q, params_dist, round_t=0):
mesh_trans_list = params_dist.mesh_trans_list[round_t]
dev = coords_q.device
img_size = params_dist.img_size_bs[round_t]
mesh_size = mesh_trans_list[0].shape[2] - 3
delta = img_size / (mesh_size - 1.0)
coords_query_it = copy.deepcopy(coords_q)
for it in trange(len(mesh_trans_list), desc="", leave=True):
mesh_trans = mesh_trans_list[it]
pos_reg = coords_query_it.T / delta.reshape(2, 1).to(dev)
pos_floor = pos_reg.floor().long()
uv = pos_reg - pos_floor
ij = pos_floor - 1
kls = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(4))).flatten(1).to(dev)
result_tt = torch.zeros_like(uv).float()
for kl in kls.T:
B = torch_Bspline(uv, kl.view(2, 1))
pivots = (ij + 1 + kl.view(2, 1)).clamp(0, mesh_trans.size(-1) - 1)
mesh_t = mesh_trans[:, pivots[0], pivots[1]]
result_tt += B.prod(0, keepdim=True) * mesh_t
coords_query_it = result_tt.T
return coords_query_it
[docs]def rescale_coords(coords_raw, graph_list, rescale=False):
rescale_factor = 1
if rescale:
coords_raw = coords_raw.copy()
for sample_t in graph_list:
rescale_factor_t = 22340 / np.abs(coords_raw[sample_t]).max()
coords_raw[sample_t] = coords_raw[sample_t].copy() * rescale_factor_t
if sample_t == graph_list[1]:
rescale_factor = rescale_factor_t
return coords_raw, rescale_factor
#################### Visualization ####################
[docs]def mesh_plot(mesh_t, coords_q_t, mesh_trans_t=None):
mesh_no_last_row = mesh_t[:, :, :].numpy()
plt.figure(figsize=[10, 10])
plt.plot(mesh_no_last_row[0], mesh_no_last_row[1], "blue")
plt.plot(mesh_no_last_row.T[..., 0], mesh_no_last_row.T[..., 1], "blue")
if type(mesh_trans_t) is not type(None):
mesh_trans_no_last_row = mesh_trans_t[:, :, :].numpy()
plt.plot(mesh_trans_no_last_row[0], mesh_trans_no_last_row[1], "orange")
plt.plot(mesh_trans_no_last_row.T[..., 0], mesh_trans_no_last_row.T[..., 1], "orange")
plt.scatter(coords_q_t.T[0, :], coords_q_t.T[1, :], c="blue", s=0.5, alpha=0.5, rasterized=True)
[docs]def plot_mid(coords_q, coords_r, output_path="", filename=None, title_t=["ref", "query"], s_t=8, scale_bar_t=None):
plt.rcParams.update({"font.size": 30, "axes.titlesize": 30, "pdf.fonttype": 42, "legend.markerscale": 5})
plt.figure(figsize=[10, 12])
plt.scatter(
np.array(coords_r)[:, 0].tolist(),
np.array(coords_r)[:, 1].tolist(),
s=s_t,
edgecolors="none",
alpha=0.5,
rasterized=True,
c="#9295CA",
label=title_t[0],
)
plt.scatter(
np.array(coords_q)[:, 0].tolist(),
np.array(coords_q)[:, 1].tolist(),
s=s_t,
edgecolors="none",
alpha=0.5,
rasterized=True,
c="#E66665",
label=title_t[1],
)
plt.legend(fontsize=15)
plt.axis("equal")
if type(scale_bar_t) != type(None):
add_scale_bar(scale_bar_t[0], scale_bar_t[1])
if filename != None:
plt.savefig(os.path.join(output_path, filename + ".pdf"), dpi=100)
[docs]def corr_heat(
coords_q, coords_r, corr, output_path, title_t=["Corr in ref", "Anchor in query"], filename=None, scale_bar_t=None
):
plt.rcParams.update({"font.size": 20, "axes.titlesize": 20, "pdf.fonttype": 42})
random.seed(2)
sampled_points = np.sort(random.sample(list(range(0, coords_q.shape[0])), 20))
plt.figure(figsize=((40, 25)))
for t in range(0, len(sampled_points)):
plt_ind = t * 2
ins_cell_idx = sampled_points[t]
col_value = corr[ins_cell_idx, :]
col_value_bg = [0] * coords_q.shape[0]
col_value_bg[ins_cell_idx] = 1
size_value_bg = [5] * coords_q.shape[0]
size_value_bg[ins_cell_idx] = 30
plt.subplot(5, 8, plt_ind + 1)
plt.scatter(
np.array(coords_r[:, 0]),
np.array(coords_r[:, 1]),
s=5,
edgecolors="none",
c=col_value,
cmap="vlag",
vmin=-1,
vmax=1,
rasterized=True,
)
plt.title(title_t[0])
plt.axis("equal")
if type(scale_bar_t) != type(None):
add_scale_bar(scale_bar_t[0], scale_bar_t[1])
plt.subplot(5, 8, plt_ind + 2)
plt.scatter(
np.array(coords_q[:, 0]),
np.array(coords_q[:, 1]),
s=size_value_bg,
edgecolors="none",
c=col_value_bg,
cmap="vlag",
vmin=-1,
vmax=1,
rasterized=True,
)
plt.scatter(
np.array(coords_q[ins_cell_idx, 0]),
np.array(coords_q[ins_cell_idx, 1]),
s=size_value_bg[ins_cell_idx],
edgecolors="none",
c=col_value_bg[ins_cell_idx],
cmap="vlag",
vmin=-1,
vmax=1,
rasterized=True,
)
plt.title(title_t[1])
plt.axis("equal")
if type(scale_bar_t) != type(None):
add_scale_bar(scale_bar_t[0], scale_bar_t[1])
plt.tight_layout()
plt.colorbar()
if filename != None:
plt.savefig(os.path.join(output_path, filename + ".pdf"), dpi=100, transparent=True)
[docs]def prelocate_loss_plot(J_t, output_path, prefix="test"):
plt.rcParams.update({"font.size": 15})
plt.figure(figsize=[5, 5])
plt.scatter(x=list(range(0, len(J_t))), y=J_t)
plt.savefig(f"{output_path}/{prefix}_prelocate_loss.pdf")
[docs]def register_result(
coords_q, coords_r, cov_anchor_t, bleeding, embed_stack, output_path, k=8, prefix="test", scale_t=1, index_list=None
):
idx_q = np.ones(coords_q.shape[0], dtype=bool) if index_list is None else index_list[0]
idx_r = np.ones(coords_r.shape[0], dtype=bool) if index_list is None else index_list[1]
coords_q = coords_q * scale_t
coords_r = coords_r * scale_t
kmeans = KMeans(n_clusters=k, random_state=0).fit(embed_stack)
cell_label = kmeans.labels_
cluster_pl = sns.color_palette("tab20", len(np.unique(cell_label)))
### panel 1 ###
plot_mid(coords_q[idx_q], coords_r[idx_r], output_path, f"{prefix}_Results_1", scale_bar_t=None)
### panel 2 ###
plt.figure(figsize=[10, 12])
plt.rcParams.update({"font.size": 10, "axes.titlesize": 20, "pdf.fonttype": 42})
col = coords_q[idx_q, 0]
row = coords_q[idx_q, 1]
cell_type_t = cell_label[0 : coords_q[idx_q].shape[0]]
for i in set(cell_type_t):
plt.scatter(
np.array(col)[cell_type_t == i],
np.array(row)[cell_type_t == i],
s=12,
edgecolors="none",
alpha=0.5,
rasterized=True,
c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],
label=str(i),
)
col = coords_r[idx_r, 0]
row = coords_r[idx_r, 1]
cell_type_t = cell_label[coords_q[idx_q].shape[0] :]
for i in set(cell_type_t):
plt.scatter(
np.array(col)[cell_type_t == i],
np.array(row)[cell_type_t == i],
s=12,
edgecolors="none",
alpha=0.5,
rasterized=True,
c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],
label=str(i),
)
plt.axis("equal")
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.title("K means (k = " + str(k) + ")", fontsize=30)
add_scale_bar(200, "200 µm")
plt.savefig(f"{output_path}/{prefix}_Results_2.pdf", dpi=300)
### panel 3 ###
plt.figure(figsize=[20, 12])
plt.subplot(1, 2, 1)
t_score = J_cal(torch.from_numpy(coords_q[idx_q]), torch.from_numpy(coords_r[idx_r]), cov_anchor_t, bleeding)
plt.scatter(
coords_q[idx_q, 0],
coords_q[idx_q, 1],
c=1 - t_score,
cmap="vlag",
vmin=-1,
vmax=1,
s=15,
edgecolors="none",
alpha=0.5,
rasterized=True,
)
add_scale_bar(200, "200 µm")
plt.subplot(1, 2, 2)
plt.scatter(coords_q[0, 0], coords_q[0, 1], c=[0], cmap="vlag", vmin=-1, vmax=1, s=15, alpha=0.5)
plt.axis("off")
plt.colorbar()
plt.savefig(f"{output_path}/{prefix}_Results_3.pdf", dpi=300)
[docs]def affine_reg_params(it_theta, similarity_score, iterations, output_path, prefix="test"):
plt.rcParams.update({"font.size": 15, "axes.titlesize": 15, "pdf.fonttype": 42})
similarity_score_t = copy.deepcopy(similarity_score)
titles = ["a", "d", "φ", "t1", "t2", "loss = " + format(similarity_score[-1], ".1f")]
plt.figure(figsize=[15, 8])
for i in range(0, 6):
plt.subplot(2, 4, i + 1)
if i == 5:
plt.scatter(list(range(0, len(similarity_score_t))), similarity_score_t, s=5)
else:
# plt.scatter(x = range(0,iterations),y=np.array(it_theta)[:,i,0],s = 5)
plt.scatter(x=range(1, len(similarity_score_t)), y=np.array(it_theta)[:, i, 0], s=5)
plt.title(titles[i], fontsize=20)
plt.savefig(os.path.join(output_path, prefix + "_params_Affine_GD_" + str(iterations) + "its.pdf"))
[docs]def CAST_STACK_rough(coords_raw_list, ifsquare=True, if_max_xy=True, percentile=None):
"""
coords_raw_list: list of numpy arrays, each array is the coordinates of a layer
ifsquare: if True, the coordinates will be scaled to a square
if_max_xy: if True, the coordinates will be scaled to the max value of the `max_range_x` and `max_range_y`, respectively (if ifsquare is False), or the max value of [max_range_x,max_range_y] (if ifsquare is True)
percentile: if not None, the min and max will be calculated based on the percentile of the coordinates for each slice.
"""
# Convert list of arrays to a single numpy array for easier processing
all_coords = np.concatenate(coords_raw_list)
# Finding the global min and max for both x and y
if percentile is None:
min_x, min_y = np.min(all_coords, axis=0)
max_x, max_y = np.max(all_coords, axis=0)
else:
min_x_list, min_y_list, max_x_list, max_y_list = [], [], [], []
for coords_t in coords_raw_list:
min_x_list.append(np.percentile(coords_t[:, 0], percentile))
min_y_list.append(np.percentile(coords_t[:, 1], percentile))
max_x_list.append(np.percentile(coords_t[:, 0], 100 - percentile))
max_y_list.append(np.percentile(coords_t[:, 1], 100 - percentile))
min_x, min_y = np.min(min_x_list), np.min(min_y_list)
max_x, max_y = np.max(max_x_list), np.max(max_y_list)
max_xy = np.array([max_x - min_x, max_y - min_y])
scaled_coords_list = []
for coords_t in coords_raw_list:
coords_t2 = (coords_t - coords_t.min(axis=0)) / np.ptp(coords_t, axis=0)
if if_max_xy:
max_xy_scale = max_xy
else:
max_xy_scale = max_xy / np.max(max_xy)
scaled_coords = coords_t2 * np.max(max_xy_scale) if ifsquare else coords_t2 * max_xy_scale
scaled_coords_list.append(scaled_coords)
return scaled_coords_list
#################### Calculation ####################
[docs]def coords_minus_mean(coord_t):
return np.array(coord_t) - np.mean(np.array(coord_t), axis=0)
[docs]def coords_minus_min(coord_t):
return np.array(coord_t) - np.min(np.array(coord_t), axis=0)
[docs]def max_minus_value(corr):
return np.max(corr) - corr
[docs]def coords_minus_min_t(coord_t):
return coord_t - coord_t.min(0)[0]
[docs]def max_minus_value_t(corr):
return corr.max() - corr
[docs]def corr_dist(query_np, ref_np, nan_as="min"):
from sklearn.metrics import pairwise_distances_chunked
def chunked_callback(dist_matrix, start):
return 1 - dist_matrix
chunks = pairwise_distances_chunked(
query_np, ref_np, metric="correlation", n_jobs=-1, working_memory=1024, reduce_func=chunked_callback
)
corr_q_r = np.vstack(list(chunks))
if nan_as == "min":
corr_q_r[np.isnan(corr_q_r)] = np.nanmin(corr_q_r)
return corr_q_r
[docs]def region_detect(embed_dict_t, coords0, k=20):
plot_row = int(np.floor((k + 1) / 4) + 1)
kmeans = KMeans(n_clusters=k, random_state=0).fit(embed_dict_t)
cell_label = kmeans.labels_
cluster_pl = sns.color_palette("tab20", len(np.unique(cell_label)))
plt.figure(figsize=((20, 5 * plot_row)))
plt.subplot(plot_row, 4, 1)
cell_label_idx = 0
col = coords0[:, 0].tolist()
row = coords0[:, 1].tolist()
cell_type_t = cell_label[cell_label_idx : (cell_label_idx + coords0.shape[0])]
cell_label_idx += coords0.shape[0]
for i in set(cell_type_t):
plt.scatter(
np.array(col)[cell_type_t == i],
np.array(row)[cell_type_t == i],
s=5,
edgecolors="none",
c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],
label=str(i),
)
plt.title(" (KMeans, k = " + str(k) + ")", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.axis("equal")
for j, i in enumerate(set(cell_type_t)):
plt.subplot(plot_row, 4, j + 2)
plt.scatter(np.array(col), np.array(row), s=3, c="#DDDDDD")
plt.scatter(
np.array(col)[cell_type_t == i],
np.array(row)[cell_type_t == i],
s=5,
edgecolors="none",
c=np.array(cluster_pl)[cell_type_t[cell_type_t == i]],
label=str(i),
)
plt.title(str(i), fontsize=20)
plt.axis("equal")
return cell_label