import collections
import inspect
import textwrap
from typing import Any, Optional, OrderedDict
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import numpy.typing as npt
import pydot
from matplotlib.axes import Axes
from matplotlib.colors import rgb2hex
from pyvis.network import Network
from inheritance_explorer.similarity import PycodeSimilarity
class _ChildNode:
def __init__(
self,
child: Any,
child_id: int,
parent: Optional[Any] = None,
parent_id: Optional[int] = None,
color: Optional[str] = "#000000",
):
self.child = child
self.child_name: str = str(child.__name__)
self._child_id = child_id
self.parent = parent
self._parent_id = parent_id
self.parent_name = None
if parent:
self.parent_name = parent.__name__
self.color = color
self._extra_info = "comment string"
@property
def child_id(self) -> str:
return str(self._child_id)
@property
def parent_id(self) -> str | None:
if self._parent_id:
return str(self._parent_id)
return None
_similarity_container_types = PycodeSimilarity
[docs]
class ClassGraphTree:
"""
A hierarchical class graph container.
Parameters
----------
baseclass
the starting base class to begin mapping from
funcname: str
(optional) the name of a function to watch for overrides
default_color: str
(optional) the default outline color of nodes, in any graphviz string
func_override_color: str
(optional) the outline color of nodes that override funcname, in any
graphviz string
max_recursion_level: int
(optional) the max number of recrusion levels to map to. A value of 0
will show only the immediate children of the provided class. Default is
set to 500.
classes_to_exclude : List[str]
(optional) a list of class names to exclude from the mapping.
"""
def __init__(
self,
baseclass: Any,
funcname: Optional[str] = None,
default_color: str = "#000000",
func_override_color: str = "#ff0000",
similarity_cutoff: float = 0.75,
max_recursion_level: int = 500,
classes_to_exclude: Optional[list[str]] = None,
):
self.baseclass = baseclass
self.basename: str = baseclass.__name__
self.funcname = funcname
self._tracking_function = self.funcname is not None
self.max_recursion_level = max_recursion_level
self._nodenum: int = 0
self._node_list: list[_ChildNode] = [] # a list of unique ChildNodes
self._node_map: dict[int, str] = {} # map of global node index to node name
self._override_src: OrderedDict[int, str] = collections.OrderedDict()
self._override_src_files: dict[int, str] = {}
self._current_node = 1 # the current global node, must start at 1
self._default_color = default_color
self._override_color = func_override_color
self._graphviz_args_kwargs: dict[str, Any] = {}
self.similarity_container: _similarity_container_types | None = None
self.similarity_results: dict[str, npt.NDArray[Any]]
self.similarity_cutoff = similarity_cutoff
if classes_to_exclude is None:
classes_to_exclude = []
self.classes_to_exclude = classes_to_exclude
self._build()
self._node_map_r: dict[str, int] = {
v: k for k, v in self._node_map.items()
} # name to index
def _get_source_info(self, obj) -> Optional[str]:
if self.funcname is None:
raise RuntimeError("this functionality requires function tracking.")
fname: str = self.funcname
f = getattr(obj, fname)
if isinstance(f, collections.abc.Callable): # type: ignore[arg-type]
return f"{inspect.getsourcefile(f)}:{inspect.getsourcelines(f)[1]}"
return None
def _node_overrides_func(self, child, parent) -> bool:
childsrc = self._get_source_info(child)
parentsrc = self._get_source_info(parent)
if childsrc != parentsrc:
return True # it overrides!
return False
def _get_new_node_color(self, child, parent) -> str:
if self.funcname and self._node_overrides_func(child, parent):
return self._override_color
return self._default_color
def _get_baseclass_color(self) -> str:
color = self._default_color
if self.funcname:
f = getattr(self.baseclass, self.funcname)
class_where_its_defined = f.__qualname__.split(".")[0]
if self.basename == class_where_its_defined:
# then its defined here, use the override color
color = self._override_color
return color
[docs]
def check_subclasses(
self, parent, parent_id: int, node_i: int, current_recursion_level: int
) -> int:
if current_recursion_level <= self.max_recursion_level:
for child in parent.__subclasses__():
if child.__name__ not in self.classes_to_exclude:
color = self._get_new_node_color(child, parent)
new_node = _ChildNode(
child, node_i, parent=parent, parent_id=parent_id, color=color
)
self._node_list.append(new_node)
self._node_map[node_i] = new_node.child_name
if self.funcname and self._node_overrides_func(child, parent):
self._store_node_func_source(child, node_i)
node_i += 1
node_i = self.check_subclasses(
child, node_i - 1, node_i, current_recursion_level + 1
)
return node_i
def _store_node_func_source(self, clss, current_node: int):
# store the source code of funcname for the current class and node
# clss: a class
# current_node: the
if self.funcname is None:
raise RuntimeError("this functionality requires function tracking.")
fname: str = self.funcname
f = getattr(clss, fname)
if isinstance(f, collections.abc.Callable): # type: ignore[arg-type]
src = textwrap.dedent(inspect.getsource(f))
self._override_src_files[current_node] = (
f"{inspect.getsourcefile(f)}:{inspect.getsourcelines(f)[1]}"
)
self._override_src[current_node] = src
[docs]
def check_source_similarity(
self,
similarity_container_class: str = "PycodeSimilarity",
method: str = "reference",
reference: Optional[int] = None,
):
# compares all the source code of the child methods that have
# over-ridden funcname
if reference is None:
ref = 1 # use whatever the basenode is
else:
ref = reference
if similarity_container_class == "PycodeSimilarity":
SimClass = PycodeSimilarity
else:
raise ValueError(f"unexpected value, {similarity_container_class=}")
self.similarity_container = SimClass(method=method)
sim = self.similarity_container.run(self._override_src, reference=ref)
return sim
def _build(self) -> None:
# construct the first node
color = self._get_baseclass_color()
self._node_list.append(
_ChildNode(self.baseclass, self._current_node, parent=None, color=color)
)
self._node_map[self._current_node] = self._node_list[-1].child_name
if self.funcname:
self._store_node_func_source(self.baseclass, self._current_node)
# now check all the children
self._current_node += 1
_ = self.check_subclasses(
self.baseclass, self._current_node - 1, self._current_node, 0
)
# construct the full similarity matrix
s_c = PycodeSimilarity(method="permute")
_, sim_matrix, sim_axis = s_c.run(self._override_src)
assert isinstance(sim_matrix, np.ndarray)
sim_axis_array = np.array(sim_axis)
sim_axis_names = np.array([c.child_name for c in self._node_list])
self.similarity_results = {
"matrix": sim_matrix,
"axis": sim_axis_array,
"axis_names": sim_axis_names,
}
cutoff_sim = self.similarity_cutoff
similarity_sets = {} # a dict that points to other similar nodes
M = sim_matrix
for irow in range(M.shape[0]):
rowvals = M[irow, :]
indxs = np.where(rowvals >= cutoff_sim)[0]
indxs = indxs[indxs != irow] # these are matrix indeces
node_ids = sim_axis_array[indxs]
if len(node_ids) > 0:
this_child = sim_axis_array[irow]
similarity_sets[this_child] = set(node_ids.tolist())
self.similarity_sets = similarity_sets
def _build_graph(
self, *args, include_similarity: bool = True, **kwargs
) -> pydot.Dot:
"""
build a digraph from the current node list
Parameters
----------
*args:
any arg accepted by pydot.Dot
include_similarity: bool
include edges for similar code (default True)
kwargs:
any additional keyword arguments are passed to graphviz.Digraph
"""
gtype = "digraph"
if "graph_type" in kwargs:
gtype = kwargs.pop("graph_type")
dot = pydot.Dot("test_graph", *args, graph_type=gtype, **kwargs)
iset = 0
Nsets = len(self.similarity_sets)
for node in self._node_list:
new_node = pydot.Node(
node.child_id, label=node.child_name, color=node.color
)
dot.add_node(new_node)
if node.parent:
dot.add_edge(pydot.Edge(node.child_id, node.parent_id))
if include_similarity:
if int(node.child_id) in self.similarity_sets:
R = (iset + 1.0) / Nsets * 0.5 + 0.5
G = 0.5
B = 0.5
hexcolor = rgb2hex((R, G, B))
iset += 1
for similar_node_id in self.similarity_sets[int(node.child_id)]:
new_edge = pydot.Edge(
node.child_id, str(similar_node_id), color=hexcolor
)
dot.add_edge(new_edge)
return dot
_graph = None
# @property
[docs]
def graph(self, *args, include_similarity: bool = True, **kwargs) -> pydot.Dot:
"""a GraphViz dot graph of the class hierarchy using pydot"""
# if self._graph is None:
self._graph = self._build_graph(
*args, include_similarity=include_similarity, **kwargs
)
return self._graph
[docs]
def set_graphviz_args_kwargs(self, *args, **kwargs):
self._graphviz_args_kwargs = {"args": args, "kwargs": kwargs}
[docs]
def show_graph(self, *args, env: str = "notebook", format: str = "png", **kwargs):
"""display a static GraphViz graph"""
return _show_graph(self.graph(*args, **kwargs), env=env, format=format)
[docs]
def plot_similarity(
self,
above_cutoff: Optional[bool] = False,
ax: Optional[Axes] = None,
colorbar: Optional[bool] = True,
**kwargs,
) -> tuple[dict[int, str], Axes]:
"""
add the similarity plot to a matplotlib axis (or create a new one)
Parameters
----------
above_cutoff: bool
if True (default False), plots where similarity > cutoff
ax: Axes
matplotlib axis to add the plot to. A new axis handle will be
created and returned if this is not set.
colorbar: bool
adds a colorbar to ax if True (default)
kwargs
any keyword argument accepted by plt.imshow()
Returns
-------
(sim_labels, ax)
sim_labels: dictionary mapping the matrix indices to label
ax: the modified (or new) axis handle
"""
if ax is None:
_, ax = plt.subplots(1)
if above_cutoff:
M = self.similarity_results["matrix"] > self.similarity_cutoff
else:
M = self.similarity_results["matrix"]
if "cmap" not in kwargs:
if above_cutoff:
kwargs["cmap"] = "gray"
else:
kwargs["cmap"] = "magma"
im = ax.imshow(M, **kwargs)
_ = ax.set_xticks(range(M.shape[0]))
_ = ax.set_yticks(range(M.shape[0]))
if colorbar:
plt.colorbar(im, ax=ax)
sim_labels = [
self._node_list[cid - 1].child_name for cid in self._override_src.keys()
]
sim_labels_dict = {lid: label for lid, label in enumerate(sim_labels)}
return sim_labels_dict, ax
[docs]
def build_interactive_graph(
self,
include_similarity: bool = True,
node_style: dict[str, Any] | None = None,
edge_style: dict[str, Any] | None = None,
similarity_edge_style: dict[str, Any] | None = None,
override_node_color: str | tuple[float, ...] | None = None,
**kwargs,
) -> Network:
"""
build an interactive Network graph from the current node list
Parameters
----------
include_similarity: bool
include edges for similar code (default True)
node_style: dict
a dictionary of parameters to pass to pyvis.network.Network.add_node
note that these settings will be applied to **all** nodes.
edge_style: dict
a dictionary of parameters to pass to pyvis.network.Network.add_edge
note that these settings will be applied to **all** edges.
similarity_edge_style: dict
a dictionary of parameters to pass to pyvis.network.Network.add_edge
for the similarity links. Only used if include_similarity is True.
override_node_color: str or tuple
the color for nodes that over-ride the function being tracked. Only
used if the base ClassGraphTree was initialized with a ``funcname``
to track.
kwargs:
any additional keyword arguments are passed to pyvis.Network
Returns
-------
Network
the pyvis.Network representation of the class hierarchy.
"""
if node_style is None:
node_style = {}
if edge_style is None:
edge_style = {}
if similarity_edge_style is None:
similarity_edge_style = {}
sim_node_physics = similarity_edge_style.pop("physics", False)
edge_physics = edge_style.pop("physics", True)
node_color = _validate_color(node_style.get("color", None), (0.7, 0.7, 0.7))
edge_color = _validate_color(edge_style.pop("color", None), (0.7, 0.7, 0.7))
sim_edge_color = _validate_color(
similarity_edge_style.pop("color", None), (0, 0.5, 1.0)
)
override_color = _validate_color(override_node_color, (0.5, 0.5, 1.0))
bgcolor = _validate_color(kwargs.pop("bgcolor", None), (1.0, 1.0, 1.0))
font_color = _validate_color(kwargs.pop("font_color", None), (0.0, 0.0, 0.0))
grph = nx.Graph(directed=True)
iset = 0
for node in self._node_list:
if node.color == "#000000":
clr_val = node_color
else:
# this node is over-ridden, use over-ride color
clr_val = override_color
node_style["color"] = clr_val
if node.parent:
parent_info = f"({node.parent.__name__})"
else:
parent_info = ""
grph.add_node(
node.child_id,
title=f"{node.child_name}{parent_info}",
**node_style,
)
if include_similarity:
if int(node.child_id) in self.similarity_sets:
iset += 1
for similar_node_id in self.similarity_sets[int(node.child_id)]:
grph.add_edge(
node.child_id,
str(similar_node_id),
color=sim_edge_color,
physics=sim_node_physics,
**similarity_edge_style,
)
arrowsop = {"from": {"enabled": True}}
if node.parent:
grph.add_edge(
node.child_id,
node.parent_id,
color=edge_color,
physics=edge_physics,
arrows=arrowsop,
**edge_style,
)
# return the interactive pyvis Network graph
network_wrapper = Network(
notebook=True, bgcolor=bgcolor, font_color=font_color, **kwargs
)
network_wrapper.from_nx(grph)
return network_wrapper
[docs]
def get_source_code(self, node: int | str) -> str:
"""
retrieve the source code of the comparison function for a
specified node
Parameters
----------
node: int
the node to fetch the source code for
Returns
-------
str
a string containing the source code for the node.
"""
node_id: int
if not isinstance(node, int) and not isinstance(node, str):
raise TypeError("Unexpected type for node")
if isinstance(node, int) and node in self._node_map:
node_id = node
elif isinstance(node, str) and node in self._node_map_r:
node_id = self._node_map_r[node]
else:
raise ValueError(f"Could not find node for {node}")
if node_id in self._override_src:
return self._override_src[node_id]
else:
raise ValueError(f"node {node} does not override the chosen function.")
[docs]
def get_multiple_source_code(
self, node_1: int | str, *args
) -> dict[int | str, str]:
"""
Retrieve the source code for multiple nodes
Parameters
----------
node_1: Union[str, int]
the first node to fetch
*args: Union[str, int]
any additional nodes to fetch
Returns
-------
dict
dictionary where the node is the key and the source code the value.
"""
src_dict = {}
src_dict[node_1] = self.get_source_code(node_1)
for src_key in args:
src_dict[src_key] = self.get_source_code(src_key)
return src_dict
[docs]
def display_code_comparison(self, include_overrides_only: bool = True):
"""
show the code comparison widget
Parameters
----------
include_overrides_only: bool
if True (default), only displays the classes that override the function
being compared.
"""
# add a check that we are running from a notebook?
if self.funcname is not None:
from inheritance_explorer._widget_support import display_code_compare
display_code_compare(self, include_overrides_only=include_overrides_only)
def _validate_color(clr, default_rgb_tuple: tuple[float, float, float]) -> str:
if clr is None:
return str(rgb2hex(default_rgb_tuple))
elif isinstance(clr, tuple):
return str(rgb2hex(clr))
elif isinstance(clr, str):
return clr
msg = f"clr has unexpected type: {type(clr)}"
raise TypeError(msg)
def _show_graph(dot_graph: pydot.Dot, format: str = "svg", env: str = "notebook"):
# return a GraphViz dot graph in a jupyter-friendly format.
create_func = getattr(dot_graph, f"create_{format}")
graph = create_func()
if env == "notebook":
if format == "svg":
from IPython.core.display import SVG
return SVG(graph)
else:
from IPython.core.display import Image
return Image(graph, unconfined=True)
return graph