"""Written by @Jinerhal, adapted by @Xiaojieqiu.
"""
import random
from typing import Any, List, Optional, Tuple, Union
import cv2
import numpy as np
from anndata import AnnData
from skimage import morphology
from ..configuration import SKM
from ..logging import logger_manager as lm
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE)
[docs]def gen_cluster_image(
adata: AnnData,
bin_size: Optional[int] = None,
spatial_key: str = "spatial",
cluster_key: str = "scc",
label_mapping_key: str = "cluster_img_label",
cmap: str = "tab20",
show: bool = True,
) -> np.ndarray:
"""Generate matrix/image of spatial clusters with distinct labels/colors.
Args:
adata: The adata object used to create the matrix/image for clusters.
bin_size: The size of the binning.
spatial_key: The key name of the spatial coordinates in `adata.obs`
cluster_key: The key name of the spatial cluster in `adata.obs`
label_mapping_key: The key name to store the label index values, mapped from the cluster names in `adata.obs`.
Note that background is 0 so `label_mapping_key` starts from 1.
cmap: The colormap that will be used to draw colors for the resultant cluster image.
show: Whether to visualize the cluster image.
Returns:
cluster_label_image: A numpy array that stores the image of clusters, each with a distinct color. When `show`
is True, `plt.imshow(cluster_rgb_image)` will be used to plot the clusters each with distinct labels
prepared from the designated cmap.
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
if bin_size is None:
bin_size = adata.uns["bin_size"]
lm.main_info(f"Set up the color for the clusters with the {cmap} colormap.")
# TODO: what if cluster number is larger than cmap.N?
cmap = mpl.colormaps[cmap]
colors = cmap(np.arange(cmap.N))
color_ls = []
for i in range(cmap.N):
color_ls.append(tuple(np.array(colors[i][:3] * 255).astype(int)))
random.seed(1)
color_ls_sampled = random.sample(color_ls, len(np.unique(adata.obs[cluster_key])))
lm.main_info(f"Saving integer labels for clusters into adata.obs['{label_mapping_key}'].")
# background is 0, so adata.obs[label_mapping_key] starts from 1
adata.obs[label_mapping_key] = 0
cluster_list = np.unique(adata.obs[cluster_key])
for i in range(len(cluster_list)):
adata.obs[label_mapping_key][adata.obs[cluster_key] == cluster_list[i]] = i + 1
# get cluster image
lm.main_info(f"Prepare a mask image and assign each pixel to the corresponding cluster id.")
max_coords = [int(np.max(adata.obsm[spatial_key][:, 0])) + 1, int(np.max(adata.obsm[spatial_key][:, 1])) + 1]
cluster_label_image = np.zeros((max_coords[0], max_coords[1]), np.uint8)
for i in range(len(adata)):
# fill the image (mask) with the label
cv2.circle(
img=cluster_label_image,
center=(int(adata.obsm[spatial_key][i, 1]), int(adata.obsm[spatial_key][i, 0])),
radius=bin_size // 2,
color=int(adata.obs[label_mapping_key][i]),
thickness=-1,
)
if show:
lm.main_info(f"Plot the cluster image with the color(s) in the color list.")
cluster_rgb_image = np.zeros((max_coords[0], max_coords[1], 3), np.uint8)
for i in np.unique(adata.obs[label_mapping_key]):
cluster_rgb_image[cluster_label_image == i] = color_ls_sampled[i - 1]
plt.imshow(cluster_rgb_image)
return cluster_label_image
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata_high_res")
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata_low_res", optional=True)
[docs]def set_domains(
adata_high_res: AnnData,
adata_low_res: Optional[AnnData] = None,
spatial_key: str = "spatial",
cluster_key: str = "scc",
domain_key_prefix: str = "domain",
bin_size_high: Optional[int] = None,
bin_size_low: Optional[int] = None,
k_size: float = 2,
min_area: float = 9,
) -> None:
"""Set the domains for each bucket based on spatial clusters. Use adata object of low resolution for contour
identification but adata object of high resolution for domain assignment.
Args:
adata_high_res: The anndata object in high spatial resolution. The adata with smaller binning (or single cell
segmetnation) is more suitable to define more fine grained spatial domains.
adata_low_res: The anndata object in low spatial resolution. When using data with big binning, it can often
produce better spatial domain clustering results with the `scc` method and thus domain/domain contour
identification.
spatial_key: The key in `.obsm` of the spatial coordinate for each bucket. Should be same key in both
`adata_high_res` and `adata_low_res`.
cluster_key: The key in `.obs` (`adata_low_res`) to the spatial cluster.
domain_key_prefix: The key prefix in `.obs` (in `adata_high_res`) that will be used to store the spatial domain
for each bucket. The full key name will be set as: `domain_key_prefix` + "_" + `cluster_key`.
bin_size_low: The binning size of the `adata_high_res` object.
bin_size_low: The binning size of the `adata_low_res` object (only works when `adata_low_res` is provided).
k_size: Kernel size of the elliptic structuring element.
min_area: Minimal area threshold corresponding to the resulting contour(s).
Returns:
Nothing but update the `adata_high_res` with the `domain` in `domain_key_prefix` + "_" + `cluster_key`.
"""
domain_key = domain_key_prefix + "_" + cluster_key
if bin_size_high is None:
bin_size_high = adata_high_res.uns["bin_size"]
if adata_low_res is None:
adata_low_res = adata_high_res
bin_size_low = bin_size_high
elif bin_size_low is None:
bin_size_low = adata_low_res.uns["bin_size"]
lm.main_info(f"Generate the cluster label image with `gen_cluster_image`.")
cluster_label_image = gen_cluster_image(
adata_low_res, bin_size=bin_size_low, spatial_key=spatial_key, cluster_key=cluster_key, show=False
)
lm.main_info(f"Iterate through each cluster and identify contours with `extract_cluster_contours`.")
# TODO need a more stable mapping for ids and labels
u, count = np.unique(adata_low_res.obs[cluster_key], return_counts=True)
count_sort_ind = np.argsort(-count)
cluster_ids = u[count_sort_ind]
cluster_ids = [str(c) for c in cluster_ids]
u, count = np.unique(adata_low_res.obs["cluster_img_label"], return_counts=True)
# `cluster_img_label` is produced from `cluster_key`, so use the same count_sort_ind
cluster_labels = u[count_sort_ind]
cluster_labels = [c for c in cluster_labels]
adata_high_res.obs[domain_key] = "NA"
for i in range(len(cluster_ids)):
ctrs, _, _ = extract_cluster_contours(
cluster_label_image, cluster_labels[i], bin_size=bin_size_low, k_size=k_size, min_area=min_area, show=False
)
for j in range(len(adata_high_res)):
x = adata_high_res.obsm[spatial_key][j, 0]
y = adata_high_res.obsm[spatial_key][j, 1]
for k in range(len(ctrs)):
if cv2.pointPolygonTest(ctrs[k], (y, x), False) >= 0:
adata_high_res.obs[domain_key][j] = cluster_ids[i]
# assume one bucket to one domain mapping
break