import os
import pandas as pd
import re
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import seaborn as sns
import networkx as nx
import graphviz
from adjustText import adjust_text
import copy
import distutils
from distutils import util
import sys
import tfcomb
from tfcomb.utils import check_columns, check_type, check_string, check_value, random_string
from tfcomb.logging import TFcombLogger, InputError
import tobias
# fix 'dot' not found error
# only if conda is found
# https://stackoverflow.com/a/51267131
if os.path.exists(os.path.join(sys.prefix, 'conda-meta')):
# add install path of active conda bin
os.environ["PATH"] += os.pathsep + os.path.join(sys.prefix, 'bin')
[docs]def bubble(rules_table, yaxis="confidence", size_by="TF1_TF2_support", color_by="lift", figsize=(7,4), save=None):
"""
Plot bubble plot with TF1-TF2 pairs on the x-axis and a choice of measure on the y-axis, as well as color and size of bubbles.
Parameters
----------
rules_table : pandas.DataFrame
Dataframe containing data to plot.
yaxis : str, optional
Column containing yaxis information. Default: "confidence".
size_by : str
Default: "TF1_TF2_support".
color_by : str
Default: None
figsize : tuple
Default: (7,7).
save : str, optional
Save the plot to the file given in 'save'. Default: None.
Returns
--------
ax
"""
check_columns(rules_table, [yaxis, color_by, size_by])
check_type(figsize, tuple, "figsize")
fig, ax = plt.subplots(figsize=figsize)
with sns.axes_style("whitegrid"):
ax = sns.scatterplot(
data=rules_table,
ax=ax,
x=rules_table.index,
y=yaxis,
hue=color_by,
size=size_by,
palette="PuBu",
edgecolor=".7",
)
#Set legend
sns.move_legend(ax, "center left", bbox_to_anchor=(1.02, 0.5), borderaxespad=0)
# Tweak the figure to finalize
labels = list(rules_table.index)
ax.set_ylabel(yaxis, fontsize=12)
ax.set_xlabel("Co-occurring pairs", fontsize=12)
ax.set_xticks(range(len(labels))) #explicitly set xticks to prevent matplotlib error
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.grid(color="0.9") #very light grey
ax.set_axisbelow(True) #prevent grid from plotting above points
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
if save is not None:
plt.savefig(save, dpi=600, bbox_inches="tight")
return(ax)
[docs]def heatmap(rules_table, columns="TF1", rows="TF2", color_by="cosine", figsize=(7,7), save=None):
"""
Plot heatmap with TF1 and TF2 on rows and columns respectively. Heatmap colormap is chosen by .color_by.
Parameters
----------
rules_table : pandas.DataFrame
The <CombObj>.rules table calculated by market basket analysis
columns : str, optional
The name of the column in rules_table to use as heatmap column names. Default: TF1.
rows : str, optional
The name of the column in rules_table to use as heatmap row names. Default: TF2.
color_by : str, optional
The name of the column in rules_table to use as heatmap colors. Default: "cosine".
figsize : tuple
The size of the output heatmap. Default: (7,7)
save : str
Save the plot to the file given in 'save'. Default: None.
"""
#Test input format
check_columns(rules_table, [columns, rows, color_by])
# Create support table for the heatmap
pivot_table = rules_table.pivot(index=rows, columns=columns, values=color_by)
#Convert any NaN to null
pivot_table = pivot_table.fillna(0)
#Mask any NaN/0 values
mask = np.zeros_like(pivot_table)
mask[np.isnan(pivot_table)] = True
#Choose cmap based on values of 'color_by' columns
colorby_values = rules_table[color_by]
if np.min(colorby_values) < 0:
cmap = "bwr" #divergent colormap
center = 0
else:
cmap = "PuBu"
center = None
row_cluster = True if pivot_table.shape[0] > 1 else False
col_cluster = True if pivot_table.shape[1] > 1 else False
#Plot heatmap
h = sns.clustermap(pivot_table,
mask=mask,
cbar=True,
cmap=cmap,
center=center,
row_cluster=row_cluster,
col_cluster=col_cluster,
cbar_kws={'label': color_by},
xticklabels=True,
yticklabels=True,
figsize=figsize
)
xticklabels = h.ax_heatmap.axes.get_xticklabels()
yticklabels = h.ax_heatmap.axes.get_yticklabels()
h.ax_heatmap.axes.set_xticklabels(xticklabels, rotation=45, ha="right")
h.ax_heatmap.axes.set_yticklabels(yticklabels, rotation=0)
h.ax_heatmap.axes.set_facecolor('lightgrey') #color of NA-values
#plt.title("Top {0} association rules".format(n_rules))
#plt.tight_layout()
if save is not None:
plt.savefig(save, dpi=600)
return(h)
[docs]def scatter(table, x, y,
x_threshold=None,
y_threshold=None,
label=None,
label_fontsize=9,
label_color="red",
title=None,
save=None,
**kwargs):
"""
Plot scatter-plot of x/y values within table. Can also set thresholds and label values within plot.
Parameters
-----------
table : pd.DataFrame
A table containing columns of 'measure' and 'pvalue'.
x : str
Name of column in table containing values to map on the x-axis.
y : str
Name of column in table containing values to map on the y-axis.
x_threshold : float, tuple of floats or None, optional
Gives the option to visualize an x-axis threshold within plot. If None, no measure threshold is set. Default: None.
y_threshold : float, tuple of floats or None, optional
Gives the option to visualize an y-axis threshold within plot. If None, no measure threshold is set. Default: None.
label : str or list, optional
If None, no point labels are plotted. If "selection", the . Default: None.
label_fontsize : float, optional
Size of labels. Default: 9.
label_color : str, optional
Color of labels. Default: 'red'.
title : str, optional
Title of plot. Default: None.
kwargs : arguments
Any additional arguments are passed to sns.jointplot.
"""
check_columns(table, [x, y])
#Handle thresholds being either float or tuple
if x_threshold is not None:
x_threshold = (x_threshold,) if not isinstance(x_threshold, tuple) else x_threshold
for threshold in x_threshold:
check_value(threshold, name="x_threshold")
if y_threshold is not None:
y_threshold = (y_threshold,) if not isinstance(y_threshold, tuple) else y_threshold
for threshold in y_threshold:
check_value(threshold, name="y_threshold")
#Plot all data
x_finite = table[x][~np.isinf(table[x].astype(float))]
y_finite = table[y][~np.isinf(table[y].astype(float))]
g = sns.jointplot(x=x_finite, y=y_finite, space=0, **kwargs) #, joint_kws={"s": 100})
#Plot thresholds
if x_threshold is not None:
for threshold in x_threshold:
g.ax_joint.axvline(threshold, linestyle="--", color="grey")
g.ax_marg_x.axvline(threshold, linestyle="--", color="grey")
if y_threshold is not None:
for threshold in y_threshold:
g.ax_joint.axhline(threshold, linestyle="--", color="grey")
g.ax_marg_y.axhline(threshold, linestyle="--", color="grey")
## Mark selection of pairs below above thresholds in red
if x_threshold is not None or y_threshold is not None:
if x_threshold is not None:
if len(x_threshold) == 1:
x_threshold = (-np.inf, x_threshold[0]) #assume that value is lower bound
if y_threshold is not None:
if len(y_threshold) == 1:
y_threshold = (-np.inf, y_threshold[0])
#Set threshold to minimum if not set
selection = table[((table[x] <= x_threshold[0]) | (table[x] >= x_threshold[1])) &
((table[y] <= y_threshold[0]) | (table[y] >= y_threshold[1]))]
n_selected = len(selection) #including any non-finite values
#Mark chosen TF pairs in red
xvals = selection[x]
xvals_finite = xvals[~np.isinf(xvals)]
yvals = selection[y]
yvals_finite = yvals[~np.isinf(yvals)]
_ = sns.scatterplot(x=xvals_finite, y=yvals_finite, ax=g.ax_joint, color="red", linewidth=0,
label="Selection (n={0})".format(n_selected))
#Label given indices
if label is not None:
if isinstance(label, list):
#Check if labels are within table index
selection = table.loc[label, :]
txts = _add_labels(selection, x, y, g.ax_joint, color=label_color, label_fontsize=label_fontsize)
elif label == "selection":
txts = _add_labels(selection, x, y, g.ax_joint, color=label_color, label_fontsize=label_fontsize)
elif label == "all":
txts = _add_labels(table, x, y, g.ax_joint, color=label_color, label_fontsize=label_fontsize)
#Adjust positions of labels
adjust_text(txts,
x=table[x].tolist(),
y=table[y].tolist(),
ax=g.ax_joint,
#add_objects=[],
text_from_points=True,
arrowprops=dict(arrowstyle='-', color='black', lw=0.5),
expand_points=(1.2, 1.2),
expand_text=(1.2, 1.2)
)
if title is not None:
g.ax_marg_x.set_title(title)
#Save plot to file
if save is not None:
plt.savefig(save, dpi=600, bbox_inches="tight")
return(g)
#Add labels to ax
def _add_labels(table, x, y, ax, color="black", label_col=None, label_fontsize=9):
""" Utility to add labels to coordinates
Parameters
----------
table : pandas.DataFrame
A dataframe containing coordinates and labels to plot.
x : str
The name of a column in table containing x coordinates.
y : str
The name of a column in table containing y cooordinates.
ax : plt axes
Axes to plot texts on.
color : str, optional
Color of label text. Default: "black".
label_col : str, optional
Name of column containing labels to plot. Default: None (label is table index)
label_fontsize : str, optional
Size of labels. Default: 9.
Returns
--------
None
The labels are added to ax in place
"""
#Check if columns are in table
tfcomb.utils.check_columns(table, [x,y,label_col]) #label_col is not checked if it is None
#Add texts
txts = []
for label, row in table.iterrows():
coord = (row[x], row[y])
if label_col != None:
label = row[label_col]
txts.append(ax.text(coord[0], coord[1], label, fontsize=label_fontsize, color=color))
return(txts)
[docs]def go_bubble(table, aspect="MF", n_terms=20, threshold=0.05, save=None):
"""
Plot a bubble-style plot of GO-enrichment results.
Parameters
--------------
table : pandas.DataFrame
The output of tfcomb.analysis.go_enrichment.
aspect : str
One of ["MF", "BP", "CC"]
n_terms : int
Maximum number of terms to show in graph. Default: 20
threshold : float between 0-1
The p-value-threshold to show in plot.
save : str, optional
Save the plot to the file given in 'save'. Default: None.
Returns
----------
ax
"""
check_string(aspect, ["BP", "CC", "MF"], "aspect")
#aspect has to be one of {'BP', 'CC', 'MF'}
#Choose aspect
aspect_table = table[table["NS"] == aspect]
aspect_table.loc[:,"-log(p-value)"] = -np.log(aspect_table["p_fdr_bh"])
aspect_table.loc[:,"n_genes"] = aspect_table["study_count"]
#Sort by pvalue and ngenes
aspect_table = aspect_table.sort_values(["-log(p-value)", "p_uncorrected"], ascending=False)
aspect_table = aspect_table.iloc[:n_terms,:] #first n rows
#Plot enriched terms
#todo: size of plot depending on number of terms to show
ax = sns.scatterplot(x="-log(p-value)",
y="name",
size="n_genes",
#sizes=(20,500),
#alpha=0.5,
hue="-log(p-value)",
data=aspect_table,
#ax=this_ax
)
ax.set_title(aspect) #, pad=20, size=15)
ax.axvline(-np.log(threshold), color="red")
ax.set_ylabel(aspect)
ax.legend(bbox_to_anchor=(1.01, 1), borderaxespad=0)
ax.grid()
if save is not None:
plt.savefig(save, dpi=600)
return(ax)
def _truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
"""
Create a colormap with only a subset of the original range.
Source: https://stackoverflow.com/a/18926541
"""
new_cmap = colors.LinearSegmentedColormap.from_list(
'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
cmap(np.linspace(minval, maxval, n)))
new_cmap.set_bad
return new_cmap
def _rgb_to_hex(rgb):
return '#%02x%02x%02x' % rgb
def _values_to_cmap(values, plt_cmap=None):
""" Map values onto a cmap function taking value and returning hex color.
Parameters
------------
values : list-like object
An object containing values to be mapped to colors.
plt_cmap : str, optional
Name of a matplotlib colormap to use. Default: None (colors are automatically chosen).
"""
#Decide which colormap to use
colormap_binary = colors.ListedColormap(['lightblue', 'blue'])
colormap_red = _truncate_colormap(plt.cm.Reds, minval=0.3, maxval=0.7)
colormap_blue = _truncate_colormap(plt.cm.Blues_r, minval=0.3, maxval=0.7)
colormap_divergent = _truncate_colormap(plt.cm.bwr, minval=0.1, maxval=0.9)
colormap_discrete = _truncate_colormap(plt.cm.jet, minval=0.3, maxval=0.7)
colormap_custom = copy.copy(matplotlib.cm.get_cmap(plt_cmap))
#First, convert values to bool if possible
values = _convert_boolean(values)
#Check if values are strings
if sum([isinstance(s, str) for s in values]) > 0: #values are strings, cmap should be discrete
cmap = colormap_discrete if plt_cmap is None else colormap_custom
cmap.set_bad(color="grey")
values_unique = list(set(values))
floats = np.linspace(0,1,len(values_unique))
name2val = dict(zip(values_unique, floats)) #map strings to cmap values
color_func = lambda string: _rgb_to_hex(cmap(name2val[string], bytes=True)[:3])
typ = "string"
#Check if values are boolean
elif sum([isinstance(s, bool) for s in values]) > 0:
cmap = colormap_binary if plt_cmap is None else colormap_custom
cmap.set_bad(color="grey") #color for NaN
color_func = lambda value: _rgb_to_hex(cmap(int(value), bytes=True)[:3])
typ = "bool"
#Values are int/float
else:
#Check if values contain NaN
clean_values = np.array(values)[~np.isnan(values)]
#Get min and max
vmin, vmax = np.min(clean_values), np.max(clean_values)
if plt_cmap != None: #plt_cmap is given explicitly
cmap = colormap_custom
elif vmin >= 0 and vmax >= 0:
cmap = colormap_red
elif vmin < 0 and vmax <= 0:
cmap = colormap_blue
elif vmin < 0 and vmax >= 0:
cmap = colormap_divergent
max_abs = max([abs(vmin), abs(vmax)])
vmin = -max_abs #make sure that convergent maps are centered at 0
vmax = max_abs
#Normalize values and create cmap
norm_func = plt.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm_func)
cmap = sm.get_cmap()
cmap.set_bad(color="grey") #set color for np.nan
color_func = lambda value: _rgb_to_hex(cmap(norm_func(value), bytes=True)[:3])
typ = "continuous"
return((typ, color_func))
def _isnan(num):
return num != num
def _convert_boolean(values):
""" Converts a list of boolean/string/nan values into boolean values - but only if all values could be converted """
#Convert any boolean values
bool_vals = ["y", "yes", "t", "true", "on", "1", "n", "no", "f", "false", "off", "0"]
converted = [bool(distutils.util.strtobool(val)) if (isinstance(val, str) and (val.lower() in bool_vals)) else val for val in values]
#Check if clean values contain only bool
clean = [val for val in converted if _isnan(val) == False]
n_bool = sum([isinstance(val, bool) for val in clean])
if n_bool == len(clean):
return(converted) #all values could be converted - these are boolean values
else:
return(values) #not all values could be converted - these are not boolean values
def _get_html_colormap(colormap, vmin, vmax, n):
""" Function to create a text-colormap for network legend """
html = ""
steps = np.linspace(vmin, vmax, n)
for step in steps:
color = colormap(step)
html += f'<FONT COLOR="{color}">█</FONT>'
return(html)
[docs]def network(network,
color_node_by=None,
color_edge_by=None,
size_node_by=None,
size_edge_by=None,
engine="sfdp",
size="8,8",
min_edge_size=2,
max_edge_size=8,
min_node_size=14,
max_node_size=20,
legend_size='auto',
node_border=False,
node_cmap=None,
edge_cmap=None,
node_attributes={},
save=None,
verbosity=1,
):
"""
Plot network of a networkx object using Graphviz for python.
Parameters
-----------
network : networkx.Graph
A networkx Graph/DiGraph object containing the network to plot.
color_node_by : str, optional
The name of a node attribute
color_edge_by : str, optional
The name of an edge attribute
size_node_by : str, optional
The name of a node attribute
size_edge_by : str, optional
The name of an edge attribute
engine : str, optional
The graphviz engine to use for finding network layout. Default: "sfdp".
size : str, optional
Size of the output figure. Default: "8,8".
min_edge_size : float, optional
Default: 2.
max_edge_size : float, optional
Default: 8.
min_node_size : float, optional
Default: 14.
max_node_size : float, optional
Default: 20.
legend_size : int, optional
Fontsize for legend explaining color_node_by/color_edge_by/size_node_by/size_edge_by. Set to 0 to hide legend. Default: 'auto'.
node_border : bool, optional
Whether to plot border on nodes. Can be useful if the node colors are very light. Default: False.
node_cmap : str, optional
Name of colormap for node coloring. Default: None (colors are automatically chosen).
edge_cmap : str, optional
Name of colormap for edge coloring. Default: None (colors are automatically chosen).
node_attributes : dict, optional
Additional node attributes to apply to graph. Default: No additional attributes.
save : str, optional
Path to save network figure to. Format is inferred from the filename - if not valid, the default format is '.pdf'.
verbosity : int, optional
verbosity of the logging. Default: 1.
Raises
-------
TypeError
If network is not a networkx.Graph object
InputError
If any of 'color_node_by', 'color_edge_by' or 'size_edge_by' is not in node/edge attributes, or if 'engine' is not a valid graphviz engine.
"""
# Setup logger
logger = TFcombLogger(verbosity)
############### Test input ###############
check_type(network, [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph])
if legend_size is not 'auto':
check_value(legend_size, vmin=0, integer=True, name="legend_size")
#Read nodes and edges from graph and check attributes
node_view = network.nodes(data=True)
edge_view = network.edges(data=True)
node_attributes_list = list(list(node_view)[0][-1].keys())
edge_attributes_list = list(list(edge_view)[0][-1].keys())
for att in [color_node_by, size_node_by]:
if (att is not None) and (att not in node_attributes_list):
raise InputError("Attribute '{0}' is not available in the network node attributes. Available attributes are: {1}".format(att, node_attributes_list))
for att in [color_edge_by, size_edge_by]:
if (att is not None) and (att not in edge_attributes_list):
raise InputError("Attribute '{0}' is not available in the network edge attributes. Available attributes are: {1}".format(att, edge_attributes_list))
#Check if engine is within graphviz
if engine not in graphviz.ENGINES:
raise ValueError("The given engine '{0}' is not in graphviz available engines: {1}".format(engine, graphviz.ENGINES))
# Check number of edges
if len(network.edges) > 10000:
logger.warning(f"Detected more than 10.000 edges ({len(network.edges)}). This can result in issues when using jupyter.")
#todo: check size with re
############### Initialize graph ###############
#Establish if network is directional
if not nx.is_directed(network):
dot = graphviz.Graph(engine=engine)
else:
dot = graphviz.Digraph(engine=engine)
dot.attr(size=size)
dot.attr(outputorder="edgesfirst")
dot.attr(overlap="false")
############ Setup colormaps/sizemaps ############
map_value = {}
map_type = {}
#Node color
if color_node_by != None:
all_values = [node[-1][color_node_by] for node in node_view]
typ, cmap = _values_to_cmap(all_values, node_cmap)
map_type["node_color"] = typ
map_value["node_color"] = cmap
#Edge color
if color_edge_by != None:
all_values = [edge[-1][color_edge_by] for edge in edge_view]
typ, cmap = _values_to_cmap(all_values, edge_cmap)
map_type["edge_color"] = typ
map_value["edge_color"] = cmap
#Node size
if size_node_by != None: #must be continuous
all_values = [node[-1][size_node_by] for node in node_view]
nmin, nmax = np.min(all_values), np.max(all_values)
map_value["node_size"] = lambda value: np.round((value-nmin)/(nmax-nmin)*(max_node_size-min_node_size)+min_node_size, 2)
#Edge size
if size_edge_by != None: #must be continuous
all_values = [edge[-1][size_edge_by] for edge in edge_view]
vmin, vmax = np.min(all_values), np.max(all_values)
map_value["edge_size"] = lambda value: np.round((value-vmin)/(vmax-vmin)*(max_edge_size-min_edge_size)+min_edge_size, 2)
############### Add nodes to network ##############
logger.debug("Adding nodes to dot network")
for node in node_view:
node_name = node[0]
node_att = node[1]
attributes = {} #attributes for dot
attributes["style"] = "filled"
attributes["width"] = "0.25" #minimum width; will expand to fit label
attributes["height"] = "0.25" #minimum width; will expand to fit label
attributes["fixedsize"] = "false" #automatically adjust node sizes
attributes["fillcolor"] = "lightgrey"
#Set color of node border (default: black)
if node_border == False:
attributes["color"] = "none"
#Set node color
if color_node_by != None:
value = node_att[color_node_by]
attributes["fillcolor"] = map_value["node_color"](value)
#Adjust label color based on darkness of fill
R, G, B = matplotlib.colors.to_rgb(attributes["fillcolor"]) #from hex to rgb
luminance = (0.2126*R + 0.7152*G + 0.0722*B)
if luminance < 0.5: #if fill is dark, the font should be white
attributes["fontcolor"] = "white"
#Set node size
if size_node_by != None:
value = node_att[size_node_by]
attributes["fontsize"] = str(map_value["node_size"](value))
#Apply any additional attributes
for key in node_attributes:
attributes[key] = str(node_attributes[key])
#After collecting all attributes; add node with attribute dict
logger.spam("Adding node {0}".format(node_name))
dot.node(node_name, _attributes=attributes)
############### Add edges to network ###############
logger.debug("Adding edges to dot network")
for edge in edge_view:
node1, node2 = edge[:2]
edge_att = edge[-1]
attributes = {}
attributes["penwidth"] = str(min_edge_size) #default size; can be overwritten by size_edge_by
#Set edge color
if color_edge_by != None:
value = edge_att[color_edge_by]
attributes["color"] = map_value["edge_color"](value)
#Set edge size
if size_edge_by != None:
value = edge_att[size_edge_by]
attributes["penwidth"] = str(map_value["edge_size"](value))
#After collecting all edge attributes; add edge to dot object
dot.edge(node1, node2, _attributes=attributes)
#Plot legend to dot object
if legend_size == 'auto':
n_nodes = len(node_view)
legend_size = int(10 + n_nodes*0.1) #incrementally increasing size
logger.debug("legend_size is estimated at: {0}".format(legend_size))
if legend_size > 0:
h = int(legend_size/3)
spacer = f'<TR><TD HEIGHT="{h}"></TD></TR>' #spacer between rows
#Start building legend
html_legend = f'<<FONT POINT-SIZE="{legend_size}" FACE="ARIAL">'
#TODO: add a bit of space between nodes and legend position
html_legend += '<TABLE ALIGN="LEFT" BORDER="1" CELLBORDER="0" CELLSPACING="0" VALIGN="MIDDLE">'
html_legend += spacer
if color_node_by is not None:
html_legend += f'<TR><TD ALIGN="LEFT" > <b>Nodes colored by:</b> </TD><TD ALIGN="LEFT"> {color_node_by} </TD>'
#Whether to create colormap
if map_type["node_color"] == "continuous":
all_values = [node[-1][color_node_by] for node in node_view]
min_val = round(np.min(all_values), 2)
max_val = round(np.max(all_values), 2)
html_colormap = _get_html_colormap(map_value["node_color"], min_val, max_val, 10)
html_legend += f'<TD ALIGN="RIGHT"><i>{min_val}</i> </TD><TD>{html_colormap}</TD><TD ALIGN="left"><i>{max_val}</i> </TD>'
html_legend += '</TR>' + spacer
if color_edge_by is not None:
html_legend += f'<TR><TD ALIGN="LEFT"> <b>Edges colored by:</b> </TD><TD ALIGN="LEFT"> {color_edge_by} </TD>'
#Whether to create colormap
if map_type["edge_color"] == "continuous":
all_values = [edge[-1][color_edge_by] for edge in edge_view]
min_val = round(np.min(all_values), 2)
max_val = round(np.max(all_values), 2)
html_colormap = _get_html_colormap(map_value["edge_color"], min_val, max_val, 10)
html_legend += f'<TD ALIGN="RIGHT"><i>{min_val}</i> </TD><TD>{html_colormap}</TD><TD ALIGN="left"><i>{max_val}</i> </TD>'
html_legend += "</TR>" + spacer
if size_node_by is not None:
html_legend += f'<TR><TD ALIGN="LEFT"> <b>Nodes sized by:</b> </TD><TD ALIGN="LEFT"> {size_node_by} </TD>'
all_values = [node[-1][size_node_by] for node in node_view]
min_val = round(np.min(all_values), 2)
max_val = round(np.max(all_values), 2)
html_legend += f'<TD></TD><TD ALIGN="CENTER"> <i>{min_val} </i> ● ⬤<i> {max_val}</i> </TD><TD></TD>'
html_legend += '</TR>' + spacer
if size_edge_by is not None:
html_legend += f'<TR><TD ALIGN="LEFT"> <b>Edges sized by:</b> </TD><TD ALIGN="LEFT"> {size_edge_by} </TD>'
all_values = [edge[-1][size_edge_by] for edge in edge_view]
min_val = round(np.min(all_values), 2)
max_val = round(np.max(all_values), 2)
html_legend += f'<TD></TD><TD ALIGN="CENTER"> <i>{min_val}</i> ◄ <i>{max_val}</i> </TD><TD></TD>'
html_legend += '</TR>' + spacer
#Finalize legend
html_legend += '</TABLE>'
html_legend += '</FONT>>'
#Add legend and location to dot obj
dot.attr(label=html_legend)
dot.attr(labelloc="b")
dot.attr(labeljust="r")
############### Save to file ###############
if save != None:
#Set dpi for output render (not for visualized, as this doesn't work with notebook)
dot_render = copy.deepcopy(dot)
#dot_render.attr(dpi="600")
splt = os.path.splitext(save)
file_prefix = "".join(splt[:-1])
fmt = splt[-1].replace(".", "")
if fmt != ".pdf":
dot_render.attr(dpi="600") #for .png's to ensure quality
if fmt not in graphviz.FORMATS:
logger.warning("File ending .{0} is not supported by graphviz/dot. Network will be saved as .pdf.".format(fmt))
fmt = "pdf"
dot_render.render(filename=file_prefix, format=fmt, cleanup=True)
return(dot)
[docs]def genome_view(TFBS,
window_chrom=None,
window_start=None,
window_end=None,
window=None,
fasta=None,
bigwigs=None,
bigwigs_sharey=False,
TFBS_track_height=4,
title=None,
highlight=None,
save=None,
figsize=None,
verbosity=1):
""" Plot TFBS in genome view via the 'DnaFeaturesViewer' package.
Parameters
--------------
TFBS : list
A list of OneTFBS objects or any other object containing .chrom, .start, .end and .name variables.
window_chrom : str, optional if 'window' is given
The chromosome of the window to show.
window_start : int, optional if 'window' is given
The genomic coordinates for the start of the window.
window_end : int, optional if 'window' is given
The genomic coordinates for the end of the window.
window : Object with .chr, .start, .end
If window_chrom/window_start/window_end are not given, window can be given as an object containing .chrom, .start, .end variables
fasta : str, optional
The path to a fasta file containing sequence information to show. Default: None.
bigwigs : str, list or dict of strings, optional
Give the paths to bigwig signals to show within graph. Default: None.
bigwigs_sharey : bool or list, optional
Whether bigwig signals should share y-axis range. If True, all signals will be shared.
It is also possible to give a list of bigwig indices (starting at 0), which should share y-axis values, e.g. [0,1,3] for the 1st, 2nd and 4th bigwig to share signal.
If list of lists, each lists correspond to a grouping, e.g. [[0,2], [1,3]]. Default: False.
TFBS_track_height : float, optional
Relative track height of TFBS. Default: 4.
title : str, optional
Title of plot. Default: None.
highlight : list, optional
A list of OneTFBS objects or any other object containing .chrom, .start, .end and .name variables.
figsize : tuple, optional
The size of the figure. Default: None (8, TFBS_track_height + number of bigwig tracks).
save : str, optional
Save the plot to the file given in 'save'. Default: None.
"""
logger = TFcombLogger(verbosity)
#Test if package is available
if tfcomb.utils.check_module("dna_features_viewer") == True:
from dna_features_viewer import GraphicFeature, GraphicRecord
#----------------- Format input data ----------------#
#Establish which region to show
logger.debug("Subsetting TFBS to window")
if window_chrom != None and window_start != None and window_end != None:
window = tfcomb.utils.OneTFBS([window_chrom, window_start, window_end])
#Subset on windows or take all TFBS?
if window != None:
TFBS = [site for site in TFBS if (site.chrom == window.chrom) and (site.start >= window.start) and (site.end <= window.end)]
else: #show all TFBS
logger.warning("No window was set - showing the first 100 sites in .TFBS")
#Only keep first chromosome of TFBS
chrom = TFBS[0].chrom
TFBS = [site for site in TFBS if site.chrom == chrom]
#Set max amount of TFBS to show
TFBS = TFBS[:100]
#Get min/max of all sites
window_start = np.inf
window_end = -np.inf
for site in TFBS:
if site.end > window_end:
window_end = site.end
if site.start < window_start:
window_start = site.start
window = tobias.utils.regions.OneRegion([chrom, window_start, window_end])
logger.debug(window)
window_length = window.end - window.start
#Establish how many bigwig paths were given
bigwigs = [] if bigwigs == None else bigwigs
bigwigs = [bigwigs] if isinstance(bigwigs, str) else bigwigs
n_bigwig_tracks = len(bigwigs)
#How many subplots to create
n_tracks = 1 + n_bigwig_tracks
#------------ Create plt subplots ------------#
height_ratios = [TFBS_track_height] + [1]*n_bigwig_tracks
if figsize is None:
figsize = (8, TFBS_track_height+n_bigwig_tracks)
else:
# Check that figsize is a tuple of length 2
if not isinstance(figsize, tuple) or len(figsize) != 2:
raise ValueError("figsize must be a tuple of length 2")
fig, axes = plt.subplots(n_tracks, 1,
sharex=True,
figsize=figsize,
constrained_layout=True,
gridspec_kw={"height_ratios": height_ratios}
)
axes = [axes] if not isinstance(axes, np.ndarray) else axes #for n_tracks == 1
#------------ Add TFBS features to plot ------------#
## Add features from TFBS list
features = []
colors_used = {}
if len(TFBS) == 0:
logger.warning("No TFBS to show within the given window.")
for site in TFBS:
strand_convert = {"+":1, "-":-1}
label = site.name
#Get unique color for this TFBS
#Add feature
feature = GraphicFeature(start=site.start, end=site.end, strand=strand_convert.get(site.strand, None), label=label)
features.append(feature)
#Add sequence track
if fasta is not None:
#Pull sequence from fasta file
genome_obj = tfcomb.utils.open_genome(fasta)
sequence = genome_obj.fetch(window.chrom, window.start, window.end)
else:
sequence = None
record = GraphicRecord(first_index=window.start,
sequence_length=window_length,
sequence=sequence,
features=features,
labels_spacing=20
)
with_ruler = True if n_bigwig_tracks == 0 else False
record.plot(ax=axes[0], with_ruler=with_ruler)
plt.xticks(rotation=45, ha="right", color="grey")
#Plot sequence
if sequence is not None:
record.plot_sequence(axes[0], y_offset=1)
#------------ Add additional features -----------#
#Add bigwig track(s)
if bigwigs is not None:
for i, bigwig_f in enumerate(bigwigs):
#Open pybw and pull values
pybw = tfcomb.utils.open_bigwig(bigwig_f)
signal = tobias.utils.regions.OneRegion.get_signal(window, pybw)
#Add signal to plot
bigwig_name = os.path.splitext(os.path.basename(bigwig_f))[0]
xvals = np.arange(window.start, window.end)
axes[i+1].fill_between(xvals, signal, step="mid")
#axes[i+1].step(xvals, signal, where="mid")
axes[i+1].set_ylabel(bigwig_name, rotation=0, ha="right")
axes[i+1].yaxis.tick_right()
ymin = np.min(signal)
ymax = np.max(signal)
pad = (ymax - ymin)*0.2
axes[i+1].set_ylim(ymin-pad, ymax+pad)
#Set spine color to grey
axes[i+1].tick_params(color='grey', labelcolor='grey')
#plt.setp(axes[i+1].spines.values(), color="grey")
plt.setp([axes[i+1].get_xticklines(), axes[i+1].get_yticklines()], color="grey")
#Whether to share y across all bigwig tracks
if bigwigs_sharey != False:
#Establish which groups should share y-axis
if bigwigs_sharey == True: #share y across all bigwig tracks
grouping = [list(range(len(bigwigs)))]
elif not isinstance(bigwigs_sharey[0], list):
grouping = [bigwigs_sharey] #list of lists - only one group
else:
grouping = bigwigs_sharey #already list of lists
#Set ylim across groups
for group in grouping:
group_ylims = list(zip(*[axes[i+1].get_ylim() for i in group]))
ymin = min(group_ylims[0])
ymax = max(group_ylims[1])
for i in group:
axes[i+1].set_ylim(ymin, ymax)
#Highlight sites given
if highlight != None:
#Get sites within the window
highlight_sites = []
for site in highlight:
if site.chrom == window.chrom:
if max([site.start, site.end]) > window.start and min([site.start, site.end]) < window.end:
highlight_sites.append(site)
#Plot highlight
xlim = axes[0].get_xlim()
for i in range(len(bigwigs)):
pass
#axes[i+1].
if title is not None:
axes[0].set_title(title, y=1.05)
plt.xlabel(window.chrom, color="grey")
#-------------- Done with plot; show/save -----------#
if save is not None:
plt.savefig(save, dpi=600)
return(axes)