"""
This file implements the LiveWire segmentation algorithm. The code is ported from:
1. https://github.com/pdyban/livewire: LiveWireSegmentation and compute_shortest_path functions/class.
2. https://github.com/Usama3627/live-wire: include the _compute_graph and compute_shortest_path functions.
"""
import math
from itertools import cycle
from typing import List, Optional, Tuple
import numpy as np
from ..logging import logger_manager as lm
[docs]class LiveWireSegmentation(object):
def __init__(self, image: Optional = None, smooth_image: bool = False, threshold_gradient_image: bool = False):
super(LiveWireSegmentation, self).__init__()
# init internal containers
# container for input image
self._image = None
# container for the gradient image
self.edges = None
# stores the image as an undirected graph for shortest path search
self.G = None
# init parameters
# should smooth the original image using bilateral smoothing filter
self.smooth_image = smooth_image
# should use the thresholded gradient image for shortest path computation
self.threshold_gradient_image = threshold_gradient_image
# init image
# store image and compute the gradient image
self.image = image
@property
[docs] def image(self):
return self._image
@image.setter
def image(self, value):
self._image = value
if self._image is not None:
if self.smooth_image:
self._smooth_image()
self._compute_gradient_image()
if self.threshold_gradient_image:
self._threshold_gradient_image()
self._compute_graph()
else:
self.edges = None
self.G = None
[docs] def _smooth_image(self):
from skimage import restoration
self._image = restoration.denoise_bilateral(self.image)
[docs] def _compute_gradient_image(self):
from skimage import filters
self.edges = filters.scharr(self._image)
[docs] def _threshold_gradient_image(self):
from skimage.filters import threshold_otsu
threshold = threshold_otsu(self.edges)
self.edges = self.edges > threshold
self.edges = self.edges.astype(float)
[docs] def _compute_graph(self):
try:
from dijkstar import Graph
except ImportError:
raise ImportError(
"You need to install the package `dijkstar`." "\nInstall dijkstar via `pip install --upgrade dijkstar`"
)
vertex = self.edges
h, w = self.edges.shape[1::-1]
graph = Graph(undirected=True)
# Iterating over an image and avoiding boundaries
for i in range(1, w - 1):
for j in range(1, h - 1):
G_x = float(vertex[i, j]) - float(vertex[i, j + 1]) # Center - right
G_y = float(vertex[i, j]) - float(vertex[i + 1, j]) # Center - bottom
G = np.sqrt((G_x) ** 2 + (G_y) ** 2)
if G_x > 0 or G_x < 0:
theeta = math.atan(G_y / G_x)
else:
theeta = 0
# Theeta is rotated in clockwise direction (90 degrees) to align with edge
theeta_a = theeta + math.pi / 2
G_x_a = abs(G * math.cos(theeta_a)) + 0.00001
G_y_a = abs(G * math.sin(theeta_a)) + 0.00001
# Strongest Edge will have lowest weights
W_x = 1 / G_x_a
W_y = 1 / G_y_a
# Assigning weights
graph.add_edge((i, j), (i, j + 1), W_x) # W_x is given to right of current vertex
graph.add_edge((i, j), (i + 1, j), W_y) # W_y is given to bottom of current vertex
self.G = graph
[docs] def compute_shortest_path(self, startPt, endPt):
try:
from dijkstar import find_path
except ImportError:
raise ImportError(
"You need to install the package `dijkstar`." "\nInstall dijkstar via `pip install --upgrade dijkstar`"
)
if self.image is None:
raise AttributeError("Load an image first!")
path = find_path(self.G, startPt, endPt)[0]
return path
[docs]def compute_shortest_path(image: np.ndarray, startPt: Tuple[float, float], endPt: Tuple[float, float]) -> List:
"""Inline function for easier computation of shortest_path in an image.
This function will create a new instance of LiveWireSegmentation class every time it is called, calling for a
recomputation of the gradient image and the shortest path graph. If you need to compute the shortest path in one
image more than once, use the class-form initialization instead.
Args:
image: image on which the shortest path should be computed
startPt: starting point for path computation
endPt: target point for path computation
Returns:
path: shortest path as a list of tuples (x, y), including startPt and endPt
"""
lm.main_info("Build LiveWireSegmentation object")
algorithm = LiveWireSegmentation(image)
lm.main_info("run compute_shortest_path to identify the shortest path")
path = algorithm.compute_shortest_path(startPt, endPt)
lm.main_finish_progress("compute_shortest_path")
return path
[docs]def live_wire(
image: np.ndarray,
smooth_image: bool = False,
threshold_gradient_image: bool = False,
interactive: bool = True,
) -> List[np.ndarray]:
"""Use LiveWire segmentation algorithm for image segmentation aka intelligent scissors. The general idea of the
algorithm is to use image information for segmentation and avoid crossing object boundaries. A gradient image
highlights the boundaries, and Dijkstra’s shortest path algorithm computes a path using gradient differences as
segment costs. Thus the line avoids strong gradients in the gradient image, which corresponds to following object
boundaries in the original image.
Now let's display the image using matplotlib front end. A click on the image starts livewire segmentation.
The suggestion for the best segmentation will appear as you will be moving mouse across the image. To submit a
suggestion, click on the image for the second time. To finish the segmentation, press Escape key.
Args:
image: image on which the shortest path should be computed.
smooth_image: Whether to smooth the original image using bilateral smoothing filter.
threshold_gradient_image: Wheter to use otsu method generate a thresholded gradient image for shortest path
computation.
interactive: Wether to generate the path interactively.
Returns:
A list of paths that are generated when running this algorithm. Paths can be used to segment a particular
spatial domain of interests.
"""
import matplotlib.pyplot as plt
algorithm = LiveWireSegmentation(
image, smooth_image=smooth_image, threshold_gradient_image=threshold_gradient_image
)
plt.gray()
COLORS = cycle("rgbyc") # use separate colors for consecutive segmentations
start_point = None
current_color = next(COLORS)
current_path = None
global path_list
def button_pressed(event):
global start_point
if start_point is None:
start_point = (int(event.ydata), int(event.xdata))
else:
end_point = (int(event.ydata), int(event.xdata))
# the line below is calling the segmentation algorithm
path = algorithm.compute_shortest_path(start_point, end_point)
path = np.array(path)
plt.plot(path[:, 1], path[:, 0], c=current_color)
start_point = end_point
path_list.append(path)
def mouse_moved(event):
if start_point is None:
return
end_point = (int(event.ydata), int(event.xdata))
# the line below is calling the segmentation algorithm
path = algorithm.compute_shortest_path(start_point, end_point)
global current_path
if current_path is not None:
current_path.pop(0).remove()
path = np.array(path)
current_path = plt.plot(path[:, 1], path[:, 0], c=current_color)
path_list.append(path)
plt.show()
def key_pressed(event):
if event.key == "escape":
global start_point, current_color
start_point = None
current_color = next(COLORS)
global current_path
if current_path is not None:
current_path.pop(0).remove()
current_path = None
plt.draw()
plt.show()
plt.connect("button_release_event", button_pressed)
if interactive:
plt.connect("motion_notify_event", mouse_moved)
plt.connect("key_press_event", key_pressed)
plt.imshow(image)
plt.autoscale(False)
plt.title("Livewire")
plt.show()
return path_list