Source code for spateo.tools.cluster_lasso

# -*- coding: utf-8 -*-
"""
@File    :   cluster_lasso.py
@Time    :   2022/06/29 10:43:55
@Author  :   LuluZuo XueWang
@Version :   1.0
@Desc    :   spatial cluster lasso
"""

from typing import Optional

import pandas as pd
import plotly.graph_objs as go
import plotly.offline as py
from ipywidgets import VBox


[docs]class Lasso: """Lasso an region of interest (ROI) based on spatial cluster. Examples: L = st.tl.Lasso(adata) L.vi_plot(group='group', group_color='group_color') """
[docs] __sub_inde = []
[docs] sub_adata = None
def __init__(self, adata): """ Args: adata: An Anndata object. """ self.adata = adata
[docs] def vi_plot( self, key="spatial", group: Optional[str] = None, group_color: Optional[str] = None, ): """Plot spatial cluster result and lasso ROI. Args: key: The column key in .obsm, default to be 'spatial'. group: The column key/name that identifies the grouping information (for example, clusters that correspond to different cell types) of buckets. group_color: The key in .uns, corresponds to a dictionary that map group names to group colors. Returns: sub_adata: subset of adata. """ if group and group_color: df = pd.DataFrame() df["group_ID"] = self.adata.obs_names df["labels"] = self.adata.obs[group].values df["spatial_0"] = self.adata.obsm[key][:, 0] df["spatial_1"] = self.adata.obsm[key][:, 1] df["color"] = df.labels.map(self.adata.uns[group_color]) py.init_notebook_mode() f = go.FigureWidget( [go.Scatter(x=df["spatial_0"], y=df["spatial_1"], mode="markers", marker_color=df["color"])] ) scatter = f.data[0] scatter.marker.opacity = 0.5 f.layout.plot_bgcolor = "rgb(255,255,255)" f.layout.autosize = True axis_dict = dict( showticklabels=False, autorange=True, ) f.layout.yaxis = axis_dict f.layout.xaxis = axis_dict # Create a table FigureWidget that updates on selection from points in the scatter plot of f t = go.FigureWidget( [ go.Table( header=dict( values=["group_ID", "labels", "spatial_0", "spatial_1"], fill=dict(color="#C2D4FF"), align=["left"] * 5, ), cells=dict( values=[df[col] for col in ["group_ID", "labels", "spatial_0", "spatial_1"]], fill=dict(color="#F5F8FF"), align=["left"] * 5, ), ) ] ) def selection_fn(trace, points, selector): t.data[0].cells.values = [ df.loc[points.point_inds][col] for col in ["group_ID", "labels", "spatial_0", "spatial_1"] ] Lasso.__sub_index = t.data[0].cells.values[0] Lasso.sub_adata = self.adata[ Lasso.__sub_index, ] scatter.on_selection(selection_fn) # Put everything together return VBox((f, t)) else: raise ValueError(f"adata.obsm doesn't have {group} or {group_color} is not in adata.uns")