from __future__ import annotations
import copy
import logging
from collections import defaultdict
from collections.abc import Iterator, Sequence
from typing import Any
import networkx as nx
import pysam
logger = logging.getLogger(__name__)
MAX_POSITION = 3_000_000_000
[docs]
def coalesce(v1: int | None, v2: int) -> int:
"""Return first non-None value."""
if v1 is not None:
return v1
return v2
[docs]
def bedfile2regions(bed_filename: str) -> list[BedRegion]:
"""Transform BED file into list of regions."""
with open(bed_filename) as infile:
regions = []
for line in infile:
if line[0] == "#":
continue
chrom, sbeg, send = line.strip().split("\t")
beg = int(sbeg)
end = int(send)
regions.append(BedRegion(chrom, beg + 1, end))
return regions
[docs]
def regions2bedfile(regions: list[BedRegion], bed_filename: str) -> None:
"""Save list of regions into a BED file."""
with open(bed_filename, "w") as outfile:
for reg in regions:
outfile.write(
f"{reg.chrom}\t{reg.start - 1}\t{reg.stop}\n")
[docs]
def calc_bin_begin(bin_len: int, bin_idx: int) -> int:
"""
Calculates the 1-based start position of the <bin_idx>-th bin
of length <bin_len>.
n 2n 3n 4n
|_______|_______|_______|
bin_len \
\
bin_begin
"""
return (bin_len * bin_idx) + 1
[docs]
def calc_bin_end(bin_len: int, bin_idx: int) -> int:
"""
Calculates the 1-based end position of the <bin_idx>-th bin
of length <bin_len>.
n 2n 3n 4n
|_______|_______|_______|
bin_len \
\
bin_end
"""
return bin_len * (bin_idx + 1)
[docs]
def calc_bin_index(bin_len: int, pos: int) -> int:
"""
Calculates the index of the <bin_len>-long bin the given 1-based
position <pos> falls into.
n 2n 3n 4n
|_______|_______|_______|
(bin 0) (bin 1) (bin 2)
"""
return (pos - 1) // bin_len
[docs]
def split_into_regions(
chrom: str, chrom_length: int,
region_size: int, start: int = 1) -> list[Region]:
"""Return a list of regions for a chrom with a given length."""
region_size = min(region_size, chrom_length)
return [
Region(chrom,
calc_bin_begin(region_size, i),
calc_bin_end(region_size, i))
for i in range(calc_bin_index(region_size, start),
calc_bin_index(region_size, chrom_length) + 1)
]
[docs]
def get_chromosome_length_tabix(
tabix_file: pysam.TabixFile | pysam.VariantFile, chrom: str,
step: int = 100_000_000, precision: int = 5_000_000,
) -> int | None:
"""
Return the length of a chromosome (or contig).
Returned value is guarnteed to be larger than the actual contig length.
"""
def any_records(riter: Iterator) -> bool:
try:
next(riter)
except StopIteration:
return False
except ValueError:
return False
return True
try:
# First we find any region that includes the last record i.e.
# the length of the chromosome
left, right = None, None
pos = step
while left is None or right is None:
region = Region(chrom, pos, None)
if any_records(tabix_file.fetch(str(region))):
left = pos
pos = pos * 2
else:
if pos == 0: # stop infinite loop if any_records is never True
raise ValueError # noqa: TRY301
right = pos
pos = pos // 2
# Second we use binary search to narrow the region until we find the
# index of the last element (in left) and the length (in right)
while (right - left) > precision:
pos = (left + right) // 2
region = Region(chrom, pos, None)
if any_records(tabix_file.fetch(str(region))):
left = pos
else:
right = pos
return right # noqa: TRY300
except ValueError as ex:
logger.warning(
"unable to find length of contig %s: %s", chrom, ex)
return None
[docs]
class Region:
"""Class representing a genomic region."""
def __init__(
self, chrom: str,
start: int | None = None, stop: int | None = None,
):
if start is not None and not isinstance(start, int):
raise TypeError(f"Invalid type for start position - {type(start)}")
if stop is not None and not isinstance(stop, int):
raise TypeError(f"Invalid type for stop position - {type(stop)}")
if start is not None and stop is not None:
assert start <= stop
self.chrom = chrom
self._start = start
self._stop = stop
@property
def start(self) -> int | None:
return self._start
@property
def stop(self) -> int | None:
return self._stop
@property
def begin(self) -> int | None:
return self.start
@property
def end(self) -> int | None:
return self.stop
def __repr__(self) -> str:
if self.start is None:
return self.chrom
if self.end is None:
return f"{self.chrom}:{self.start}"
return f"{self.chrom}:{self.start}-{self.stop}"
def __hash__(self) -> int:
return str(self).__hash__()
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Region):
return False
return bool(
self.chrom == other.chrom
and self.start == other.start
and self.stop == other.stop,
)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
[docs]
def isin(self, chrom: str, pos: int) -> bool:
"""Check if a genomic position is insde of the region."""
if chrom != self.chrom:
return False
if self.start and pos < self.start:
return False
return not (self.stop and pos > self.stop)
@staticmethod
def _min(a: int | None, b: int | None) -> int | None:
if a is None:
return b
if b is None:
return a
return min(a, b)
@staticmethod
def _max(a: int | None, b: int | None) -> int | None:
if a is None:
return b
if b is None:
return a
return max(a, b)
@staticmethod
def _make_region(
chrom: str, start: int | None, stop: int | None,
) -> Region | None:
if chrom is None:
return None
if start is not None and stop is not None and start > stop:
return None
return Region(chrom, start, stop)
[docs]
def intersection(self, other: Region) -> Region | None:
"""Return intersection of the region with other region."""
if self.chrom != other.chrom:
return None
if self.start is None and self.stop is None:
return Region(other.chrom, other.start, other.stop)
if other.start is None and other.stop is None:
return Region(self.chrom, self.start, self.stop)
if self.start is None:
assert self.stop is not None
return self._make_region(
self.chrom, other.start, self._min(self.stop, other.stop))
if self.stop is None:
assert self.start is not None
return self._make_region(
self.chrom, self._max(self.start, other.start), other.stop)
assert self.start is not None
assert self.stop is not None
return self._make_region(
self.chrom,
self._max(self.start, other.start),
self._min(self.stop, other.stop),
)
[docs]
def contains(self, other: Region) -> bool:
"""Check if the region contains other region."""
if self.chrom != other.chrom:
return False
if self.start is None and self.stop is None:
return True
if self.start is None:
assert self.stop is not None
if other.stop is not None:
assert other.stop is not None
return bool(other.stop <= self.stop)
return False
if self.stop is None:
assert self.start is not None
if other.start is not None:
return other.start >= self.start
return False
assert self.start is not None
assert self.stop is not None
if other.stop is None or other.start is None:
return False
return self.start <= other.start \
and other.stop <= self.stop
[docs]
def intersects(self, other: Region) -> bool:
"""Check if the region intersects another."""
if self.chrom != other.chrom:
return False
if self.start is not None and self.stop is not None:
if other.start is None or other.stop is None:
return other.intersects(self)
return not (self.stop < other.start or self.start > other.stop)
if self.stop is not None:
return other.start is None or self.stop >= other.start
if self.start is not None:
return other.stop is None or self.start <= other.stop
return True
[docs]
@staticmethod
def from_str(region: str) -> Region:
"""Parse string representation of a region."""
parts = [p.strip() for p in region.rsplit(":", maxsplit=1)]
if len(parts) == 1:
return Region(parts[0], None, None)
if len(parts) == 2:
chrom = parts[0]
parts = [p.strip() for p in parts[1].split("-")]
start = int(parts[0].replace(",", ""))
if len(parts) == 1:
return BedRegion(chrom, start, start)
if len(parts) == 2:
stop = int(parts[1].replace(",", ""))
return BedRegion(chrom, start, stop)
raise ValueError(f"unexpeced format for region {region}")
[docs]
class BedRegion(Region):
"""Represents proper bed regions."""
def __init__(
self, chrom: str,
start: int, stop: int,
):
assert start is not None
assert stop is not None
assert stop >= start
super().__init__(chrom, start, stop)
[docs]
@staticmethod
def from_str(region: str) -> BedRegion:
"""Parse string representation of a region."""
result = Region.from_str(region)
if result.start is None or result.stop is None:
raise ValueError(f"Invalid region format: {region}")
return BedRegion(result.chrom, result.start, result.stop)
@property
def start(self) -> int:
assert self._start is not None
return self._start
@property
def stop(self) -> int:
assert self._stop is not None
return self._stop
@property
def begin(self) -> int:
return self.start
@property
def end(self) -> int:
return self.stop
def __len__(self) -> int:
return self.stop - self.start + 1
[docs]
def all_regions_from_chrom(regions: list[Region], chrom: str) -> list[Region]:
"""Subset of regions in R that are from chr."""
return [r for r in regions if r.chrom == chrom]
[docs]
def unique_regions(regions: list[Region]) -> list[Region]:
"""Remove duplicated regions."""
return list(set(regions))
[docs]
def connected_component(regions: list[BedRegion]) -> Any:
"""Return connected component of regions.
This might be the same as collapse.
"""
graph = nx.Graph()
graph.add_nodes_from(regions)
regions_by_chrom = defaultdict(list)
for reg in regions:
regions_by_chrom[reg.chrom].append(reg)
for nds in regions_by_chrom.values():
nds.sort(key=lambda x: x.stop)
for k in range(1, len(nds)):
for j in range(k - 1, -1, -1):
if nds[k].start <= nds[j].stop:
graph.add_edge(nds[k], nds[j])
else:
break
return nx.connected_components(graph)
[docs]
def collapse(
source: Sequence[Region], *,
is_sorted: bool = False,
) -> list[Region]:
"""Collapse list of regions."""
if not source:
return list(source)
regions = copy.deepcopy(list(source))
if not is_sorted:
regions.sort(key=lambda x: x.start if x.start is not None else -1)
collapsed: dict[str, list[Region]] = defaultdict(list)
collapsed[regions[0].chrom].append(regions[0])
for reg in regions[1:]:
chrom_collapsed = collapsed.get(reg.chrom)
if not chrom_collapsed:
collapsed[reg.chrom].append(reg)
continue
prev_reg = chrom_collapsed[-1]
if coalesce(reg.start, 1) <= coalesce(prev_reg.stop, 1):
if coalesce(reg.stop, 1) > coalesce(prev_reg.stop, 1):
last = collapsed[reg.chrom][-1]
collapsed[reg.chrom][-1] = \
Region(last.chrom, last.start, reg.stop)
continue
collapsed[reg.chrom].append(reg)
result = []
for v in list(collapsed.values()):
result.extend(v)
return result
[docs]
def collapse_no_chrom(
source: list[BedRegion], *,
is_sorted: bool = False,
) -> list[BedRegion]:
"""Collapse by ignoring the chromosome.
Useful when the caller knows
that all the regions are from the same chromosome.
"""
if not source:
return source
regions = copy.copy(source)
if len(regions) == 1:
return regions
if not is_sorted:
regions.sort(key=lambda x: x.start)
collapsed = [regions[0]]
for reg in regions[1:]:
prev_reg = collapsed[-1]
if reg.start <= prev_reg.stop:
if reg.stop > prev_reg.stop:
prev_reg = BedRegion(prev_reg.chrom, prev_reg.start, reg.stop)
collapsed[-1] = prev_reg
continue
collapsed.append(reg)
return collapsed
[docs]
def total_length(regions: list[BedRegion]) -> int:
return sum(len(x) for x in regions)
[docs]
def intersection(
regions1: list[Region], regions2: list[Region],
) -> list[Region]:
"""Compute intersection of two list of regions.
First collapses each for lists of regions s1 and s2 and then find
the intersection.
"""
s1_c = collapse(regions1)
s2_c = collapse(regions2)
s1_c.sort(key=lambda x: (x.chrom, x.start))
s2_c.sort(key=lambda x: (x.chrom, x.start))
intersect = []
k = 0
for i in s2_c:
while k < len(s1_c):
if i.chrom != s1_c[k].chrom:
if i.chrom > s1_c[k].chrom:
k += 1
continue
break
if coalesce(i.stop, 1) < coalesce(s1_c[k].start, 1):
break
if coalesce(i.start, 1) > coalesce(s1_c[k].stop, MAX_POSITION):
k += 1
continue
if coalesce(i.start, 1) <= coalesce(s1_c[k].start, 1):
if coalesce(i.stop, MAX_POSITION) >= \
coalesce(s1_c[k].stop, MAX_POSITION):
intersect.append(s1_c[k])
k += 1
continue
new_i = Region(i.chrom, s1_c[k].start, i.stop)
intersect.append(new_i)
break
if coalesce(i.start, 1) > coalesce(s1_c[k].start, 1):
if coalesce(i.stop, MAX_POSITION) <= \
coalesce(s1_c[k].stop, MAX_POSITION):
intersect.append(i)
break
new_i = Region(i.chrom, i.start, s1_c[k].stop)
intersect.append(new_i)
k += 1
continue
return intersect
[docs]
def union(*r: list[Region]) -> list[Region]:
"""Collapse many lists of regions."""
r_sum = [el for rlist in r for el in rlist]
return collapse(r_sum)
def _diff(
regions_a: list[Region], regions_b: list[Region],
) -> list[Region]:
result = []
k = 0
for reg_a in regions_a:
if k >= len(regions_b):
result.append(reg_a)
continue
if reg_a.chrom < regions_b[k].chrom:
result.append(reg_a)
continue
if coalesce(reg_a.stop, MAX_POSITION) < \
coalesce(regions_b[k].start, 1):
result.append(reg_a)
continue
prev = coalesce(reg_a.start, 1)
while k < len(regions_b) \
and coalesce(regions_b[k].stop, MAX_POSITION) <= \
coalesce(reg_a.stop, MAX_POSITION) \
and regions_b[k].chrom == reg_a.chrom:
if prev < coalesce(regions_b[k].start, 1):
new_a = Region(
reg_a.chrom, prev,
coalesce(regions_b[k].start, 1) - 1)
result.append(new_a)
prev = coalesce(regions_b[k].stop, 1) + 1
k += 1
if k < len(regions_b) and regions_b[k].chrom != reg_a.chrom:
continue
if prev <= coalesce(reg_a.stop, MAX_POSITION):
result.append(Region(reg_a.chrom, prev, reg_a.stop))
return result
[docs]
def difference(
regions1: list[Region],
regions2: list[Region], *,
symmetric: bool = False,
) -> list[Region]:
"""Compute difference between two list of regions."""
if not symmetric:
left = collapse(regions1)
left.sort(key=lambda x: (x.chrom, x.start))
else:
left = union(regions1, regions2)
left.sort(key=lambda x: (x.chrom, x.start))
right = intersection(regions1, regions2)
return _diff(left, right)