import sys
import os
from turtle import end_fill
import pandas as pd
import numpy as np
import math
import copy
from copy import deepcopy
import time
import datetime
import scipy.stats
from scipy.signal import find_peaks
import random
import string
import multiprocessing as mp
import importlib
import pysam
import tfcomb
from tfcomb.logging import TFcombLogger, InputError
from tfcomb.counting import count_co_occurrence
from tobias.utils.regions import OneRegion, RegionList
from tobias.utils.motifs import MotifList
from tobias.utils.signals import fast_rolling_math
import pathlib
import pyBigWig
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation
import seaborn as sns
import tqdm
from IPython import display
import warnings
#----------------- Minimal TFBS class based on the TOBIAS 'OneRegion' class -----------------#
[docs]class OneTFBS(list):
""" Collects location information about one single TFBS """
def __init__(self, lst=[]):
super(OneTFBS, self).__init__(lst)
bed_fmt = ["chrom", "start", "end", "name", "score", "strand"]
#Initialize attributes
for att in bed_fmt:
setattr(self, att, None)
#Set attributes from list
for i in range(min(len(self), len(bed_fmt))):
setattr(self, bed_fmt[i], self[i])
#Overwrite attributes with kwargs
if hasattr(lst, "__dict__"):
kwargs = lst.__dict__
for att, value in kwargs.items():
setattr(self, att, value)
self.get_width = OneRegion.get_width
def __str__(self):
elements = [self.chrom, self.start, self.end, self.name, self.score, self.strand]
elements = [str(element) for element in elements if element is not None]
return("\t".join(elements))
def __repr__(self):
return(self.__str__())
[docs]class TFBSPair():
""" Collects information about a co-occurring pair of TFBS """
def __init__(self, TFBS1, TFBS2, anchor="inner", simplify=False):
"""
Parameters
-----------
TFBS1 : OneTFBS
OneTFBS object for the first TFBS.
TFBS2 : OneTFBS
OneTFBS object for the second TFBS.
anchor : str
Anchor is used to calculate distances. Must be one of "inner", "outer" or "center". Default: "inner".
simplify : boolean
Whether to simplify the orientation to "opposite" (convergent/divergent) and "same" (TF1-TF2/TF2-TF1). Default: False.
"""
self.site1 = TFBS1 #OneTFBS object
self.site2 = TFBS2 #OneTFBS object
#Calculate distance depending on anchor
if anchor == "inner":
distance = self.site2.start - self.site1.end
elif anchor == "outer":
distance = self.site1.start - self.site2.end
elif anchor == "center":
TF1_anchor = (self.site1.start + self.site1.end) / 2
TF2_anchor = (self.site2.start + self.site2.end) / 2
distance = int(TF2_anchor - TF1_anchor)
self.distance = distance
self._set_orientation(simplify=simplify)
def _set_orientation(self, simplify=True):
""" Sets the .orientation attribute of pair dependent on the relative orientation of sites in TFBSPair.
Parameters
-------------
simplify : bool
Whether to simplify the orientation of sites to "opposite" (convergent/divergent) and "same" (TF1-TF2/TF2-TF1). Default: True.
"""
#Calculate orientation scenario
if self.site1.strand == self.site2.strand:
if self.site1.strand == "+": #site2 strand is the same
self.orientation = "TF1-TF2"
elif self.site1.strand == "-":
self.orientation = "TF2-TF1"
elif self.site1.strand == ".":
self.orientation = "NA" #no orientation applicable
else: #Strands are different
#Calculate based on whether site1/site2 is stranded
if self.site1.strand not in ["+","-"]:
if self.site2.strand == "+":
self.orientation = "away"
elif self.site2.strand == "-":
self.orientation = "towards"
else:
self.orientation = "NA"
elif self.site2.strand not in ["+", "-"]:
if self.site1.strand == "+":
self.orientation = "towards"
elif self.site1.strand == "-":
self.orientation = "away"
else:
self.orientation = "NA"
else: #both positions are + or -
if self.site1.strand == "+" and self.site2.strand == "-":
self.orientation = "convergent"
elif self.site1.strand == "-" and self.site2.strand == "+":
self.orientation = "divergent"
#Simplify orientation if chosen
if simplify == True:
translation_dict = {"convergent": "opposite",
"divergent": "opposite",
"TF1-TF2": "same",
"TF2-TF1": "same"}
self.orientation = translation_dict.get(self.orientation, self.orientation) #translate if possible
def __str__(self):
TFBS1 = ",".join([str(getattr(self.site1, col)) for col in ["chrom", "start", "end", "name", "score", "strand"]])
TFBS2 = ",".join([str(getattr(self.site2, col)) for col in ["chrom", "start", "end", "name", "score", "strand"]])
s = f"<TFBSPair | TFBS1: ({TFBS1}) | TFBS2: ({TFBS2}) | distance: {self.distance} | orientation: {self.orientation} >"
return(s)
def __repr__(self):
return(self.__str__())
[docs]class TFBSPairList(list):
""" Class for collecting and analyzing a list of TFBSPair objects """
# init attributes
_bigwig_path = None
_plotting_tables = None
_last_flank = None
_last_align = None
[docs] def as_table(self):
""" Table representation of the pairs in the list """
table = []
for p in self:
attributes = getAllAttr(p)
# add site prefixes/ get attributes for both TFs
site1 = {f"site1_{key}": value for key, value in getAllAttr(attributes.pop("site1")).items()}
site2 = {f"site2_{key}": value for key, value in getAllAttr(attributes.pop("site2")).items()}
attributes = {f"site_{key}": value for key, value in attributes.items()}
attributes.update(site1)
attributes.update(site2)
table.append(attributes)
table = pd.DataFrame(table)
# convert types to best possible
# https://stackoverflow.com/a/65915289
table = table.apply(pd.to_numeric, errors='ignore').convert_dtypes()
#Sort columns
col_order = ["site1_chrom", "site1_start", "site1_end", "site1_name", "site1_score", "site1_strand",
"site2_chrom", "site2_start", "site2_end", "site2_name", "site2_score", "site2_strand"]
order_dict = {col: i for i, col in enumerate(col_order)}
columns = sorted(table.columns, key= lambda x: order_dict.get(x, 10**10))
table = table[columns]
return table
[docs] def write_bed(self, outfile, fmt="bed", merge=False):
"""
Write the locations of (TF1, TF2) pairs to a bed-file.
Parameters
------------
locations : list
The output of get_pair_locations().
outfile : str
The path which the pair locations should be written to.
fmt : str, optional
The format of the output file. Must be one of "bed" or "bedpe". If "bed", the TF1/TF2 sites are written individually (see merge option to merge sites). If "bedpe", the sites are written in BEDPE format. Default: "bed".
merge : bool, optional
If fmt="bed", 'merge' controls how the locations are written out. If True, will be written as one region spanning TF1.start-TF2.end. If False, TF1/TF2 sites are written individually. Default: False.
"""
tfcomb.utils.check_string(fmt, ["bed", "bedpe"], "fmt")
tfcomb.utils.check_type(merge, bool, "merge")
#Open output file
try:
f = open(outfile, "w")
except Exception as e:
raise InputError("Error opening '{0}' for writing. Error message was: {1}".format(outfile, e))
#Write locations to file in format 'fmt'
if fmt == "bed":
if merge == True:
starts = [min([l.site1.start, l.site2.start]) for l in self]
ends = [max([l.site1.end, l.site2.end]) for l in self]
s = "\n".join(["\t".join([l.site1.chrom, str(starts[i]), str(ends[i]), l.site1.name + "-" + l.site2.name, str(l.distance), "."]) for i, l in enumerate(self)]) + "\n"
else:
s = "\n".join(["\t".join([l.site1.chrom, str(l.site1.start), str(l.site1.end), l.site1.name, ".", l.site1.strand]) + "\n" +
"\t".join([l.site2.chrom, str(l.site2.start), str(l.site2.end), l.site2.name, ".", l.site2.strand]) for l in self]) + "\n"
elif fmt == "bedpe":
s = "\n".join(["\t".join([l.site1.chrom, str(l.site1.start), str(l.site1.end),
l.site2.chrom, str(l.site2.start), str(l.site2.end),
l.site1.name + "-" + l.site2.name, str(l.distance), l.site1.strand, l.site2.strand]) for l in self]) + "\n"
f.write(s)
f.close()
# ---------------------- override list inherited functions ----------------------
# This is done so that plotting_tables are reset when changes in list are detected.
[docs] def append(self, element):
self._plotting_tables = None
super().append(element)
[docs] def extend(self, l):
self._plotting_tables = None
super().extend(l)
def __add__(self, x):
return TFBSPairList(super().__add__(x))
[docs] def insert(self, index, object):
self._plotting_tables = None
super().insert(index, object)
[docs] def remove(self, value):
self._plotting_tables = None
super().remove(value)
[docs] def pop(self, index=-1):
self._plotting_table = None
return super().pop(index)
[docs] def clear(self) -> None:
self._plotting_tables = None
super().clear()
# slicing functions
def __getitem__(self, key):
new = super().__getitem__(key)
if isinstance(new, list):
return TFBSPairList(new)
else:
return new
def __setitem__(self, index, value):
self._plotting_tables = None
super().__setitem__(index, value)
# display object function
def __repr__(self):
return(f"TFBSPairList({super().__repr__()})")
# ---------------------- plotting related functions ----------------------
@property
def bigwig_path(self):
"""
Get path to bigwig file.
"""
if not self._bigwig_path:
raise Exception("No path to signal bigwig found. Please set with 'object.bigwig_path = \"path/to/file.bw\"'.")
return self._bigwig_path
@bigwig_path.setter
def bigwig_path(self, path):
"""
Set path to bigwig file. Checks for existence.
"""
if not os.path.exists(path):
raise Exception(f"Could not find file! {path}")
self._plotting_tables = None
self._bigwig_path = path
@property
def plotting_tables(self):
"""
Getter for plotting_tables. Will compute if necessary.
"""
if not self._plotting_tables:
self.comp_plotting_tables()
return self._plotting_tables
[docs] def comp_plotting_tables(self, flank=100, align="center"):
"""
Prepare pair and score tables for plotting.
Parameters
------------
flank : int or tuple, default 100
Window size of TFBSpair. Adds given amount of bases in both directions counted from alignment anchor (see align) between binding sites. Use a tuple of ints to set left and right flank independently.
align : str, default 'center'
Position from which the flanking regions are added. Must be one of 'center', 'left', 'right'.
'center': Midpoint between binding positions (rounded down if uneven).
'left': End of first binding position in pair.
'right': Start of second binding position in pair.
"""
if align not in ["center", "left", "right"]:
raise ValueError(f"Parameter 'align' has to be one of ['center', 'left', 'right']. Got '{align}'.")
# flank tuple
if not isinstance(flank, tuple):
flank = (flank, flank)
# load bigwig file
signal_bigwig = pyBigWig.open(self.bigwig_path)
# get pairs as table & sort by distance
pairs = self.as_table().sort_values(by='site_distance')
scores = []
sorted_pairs = []
for (index, row) in pairs.iterrows():
# switch pair to always start with same TF
if row["site1_name"] < row["site2_name"]:
site1 = ["site1_chrom", "site1_end", "site1_start", "site1_name", "site1_strand"]
site2 = ["site2_chrom", "site2_start", "site2_end", "site2_name", "site2_strand"]
row[site1], row[site2] = row[site2].values, row[site1].values
# get anchor point (center) to which flanks are added
if align == "center":
anchor = row["site1_end"] - row["site_distance"] // 2
elif align == "left":
anchor = row["site1_end"]
elif align == "right":
anchor = row["site2_start"]
# compute window from anchor point
row["window_start"] = anchor + flank[0]
row["window_end"] = anchor - flank[1]
# compute relative positions
row["site1_rel_start"] = row["window_start"] - row["site1_start"]
row["site1_rel_end"] = row["window_start"] - row["site1_end"]
row["site2_rel_start"] = row["window_start"] - row["site2_start"]
row["site2_rel_end"] = row["window_start"] - row["site2_end"]
# fetch scores
values = signal_bigwig.values(row["site1_chrom"], row["window_end"], row["window_start"])
values.reverse()
else:
# get anchor point (center) to which flanks are added
if align == "center":
anchor = row["site1_end"] + row["site_distance"] // 2
elif align == "left":
anchor = row["site1_end"]
elif align == "right":
anchor = row["site2_start"]
# compute window from anchor point
row["window_start"] = anchor - flank[0]
row["window_end"] = anchor + flank[1]
# compute relative positions
row["site1_rel_start"] = row["site1_start"] - row["window_start"]
row["site1_rel_end"] = row["site1_end"] - row["window_start"]
row["site2_rel_start"] = row["site2_start"] - row["window_start"]
row["site2_rel_end"] = row["site2_end"] - row["window_start"]
values = signal_bigwig.values(row["site1_chrom"], row["window_start"], row["window_end"])
sorted_pairs.append(row)
scores.append(values)
sorted_pairs = pd.DataFrame(sorted_pairs).reset_index(drop=True)
scores = pd.DataFrame(scores).fillna(value=0)
#sort by distance, then relative TFBS start
sorted_pairs.sort_values(by=["site_distance", "site1_rel_start"], inplace=True)
scores = scores.loc[sorted_pairs.index]
# warn if the binding site length is not always the same
if (len(set(sorted_pairs["site1_rel_end"] - sorted_pairs["site1_rel_start"])) > 1 or
len(set(sorted_pairs["site2_rel_end"] - sorted_pairs["site2_rel_start"])) > 1):
warnings.warn("Differences in binding site length detected! This can have undesired effects when plotting. Refer to 'CombObj.TFBS_from_motifs(resolve_overlapping)' to solve.")
self._last_flank = flank
self._last_align = align
self._plotting_tables = (sorted_pairs, scores)
[docs] def pairMap(self,
logNorm_cbar=None,
show_binding=True,
flank_plot="strand",
figsize=(7, 14),
output=None,
flank=None,
align=None,
alpha=0.7,
cmap="seismic",
show_diagonal=True,
legend_name_score="Bigwig Score",
xtick_num=10,
log=np.log1p,
dpi=300):
"""
Create a heatmap of TF binding pairs sorted for distance.
Parameters
-----------
logNorm_cbar : str, default None
[None, "centerLogNorm", "SymLogNorm"]
Choose a type of normalization for the colorbar. \n
SymLogNorm:
Use matplotlib.colors.SymLogNorm. This does not center to 0 \n
centerLogNorm:
Use custom matplotlib.colors.SymLogNorm from stackoverflow. Note this creates a weird colorbar.
show_binding : bool, default True
Shows the TF binding positions as a grey background.
flank_plot : str, default 'strand'
["strand", "orientation", None]
Decide if the plots flanking the heatmap should be colored for strand, strand-orientation or disabled.
figsize : int tuple, default (7, 14)
Figure dimensions.
output : str, default None
Save plot to given file.
flank : int or int tuple, default None
Bases added to both sides counted from center. Forwarded to comp_plotting_tables().
align : str, default None
Alignment of pairs. One of ['left', 'right', 'center']. Forwarded to comp_plotting_tables().
alpha : float, default 0.7
Alpha value for diagonal lines, TF binding positions and center line.
cmap : matplotlib colormap name or object, or list of colors, default 'seismic'
Color palette used in the main heatmap. Forwarded to seaborn.heatmap(cmap)
show_diagonal : boolean, default True
Shows diagonal lines for identifying preference in binding distance.
legend_name_score : str, default 'Bigwig Score'
Name of the score legend (upper legend).
xtick_num : int, default 10
Number of ticks shown on the x-axis. Disable ticks with None or values < 0.
log : function, default numpy.log1p
Function applied to each row of scores. Excludes 0 and will use absolute value for negative numbers adding the sign afterwards.
Use any of the numpy.log functions. For example numpy.log, numpy.log10 or numpy.log1p. None to disable.
dpi : float, default 300
The resolution of the figure in dots-per-inch.
Returns
----------
matplotlib.gridspec.GridSpec:
Object containing the finished pairMap.
"""
# check parameter values
if not logNorm_cbar in [None, "centerLogNorm", "SymLogNorm"]:
raise ValueError(f"Parameter 'logNorm_cbar' has to be one of [None, \"centerLogNorm\", \"SymLogNorm\"]. But found {logNorm_cbar}.")
if not flank_plot in ["strand", "orientation", None]:
raise ValueError(f"Parameter 'flank_plot' hat to be one of [\"strand\", \"orientation\", None]. But found {flank_plot}")
# fixes FloatingPointError: underflow encountered in multiply
# https://stackoverflow.com/a/61756043
# should be set by default but for some reason it isn't for matplotlib 3.5.1/ numpy 1.21.5
np.seterr(under='ignore')
fig = plt.figure(figsize=figsize, dpi=dpi)
# compute plotting tables with custom flank
if not flank is None and flank != self._last_flank or not align is None and align != self._last_align:
params = {}
if flank: params["flank"] = flank
if align: params["align"] = align
self.comp_plotting_tables(**params)
# load data
pairs, scores = self.plotting_tables
# log scores
if log:
def log_row(row):
out = np.zeros(len(row))
out[row > 0] = log(row[row > 0])
out[row < 0] = -log(np.abs(row[row < 0]))
return pd.Series(out)
scores = scores.apply(axis=1, func=log_row)
# load custom colorbar normalizaition class
# https://stackoverflow.com/a/65260996
if logNorm_cbar == "centerLogNorm":
class MidpointLogNorm(matplotlib.colors.SymLogNorm):
"""
Normalise the colorbar so that diverging bars work there way either side from a prescribed midpoint value)
e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-100, vmax=100))
All arguments are the same as SymLogNorm, except for midpoint
"""
def __init__(self, lin_thres, lin_scale, midpoint=None, vmin=None, vmax=None):
self.midpoint = midpoint
self.lin_thres = lin_thres
self.lin_scale = lin_scale
#fraction of the cmap that the linear component occupies
self.linear_proportion = (lin_scale / (lin_scale + 1)) * 0.5
matplotlib.colors.SymLogNorm.__init__(self, lin_thres, lin_scale, vmin, vmax)
def __get_value__(self, v, log_val, clip=None):
if v < -self.lin_thres or v > self.lin_thres:
return log_val
x = [-self.lin_thres, self.midpoint, self.lin_thres]
y = [0.5 - self.linear_proportion, 0.5, 0.5 + self.linear_proportion]
interpol = np.interp(v, x, y)
return interpol
def __call__(self, value, clip=None):
log_val = matplotlib.colors.SymLogNorm.__call__(self, value)
out = [0] * len(value)
for i, v in enumerate(value):
out[i] = self.__get_value__(v, log_val[i])
return np.ma.masked_array(out)
########## define grid ##########
# https://matplotlib.org/stable/gallery/subplots_axes_and_figures/gridspec_nested.html
# grid: [legend_area, plot_area]
grid = matplotlib.gridspec.GridSpec(nrows=1,
ncols=2,
figure=fig,
width_ratios=[1, 10],
wspace=0.3)
# plot area
plot_grid = matplotlib.gridspec.GridSpecFromSubplotSpec(ncols=3,
nrows=1,
width_ratios=[0.5, 10, 0.5],
subplot_spec=grid[1],
wspace=0)
heatmap = fig.add_subplot(plot_grid[1])
# add flanks if needed
if flank_plot is not None:
strand1 = fig.add_subplot(plot_grid[0])
strand2 = fig.add_subplot(plot_grid[2])
# legend area
legend_grid = matplotlib.gridspec.GridSpecFromSubplotSpec(ncols=1, nrows=2, subplot_spec=grid[0])
# heatmap legend
heatmap_led = fig.add_subplot(legend_grid[0])
heatmap_led.set_title(legend_name_score)
# strand legend
if flank_plot is not None:
strand_led = fig.add_subplot(legend_grid[1])
strand_led.set_title("TF Binding Strand")
###### define plots ######
### main heatmap ###
# get normalization function
if logNorm_cbar == None:
norm = None
elif logNorm_cbar == "centerLogNorm":
norm = MidpointLogNorm(lin_thres=1, lin_scale=1, midpoint=0)
elif logNorm_cbar == "SymLogNorm":
norm = matplotlib.colors.SymLogNorm(linthresh=1, linscale=1)
# get center line position
if self._last_align == "center":
pos = pairs["site1_rel_end"][0] + pairs["site_distance"][0] // 2
elif self._last_align == "left":
pos = pairs["site1_rel_end"][0]
elif self._last_align == "right":
pos = pairs["site2_rel_start"][0]
plot = sns.heatmap(scores,
yticklabels=False,
xticklabels=False,
cmap=cmap,
center=None if logNorm_cbar else 0,
cbar=True,
cbar_ax=heatmap_led,
norm=norm,
ax=heatmap,
rasterized=True)
# set title
name1, name2 = set(pairs["site1_name"]).pop(), set(pairs["site2_name"]).pop()
heatmap.set_title(f"{name1} <-> {name2}")
# show evenly spaced xticks
if xtick_num and xtick_num > 0:
xtickpositions = np.linspace(0, len(scores.columns), xtick_num, dtype=int)
xticklabels = np.linspace(-pos, len(scores.columns) - pos, xtick_num, dtype=int)
# replace nearest tick to 0 with 0
tick_index = np.absolute(xticklabels).argmin()
xticklabels[tick_index], xtickpositions[tick_index] = 0, pos
plot.set_xticks(xtickpositions)
plot.set_xticklabels(xticklabels)
# center line
plot.vlines(x=pos,
ymin=0, ymax=len(scores),
linestyles="dashed",
color="black",
alpha=alpha,
linewidth=1)
# binding sites
if show_binding:
for y, (_, row) in enumerate(pairs.iterrows()):
# left sites
plot.add_patch(matplotlib.patches.Rectangle(
xy=(row["site1_rel_start"], y),
width=row["site1_rel_end"] - row["site1_rel_start"],
height=1,
alpha=alpha,
color="gray",
edgecolor=None,
lw=0,
rasterized=True
))
# right sites
plot.add_patch(matplotlib.patches.Rectangle(
xy=(row["site2_rel_start"], y),
width=row["site2_rel_end"] - row["site2_rel_start"],
height=1,
alpha=alpha,
color="gray",
edgecolor=None,
lw=0,
rasterized=True
))
# diagonal binding lines
if show_diagonal:
linecolor="black"
linestyle="solid"
# left
# start
plot.axline(
xy1=(pairs["site1_rel_start"].max(), 0),
xy2=(pairs["site1_rel_start"].min(), len(scores)),
color=linecolor,
alpha=alpha,
linestyle=linestyle)
# end
plot.axline(
xy1=(pairs["site1_rel_end"].max(), 0),
xy2=(pairs["site1_rel_end"].min(), len(scores)),
color=linecolor,
alpha=alpha,
linestyle=linestyle)
# right
# start
plot.axline(
xy1=(pairs["site2_rel_start"].min(), 0),
xy2=(pairs["site2_rel_start"].max(), len(scores)),
color=linecolor,
alpha=alpha,
linestyle=linestyle)
# end
plot.axline(
xy1=(pairs["site2_rel_end"].min(), 0),
xy2=(pairs["site2_rel_end"].max(), len(scores)),
color=linecolor,
alpha=alpha,
linestyle=linestyle)
# https://moonbooks.org/Articles/How-to-add-a-frame-to-a-seaborn-heatmap-figure-in-python-/
# make frame visible
for _, spine in plot.spines.items():
spine.set_visible(True)
if flank_plot is not None:
### strand left ###
#https://stackoverflow.com/a/57994641
col = ["site1_strand"] if flank_plot == "strand" else ["site_orientation"]
str_to_int = {j:i for i, j in enumerate(pd.unique(pairs[col[0]].values))}
cmap = sns.color_palette("tab10", len(str_to_int))
plot = sns.heatmap(pairs[col].replace(str_to_int),
yticklabels=False,
xticklabels=False,
cmap=cmap,
cbar=True,
cbar_ax=strand_led,
ax=strand1,
rasterized=True)
colorbar = plot.collections[0].colorbar
r = colorbar.vmax - colorbar.vmin
colorbar.set_ticks([colorbar.vmin + r / len(str_to_int) * (0.5 + i) for i in range(len(str_to_int))])
colorbar.set_ticklabels(list(str_to_int.keys()))
### strand right ###
col = ["site2_strand"] if flank_plot == "strand" else ["site_orientation"]
sns.heatmap(pairs[col].replace(str_to_int),
yticklabels=False,
xticklabels=False,
cmap=cmap,
cbar=False,
ax=strand2,
rasterized=True)
# save plot
if output:
plt.savefig(output)
# show plot
plt.show()
# close figure
plt.close()
return grid
[docs] def pairTrack(self, dist=None, start=None, end=None, ymin=None, ymax=None, ylabel="Bigwig signal", output=None, flank=None, align=None, figsize=(6, 4), dpi=70, _ret_param=False):
"""
Create an aggregated footprint track on the paired binding sites.
Either aggregate all sites for a specific distance or give a range of sites that should be aggregated.
If the second approach spans multiple distances the binding locations are shown as a range as well.
Parameters
----------
dist : int or int list, default None
Show track for one or more distances between binding sites.
start : int, default None
Define start of range of sites that should be aggregated. If set will ignore 'dist'.
end : int, default None
Define end of range of sites that should be aggregated. If set will ignore 'dist'.
ymin : int, default None
Y-axis minimum limit.
ymax : int, default None
Y-axis maximum limit.
ylabel : str, default 'Bigwig signal'
Label for the y-axis.
output : str, default None
Save plot to given file.
flank : int or int tuple, default None
Bases added to both sides counted from center. Forwarded to comp_plotting_tables().
align : str, default None
Alignment of pairs. One of ['left', 'right', 'center']. Forwarded to comp_plotting_tables().
figsize : int tuple, default (3, 3)
Figure dimensions.
dpi : float, default 70
The resolution of the figure in dots-per-inch.
_ret_param : bool, default False
Intended for internal animation use!
If True will cause the function to return a dict of function call parameters used to create plot.
Returns
----------
matplotlib.axes._subplots.AxesSubplot or dict:
Return axes object of the plot.
"""
# parameter check
if dist is None and start is None and end is None:
raise ValueError("Either set dist or start and end parameter!")
# compute plotting tables with custom flank
if not flank is None and flank != self._last_flank or not align is None and align != self._last_align:
params = {}
if flank: params["flank"] = flank
if align: params["align"] = align
self.comp_plotting_tables(**params)
# load data
pairs, scores = self.plotting_tables
# dict holding dict of all function calls to create plot
parameter = dict()
# get scores
if not start is None and not end is None:
if end > len(pairs):
raise ValueError(f"Out of range! Given {start}:{end} valid range is 0:{len(pairs)}")
tmp_pairs = pairs[start:end]
tmp_scores = scores[start:end]
dist = f"{tmp_pairs['site_distance'].min()} - {tmp_pairs['site_distance'].max()} | Selected: {start}:{end}"
else:
tmp_pairs = pairs[pairs["site_distance"].isin(dist if isinstance(dist, list) else [dist])]
tmp_scores = scores.loc[tmp_pairs.index]
# get pair names
# TODO sanity check for len 1
lname = set(tmp_pairs["site1_name"]).pop()
rname = set(tmp_pairs["site2_name"]).pop()
# compute x axis range (0 centered)
if self._last_align == "center":
center = pairs["site1_rel_end"][0] + pairs["site_distance"][0] // 2
elif self._last_align == "left":
center = pairs["site1_rel_end"][0]
elif self._last_align == "right":
center = pairs["site2_rel_start"][0]
xmin, xmax = -self._last_flank[0], self._last_flank[1]
# compute y axis range + 10% padding
points = tmp_scores.mean()
if ymin is None:
ymin = points.min()
ymin += ymin * 0.1
if ymax is None:
ymax = points.max()
ymax += ymax * 0.1
##### plot #####
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
# add aggregated line
parameter["set"] = {"ylim": (ymin, ymax),
"xlim": (xmin, xmax),
"ylabel": ylabel,
"xlabel": "Basepair",
"title": f"Distance: {dist} | Sites aggregated: {len(tmp_pairs)}\n{lname} <--> {rname}"}
ax.set(**parameter["set"])
parameter["plot"] = [list(range(xmin, xmax)), points]
ax.plot(*parameter["plot"])
parameter["vlines"] = {"x": 0,
"ymin": ymin,
"ymax": ymax,
"linestyles": "dashed",
"color": "black",
"alpha": 1,
"linewidth": 1}
# add center line
ax.vlines(**parameter["vlines"])
# add binding site locations
start = tmp_pairs["site1_rel_start"].min() - center
end = tmp_pairs["site1_rel_end"].max() - center
parameter["patches.Rectangle"] = [{"xy": (start, ymin),
"width": end-start,
"height": np.abs(ymin-ymax),
"color": 'tab:red',
"alpha": 0.2}]
ax.add_patch(matplotlib.patches.Rectangle(**parameter["patches.Rectangle"][0]))
start = tmp_pairs["site2_rel_start"].min() - center
end = tmp_pairs["site2_rel_end"].max() - center
parameter["patches.Rectangle"].append({"xy": (start, ymin),
"width": end-start,
"height": np.abs(ymin-ymax),
"color": 'tab:green',
"alpha": 0.2})
ax.add_patch(matplotlib.patches.Rectangle(**parameter["patches.Rectangle"][1]))
# save plot
if output:
plt.savefig(output)
# show plot if not in parameter mode
if not _ret_param:
plt.show()
# close figure
plt.close()
if _ret_param:
return parameter
else:
return ax
[docs] def pairTrackAnimation(self, site_num=None, step=10, ymin=None, ymax=None, ylabel="Bigwig signal", interval=50, repeat_delay=0, repeat=False, output=None, flank=None, align=None, figsize=(6, 4), dpi=70):
"""
Combine a set of pairTrack plots to a .gif.
Note
--------
The memory limit can be increased with the following if necessary. Default is 20 MB.
matplotlib.rcParams['animation.embed_limit'] = 100 # in MB
Parameters
----------
site_num : int, default None
Number of sites to aggregate for every step. If None will aggregate by distance between binding pair.
step : int, default None
Step size between aggregations. Will be ignored if site_num=None.
ymin : int, default None
Y-axis minimum limit
ymax : int, default None
Y-axis maximum limit
ylabel : str, default 'Bigwig signal'
Label for the y-axis.
interval : int, default 50
Delay between frames in milliseconds
repeat_delay : int, default 0
The delay in milliseconds between consecutive animation runs, if repeat is True.
repeat : boolean, default False
Whether the animation repeats when the sequence of frames is completed.
output : str, default None
Save plot to given file.
flank : int or int tuple, default None
Bases added to both sides counted from center. Forwarded to comp_plotting_tables().
align : str, default None
Alignment of pairs. One of ['left', 'right', 'center']. Forwarded to comp_plotting_tables().
figsize : int tuple, default (6, 4)
Figure dimensions.
dpi : float, default 70
The resolution of the figure in dots-per-inch.
Returns
----------
IPython.core.display.HTML:
Gif object ready to display in a jupyter notebook.
"""
# compute plotting tables with custom flank
if not flank is None and flank != self._last_flank or not align is None and align != self._last_align:
params = {}
if flank: params["flank"] = flank
if align: params["align"] = align
self.comp_plotting_tables(**params)
# load data
pairs, scores = self.plotting_tables
# prepare plots for animation
parameter_list = list()
if site_num:
# compute animation y-range + 10% padding
if ymin is None:
ymin = np.min([scores[s:s + site_num if s + site_num < len(pairs) else len(pairs)].mean().min() for s in range(0, len(pairs), step)])
ymin += ymin * 0.1
if ymax is None:
ymax = np.max([scores[s:s + site_num if s + site_num < len(pairs) else len(pairs)].mean().max() for s in range(0, len(pairs), step)])
ymax += ymax * 0.1
pbar = tqdm.tqdm(total=len(range(0, len(pairs), step))+1)
pbar.set_description("Create frames")
for start in range(0, len(pairs), step):
pbar.update()
parameter_list.append(self.pairTrack(start=start,
end=start + site_num if start + site_num < len(pairs) else len(pairs),
ymin=ymin,
ymax=ymax,
ylabel=ylabel,
_ret_param=True
)
)
else:
# compute animation y-range + 10% padding
if ymin is None:
ymin = np.min([scores.loc[pairs[pairs["site_distance"] == d].index].mean().min() for d in set(pairs["site_distance"])])
ymin += ymin * 0.1
if ymax is None:
ymax = np.max([scores.loc[pairs[pairs["site_distance"] == d].index].mean().max() for d in set(pairs["site_distance"])])
ymax += ymax * 0.1
pbar = tqdm.tqdm(total=len(set(pairs["site_distance"]))+1)
pbar.set_description("Create frames")
for d in set(pairs["site_distance"]):
pbar.update()
parameter_list.append(self.pairTrack(dist=d,
ymin=ymin,
ymax=ymax,
ylabel=ylabel,
_ret_param=True
)
)
# update to be at 100% then close progress bar
pbar.update()
pbar.close()
##### Setup animation #####
# setup figure to draw on
fig, axes = plt.subplots(figsize=figsize, dpi=dpi)
line, = axes.plot([])
axes.add_patch(matplotlib.patches.Rectangle(xy=(0, 0), width=0, height=0))
axes.add_patch(matplotlib.patches.Rectangle(xy=(0, 0), width=0, height=0))
# progress bar
pbar = tqdm.tqdm(total=len(parameter_list)+1)
pbar.set_description("Render animation")
# create animation function (this decides what is drawn in every frame)
def animate(i):
"""
Draw all elements for a given frame index (i).
"""
# update progress bar
pbar.update()
calls = parameter_list[i]
for key, params in calls.items():
if key == "plot":
line.set_xdata(params[0])
line.set_ydata(params[1])
elif key == "patches.Rectangle":
# remove old patches
# https://stackoverflow.com/a/62591596
axes.patches.clear()
# add new patches
for p in params:
axes.add_patch(matplotlib.patches.Rectangle(**p))
elif isinstance(params, dict):
eval(f"axes.{key}(**params)")
##### Run animation #####
anim_created = matplotlib.animation.FuncAnimation(fig,
animate,
frames=len(parameter_list),
interval=interval,
repeat_delay=repeat_delay,
repeat=repeat
)
# save animation
if output:
anim_created.save(output, dpi=dpi)
pbar.reset()
# prepare output
video = anim_created.to_jshtml()#to_html5_video()
html = display.HTML(video)
# close figure & progress bar
plt.close()
pbar.close()
return html
[docs] def pairLines(self, x, y, figsize=(6, 4), dpi=70, output=None):
"""
Compare miscellaneous values between TF-pair.
Parameters
----------
x : string
Data to show on the x-axis. Set None to get a list of options.
y : string
Data to show on the y-axis. Set None to get a list of options.
figsize : int tuple, default (6, 4)
Figure dimensions.
dpi : float, default 70
The resolution of the figure in dots-per-inch.
output : str, default None
Save plot to given file.
Returns
----------
matplotlib.axes._subplots.AxesSubplot:
Return axes object of the plot.
"""
# TODO expose as parameter
# would need checks for datatype!
hue="name"
table = self.as_table()
# print x, y options
if not x or not y:
print(f"x, y options: {set(sl[1] for sl in table.columns.str.split('_', n=1).to_list())}")
return
# sort table to always start with the same TF
if len(set(table["site1_name"])) > 1:
tf = table["site1_name"][0]
rf = table["site1_name"] == tf
site1 = table.columns[table.columns.str.startswith("site1")]
site2 = table.columns[table.columns.str.startswith("site2")]
table.loc[rf, site1.append(site2)] = table.loc[rf, site2.append(site1)].values
# fetch column names
x_names = list(table.columns[table.columns.str.endswith("_" + x)]) * 2
y_names = list(table.columns[table.columns.str.endswith("_" + y)]) * 2
hue_names = list(table.columns[table.columns.str.endswith("_" + hue)]) * 2
tmp_postfix = True if len(set(hue_names)) >= 2 else False
if not x_names:
raise Exception(f"Could not find x='{x}'. Available are {set(sl[1] for sl in table.columns.str.split('_', n=1).to_list())}.")
if not y_names:
raise Exception(f"Could not find y='{y}'. Available are {set(sl[1] for sl in table.columns.str.split('_', n=1).to_list())}.")
if not hue_names:
raise Exception(f"Could not find hue='{hue}'. Available are {set(sl[1] for sl in table.columns.str.split('_', n=1).to_list())}.")
# collect data
x1, x2 = table[x_names[0]], table[x_names[1]]
y1, y2 = table[y_names[0]], table[y_names[1]]
hue1, hue2 = table[hue_names[0]], table[hue_names[1]]
# in case of both names being equal set postfix
if tmp_postfix:
hue1, hue2 = hue1 + "_1", hue2 + "_2"
##### plotting #####
plt.figure(figsize=figsize, dpi=dpi)
plot = sns.lineplot(x=x1.append(x2).values,
y=y1.append(y2).values,
hue=hue1.append(hue2).values)
# rotate x-ticks for non numeric data
if not pd.api.types.is_numeric_dtype(x1.append(x2)):
plt.xticks(rotation=90)
# remove postfix in legend
if tmp_postfix:
handles, labels = plot.get_legend_handles_labels()
plot.legend(handles=handles, labels=[l[:-2] for l in labels])
# set axis labels
plot.set(ylabel=y, xlabel=x)
# save plot
if output:
plt.savefig(output)
# show then close figure
plt.show()
plt.close()
return plot
[docs] def set_orientation(self, simplify=False):
""" Fill orientation of each TF pair """
for pair in self:
pair._set_orientation(simplify=simplify)
[docs] def plot_distances(self, groupby="orientation", figsize=None, group_order=None):
""" Plot the distribution of distances between TFBS-pairs.
Parameters
-----------
groupby : str
An attribute of each pair to group distances by. If None, all distances are shown without grouping. Default: "orientation".
figsize : tuple of ints
Set the figure size, e.g. (8,10). Default: None (default matplotlib figuresize).
"""
#get all possible values in groupby
#group_values = set([pair.getattr("orientation") for pair in self])
#Collect distances per group
distances = {}
if groupby is not None:
for pair in self:
pair_group = getattr(pair, groupby)
distances[pair_group] = distances.get(pair_group, []) + [pair.distance]
else:
distances["all"] = [pair.distance for pair in self]
#Setup figure
fig, axarr = plt.subplots(len(distances), sharex=True, sharey=True, figsize=figsize)
#adjust for one-group plotting
if len(distances) == 1:
axarr = [axarr] #make axarr subscriptable
#Plot distances per group
if group_order == None:
group_order = list(distances.keys())
else:
#todo: check that groups are in distances
pass
for i, group in enumerate(group_order):
lst = distances[group]
axarr[i].hist(lst, bins=int(max(lst)) + 1)
axarr[i].set_ylabel("Count")
axarr[i].text(1, 0.5, f"({group})",
horizontalalignment='left',
verticalalignment='center',
transform=axarr[i].transAxes)
#Make final adjustments
for ax in axarr:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
_ = axarr[-1].set_xlabel("Distance")
return axarr
#------------------------------ Notebook / script exceptions -----------------------------#
def _is_notebook():
""" Utility to check if function is being run from a notebook or a script """
try:
ipython_shell = get_ipython()
return(True)
except NameError:
return(False)
[docs]class StopExecution(Exception):
""" Stop execution of a notebook cell with error message"""
def _render_traceback_(self):
etype, msg, _ = sys.exc_info()
sys.stderr.write("{1}".format(etype.__name__, msg))
#sys.stderr.write(f"{msg}")
[docs]def check_module(module):
""" Check if <module> can be imported without error """
error = 0
try:
importlib.import_module(module)
except ModuleNotFoundError:
error = 1
except:
raise #unexpected error loading module
#Write out error if module was not found
if error == 1:
s = f"ERROR: Could not find the '{module}' module on path. This module is needed for this functionality. Please install this package to proceed."
if _is_notebook():
raise StopExecution(s) from None
else:
sys.exit(s)
return(True)
#--------------------------------- File/type checks ---------------------------------#
[docs]def check_columns(df, columns):
""" Utility to check whether columns are found within a pandas dataframe.
Parameters
------------
df : pandas.DataFrame
A pandas dataframe to check.
columns : list
A list of column names to check for within 'df'.
Raises
--------
InputError
If any of the columns are not in 'df'.
"""
df_columns = df.columns
not_found = []
for column in columns:
if column is not None:
if column not in df_columns:
not_found.append(column)
if len(not_found) > 0:
error_str = "Columns '{0}' are not found in dataframe. Available columns are: {1}".format(not_found, df_columns)
raise InputError(error_str)
[docs]def check_dir(dir_path, create=True):
""" Check if a dir is writeable.
Parameters
------------
dir_path : str
A path to a directory.
Raises
--------
InputError
If dir_path is not writeable.
"""
#Check if dir already exists
if dir_path is not None: #don't check path given as None; assume that this is taken care of elsewhere
if os.path.exists(dir_path):
if not os.path.isdir(dir_path): # is it a file or a dir?
raise InputError("Path '{0}' is not a directory".format(dir_path))
#check writeability of parent dir
else:
if create:
pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
[docs]def check_writeability(file_path):
""" Check if a file is writeable.
Parameters
------------
file_path : str
A path to a file.
Raises
--------
InputError
If file_path is not writeable.
"""
check_type(file_path, str)
#Check if file already exists
error_str = None
if os.path.exists(file_path):
if not os.path.isfile(file_path): # is it a file or a dir?
error_str = "Path '{0}' is not a file".format(file_path)
#check writeability of parent dir
else:
pdir = os.path.dirname(file_path)
pdir = "." if pdir == "" else pdir #if file_path is in current folder, pdir is empty string
if os.access(pdir, os.W_OK) == False:
error_str = "Parent directory '{0}' is not writeable".format(pdir)
#If any errors were found
if error_str is not None:
raise InputError(error_str)
[docs]def check_type(obj, allowed, name=None):
"""
Check whether given object is within a list of allowed types.
Parameters
------------
obj : object
Object to check type on
allowed : type or list of types
A type or a list of object types to be allowed
name : str, optional
Name of object to be written in error. Default: None (the input is referred to as 'object')
Raises
--------
InputError
If object type is not within types.
"""
#Convert allowed to list
if not isinstance(allowed, list):
allowed = [allowed]
#Check if any of the types fit
flag = 0
for t in allowed:
if isinstance(obj, t):
flag = 1
#Raise error if none of the types fit
if flag == 0:
name = "object" if name is None else f'\'{name}\''
raise InputError("The {0} given has type '{1}', but must be one of: {2}".format(name, type(obj), allowed))
[docs]def check_string(astring, allowed, name=None):
"""
Check whether given string is within a list of allowed strings.
Parameters
------------
astring : str
A string to check.
allowed : str or list of strings
A string or list of allowed strings to check against 'astring'.
name : str, optional
The name of the string to be written in error. Default: None (the value is referred to as 'string').
Raises
--------
InputError
If 'astring' is not in 'allowed'.
"""
#Convert allowed to list
if not isinstance(allowed, list):
allowed = [allowed]
#Check if astring is within allowed
if astring not in allowed:
name = "string" if name is None else f'\'{name}\''
raise InputError("The {0} given ({1}) is not valid - it must be one of: {2}".format(name, astring, allowed))
[docs]def check_value(value, vmin=-math.inf, vmax=math.inf, integer=False, name=None):
"""
Check whether given 'value' is a valid value (or integer) and if it is within the bounds of vmin/vmax.
Parameters
-------------
value : int or float
The value to check.
vmin : int or float, optional
Minimum the value is allowed to be. Default: -infinity (no bound)
vmax : int or float
Maxmum the value is allowed to be. Default: +infinity (no bound)
integer : bool, optional
Whether value must be an integer. Default: False (value can be float)
name : str, optional
The name of the value to be written in error. Default: None (the value is referred to as 'value').
Raises
--------
InputError
If 'value' is not a valid value as given by parameters.
"""
if vmin > vmax:
raise InputError("vmin must be smaller than vmax")
error_msg = None
if integer == True:
if not isinstance(value, int):
error_msg = "The {0} given ({1}) is not an integer, but integer is set to True.".format(name, value)
else:
#check if value is any value
try:
_ = int(value)
except:
error_msg = "The {0} given ({1}) is not a valid number".format(name, value)
#If value is a number, check if it is within bounds
if error_msg is None:
if not ((value >= vmin) & (value <= vmax)):
error_msg = "The {0} given ({1}) is not within the bounds of [{2};{3}]".format(name, value, vmin, vmax)
#Finally, raise error if necessary:
if error_msg is not None:
raise InputError(error_msg)
[docs]def random_string(l=8):
""" Get a random string of length l """
s = ''.join(random.choice(string.ascii_uppercase) for _ in range(l))
return(s)
#--------------------------------- Multiprocessing ---------------------------------#
[docs]class Progress():
def __init__(self, n_total, n_print, logger):
"""
Utility class to monitor progress of a list of tasks.
Parameters
-----------
n_total : int
Number of total jobs
n_print : int
Number of times to write progress
logger : logger instance
The logger to use for writing progress
"""
self.n_total = n_total
self.n_print = n_print
self.logger = logger
#At what number of tasks should the updates be written?
n_step = int(n_total / (n_print))
self.progress_steps = [n_step*(i+1) for i in range(n_print)]
self.next = self.progress_steps[0] #first limit in progress_steps to write
[docs] def write_progress(self, n_done):
""" Log the progress of the current tasks.
Parameters
-----------
n_done : int
Number of tasks done (of n_total tasks)
"""
if n_done >= self.next:
self.logger.info("Progress: {0:.0f}%".format(n_done/self.n_total*100.0))
#in case more than one progress step was jumped
remaining_steps = [step for step in self.progress_steps if n_done < step] + [np.inf]
self.next = remaining_steps[0] #this is the next idx to write (or np.inf if end of list was reached)
[docs]def log_progress(jobs, logger, n=10):
"""
Log progress of jobs within job list.
Parameters
------------
jobs : list
List of multiprocessing jobs to write progress for.
logger : logger instance
A logger to use for writing out progress.
n : int, optional
Maximum number of progress statements to show. Default: 10.
"""
#Setup progress obj
n_tasks = len(jobs)
p = Progress(n_tasks, n, logger)
n_done = sum([task.ready() for task in jobs])
while n_done != n_tasks:
p.write_progress(n_done)
time.sleep(0.1)
n_done = sum([task.ready() for task in jobs]) #recalculate number of finished jobs
logger.info("Finished!")
return(0) #doesn't return until the while loop exits
#--------------------------------------- Motif / TFBS scanning and processing ---------------------------------------#
[docs]def prepare_motifs(motifs_file, motif_pvalue=0.0001, motif_naming="name"):
""" Read motifs from motifs_file and set threshold/name. """
# Read and prepare motifs
motifs_obj = MotifList().from_file(motifs_file)
_ = [motif.get_threshold(motif_pvalue) for motif in motifs_obj]
_ = [motif.set_prefix(motif_naming) for motif in motifs_obj] # using naming from args
# If motif_naming is "name", but motifs do not have a name, use the motif ID as name
for motif in motifs_obj:
if motif.prefix == "" and motif_naming == "name":
motif.prefix = motif.id
return motifs_obj
[docs]def open_genome(genome_f):
""" Opens an internal genome object for fetching sequences.
Parameters
------------
genome_f : str
The path to a fasta file.
Returns
---------
pysam.FastaFile
"""
genome_obj = pysam.FastaFile(genome_f)
return(genome_obj)
[docs]def open_bigwig(bigwig_f):
"""
Parameters
------------
bigwig_f : str
The path to a bigwig file.
"""
pybw_obj = pyBigWig.open(bigwig_f)
return(pybw_obj)
[docs]def check_boundaries(regions, genome):
""" Utility to check whether regions are within the boundaries of genome.
Parameters
-----------
regions : tobias.utils.regions.RegionList
A RegionList() object containing regions to check.
genome : pysam.FastaFile
An object (e.g. from open_genome()) to use as reference.
Raises
-------
InputError
If a region is not available within genome
"""
chromosomes = genome.references
lengths = genome.lengths
genome_bounds = dict(zip(chromosomes, lengths))
for region in regions:
if region.chrom not in chromosomes:
raise InputError("Region '{0} {1} {2} {3}' is not present in the given genome. Available chromosomes are: {4}.".format(region.chrom, region.start, region.end, region.name, chromosomes))
else:
if region.start < 0 or region.end > genome_bounds[region.chrom]:
raise InputError("Region '{0} {1} {2} {3}' is out of bounds in the given genome. The length of the chromosome is: {4}".format(region.chrom, region.start, region.end, region.name, genome_bounds[region.chrom]))
[docs]def unique_region_names(regions):
"""
Get a list of unique region names within regions.
Parameters
-----------
regions : tobias.utils.regions.RegionList
A RegionList() object containing regions with .name attributes.
Returns
--------
list
The list of sorted names from regions.
"""
names_dict = {r.name: True for r in regions}
names = sorted(list(names_dict.keys()))
return(names)
[docs]def calculate_TFBS(regions, motifs, genome, resolve="merge"):
"""
Multiprocessing-safe function to scan for motif occurrences
Parameters
----------
genome : str or
If string , genome will be opened
regions : RegionList()
A RegionList() object of regions
resolve : str
How to resolve overlapping sites from the same TF. Must be one of "off", "highest_score" or "merge". If "highest_score", the highest scoring overlapping site is kept.
If "merge", the sites are merged, keeping the information of the first site. If "off", overlapping TFBS are kept. Default: "merge".
Returns
----------
List of TFBS within regions
"""
check_string(resolve, ["merge", "highest_score", "off"], "resolve")
#open the genome given
if isinstance(genome, str):
genome_obj = open_genome(genome)
else:
genome_obj = genome
TFBS_list = RegionList()
for region in regions:
seq = genome_obj.fetch(region.chrom, region.start, region.end)
region_TFBS = motifs.scan_sequence(seq, region)
#Convert RegionLists to TFBS class
region_TFBS = RegionList([OneTFBS(region) for region in region_TFBS])
TFBS_list += region_TFBS
#Sort all sites
TFBS_list.loc_sort()
#Resolve overlapping
if resolve != "off":
TFBS_list = resolve_overlaps(TFBS_list, how=resolve)
if isinstance(genome, str):
genome_obj.close()
return(TFBS_list)
[docs]def resolve_overlaps(sites, how="merge", per_name=True):
"""
Resolve overlapping sites within a list of genomic regions.
Parameters
------------
sites : RegionList
A list of TFBS/regions with .chrom, .start, .end and .name information.
how : str
How to resolve the overlapping site. Must be one of "highest_score", "merge". If "highest_score", the highest scoring overlapping site is kept.
If "merge", the sites are merged, keeping the information of the first site. Default: "merge".
per_name : bool
Whether to resolve overlapping only per name or across all sites. If 'True' overlaps are only resolved if the name of the sites are equal.
If 'False', overlaps are resolved across all sites. Default: True.
"""
check_string(how, ["highest_score", "merge"], "how")
#Create a copy of sites to ensure that original sites are not changed
n_sites = len(sites)
new_sites = [None]*n_sites
tracking = {} # dictionary for tracking positions of TFBS per name (or across all)
for current_i in range(n_sites):
current_site = sites[current_i]
new_sites[current_i] = current_site #might change again during merging
site_name = current_site.name if per_name == True else "." #control which site to fetch as 'previous'
if site_name in tracking: #if not in tracking, site is the first site of this name
#previous_site = tracking[site_name]["site"]
previous_i = tracking[site_name]
previous_site = new_sites[previous_i]
if (current_site.chrom == previous_site.chrom) and (current_site.start < previous_site.end): #overlapping
#How to deal with overlap:
if how == "highest_score":
if current_site.score >= previous_site.score: #keep current site
new_sites[previous_i] = None
tracking[site_name] = current_i #new tracking
else: #keep previous site
new_sites[current_i] = None
#tracking stays the same
elif how == "merge":
merged_end = max([previous_site.end, current_site.end])
#merge site into the previous; keep previous score/strand
merged = OneTFBS([current_site.chrom,
previous_site.start,
merged_end,
previous_site.name,
previous_site.score,
previous_site.strand])
new_sites[previous_i] = merged
new_sites[current_i] = None
#tracking i stays the same
else: #no overlaps with previous; save this site to tracking
tracking[site_name] = current_i
else: #Save first site to tracking
tracking[site_name] = current_i
resolved = [site for site in new_sites if site is not None]
return(resolved)
[docs]def add_region_overlap(a, b, att="overlap"):
""" Overlap regions in regionlist 'a' with regions from regionlist 'b' and add
a boolean attribute to the regions in 'a' containing overlap status with 'b'.
Parameters
------------
a : list of OneTFBS objects
A list of objects containing genomic locations.
b : list of OneTFBS objects
A list of objects containing genomic locations to overlap with 'a' regions.
att : str, optional
The name of the attribute to add to 'a' objects. Default: "overlap".
"""
a = a.copy()
b = b.copy()
a.sort(key=lambda region: (region.chrom, region.start, region.end))
b.sort(key=lambda region: (region.chrom, region.start, region.end))
#Establish order of chromosomes
chromlist = sorted(list(set([region.chrom for region in a] + [region.chrom for region in b])))
chrom_pos = {chrom: chromlist.index(chrom) for chrom in chromlist}
## Find overlap yes/no
a_n = len(a)
b_n = len(b)
a_i = 0 #initialize
b_i = 0
while a_i < a_n and b_i < b_n:
a_chrom, a_start, a_end = a[a_i].chrom, a[a_i].start, a[a_i].end
b_chrom, b_start, b_end = b[b_i].chrom, b[b_i].start, b[b_i].end
#Check possibility of overlap
if a_chrom == b_chrom:
if a_end <= b_start: #current a is placed before current b
a_i += 1
elif a_start >= b_end: #current a is placed after current b
b_i += 1
else: #a region overlaps b region
setattr(a[a_i], att, True) #save overlap
a_i += 1 #see if next a also overlaps this b
elif chrom_pos[a_chrom] > chrom_pos[b_chrom]: #if a_chrom is after current b_chrom
b_i += 1
elif chrom_pos[b_chrom] > chrom_pos[a_chrom]: #if b_chrom is after current a_chrom
a_i += 1
#The additional sites are False
for site_a in a:
if not hasattr(site_a, att):
setattr(site_a, att, False)
return(a)
#--------------------------------- Background calculation ---------------------------------#
[docs]def shuffle_array(arr, seed=1):
np.random.seed(seed)
length = arr.shape[0]
return(arr[np.random.permutation(length),:])
[docs]def shuffle_sites(sites, seed=1):
""" Shuffle TFBS names to existing positions and updates lengths of the new positions.
Parameters
-----------
sites : np.array
An array of sites in shape (n_sites,4), where each row is a site and columns correspond to chromosome, start, end, name.
Returns
--------
An array containing shuffled names with site lengths corresponding to original length of sites.
"""
#Establish lengths of regions
lengths = sites[:,2] - sites[:,1]
sites_plus = np.c_[sites, lengths]
#Shuffle names (and corresponding lengths)
sites_plus[:,-2:] = shuffle_array(sites_plus[:,-2:], seed)
#Adjust coordinates to new length
#new start = old start + old half length - new half length
#new end = new start + new length
sites_plus[:,1] = sites_plus[:,1] + ((sites_plus[:,2] - sites_plus[:,1])/2) - sites_plus[:,-1]/2 #new start
sites_plus[:,2] = sites_plus[:,1] + sites_plus[:,-1] #new end
#Remove length again
sites_shuffled = sites_plus[:,:-1]
return(sites_shuffled)
[docs]def calculate_background(sites, seed=1, directional=False, **kwargs):
"""
Wrapper to shuffle sites and count co-occurrence of the shuffled sites.
Parameters
------------
sites : np.array
An array of sites in shape (n_sites,4), where each row is a site and columns correspond to chromosome, start, end, name.
seed : int, optional
Seed for shuffling sites. Default: 1.
directional : bool
Decide if direction of found pairs should be taken into account. Default: False.
kwargs : arguments
Additional arguments for count_co_occurrence
"""
#Shuffle sites
s = datetime.datetime.now()
shuffled = shuffle_sites(sites, seed=seed)
e = datetime.datetime.now()
#print("Shuffling: {0}".format(e-s))
s = datetime.datetime.now()
_, pair_counts = count_co_occurrence(shuffled, **kwargs)
e = datetime.datetime.now()
#print("counting: {0}".format(e-s))
pair_counts = tfcomb.utils.make_symmetric(pair_counts) if directional == False else pair_counts #Deal with directionality
return(pair_counts)
#--------------------------------- Thresholding ---------------------------------#
[docs]def get_threshold(data, which="upper", percent=0.05, _n_max=10000, verbosity=0, plot=False):
"""
Function to get upper/lower threshold(s) based on the distribution of data. The threshold is calculated as the probability of "percent" (upper=1-percent).
Parameters
------------
data : list or array
An array of data to find threshold on.
which : str
Which threshold to calculate. Can be one of "upper", "lower", "both". Default: "upper".
percent : float between 0-1
Controls how strict the threshold should be set in comparison to the distribution. Default: 0.05.
Returns
---------
If which is one of "upper"/"lower", get_threshold returns a float. If "both", get_threshold returns a list of two float thresholds.
"""
distributions = [scipy.stats.norm, scipy.stats.lognorm, scipy.stats.laplace,
scipy.stats.expon, scipy.stats.truncnorm, scipy.stats.truncexpon, scipy.stats.wald, scipy.stats.weibull_min]
logger = tfcomb.logging.TFcombLogger(verbosity)
#Check input parameters
check_string(which, ["upper", "lower", "both"], "which")
check_value(percent, vmin=0, vmax=1, name="percent")
#Subset data to _n_max:
if len(data) > _n_max:
np.random.seed(0)
data = np.random.choice(data, size=_n_max, replace=False)
data_finite = np.array(data)[~np.isinf(data)]
#Fit data to each distribution
distribution_dict = {}
for distribution in distributions:
#Catch any exceptions from fitting
try:
# temporary change numpy floating-point error handling
with np.errstate(all="raise"):
params = distribution.fit(data_finite)
except Exception as e:
logger.error("Exception ({0}) occurred while fitting data to '{1}' distribution; skipping this distribution. Error message was: {2} ".format(e.__class__.__name__, distribution.name, e))
continue
#Test fit using negative loglikelihood function
mle = distribution.nnlf(params, data_finite)
#Save info on distribution fit
distribution_dict[distribution.name] = {"distribution": distribution,
"params": params,
"mle": mle}
logger.spam("Fitted data to '{0}' with mle={1} and params: {2}".format(distribution.name, mle, params))
if len(distribution_dict) == 0:
raise ValueError("No distributions could be fit to the input data.")
#Get best distribution
best_fit_name = sorted(distribution_dict, key=lambda x: distribution_dict[x]["mle"])[0]
logger.debug("Best fitting distribution was: {0}".format(best_fit_name))
parameters = distribution_dict[best_fit_name]["params"]
best_distribution = distribution_dict[best_fit_name]["distribution"]
#Get threshold
thresholds = best_distribution(*parameters).ppf([percent, 1-percent])
if which == "upper":
final = thresholds[-1]
elif which == "lower":
final = thresholds[0]
elif which == "both":
final = tuple(thresholds)
#Plot fit and threshold
if plot == True:
plt.hist(data, bins=30, density=True)
xlims = plt.xlim()
ylims = plt.ylim()
xmin = np.min(data)
xmax = np.max(data)
x = np.linspace(xmin, xmax, 100)
plt.plot(x, best_distribution(*parameters).pdf(x), lw=5, alpha=0.6, label=best_distribution.name)
thresh_list = [final] if not isinstance(final, tuple) else final
for t in thresh_list:
plt.axvline(t, color="red")
plt.legend()
plt.xlim(xlims)
plt.ylim(ylims)
return(final)
#--------------------------------- Working with TF-COMB objects ---------------------------------#
[docs]def is_symmetric(matrix):
""" Check if a matrix is symmetric around the diagonal """
if matrix.shape[0] != matrix.shape[1]:
b = False #not symmetric if matrix is not square
else:
b = np.allclose(matrix, matrix.T, equal_nan=True)
return(b)
[docs]def make_symmetric(matrix):
"""
Make a numpy matrix matrix symmetric by merging x-y and y-x
"""
matrix_T = matrix.T
symmetric = matrix + matrix_T
#don't add up diagonal indices
di = np.diag_indices(symmetric.shape[0])
symmetric[di] = matrix_T[di]
return(symmetric)
[docs]def set_contrast(contrast, available_contrasts):
""" Utility function for the plotting functions of tfcomb.objects.DiffCombObj """
#Setup contrast to use
if contrast == None:
contrast = available_contrasts[0]
else:
#Check if contrast is tuple
if contrast not in available_contrasts:
raise ValueError("Contrast {0} is not valid (available contrasts are {1})".format(contrast, available_contrasts))
return(contrast)
# ------------------------- chunk operations ---------------------------------------- #
[docs]def analyze_signal_chunks(datasource, threshold):
""" Evaluating signal for chunks.
Parameters
----------
datasource : pd.DataFrame
A (sub-)Dataframe with the (corrected) distance counts for the pairs
threshold : float
Threshold for prominence and height in peak calling (see scipy.signal.find_peaks() for detailed information)
Returns
--------
list
list of found peaks in form [TF1, TF2, Distance, Peak Heights, Prominences, Prominence Threshold]
See also
--------
tfcomb.object.analyze_signal_all
"""
# make sure index is correct
#datasource = datasource.reset_index()
datasource.index = datasource["TF1"] + "-" + datasource["TF2"]
distances = datasource.columns.tolist()[2:]
pairs = zip(datasource["TF1"], datasource["TF2"]) #.index().tolist() #list of pair tuples, e.g. ("NFYA","NFYB")
# get data column
distance_cols = np.array([-1 if d == "neg" else d for d in distances]) #neg counts as -1
#Calculate peaks for each row in datasource
results = []
for pair in pairs:
# get pair
tf1, tf2 = pair
ind = "-".join(pair)
signal = datasource.loc[ind, distances].values
x = [0] + list(signal) + [0]
# signal.find_peaks() will not find peaks on first and last position without having
# an other number left and right.
#Find positions of peaks
peaks_idx, properties = find_peaks(x, prominence=threshold, height=threshold)
# subtract the position added above (first zero)
peaks_idx = peaks_idx - 1
#Get distances from columns
peak_distances = [distance_cols[idx] for idx in peaks_idx]
n_peaks = len(peak_distances)
properties["TF1"] = [tf1]*n_peaks # insert tf1,tf2 names number of peaks times
properties["TF2"] = [tf2]*n_peaks
properties["Distance"] = peak_distances
properties["Threshold"] = threshold
results.append(properties)
return results
[docs]def evaluate_noise_chunks(signals, peaks, method="median", height_multiplier=0.75):
"""
Evaluate the noisiness of a signal for chunks (a chunk can also be the whole dataset).
Parameters
---------
pairs : list(tuples(str,str))
list of pairs to perform analysis on
signals : pd.Dataframe
A (sub-)Dataframe containing signal data for pairs
method : str, otional
Method used to get noise measurement, either "median" or "min_max" allowed.
Default: "median"
height_multiplier : float, optional
Height multiplier (percentage) to calculate cut points. Must be between 0 and 1.
Default: 0.75
Raises
------
ValueError
If no signal data is given for a pair
Note
-----
Constraint: DataFrame with signals need to contain a signal for each pair given within pairs.
"""
# make sure index is correct
signals.index = signals["TF1"] + "-" + signals["TF2"]
pairs = zip(signals["TF1"], signals["TF2"])
check_value(height_multiplier, vmin=0, vmax=1)
# get data for each pair
results = []
for pair in pairs:
tf1, tf2 = pair
ind = "-".join(pair)
#Read information for pair
signal = signals.loc[ind].iloc[2:] #get signal for pair
peaks_pair = peaks[(peaks.TF1 == tf1) & (peaks.TF2 == tf2)] # get peaks for specific pair
results.append([tf1, tf2, _get_noise_measure(peaks_pair, signal, method, height_multiplier)])
return results
def _get_noise_measure(peaks, signal, method, height_multiplier):
#check method input
check_string(method, ["median", "min_max"], "method")
# get the cutting points fot the signal
cuts = _get_cut_points(peaks, height_multiplier, signal)
# cut all peaks out of the signal
for cut in cuts:
signal.iloc[cut[0]:cut[1]] = np.nan
measure = None
if method == "median":
measure = pd.Series(signal).median()
elif method == "min_max":
measure = signal.max() - signal.min()
return float(measure)
def _get_cut_points(peaks, height_multiplier, signal):
cuts =[]
for idx,row in peaks.iterrows():
# get the peak distance
peak = row.Distance - int(signal.index[0]) # subract min distance for peak offset
# get the peak height
peak_height = signal.iloc[peak]
# determine cutoff, in common sense this should be "going ~25% down the peak size"
cut_off = height_multiplier * peak_height
cuts.append(_expand_peak(peak, cut_off, signal))
return cuts
def _expand_peak(start_pos, cut_off, signal):
found_left = False
found_right = False
pos_left = start_pos - 1
pos_right = start_pos + 1
# expand the peak until both borders are found
while(not(found_left & found_right)):
# left side
if(not found_left):
# left border not found
if pos_left <= -1: # check if position less than start of signal
found_left = True
left = 0
elif signal.iloc[pos_left] <= cut_off:
found_left = True
left = pos_left + 1 # we are one to far left
pos_left -= 1
# right side
if(not found_right):
# right border not found
if pos_right == len(signal): # check if position higher than end of signal
found_right = True
right = len(signal) - 1
elif signal.iloc[pos_right] < cut_off:
found_right = True
right = pos_right - 1 # we are one to far right
pos_right += 1
return(left, right)
[docs]def getAllAttr(object, private=False, functions=False):
"""
Collect all attributes of an object and return as dict.
Parameters
----------
private : boolean, default False
If private attributes should be included. Everything with '_' prefix.
functions : boolean, default False
If callable attributes ie functions shoudl be included.
Returns
----------
dictionary :
Dict of all the objects attributes.
"""
output = {}
for attribute_name in dir(object):
attribute_value = getattr(object, attribute_name)
if not private and attribute_name.startswith("_"):
continue
if not functions and callable(attribute_value):
continue
output[attribute_name] = attribute_value
return output