from .CAST_Mark import *
from .CAST_Projection import *
from .CAST_Stack import *
from .model.model_GCNII import CCA_SSG, Args
from .utils import *
from .visualize import *
[docs]def CAST_MARK(
coords_raw_t,
exp_dict_t,
output_path_t,
task_name_t=None,
gpu_t=None,
args=None,
epoch_t=None,
if_plot=True,
graph_strategy="convex",
device="cuda:0",
):
### setting
try:
import dgl
except:
print("Maybe you need to using `pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu121/repo.html`")
raise ImportError("Please install the dgl package from https://www.dgl.ai/pages/start.html")
# gpu_t = 0 if torch.cuda.is_available() and gpu_t is None else -1
# device = 'cuda:0' if gpu_t == 0 else 'cpu'
samples = list(exp_dict_t.keys())
task_name_t = task_name_t if task_name_t is not None else "task1"
inputs = []
### construct delaunay graphs and input data
print(f"Constructing delaunay graphs for {len(samples)} samples...")
for sample_t in samples:
graph_dgl_t = delaunay_dgl(
sample_t, coords_raw_t[sample_t], output_path_t, if_plot=if_plot, strategy_t=graph_strategy
).to(device)
feat_torch_t = torch.tensor(exp_dict_t[sample_t], dtype=torch.float32, device=device)
inputs.append((sample_t, graph_dgl_t, feat_torch_t))
### parameters setting
if args is None:
args = Args(
dataname=task_name_t, # name of the dataset, used to save the log file
gpu=gpu_t, # gpu id, set to zero for single-GPU nodes
epochs=400, # number of epochs for training
lr1=1e-3, # learning rate
wd1=0, # weight decay
lambd=1e-3, # lambda in the loss function, refer to online methods
n_layers=9, # number of GCNII layers, more layers mean a deeper model, larger reception field, at a cost of VRAM usage and computation time
der=0.5, # edge dropout rate in CCA-SSG
dfr=0.3, # feature dropout rate in CCA-SSG
use_encoder=True, # perform a single-layer dimension reduction before the GNNs, helps save VRAM and computation time if the gene panel is large
encoder_dim=512, # encoder dimension, ignore if `use_encoder` set to `False`
)
args.epochs = epoch_t if epoch_t is not None else args.epochs
### Initialize the model
in_dim = inputs[0][-1].size(-1)
model = CCA_SSG(
in_dim=in_dim, encoder_dim=args.encoder_dim, n_layers=args.n_layers, use_encoder=args.use_encoder
).to(args.device)
### Training
print(f"Training on {args.device}...")
embed_dict, loss_log, model = train_seq(
graphs=inputs, args=args, dump_epoch_list=[], out_prefix=f"{output_path_t}/{task_name_t}_seq_train", model=model
)
### Saving the results
torch.save(embed_dict, f"{output_path_t}/demo_embed_dict.pt")
torch.save(loss_log, f"{output_path_t}/demo_loss_log.pt")
torch.save(model, f"{output_path_t}/demo_model_trained.pt")
print(f"Finished.")
print(f"The embedding, log, model files were saved to {output_path_t}")
return embed_dict
[docs]def CAST_STACK(
coords_raw,
embed_dict,
output_path,
graph_list,
params_dist=None,
tmp1_f1_idx=None,
mid_visual=False,
sub_node_idxs=None,
rescale=False,
corr_q_r=None,
if_embed_sub=False,
early_stop_thres=None,
renew_mesh_trans=True,
):
### setting parameters
query_sample = graph_list[0]
ref_sample = graph_list[1]
prefix_t = f"{query_sample}_align_to_{ref_sample}"
result_log = dict()
coords_raw, result_log["ref_rescale_factor"] = rescale_coords(coords_raw, graph_list, rescale=rescale)
if sub_node_idxs is None:
sub_node_idxs = {
query_sample: np.ones(coords_raw[query_sample].shape[0], dtype=bool),
ref_sample: np.ones(coords_raw[ref_sample].shape[0], dtype=bool),
}
if params_dist is None:
params_dist = reg_params(
dataname=query_sample,
gpu=0,
#### Affine parameters
iterations=500,
dist_penalty1=0,
bleeding=500,
d_list=[3, 2, 1, 1 / 2, 1 / 3],
attention_params=[None, 3, 1, 0],
#### FFD parameters
dist_penalty2=[0],
alpha_basis_bs=[500],
meshsize=[8],
iterations_bs=[400],
attention_params_bs=[[tmp1_f1_idx, 3, 1, 0]],
mesh_weight=[None],
)
if params_dist.alpha_basis == []:
params_dist.alpha_basis = torch.Tensor([1 / 3000, 1 / 3000, 1 / 100, 5, 5]).reshape(5, 1).to(params_dist.device)
round_t = 0
plt.rcParams.update({"pdf.fonttype": 42})
plt.rcParams["axes.grid"] = False
### Generate correlation matrix of the graph embedding
if corr_q_r is None:
if if_embed_sub:
corr_q_r = corr_dist(
embed_dict[query_sample].cpu()[sub_node_idxs[query_sample]],
embed_dict[ref_sample].cpu()[sub_node_idxs[ref_sample]],
)
else:
corr_q_r = corr_dist(embed_dict[query_sample].cpu(), embed_dict[ref_sample].cpu())
else:
corr_q_r = corr_q_r
# Plot initial coordinates
kmeans_plot_multiple(
embed_dict, graph_list, coords_raw, prefix_t, output_path, k=15, dot_size=10
) if mid_visual else None
corr_heat(
coords_raw[query_sample][sub_node_idxs[query_sample]],
coords_raw[ref_sample][sub_node_idxs[ref_sample]],
corr_q_r,
output_path,
filename=prefix_t + "_corr",
) if mid_visual else None
plot_mid(coords_raw[query_sample], coords_raw[ref_sample], output_path, f"{prefix_t}_raw")
### Initialize the coordinates and tensor
corr_q_r = torch.Tensor(corr_q_r).to(params_dist.device)
params_dist.mean_q = coords_raw[query_sample].mean(0)
params_dist.mean_r = coords_raw[ref_sample].mean(0)
coords_query = torch.Tensor(coords_minus_mean(coords_raw[query_sample])).to(params_dist.device)
coords_ref = torch.Tensor(coords_minus_mean(coords_raw[ref_sample])).to(params_dist.device)
### Pre-location
theta_r1_t = prelocate(
coords_query,
coords_ref,
max_minus_value_t(corr_q_r),
params_dist.bleeding,
output_path,
d_list=params_dist.d_list,
prefix=prefix_t,
index_list=[sub_node_idxs[k_t] for k_t in graph_list],
translation_params=params_dist.translation_params,
mirror_t=params_dist.mirror_t,
)
params_dist.theta_r1 = theta_r1_t
coords_query_r1 = affine_trans_t(params_dist.theta_r1, coords_query)
plot_mid(
coords_query_r1.cpu(), coords_ref.cpu(), output_path, prefix_t + "_prelocation"
) if mid_visual else None ### consistent scale with ref coords
### Affine
output_list = Affine_GD(
coords_query_r1,
coords_ref,
max_minus_value_t(corr_q_r),
output_path,
params_dist.bleeding,
params_dist.dist_penalty1,
alpha_basis=params_dist.alpha_basis,
iterations=params_dist.iterations,
prefix=prefix_t,
attention_params=params_dist.attention_params,
coords_log=True,
index_list=[sub_node_idxs[k_t] for k_t in graph_list],
mid_visual=mid_visual,
early_stop_thres=early_stop_thres,
ifrigid=params_dist.ifrigid,
)
similarity_score, it_J, it_theta, coords_log = output_list
params_dist.theta_r2 = it_theta[-1]
result_log["affine_J"] = similarity_score
result_log["affine_it_theta"] = it_theta
result_log["affine_coords_log"] = coords_log
result_log["coords_ref"] = coords_ref
# Affine results
affine_reg_params(
[i.cpu().numpy() for i in it_theta], similarity_score, params_dist.iterations, output_path, prefix=prefix_t
) # if mid_visual else None
if if_embed_sub:
embed_stack_t = np.row_stack(
(
embed_dict[query_sample].cpu().detach().numpy()[sub_node_idxs[query_sample]],
embed_dict[ref_sample].cpu().detach().numpy()[sub_node_idxs[ref_sample]],
)
)
else:
embed_stack_t = np.row_stack(
(embed_dict[query_sample].cpu().detach().numpy(), embed_dict[ref_sample].cpu().detach().numpy())
)
coords_query_r2 = affine_trans_t(params_dist.theta_r2, coords_query_r1)
register_result(
coords_query_r2.cpu().detach().numpy(),
coords_ref.cpu().detach().numpy(),
max_minus_value_t(corr_q_r).cpu(),
params_dist.bleeding,
embed_stack_t,
output_path,
k=20,
prefix=prefix_t,
scale_t=1,
index_list=[sub_node_idxs[k_t] for k_t in graph_list],
) # if mid_visual else None
if params_dist.iterations_bs[round_t] != 0:
### B-Spline free-form deformation
padding_rate = params_dist.PaddingRate_bs # by default, 0
coords_query_r2_min = coords_query_r2.min(0)[0] # The x and y min of the query coords
coords_query_r2_tmp = coords_minus_min_t(coords_query_r2) # min of the x and y is 0
max_xy_tmp = coords_query_r2_tmp.max(0)[0] # max_xy withouth padding
adj_min_qr2 = coords_query_r2_min - max_xy_tmp * padding_rate # adjust the min_qr2
setattr(params_dist, "img_size_bs", [(max_xy_tmp * (1 + padding_rate * 2)).cpu()]) # max_xy
params_dist.min_qr2 = [adj_min_qr2]
t1 = BSpline_GD(
coords_query_r2 - params_dist.min_qr2[round_t],
coords_ref - params_dist.min_qr2[round_t],
max_minus_value_t(corr_q_r),
params_dist.iterations_bs[round_t],
output_path,
params_dist.bleeding,
params_dist.dist_penalty2[round_t],
params_dist.alpha_basis_bs[round_t],
params_dist.diff_step,
params_dist.meshsize[round_t],
prefix_t + "_" + str(round_t),
params_dist.mesh_weight[round_t],
params_dist.attention_params_bs[round_t],
coords_log=True,
index_list=[sub_node_idxs[k_t] for k_t in graph_list],
mid_visual=mid_visual,
max_xy=params_dist.img_size_bs[round_t],
renew_mesh_trans=renew_mesh_trans,
)
# B-Spline FFD results
register_result(
t1[0].cpu().numpy(),
(coords_ref - params_dist.min_qr2[round_t]).cpu().numpy(),
max_minus_value_t(corr_q_r).cpu(),
params_dist.bleeding,
embed_stack_t,
output_path,
k=20,
prefix=prefix_t + "_" + str(round_t) + "_BSpine_" + str(params_dist.iterations_bs[round_t]),
index_list=[sub_node_idxs[k_t] for k_t in graph_list],
) # if mid_visual else None
# register_result(t1[0].cpu().numpy(),(coords_ref - coords_query_r2.min(0)[0]).cpu().numpy(),max_minus_value_t(corr_q_r).cpu(),params_dist.bleeding,embed_stack_t,output_path,k=20,prefix=prefix_t+ '_' + str(round_t) +'_BSpine_' + str(params_dist.iterations_bs[round_t]),index_list=[sub_node_idxs[k_t] for k_t in graph_list])# if mid_visual else None
result_log["BS_coords_log1"] = t1[4]
result_log["BS_J1"] = t1[3]
if renew_mesh_trans:
setattr(params_dist, "mesh_trans_list", [t1[1]])
else:
setattr(params_dist, "mesh_trans_list", [[t1[1][-1]]])
### Save results
torch.save(params_dist, os.path.join(output_path, f"{prefix_t}_params.data"))
torch.save(result_log, os.path.join(output_path, f"{prefix_t}_result_log.data"))
coords_final = dict()
_, coords_q_final = reg_total_t(coords_raw[query_sample], coords_raw[ref_sample], params_dist)
coords_final[query_sample] = (
coords_q_final.cpu() / result_log["ref_rescale_factor"]
) ### rescale back to the original scale
coords_final[ref_sample] = (
coords_raw[ref_sample] / result_log["ref_rescale_factor"]
) ### rescale back to the original scale
plot_mid(coords_final[query_sample], coords_final[ref_sample], output_path, f"{prefix_t}_align")
torch.save(coords_final, os.path.join(output_path, f"{prefix_t}_coords_final.data"))
return coords_final
[docs]def CAST_PROJECT(
sdata_inte, # the integrated dataset
source_sample, # the source sample name
target_sample, # the target sample name
coords_source, # the coordinates of the source sample
coords_target, # the coordinates of the target sample
scaled_layer="log2_norm1e4_scaled", # the scaled layer name in `adata.layers`, which is used to be integrated
raw_layer="raw", # the raw layer name in `adata.layers`, which is used to be projected into target sample
batch_key="protocol", # the column name of the samples in `obs`
use_highly_variable_t=True, # if use highly variable genes
ifplot=True, # if plot the result
n_components=50, # the `n_components` parameter in `sc.pp.pca`
umap_n_neighbors=50, # the `n_neighbors` parameter in `sc.pp.neighbors`
umap_n_pcs=30, # the `n_pcs` parameter in `sc.pp.neighbors`
min_dist=0.01, # the `min_dist` parameter in `sc.tl.umap`
spread_t=5, # the `spread` parameter in `sc.tl.umap`
k2=1, # select k2 cells to do the projection for each cell
source_sample_ctype_col="level_2", # the column name of the cell type in `obs`
output_path="", # the output path
umap_feature="X_umap", # the feature used for umap
pc_feature="X_pca_harmony", # the feature used for the projection
integration_strategy="Harmony", # 'Harmony' or None (use existing integrated features)
ave_dist_fold=3, # the `ave_dist_fold` is used to set the distance threshold (average_distance * `ave_dist_fold`)
save_result=True, # if save the results
ifcombat=True, # if use combat when using the Harmony integration
alignment_shift_adjustment=50, # to adjust the small alignment shift for the distance threshold)
color_dict=None, # the color dict for the cell type
adjust_shift=False, # if adjust the alignment shift by group
metric_t="cosine",
working_memory_t=1000, # the working memory for the pairwise distance calculation
):
#### integration
if integration_strategy == "Harmony":
sdata_inte = Harmony_integration(
sdata_inte=sdata_inte,
scaled_layer=scaled_layer,
use_highly_variable_t=use_highly_variable_t,
batch_key=batch_key,
umap_n_neighbors=umap_n_neighbors,
umap_n_pcs=umap_n_pcs,
min_dist=min_dist,
spread_t=spread_t,
source_sample_ctype_col=source_sample_ctype_col,
output_path=output_path,
n_components=n_components,
ifplot=True,
ifcombat=ifcombat,
)
elif integration_strategy is None:
print(f"Using the pre-integrated data {pc_feature} and the UMAP {umap_feature}")
#### Projection
idx_source = sdata_inte.obs[batch_key] == source_sample
idx_target = sdata_inte.obs[batch_key] == target_sample
source_cell_pc_feature = sdata_inte[idx_source, :].obsm[pc_feature]
target_cell_pc_feature = sdata_inte[idx_target, :].obsm[pc_feature]
sdata_ref, output_list = space_project(
sdata_inte=sdata_inte,
idx_source=idx_source,
idx_target=idx_target,
raw_layer=raw_layer,
source_sample=source_sample,
target_sample=target_sample,
coords_source=coords_source,
coords_target=coords_target,
output_path=output_path,
source_sample_ctype_col=source_sample_ctype_col,
target_cell_pc_feature=target_cell_pc_feature,
source_cell_pc_feature=source_cell_pc_feature,
k2=k2,
ifplot=ifplot,
umap_feature=umap_feature,
ave_dist_fold=ave_dist_fold,
alignment_shift_adjustment=alignment_shift_adjustment,
color_dict=color_dict,
metric_t=metric_t,
adjust_shift=adjust_shift,
working_memory_t=working_memory_t,
)
### Save the results
if save_result == True:
sdata_ref.write_h5ad(f"{output_path}/sdata_ref.h5ad")
torch.save(output_list, f"{output_path}/projection_data.pt")
return sdata_ref, output_list