Source code for spateo.plotting.static.networks

from typing import Dict, List, Optional, Tuple, Union

import networkx
import numpy as np
import plotly.figure_factory as ff
import plotly.graph_objects as go
from plotly import callbacks

from ...logging import logger_manager as lm


[docs]class PlotNetwork: def __init__(self, G: Union[networkx.Graph, networkx.DiGraph], layout: str): """Sets up and configures nodes and edges to plot a network graph. Args: G: Networkx graph object layout: Controls shape of the plot. Options: - random (default): Position nodes uniformly at random in the unit square. For every node, a position is generated by choosing each of dim coordinates uniformly at random on the interval [0.0, 1.0). - circular: Position nodes on a circle - kamada: Position nodes using Kamada-Kawai path-length cost-function - planar: Position nodes without edges intersecting (only if possible) - spring: Position nodes using Fruchterman-Reingold force-directed algorithm - spectral: Position nodes using eigenvectors of the graph Laplacian - spiral: Position nodes in a spiral layout """ self.G = G self.layout = layout self.logger = lm.get_main_logger() if layout: self.pos_dict = self._apply_layout(G, layout) elif not networkx.get_node_attributes(G, "pos"): self.logger.info("Invalid layout specified, defaulting to spring layout.") self.pos_dict = self._apply_layout(G, "spring") else: self.logger.info("Layout information already present in graph.") self.pos_dict = networkx.get_node_attributes(G, "pos") self.inverse_pos_dict = {(v[0], v[1]): k for k, v in self.pos_dict.items()}
[docs] def generate_node_traces( self, colorscale: str, colorbar_title: str, color_method: Union[str, List[str]], node_label: str, node_text: List[str], node_label_size: int, node_label_position: str, node_opacity: float, size_method: Union[str, List[str]], show_colorbar: bool = True, ) -> go.Scatter: """Formatting for nodes. Args: colorscale: Colormap to use for nodes colorbar_title: Title for the colorbar color_method: Either label of node property or list containing the color of each node node_label: Node property to be used as label node_text: List containing properties to be displayed when hovering over nodes node_label_size: Font size of node text node_label_position: Position of node labels. Options: 'top left', 'top center', 'top right', 'middle left', 'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right' node_opacity: Transparency of nodes size_method: Either label of node property or list containing the size of each node show_colorbar: Set True to include colorbar, False to remove from plotting window Returns: node_trace: Plotly graph objects scatter plot """ node_mode = "markers+text" if node_label else "markers" node_trace = go.Scatter( x=[], y=[], mode=node_mode, text=[], hovertext=[], hoverinfo="text", textposition=node_label_position, textfont=dict(size=node_label_size, color="black"), showlegend=False, marker=dict( showscale=show_colorbar, colorscale=colorscale, reversescale=True, color=[], size=[], colorbar=dict( thickness=15, title=colorbar_title, xanchor="left", titleside="right", ), line_width=0, opacity=node_opacity, ), ) for node in self.G.nodes(): # Hover text default: name, degree text = f"Node: {node}<br>Degree: {self.G.degree(node)}" x, y = self.G.nodes[node]["pos"] node_trace["x"] += (x,) node_trace["y"] += (y,) if node_label: node_trace["text"] += (self.G.nodes[node][node_label],) if node_text: for prop in node_text: text += f"<br></br>{prop}: {self.G.nodes[node][prop]}" node_trace["hovertext"] += (text.strip(),) if isinstance(size_method, list): node_trace["marker"]["size"] = size_method elif size_method == "degree": node_trace["marker"]["size"] += (self.G.degree(node) + 12,) elif size_method == "static": node_trace["marker"]["size"] += (28,) else: node_trace["marker"]["size"] += (self.G.nodes[node][size_method],) if isinstance(color_method, list): node_trace["marker"]["color"] = color_method elif color_method == "degree": node_trace["marker"]["color"] += (self.G.degree(node),) else: node_trace["marker"]["color"] += ( (self.G.nodes[node][color_method],) if color_method in self.G.nodes[node] else (color_method,) ) return node_trace
[docs] def generate_edge_traces( self, edge_label: str, edge_label_size: int, edge_label_position: str, edge_text: List[str], edge_attribute_for_linestyle: Optional[str] = None, edge_attribute_for_thickness: Optional[str] = None, add_text: bool = False, ) -> Tuple[List[go.Scatter], go.Scatter]: """Formatting for edges Args: edge_label: Edge property to be used as label edge_label_size: Font size of edge text edge_label_position: Position of edge labels. Options: 'top left', 'top center', 'top right', 'middle left', 'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right' edge_text: List containing properties to be displayed when hovering over edges edge_attribute_for_linestyle: Optional edge property to use for linestyle. If not given, will default to property given to 'edge_label'. edge_attribute_for_thickness: Optional edge property to use for thickness. If not given, all edges will have thickness 1. add_text: If True, will add text corresponding to edge_label onto the edges rather than only adding different line styles. Returns: edge_traces: Plotly graph objects scatter plots middle_node_trace: Labels are created by adding invisible nodes to the middle of each edge. This trace contains information for these invisible nodes. """ edge_properties = {} if edge_attribute_for_linestyle is None: edge_attribute_for_linestyle = edge_label unique_values = list( { edge[2].get(edge_attribute_for_linestyle) for edge in self.G.edges(data=True) if edge[2].get(edge_attribute_for_linestyle) } ) # Can accomodate up to four line styles if len(unique_values) > 4: self.logger.info("More than four unique labels detected. Using the first four.") unique_values = unique_values[:4] styles = { unique_values[0]: dict(color="#888", dash="solid"), unique_values[1]: dict(color="#555", dash="dash"), unique_values[2]: dict(color="#222", dash="dot"), unique_values[3]: dict(color="#000", dash="dashdot"), } # Initialize an empty list for the edge traces- these will be the lines connecting nodes edge_traces = [] created_styles = set() middle_node_trace = go.Scatter( x=[], y=[], text=[], mode="markers", hoverinfo="text", textposition=edge_label_position, textfont=dict(size=edge_label_size, color="black"), marker=dict(opacity=0), showlegend=False, ) for edge in self.G.edges(data=True): x0, y0 = self.G.nodes[edge[0]]["pos"] x1, y1 = self.G.nodes[edge[1]]["pos"] if edge_attribute_for_thickness is not None and edge[2].get(edge_attribute_for_thickness): thickness = (edge[2].get(edge_attribute_for_thickness) * 2) ** 2 else: thickness = 1 if edge_attribute_for_linestyle is not None and edge[2].get(edge_attribute_for_linestyle): style = styles.get(edge[2][edge_attribute_for_linestyle], {"color": "#888", "dash": "solid"}) else: style = {"color": "#888", "dash": "solid"} style_key = (style["color"], style["dash"]) edge_trace = go.Scatter( x=(x0, x1, None), y=(y0, y1, None), line=dict(width=thickness, color=style["color"], dash=style["dash"]), hoverinfo="text", mode="lines", name=edge[2].get(edge_attribute_for_linestyle, "Unknown Linestyle"), showlegend=style_key not in created_styles, ) created_styles.add(style_key) edge_traces.append(edge_trace) if edge_text or edge_label: edge_pair = edge[0], edge[1] if edge_pair not in edge_properties: edge_properties[edge_pair] = {} middle_node_trace["x"] += ((x0 + x1) / 2,) middle_node_trace["y"] += ((y0 + y1) / 2,) if edge_text: for prop in edge_text: if edge[2][prop] not in edge_properties[edge_pair]: edge_properties[edge_pair][prop] = [] edge_properties[edge_pair][prop].append(edge[2][prop]) if add_text: middle_node_trace["text"] += (edge[2][edge_label],) middle_node_trace["mode"] = "markers+text" if edge_text: edge_text_list = ["\n".join(f"{k}: {v}" for k, v in vals.items()) for _, vals in edge_properties.items()] middle_node_trace["hovertext"] = edge_text_list return edge_traces, middle_node_trace
[docs] def generate_figure( self, node_trace: go.Scatter, edge_traces: List[go.Scatter], middle_node_trace: go.Scatter, title: str, title_font_size: int, arrow_size: float, transparent_background: bool, highlight_neighbors_on_hover: bool, upper_margin: float = 40, lower_margin: float = 20, left_margin: float = 50, right_margin: float = 50, ) -> go.FigureWidget: """Generate figure for graph""" annotations = [] data = [node_trace, middle_node_trace] data = data + edge_traces if isinstance(self.G, networkx.DiGraph): annotations.extend( dict( ax=self.G.nodes[edge[0]]["pos"][0], ay=self.G.nodes[edge[0]]["pos"][1], axref="x", ayref="y", x=self.G.nodes[edge[1]]["pos"][0] * 0.85 + self.G.nodes[edge[0]]["pos"][0] * 0.15, y=self.G.nodes[edge[1]]["pos"][1] * 0.85 + self.G.nodes[edge[0]]["pos"][1] * 0.15, xref="x", yref="y", showarrow=False, text="", ) for edge in self.G.edges() ) # Draw arrows: x_vals, y_vals, u_vals, v_vals = [], [], [], [] # Compute all edge lengths edge_lengths = [ np.linalg.norm(np.array(self.G.nodes[edge[1]]["pos"]) - np.array(self.G.nodes[edge[0]]["pos"])) for edge in self.G.edges() ] median_length = np.median(edge_lengths) for edge in self.G.edges(): start = np.array(self.G.nodes[edge[0]]["pos"]) end = np.array(self.G.nodes[edge[1]]["pos"]) # Calculate direction vector direction = end - start direction_length = np.linalg.norm(direction) direction_normalized = direction / direction_length # Adjust the arrow's end position based on edge length if direction_length <= median_length: scale_factor = 0.5 else: scale_factor = 0.9 start_shortened = start + scale_factor * direction x_vals.append(start_shortened[0]) y_vals.append(start_shortened[1]) u_vals.append(direction_normalized[0]) v_vals.append(direction_normalized[1]) # Create the quiver plot for the arrows quiver = ff.create_quiver( x_vals, y_vals, u_vals, v_vals, scale=0.1, arrow_scale=arrow_size, line=dict(width=2), showlegend=False ) for trace in quiver.data: data.append(trace) self.f = go.FigureWidget( data=data, layout=go.Layout( title=title, titlefont=dict(size=title_font_size), showlegend=True, hovermode="closest", margin=dict(b=lower_margin, l=left_margin, r=right_margin, t=upper_margin), annotations=annotations, xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor="rgba(0,0,0,0)" if transparent_background else "#fff", autosize=True, ), ) if transparent_background: self.f.update_layout(paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)") if highlight_neighbors_on_hover: self.original_node_trace = node_trace self.f.data[1].on_hover(self.on_hover) self.f.data[1].on_unhover(self.on_unhover) return self.f
[docs] def _apply_layout(self, G, layout): """ Applies a layout to a Graph. """ layout_functions = { "random": networkx.random_layout, "circular": networkx.circular_layout, "kamada": networkx.kamada_kawai_layout, "planar": networkx.planar_layout, "spring": networkx.spring_layout, "spectral": networkx.spectral_layout, "spiral": networkx.spiral_layout, } pos_dict = layout_functions[layout](G) networkx.set_node_attributes(G, pos_dict, "pos") return pos_dict
[docs] def on_hover(self, trace: go.Scatter, points: callbacks.Points, state: callbacks.InputDeviceState): """Callback function for when a node is hovered over Args: trace: Figure trace for the node points: Points that are hovered over """ if not points.point_inds: return node = self.inverse_pos_dict[(points.xs[0], points.ys[0])] neighbours = list(self.G.neighbors(node)) node_colours = list(trace.marker.color) new_colors = ["#E4E4E4"] * len(node_colours) new_colors[points.point_inds[0]] = node_colours[points.point_inds[0]] for neighbour in neighbours: trace_position = list(self.pos_dict).index(neighbour) new_colors[trace_position] = node_colours[trace_position] with self.f.batch_update(): trace.marker.color = new_colors
[docs] def on_unhover(self, trace: go.Scatter, points: callbacks.Points, state: callbacks.InputDeviceState): """ Callback function for when a node is unhovered over. Args: trace: go.Scatter Figure trace for the node points: callbacks.Points Points that are hovered over """ with self.f.batch_update(): trace.marker.color = self.original_node_trace.marker.color trace.marker.size = self.original_node_trace.marker.size
[docs]def plot_network( G: Union[networkx.Graph, networkx.DiGraph], title: str, size_method: Union[str, List[float]], color_method: Union[str, List[str]], layout: Optional[str] = None, node_label: Optional[str] = None, node_label_position: str = "top center", node_text: List[str] = None, nodefont_size: int = 8, edge_label: Optional[str] = None, edge_thickness_attr: Optional[str] = None, edge_label_position: str = "middle center", edge_text: List[str] = None, edgefont_size: int = 8, titlefont_size: int = 16, show_colorbar: bool = True, colorscale: str = "YlGnBu", colorbar_title: Optional[str] = None, node_opacity: float = 0.8, arrow_size: float = 2, transparent_background: bool = True, highlight_neighbors_on_hover: bool = True, upper_margin: float = 40, lower_margin: float = 20, left_margin: float = 50, right_margin: float = 50, ) -> go.FigureWidget: """Network graph using plotly, used to plot intercellular GRN as inferred by Spateo. Args: G: Networkx graph object title: Title of the plot size_method: Either label of node property or list containing the size of each node color_method: Either label of node property or list containing the color of each node layout: Controls shape of the plot. Options: - random (default): Position nodes uniformly at random in the unit square. For every node, a position is generated by choosing each of dim coordinates uniformly at random on the interval [0.0, 1.0). - circular: Position nodes on a circle - kamada: Position nodes using Kamada-Kawai path-length cost-function - planar: Position nodes without edges intersecting (only if possible) - spring: Position nodes using Fruchterman-Reingold force-directed algorithm - spectral: Position nodes using eigenvectors of the graph Laplacian - spiral: Position nodes in a spiral layout node_label: Node property to be used as label node_label_position: Position of node labels. Options: 'top left', 'top center', 'top right', 'middle left', 'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right' node_text: List containing properties to be displayed when hovering over nodes nodefont_size: Size of 'node_label' edge_label: Edge property to be used as label edge_label_position: Position of edge labels. Options: 'top left', 'top center', 'top right', 'middle left', 'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right' edge_thickness_attr: Edge property to be used for determining edge thickness edge_text: List containing properties to be displayed when hovering over edges edgefont_size: Size of 'edge_label' titlefont_size: Size of title show_colorbar: Set True to display colorbar colorscale: Colormap used for the colorbar. Options: 'Greys', 'YlGnBu', 'Greens', 'YlOrRd', 'Bluered', 'RdBu', 'Reds', 'Blues', 'Picnic', 'Rainbow', 'Portland', 'Jet', 'Hot', 'Blackbody', 'Earth', 'Electric', 'Viridis' colorbar_title: Colorbar title node_opacity: Node transparency, from 0 to 1, where 0 is completely transparent arrow_size: Size of the arrow for directed graphs, by default 2 transparent_background: Set True for transparent background highlight_neighbours_on_hover: Set True to highlight neighbors of a node (by name) when hovering over it upper_margin: Margin between top of the plot and top of the figure lower_margin: Margin between bottom of the plot and bottom of the figure left_margin: Margin between left of the plot and left of the figure right_margin: Margin between right of the plot and right of the figure Returns: fig: Plotly figure widget object """ plot = PlotNetwork(G, layout) # Below the title, always include explanation of line thickness: if edge_thickness_attr is not None: title = title + f"<br>Line thickness: {edge_thickness_attr}, L:R line thickness always = 1" node_trace = plot.generate_node_traces( colorscale=colorscale, colorbar_title=colorbar_title, color_method=color_method, node_label=node_label, node_text=node_text, node_label_size=nodefont_size, node_label_position=node_label_position, node_opacity=node_opacity, size_method=size_method, show_colorbar=show_colorbar, ) edge_traces, middle_node_trace = plot.generate_edge_traces( edge_label=edge_label, edge_label_size=edgefont_size, edge_label_position=edge_label_position, edge_text=edge_text, edge_attribute_for_linestyle=edge_label, edge_attribute_for_thickness=edge_thickness_attr, ) fig = plot.generate_figure( node_trace=node_trace, edge_traces=edge_traces, middle_node_trace=middle_node_trace, title=title, title_font_size=titlefont_size, arrow_size=arrow_size, transparent_background=transparent_background, highlight_neighbors_on_hover=highlight_neighbors_on_hover, upper_margin=upper_margin, lower_margin=lower_margin, left_margin=left_margin, right_margin=right_margin, ) return fig