from typing import Dict, List, Optional, Tuple, Union
import networkx
import numpy as np
import plotly.figure_factory as ff
import plotly.graph_objects as go
from plotly import callbacks
from ...logging import logger_manager as lm
[docs]class PlotNetwork:
def __init__(self, G: Union[networkx.Graph, networkx.DiGraph], layout: str):
"""Sets up and configures nodes and edges to plot a network graph.
Args:
G: Networkx graph object
layout: Controls shape of the plot. Options:
- random (default): Position nodes uniformly at random in the unit square.
For every node, a position is generated by choosing each of dim coordinates uniformly at random on
the interval [0.0, 1.0).
- circular: Position nodes on a circle
- kamada: Position nodes using Kamada-Kawai path-length cost-function
- planar: Position nodes without edges intersecting (only if possible)
- spring: Position nodes using Fruchterman-Reingold force-directed algorithm
- spectral: Position nodes using eigenvectors of the graph Laplacian
- spiral: Position nodes in a spiral layout
"""
self.G = G
self.layout = layout
self.logger = lm.get_main_logger()
if layout:
self.pos_dict = self._apply_layout(G, layout)
elif not networkx.get_node_attributes(G, "pos"):
self.logger.info("Invalid layout specified, defaulting to spring layout.")
self.pos_dict = self._apply_layout(G, "spring")
else:
self.logger.info("Layout information already present in graph.")
self.pos_dict = networkx.get_node_attributes(G, "pos")
self.inverse_pos_dict = {(v[0], v[1]): k for k, v in self.pos_dict.items()}
[docs] def generate_node_traces(
self,
colorscale: str,
colorbar_title: str,
color_method: Union[str, List[str]],
node_label: str,
node_text: List[str],
node_label_size: int,
node_label_position: str,
node_opacity: float,
size_method: Union[str, List[str]],
show_colorbar: bool = True,
) -> go.Scatter:
"""Formatting for nodes.
Args:
colorscale: Colormap to use for nodes
colorbar_title: Title for the colorbar
color_method: Either label of node property or list containing the color of each node
node_label: Node property to be used as label
node_text: List containing properties to be displayed when hovering over nodes
node_label_size: Font size of node text
node_label_position: Position of node labels. Options: 'top left', 'top center', 'top right',
'middle left', 'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right'
node_opacity: Transparency of nodes
size_method: Either label of node property or list containing the size of each node
show_colorbar: Set True to include colorbar, False to remove from plotting window
Returns:
node_trace: Plotly graph objects scatter plot
"""
node_mode = "markers+text" if node_label else "markers"
node_trace = go.Scatter(
x=[],
y=[],
mode=node_mode,
text=[],
hovertext=[],
hoverinfo="text",
textposition=node_label_position,
textfont=dict(size=node_label_size, color="black"),
showlegend=False,
marker=dict(
showscale=show_colorbar,
colorscale=colorscale,
reversescale=True,
color=[],
size=[],
colorbar=dict(
thickness=15,
title=colorbar_title,
xanchor="left",
titleside="right",
),
line_width=0,
opacity=node_opacity,
),
)
for node in self.G.nodes():
# Hover text default: name, degree
text = f"Node: {node}<br>Degree: {self.G.degree(node)}"
x, y = self.G.nodes[node]["pos"]
node_trace["x"] += (x,)
node_trace["y"] += (y,)
if node_label:
node_trace["text"] += (self.G.nodes[node][node_label],)
if node_text:
for prop in node_text:
text += f"<br></br>{prop}: {self.G.nodes[node][prop]}"
node_trace["hovertext"] += (text.strip(),)
if isinstance(size_method, list):
node_trace["marker"]["size"] = size_method
elif size_method == "degree":
node_trace["marker"]["size"] += (self.G.degree(node) + 12,)
elif size_method == "static":
node_trace["marker"]["size"] += (28,)
else:
node_trace["marker"]["size"] += (self.G.nodes[node][size_method],)
if isinstance(color_method, list):
node_trace["marker"]["color"] = color_method
elif color_method == "degree":
node_trace["marker"]["color"] += (self.G.degree(node),)
else:
node_trace["marker"]["color"] += (
(self.G.nodes[node][color_method],) if color_method in self.G.nodes[node] else (color_method,)
)
return node_trace
[docs] def generate_edge_traces(
self,
edge_label: str,
edge_label_size: int,
edge_label_position: str,
edge_text: List[str],
edge_attribute_for_linestyle: Optional[str] = None,
edge_attribute_for_thickness: Optional[str] = None,
add_text: bool = False,
) -> Tuple[List[go.Scatter], go.Scatter]:
"""Formatting for edges
Args:
edge_label: Edge property to be used as label
edge_label_size: Font size of edge text
edge_label_position: Position of edge labels. Options: 'top left', 'top center', 'top right',
'middle left', 'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right'
edge_text: List containing properties to be displayed when hovering over edges
edge_attribute_for_linestyle: Optional edge property to use for linestyle. If not given, will default to
property given to 'edge_label'.
edge_attribute_for_thickness: Optional edge property to use for thickness. If not given, all edges will
have thickness 1.
add_text: If True, will add text corresponding to edge_label onto the edges rather than only
adding different line styles.
Returns:
edge_traces: Plotly graph objects scatter plots
middle_node_trace: Labels are created by adding invisible nodes to the middle of each edge. This
trace contains information for these invisible nodes.
"""
edge_properties = {}
if edge_attribute_for_linestyle is None:
edge_attribute_for_linestyle = edge_label
unique_values = list(
{
edge[2].get(edge_attribute_for_linestyle)
for edge in self.G.edges(data=True)
if edge[2].get(edge_attribute_for_linestyle)
}
)
# Can accomodate up to four line styles
if len(unique_values) > 4:
self.logger.info("More than four unique labels detected. Using the first four.")
unique_values = unique_values[:4]
styles = {
unique_values[0]: dict(color="#888", dash="solid"),
unique_values[1]: dict(color="#555", dash="dash"),
unique_values[2]: dict(color="#222", dash="dot"),
unique_values[3]: dict(color="#000", dash="dashdot"),
}
# Initialize an empty list for the edge traces- these will be the lines connecting nodes
edge_traces = []
created_styles = set()
middle_node_trace = go.Scatter(
x=[],
y=[],
text=[],
mode="markers",
hoverinfo="text",
textposition=edge_label_position,
textfont=dict(size=edge_label_size, color="black"),
marker=dict(opacity=0),
showlegend=False,
)
for edge in self.G.edges(data=True):
x0, y0 = self.G.nodes[edge[0]]["pos"]
x1, y1 = self.G.nodes[edge[1]]["pos"]
if edge_attribute_for_thickness is not None and edge[2].get(edge_attribute_for_thickness):
thickness = (edge[2].get(edge_attribute_for_thickness) * 2) ** 2
else:
thickness = 1
if edge_attribute_for_linestyle is not None and edge[2].get(edge_attribute_for_linestyle):
style = styles.get(edge[2][edge_attribute_for_linestyle], {"color": "#888", "dash": "solid"})
else:
style = {"color": "#888", "dash": "solid"}
style_key = (style["color"], style["dash"])
edge_trace = go.Scatter(
x=(x0, x1, None),
y=(y0, y1, None),
line=dict(width=thickness, color=style["color"], dash=style["dash"]),
hoverinfo="text",
mode="lines",
name=edge[2].get(edge_attribute_for_linestyle, "Unknown Linestyle"),
showlegend=style_key not in created_styles,
)
created_styles.add(style_key)
edge_traces.append(edge_trace)
if edge_text or edge_label:
edge_pair = edge[0], edge[1]
if edge_pair not in edge_properties:
edge_properties[edge_pair] = {}
middle_node_trace["x"] += ((x0 + x1) / 2,)
middle_node_trace["y"] += ((y0 + y1) / 2,)
if edge_text:
for prop in edge_text:
if edge[2][prop] not in edge_properties[edge_pair]:
edge_properties[edge_pair][prop] = []
edge_properties[edge_pair][prop].append(edge[2][prop])
if add_text:
middle_node_trace["text"] += (edge[2][edge_label],)
middle_node_trace["mode"] = "markers+text"
if edge_text:
edge_text_list = ["\n".join(f"{k}: {v}" for k, v in vals.items()) for _, vals in edge_properties.items()]
middle_node_trace["hovertext"] = edge_text_list
return edge_traces, middle_node_trace
[docs] def _apply_layout(self, G, layout):
"""
Applies a layout to a Graph.
"""
layout_functions = {
"random": networkx.random_layout,
"circular": networkx.circular_layout,
"kamada": networkx.kamada_kawai_layout,
"planar": networkx.planar_layout,
"spring": networkx.spring_layout,
"spectral": networkx.spectral_layout,
"spiral": networkx.spiral_layout,
}
pos_dict = layout_functions[layout](G)
networkx.set_node_attributes(G, pos_dict, "pos")
return pos_dict
[docs] def on_hover(self, trace: go.Scatter, points: callbacks.Points, state: callbacks.InputDeviceState):
"""Callback function for when a node is hovered over
Args:
trace: Figure trace for the node
points: Points that are hovered over
"""
if not points.point_inds:
return
node = self.inverse_pos_dict[(points.xs[0], points.ys[0])]
neighbours = list(self.G.neighbors(node))
node_colours = list(trace.marker.color)
new_colors = ["#E4E4E4"] * len(node_colours)
new_colors[points.point_inds[0]] = node_colours[points.point_inds[0]]
for neighbour in neighbours:
trace_position = list(self.pos_dict).index(neighbour)
new_colors[trace_position] = node_colours[trace_position]
with self.f.batch_update():
trace.marker.color = new_colors
[docs] def on_unhover(self, trace: go.Scatter, points: callbacks.Points, state: callbacks.InputDeviceState):
"""
Callback function for when a node is unhovered over.
Args:
trace: go.Scatter
Figure trace for the node
points: callbacks.Points
Points that are hovered over
"""
with self.f.batch_update():
trace.marker.color = self.original_node_trace.marker.color
trace.marker.size = self.original_node_trace.marker.size
[docs]def plot_network(
G: Union[networkx.Graph, networkx.DiGraph],
title: str,
size_method: Union[str, List[float]],
color_method: Union[str, List[str]],
layout: Optional[str] = None,
node_label: Optional[str] = None,
node_label_position: str = "top center",
node_text: List[str] = None,
nodefont_size: int = 8,
edge_label: Optional[str] = None,
edge_thickness_attr: Optional[str] = None,
edge_label_position: str = "middle center",
edge_text: List[str] = None,
edgefont_size: int = 8,
titlefont_size: int = 16,
show_colorbar: bool = True,
colorscale: str = "YlGnBu",
colorbar_title: Optional[str] = None,
node_opacity: float = 0.8,
arrow_size: float = 2,
transparent_background: bool = True,
highlight_neighbors_on_hover: bool = True,
upper_margin: float = 40,
lower_margin: float = 20,
left_margin: float = 50,
right_margin: float = 50,
) -> go.FigureWidget:
"""Network graph using plotly, used to plot intercellular GRN as inferred by Spateo.
Args:
G: Networkx graph object
title: Title of the plot
size_method: Either label of node property or list containing the size of each node
color_method: Either label of node property or list containing the color of each node
layout: Controls shape of the plot. Options:
- random (default): Position nodes uniformly at random in the unit square.
For every node, a position is generated by choosing each of dim coordinates uniformly at random on the
interval [0.0, 1.0).
- circular: Position nodes on a circle
- kamada: Position nodes using Kamada-Kawai path-length cost-function
- planar: Position nodes without edges intersecting (only if possible)
- spring: Position nodes using Fruchterman-Reingold force-directed algorithm
- spectral: Position nodes using eigenvectors of the graph Laplacian
- spiral: Position nodes in a spiral layout
node_label: Node property to be used as label
node_label_position: Position of node labels. Options: 'top left', 'top center', 'top right', 'middle left',
'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right'
node_text: List containing properties to be displayed when hovering over nodes
nodefont_size: Size of 'node_label'
edge_label: Edge property to be used as label
edge_label_position: Position of edge labels. Options: 'top left', 'top center', 'top right', 'middle left',
'middle center', 'middle right', 'bottom left', 'bottom center', 'bottom right'
edge_thickness_attr: Edge property to be used for determining edge thickness
edge_text: List containing properties to be displayed when hovering over edges
edgefont_size: Size of 'edge_label'
titlefont_size: Size of title
show_colorbar: Set True to display colorbar
colorscale: Colormap used for the colorbar. Options: 'Greys', 'YlGnBu', 'Greens', 'YlOrRd', 'Bluered', 'RdBu',
'Reds', 'Blues', 'Picnic', 'Rainbow', 'Portland', 'Jet', 'Hot', 'Blackbody', 'Earth', 'Electric', 'Viridis'
colorbar_title: Colorbar title
node_opacity: Node transparency, from 0 to 1, where 0 is completely transparent
arrow_size: Size of the arrow for directed graphs, by default 2
transparent_background: Set True for transparent background
highlight_neighbours_on_hover: Set True to highlight neighbors of a node (by name) when hovering over it
upper_margin: Margin between top of the plot and top of the figure
lower_margin: Margin between bottom of the plot and bottom of the figure
left_margin: Margin between left of the plot and left of the figure
right_margin: Margin between right of the plot and right of the figure
Returns:
fig: Plotly figure widget object
"""
plot = PlotNetwork(G, layout)
# Below the title, always include explanation of line thickness:
if edge_thickness_attr is not None:
title = title + f"<br>Line thickness: {edge_thickness_attr}, L:R line thickness always = 1"
node_trace = plot.generate_node_traces(
colorscale=colorscale,
colorbar_title=colorbar_title,
color_method=color_method,
node_label=node_label,
node_text=node_text,
node_label_size=nodefont_size,
node_label_position=node_label_position,
node_opacity=node_opacity,
size_method=size_method,
show_colorbar=show_colorbar,
)
edge_traces, middle_node_trace = plot.generate_edge_traces(
edge_label=edge_label,
edge_label_size=edgefont_size,
edge_label_position=edge_label_position,
edge_text=edge_text,
edge_attribute_for_linestyle=edge_label,
edge_attribute_for_thickness=edge_thickness_attr,
)
fig = plot.generate_figure(
node_trace=node_trace,
edge_traces=edge_traces,
middle_node_trace=middle_node_trace,
title=title,
title_font_size=titlefont_size,
arrow_size=arrow_size,
transparent_background=transparent_background,
highlight_neighbors_on_hover=highlight_neighbors_on_hover,
upper_margin=upper_margin,
lower_margin=lower_margin,
left_margin=left_margin,
right_margin=right_margin,
)
return fig