import sys
import numpy as np
import time
import operator
import warnings
import networkx as nx
from .MergeTree import MergeTree
from .TopologicalObject import TopologicalObject
[docs]class ContourTree(TopologicalObject):
""" A class for computing a contour tree from two merge trees
Parameters
----------
graph : nglpy.Graph
A graph object used for determining neighborhoods in gradient estimation
gradient : str
An optional string specifying the type of gradient estimator to use.
Currently the only available option is 'steepest'.
normalization : str
An optional string specifying whether the inputs/output should be
scaled before computing. Currently, two modes are supported 'zscore'
and 'feature'. 'zscore' will ensure the data has a mean of zero and a
standard deviation of 1 by subtracting the mean and dividing by the
variance. 'feature' scales the data into the unit hypercube.
aggregator : str
An optional string that specifies what type of aggregation to do when
duplicates are found in the domain space. Default value is None meaning
the code will error if duplicates are identified.
debug : bool
An optional boolean flag for whether debugging output should be enabled.
short_circuit : bool
An optional boolean flag for whether the contour tree should be short
circuited. Enabling this will speed up the processing by bypassing the
fully augmented search and only focusing on partially augmented split
and join trees
"""
def __init__(
self,
graph=None,
gradient="steepest",
normalization=None,
aggregator=None,
debug=False,
short_circuit=True,
):
super(ContourTree, self).__init__(
graph=graph,
gradient=gradient,
normalization=normalization,
aggregator=aggregator,
debug=debug,
)
self.short_circuit = short_circuit
[docs] def reset(self):
""" Empties all internal storage containers
Returns
-------
None
"""
super(ContourTree, self).reset()
self.edges = []
self.augmentedEdges = {}
self.sortedNodes = []
self.branches = set()
self.superNodes = []
self.superArcs = []
[docs] def build(self, X, Y, w=None):
""" Assigns data to this object and builds the Contour Tree
Uses an internal graph given in the constructor to build a contour tree
on the passed in data. Weights are currently ignored.
Parameters
----------
X : np.ndarray
An m-by-n array of values specifying m n-dimensional samples
Y : np.array
An m vector of values specifying the output responses corresponding
to the m samples specified by X
w : np.array
An optional m vector of values specifying the weights associated to
each of the m samples used. Default of None means all points will be
equally weighted
Returns
-------
None
"""
super(ContourTree, self).build(X, Y, w)
# Build the join and split trees that we will merge into the
# contour tree
joinTree = MergeTree(debug=self.debug)
splitTree = MergeTree(debug=self.debug)
joinTree._build_for_contour_tree(self, True)
splitTree._build_for_contour_tree(self, False)
self.augmentedEdges = dict(joinTree.augmentedEdges)
self.augmentedEdges.update(dict(splitTree.augmentedEdges))
if self.short_circuit:
jt = self._construct_nx_tree(joinTree, splitTree)
st = self._construct_nx_tree(splitTree, joinTree)
else:
jt = self._construct_nx_tree(joinTree)
st = self._construct_nx_tree(splitTree)
self._process_tree(jt, st)
self._process_tree(st, jt)
# Now we have a fully augmented contour tree stored in nodes and
# edges The rest is some convenience stuff for querying later
self._identifyBranches()
self._identifySuperGraph()
if self.debug:
sys.stdout.write("Sorting Nodes: ")
start = time.perf_counter()
self.sortedNodes = sorted(enumerate(self.Y),
key=operator.itemgetter(1))
if self.debug:
end = time.perf_counter()
sys.stdout.write("%f s\n" % (end - start))
def _identifyBranches(self):
""" A helper function for determining all of the branches in the
tree. This should be called after the tree has been fully
constructed and its nodes and edges are populated.
"""
if self.debug:
sys.stdout.write("Identifying branches: ")
start = time.perf_counter()
seen = set()
self.branches = set()
# Find all of the branching nodes in the tree, degree > 1
# That is, they appear in more than one edge
for e1, e2 in self.edges:
if e1 not in seen:
seen.add(e1)
else:
self.branches.add(e1)
if e2 not in seen:
seen.add(e2)
else:
self.branches.add(e2)
if self.debug:
end = time.perf_counter()
sys.stdout.write("%f s\n" % (end - start))
def _identifySuperGraph(self):
""" A helper function for determining the condensed
representation of the tree. That is, one that does not hold
all of the internal nodes of the graph. The results will be
stored in ContourTree.superNodes and ContourTree.superArcs.
These two can be used to potentially speed up queries by
limiting the searching on the graph to only nodes on these
super arcs.
"""
if self.debug:
sys.stdout.write("Condensing Graph: ")
start = time.perf_counter()
G = nx.DiGraph()
G.add_edges_from(self.edges)
if self.short_circuit:
self.superNodes = G.nodes()
self.superArcs = G.edges()
# There should be a way to populate this from the data we
# have...
return
self.augmentedEdges = {}
N = len(self.Y)
processed = np.zeros(N)
for node in range(N):
# We can short circuit this here, since some of the nodes
# will be handled within the while loops below.
if processed[node]:
continue
# Loop through each internal node (see if below for
# determining what is internal), trace up and down to a
# node's first non-internal node in either direction
# removing all of the internal nodes and pushing them into a
# list. This list (removedNodes) will be put into a
# dictionary keyed on the endpoints of the final super arc.
if G.in_degree(node) == 1 and G.out_degree(node) == 1:
# The sorted list of nodes that will be condensed by
# this super arc
removedNodes = []
# Trace down to a non-internal node
lower_link = list(G.in_edges(node))[0][0]
while (
G.in_degree(lower_link) == 1
and G.out_degree(lower_link) == 1
):
new_lower_link = list(G.in_edges(lower_link))[0][0]
G.add_edge(new_lower_link, node)
G.remove_node(lower_link)
removedNodes.append(lower_link)
lower_link = new_lower_link
removedNodes.reverse()
removedNodes.append(node)
# Trace up to a non-internal node
upper_link = list(G.out_edges(node))[0][1]
while (
G.in_degree(upper_link) == 1
and G.out_degree(upper_link) == 1
):
new_upper_link = list(G.out_edges(upper_link))[0][1]
G.add_edge(node, new_upper_link)
G.remove_node(upper_link)
removedNodes.append(upper_link)
upper_link = new_upper_link
G.add_edge(lower_link, upper_link)
G.remove_node(node)
self.augmentedEdges[(lower_link, upper_link)] = removedNodes
# This is to help speed up the process by skipping nodes
# we have already condensed, and to prevent us from not
# being able to find nodes that have already been
# removed.
processed[removedNodes] = 1
self.superNodes = G.nodes()
self.superArcs = G.edges()
if self.debug:
end = time.perf_counter()
sys.stdout.write("%f s\n" % (end - start))
[docs] def get_seeds(self, threshold):
""" Returns a list of seed points for isosurface extraction given a
threshold value
Parameters
----------
threshold : float
The isovalue for which we want to identify seed points for
isosurface extraction
Returns
-------
list of int
A list of integers representing seed points in the data held by
this object. There will be one seed point for each connected
component of the isosurface defined by the given threshold value.
"""
seeds = []
for e1, e2 in self.superArcs:
# Because we did some extra work in _process_tree, we can
# safely assume e1 is lower than e2
if self.Y[e1] <= threshold <= self.Y[e2]:
if (e1, e2) in self.augmentedEdges:
# These should be sorted
edgeList = self.augmentedEdges[(e1, e2)]
elif (e2, e1) in self.augmentedEdges:
e1, e2 = e2, e1
# These should be reverse sorted
edgeList = list(reversed(self.augmentedEdges[(e1, e2)]))
else:
continue
startNode = e1
for endNode in edgeList + [e2]:
if self.Y[endNode] >= threshold:
# Stop when you find the first point above the
# threshold
break
startNode = endNode
seeds.append(startNode)
seeds.append(endNode)
return seeds
def _construct_nx_tree(self, thisTree, thatTree=None):
""" A function for creating networkx instances that can be used
more efficiently for graph manipulation than the MergeTree
class.
@ In, thisTree, a MergeTree instance for which we will
construct a networkx graph
@ In, thatTree, a MergeTree instance optionally used to
speed up the processing by bypassing the fully augmented
search and only focusing on the partially augmented
split and join trees
@ Out, nxTree, a networkx.Graph instance matching the
details of the input tree.
"""
if self.debug:
sys.stdout.write("Networkx Tree construction: ")
start = time.perf_counter()
nxTree = nx.DiGraph()
nxTree.add_edges_from(thisTree.edges)
nodesOfThatTree = []
if thatTree is not None:
nodesOfThatTree = thatTree.nodes.keys()
# Fully or partially augment the join tree
for (superNode, _), nodes in thisTree.augmentedEdges.items():
superNodeEdge = list(nxTree.out_edges(superNode))
if len(superNodeEdge) > 1:
warnings.warn(
"The supernode {} should have only a single "
"emanating edge. Merge tree is invalidly "
"structured".format(superNode)
)
endNode = superNodeEdge[0][1]
startNode = superNode
nxTree.remove_edge(startNode, endNode)
for node in nodes:
if thatTree is None or node in nodesOfThatTree:
nxTree.add_edge(startNode, node)
startNode = node
# Make sure this is not the root node trying to connect to
# itself
if startNode != endNode:
nxTree.add_edge(startNode, endNode)
if self.debug:
end = time.perf_counter()
sys.stdout.write("%f s\n" % (end - start))
return nxTree
def _process_tree(self, thisTree, thatTree):
""" A function that will process either a split or join tree
with reference to the other tree and store it as part of
this CT instance.
@ In, thisTree, a networkx.Graph instance representing a
merge tree for which we will process all of its leaf
nodes into this CT object
@ In, thatTree, a networkx.Graph instance representing the
opposing merge tree which will need to be updated as
nodes from thisTree are processed
@ Out, None
"""
if self.debug:
sys.stdout.write("Processing Tree: ")
start = time.perf_counter()
# Get all of the leaf nodes that are not branches in the other
# tree
if len(thisTree.nodes()) > 1:
leaves = set(
[
v
for v in thisTree.nodes()
if thisTree.in_degree(v) == 0 and thatTree.in_degree(v) < 2
]
)
else:
leaves = set()
while len(leaves) > 0:
v = leaves.pop()
# if self.debug:
# sys.stdout.write('\tProcessing {} -> {}\n'
# .format(v, thisTree.edges(v)[0][1]))
# Take the leaf and edge out of the input tree and place it
# on the CT
edges = list(thisTree.out_edges(v))
if len(edges) != 1:
warnings.warn(
"The node {} should have a single emanating "
"edge.\n".format(v)
)
e1 = edges[0][0]
e2 = edges[0][1]
# This may be a bit beside the point, but if we want all of
# our edges pointing 'up,' we can verify that the edges we
# add have the lower vertex pointing to the upper vertex.
# This is useful only for nicely plotting with some graph
# tools (graphviz/networkx), and I guess for consistency
# sake.
if self.Y[e1] < self.Y[e2]:
self.edges.append((e1, e2))
else:
self.edges.append((e2, e1))
# Removing the node will remove its constituent edges from
# thisTree
thisTree.remove_node(v)
# This is the root node of the other tree
if thatTree.out_degree(v) == 0:
thatTree.remove_node(v)
# if self.debug:
# sys.stdout.write('\t\tRemoving root {} from other tree\n'
# .format(v))
# This is a "regular" node in the other tree, suppress it
# there, but be sure to glue the upper and lower portions
# together
else:
# The other ends of the node being removed are added to
# "that" tree
if len(thatTree.in_edges(v)) > 0:
startNode = list(thatTree.in_edges(v))[0][0]
else:
# This means we are at the root of the other tree,
# we can safely remove this node without connecting
# its predecessor with its descendant
startNode = None
if len(thatTree.out_edges(v)) > 0:
endNode = list(thatTree.out_edges(v))[0][1]
else:
# This means we are at a leaf of the other tree,
# we can safely remove this node without connecting
# its predecessor with its descendant
endNode = None
if startNode is not None and endNode is not None:
thatTree.add_edge(startNode, endNode)
thatTree.remove_node(v)
# if self.debug:
# sys.stdout.write('\t\tSuppressing {} in other tree and '
# 'gluing {} to {}\n'
# .format(v, startNode, endNode))
if len(thisTree.nodes()) > 1:
leaves = set(
[
v
for v in thisTree.nodes()
if thisTree.in_degree(v) == 0
and thatTree.in_degree(v) < 2
]
)
else:
leaves = set()
# if self.debug:
# myMessage = '\t\tValid leaves: '
# sep = ''
# for leaf in leaves:
# myMessage += sep + str(leaf)
# sep = ','
# sys.stdout.write(myMessage+'\n')
if self.debug:
end = time.perf_counter()
sys.stdout.write("%f s\n" % (end - start))