Source code for spateo.tdr.models.models_individual.point_clouds
fromtypingimportOptional,Tuple,UnionimportnumpyasnpimportpyvistaaspvfromanndataimportAnnDatafrompandas.core.frameimportDataFramefrompyvistaimportPolyDatatry:fromtypingimportLiteralexceptImportError:fromtyping_extensionsimportLiteralfrom..utilitiesimportadd_model_labels################################ Construct point cloud model ################################
[docs]defconstruct_pc(adata:AnnData,layer:str="X",spatial_key:str="spatial",groupby:Union[str,tuple]=None,key_added:str="groups",mask:Union[str,int,float,list]=None,colormap:Union[str,list,dict]="rainbow",alphamap:Union[float,list,dict]=1.0,)->Tuple[PolyData,Optional[str]]:""" Construct a point cloud model based on 3D coordinate information. Args: adata: AnnData object. layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``. spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket. groupby: The key that stores clustering or annotation information in ``.obs``, a gene name or a list of gene names in ``.var``. key_added: The key under which to add the labels. mask: The part that you don't want to be displayed. colormap: Colors to use for plotting pc. The default colormap is ``'rainbow'``. alphamap: The opacity of the colors to use for plotting pc. The default alphamap is ``1.0``. Returns: pc: A point cloud, which contains the following properties: ``pc.point_data[key_added]``, the ``groupby`` information. ``pc.point_data[f'{key_added}_rgba']``, the rgba colors of the ``groupby`` information. ``pc.point_data['obs_index']``, the obs_index of each coordinate in the original adata. plot_cmap: Recommended colormap parameter values for plotting. """# create an initial pc.adata=adata.copy()bucket_xyz=adata.obsm[spatial_key].astype(np.float64)ifisinstance(bucket_xyz,DataFrame):bucket_xyz=bucket_xyz.valuespc=pv.PolyData(bucket_xyz)# The`groupby` array in original adata.obs or adata.X.mask_list=maskifisinstance(mask,list)else[mask]obs_names=set(adata.obs_keys())gene_names=set(adata.var_names.tolist())ifgroupbyisNone:groups=np.asarray(["same"]*adata.obs.shape[0],dtype=str)elifgroupbyinobs_names:groups=np.asarray(adata.obs[groupby].map(lambdax:"mask"ifxinmask_listelsex).values)elifgroupbyingene_namesorset(groupby)<=gene_names:adata.X=adata.Xiflayer=="X"elseadata.layers[layer]groups=np.asarray(adata[:,groupby].X.sum(axis=1).flatten())else:raiseValueError("`groupby` value is wrong.""\n`groupby` can be a string and one of adata.obs_names or adata.var_names.""\n`groupby` can also be a list and is a subset of adata.var_names.")_,plot_cmap=add_model_labels(model=pc,labels=groups,key_added=key_added,where="point_data",colormap=colormap,alphamap=alphamap,inplace=True,)# The obs_index of each coordinate in the original adata.pc.point_data["obs_index"]=np.array(adata.obs_names.tolist())returnpc,plot_cmap