import gzip
import logging
import operator
from datetime import datetime
from io import StringIO
from typing import IO
from deprecation import deprecated
from dae.utils.regions import BedRegion, difference, total_length
from .gene_models import (
Exon,
GeneModels,
TranscriptModel,
)
logger = logging.getLogger(__name__)
GTF_FEATURE_ORDER: dict[str, int] = {
"gene": 0,
"transcript": 1,
"exon": 2,
"CDS": 2,
"start_codon": 2,
"stop_codon": 2,
"UTR": 3,
}
GTFRecordIndex = tuple[str, int, int, int]
GTFRecord = tuple[GTFRecordIndex, str]
[docs]
def gtf_canonical_index(index: GTFRecordIndex) -> tuple:
# This function converts a GTFRecordIndex for GTF-canonical
# sorting of a GTF file by placing the feature's index at the front
return (index[3], *index[:3])
[docs]
def gene_models_to_gtf(
gene_models: GeneModels, *,
sort_by_position: bool = True,
) -> StringIO:
"""Output a GTF format string representation."""
if not gene_models.gene_models:
logger.warning("Serializing empty (probably not loaded) gene models!")
return StringIO()
record_buffer: list[GTFRecord] = []
for gene_name, transcripts in gene_models.gene_models.items():
t = transcripts[0]
chrom = t.chrom
start = min(t.tx[0] for t in transcripts)
stop = max(t.tx[1] for t in transcripts)
strand = t.strand
gene_id = gene_name
version = t.attributes.get("gene_version", ".")
src = t.attributes.get("gene_source", ".")
biotype = t.attributes.get("gene_biotype", ".")
attrs = ";".join([
f'gene_id "{gene_id}"',
f'gene_version "{version}"',
f'gene_name "{gene_name}"',
f'gene_source "{src}"',
f'gene_biotype "{biotype}"',
])
gene_rec = \
f"{chrom}\t{src}\tgene\t{start}\t{stop}\t.\t{strand}\t.\t{attrs};"
record_buffer.append(
((chrom, start, -stop, GTF_FEATURE_ORDER["gene"]), gene_rec))
for transcript in transcripts:
record_buffer.extend(transcript_to_gtf(transcript))
if sort_by_position:
record_buffer.sort(key=operator.itemgetter(0))
else:
record_buffer.sort(key=lambda rec: gtf_canonical_index(rec[0]))
joined_records = "\n".join(rec[1] for rec in record_buffer)
return StringIO(
f"""##description: GTF format dump for gene models "{gene_models.resource.resource_id or '?'}"
##provider: GPF
##format: gtf
##date: {datetime.today().strftime('%Y-%m-%d')}
{joined_records}
""") # noqa: E501
[docs]
def build_gtf_record(
transcript: TranscriptModel,
feature: str,
start: int, stop: int,
attrs: str,
) -> GTFRecord:
"""Build an indexed GTF format record for a feature."""
src = transcript.attributes.get("gene_source", ".")
phase = "."
exon_number = -1
if feature in ("exon", "CDS", "start_codon", "stop_codon"):
exon_number = transcript.get_exon_number_for(start, stop)
if feature in ("CDS", "start_codon", "stop_codon"):
frame = calc_frame_for_gtf_cds_feature(
transcript, BedRegion(transcript.chrom, start, stop))
phase = str((3 - frame) % 3)
line = (f"{transcript.chrom}\t{src}\t{feature}\t{start}"
f"\t{stop}\t.\t{transcript.strand}\t{phase}\t{attrs};")
if feature in ("exon", "CDS", "start_codon", "stop_codon"):
line = f'{line}exon_number "{exon_number}";'
# add stop as negative to sort it in descending order
index = \
(transcript.chrom, start, -stop, GTF_FEATURE_ORDER[feature])
return (index, line)
[docs]
@deprecated("This function was split into multiple specialized functions.")
def collect_cds_regions(
transcript: TranscriptModel,
) -> tuple[list[BedRegion], list[BedRegion], list[BedRegion]]:
"""
Returns a tuple of start codon regions, normal coding regions
and stop codon regions for a given transcript.
"""
if not transcript.is_coding():
return ([], [], [])
reverse = transcript.strand == "-"
start_codons: list[BedRegion] = []
cds_regions: list[BedRegion] = transcript.cds_regions()
stop_codons: list[BedRegion] = []
start_bases_remaining, stop_bases_remaining = 3, 3
while start_bases_remaining > 0:
cds = cds_regions.pop(0 if not reverse else -1)
cds_len = cds.stop - cds.start + 1
bases_to_write = min(start_bases_remaining, cds_len)
codon_start = cds.start if not reverse \
else cds.stop - (bases_to_write - 1)
codon_stop = codon_start + (bases_to_write - 1) if not reverse \
else cds.stop
start_codons.append(BedRegion(cds.chrom, codon_start, codon_stop))
if cds_len - bases_to_write > 0:
cds_regions.insert(0 if not reverse else -1, cds)
start_bases_remaining -= bases_to_write
while stop_bases_remaining > 0:
cds = cds_regions.pop(-1 if not reverse else 0)
cds_len = cds.stop - cds.start + 1
bases_to_write = min(stop_bases_remaining, cds_len)
codon_start = cds.stop - (bases_to_write - 1) if not reverse \
else cds.start
codon_stop = cds.stop if not reverse \
else codon_start + (bases_to_write - 1)
stop_codons.append(BedRegion(cds.chrom, codon_start, codon_stop))
if cds_len - bases_to_write > 0:
remainder = BedRegion(
cds.chrom,
cds.start if not reverse else codon_stop + 1,
codon_start - 1 if not reverse else cds.stop,
)
cds_regions.insert(-1 if not reverse else 0, remainder)
stop_bases_remaining -= bases_to_write
return start_codons, cds_regions, stop_codons
[docs]
def collect_gtf_start_codon_regions(
strand: str,
cds_regions: list[BedRegion],
) -> list[BedRegion]:
"""Returns list of all regions that represent the start codon."""
if strand == "+":
region = cds_regions[0]
if len(region) >= 3:
return [
BedRegion(
region.chrom,
region.start,
region.start + 2,
),
]
result = [region]
for region in cds_regions[1:]:
total = total_length(result)
if total + len(region) >= 3:
result.append(BedRegion(
region.chrom,
region.start,
region.start + (2 - total),
))
return result
result.append(region)
elif strand == "-":
region = cds_regions[-1]
if len(region) >= 3:
return [
BedRegion(
region.chrom,
region.stop - 2,
region.stop,
),
]
result = [region]
for region in reversed(cds_regions[:-1]):
total = total_length(result)
if total + len(region) >= 3:
result.append(BedRegion(
region.chrom,
region.stop - (2 - total),
region.stop,
))
return list(reversed(result))
result.append(region)
else:
raise ValueError("Invalid strand")
return []
[docs]
def collect_gtf_stop_codon_regions(
strand: str,
cds_regions: list[BedRegion],
) -> list[BedRegion]:
"""Returns list of all regions that represent the stop codon."""
if strand == "+":
region = cds_regions[-1]
if len(region) >= 3:
return [
BedRegion(
region.chrom,
region.stop - 2,
region.stop,
),
]
result = [region]
for region in reversed(cds_regions[:-1]):
total = total_length(result)
if total + len(region) >= 3:
result.append(BedRegion(
region.chrom,
region.stop - (2 - total),
region.stop,
))
return list(reversed(result))
result.append(region)
elif strand == "-":
region = cds_regions[0]
if len(region) >= 3:
return [
BedRegion(
region.chrom,
region.start,
region.start + 2,
),
]
result = [region]
for region in cds_regions[1:]:
total = total_length(result)
if total + len(region) >= 3:
result.append(BedRegion(
region.chrom,
region.start,
region.start + (2 - total),
))
return result
result.append(region)
else:
raise ValueError("Invalid strand")
return []
[docs]
def collect_gtf_cds_regions(
strand: str,
cds_regions: list[BedRegion],
) -> list[BedRegion]:
"""Returns list of all regions that represent the CDS."""
stop_codon_regions = collect_gtf_stop_codon_regions(strand, cds_regions)
return difference(cds_regions, stop_codon_regions) # type: ignore
[docs]
def find_exon_cds_region_for_gtf_cds_feature(
transcript: TranscriptModel,
region: BedRegion,
) -> tuple[Exon, BedRegion]:
"""Find exon and CDS region that contains the given feature."""
for exon in transcript.exons:
if exon.contains((region.start, region.stop)):
for cds_region in transcript.cds_regions():
if exon.contains((cds_region.start, cds_region.stop)):
return exon, cds_region
raise ValueError(f"exon for region {region} not found")
[docs]
def calc_frame_for_gtf_cds_feature(
transcript: TranscriptModel,
region: BedRegion,
) -> int:
"""Calculate frame for the given feature."""
exon, cds_region = find_exon_cds_region_for_gtf_cds_feature(
transcript, region)
if exon.frame is None:
raise ValueError(f"frame not found for exon {exon}")
if transcript.strand == "+":
return (exon.frame + (abs(cds_region.start - region.start) % 3)) % 3
return (exon.frame + (abs(cds_region.stop - region.stop) % 3)) % 3
[docs]
def transcript_to_gtf(transcript: TranscriptModel) -> list[GTFRecord]:
"""Output an indexed list of GTF-formatted features of a transcript."""
record_buffer: list[GTFRecord] = []
attributes = {
"transcript_id": transcript.tr_id,
"gene_name": transcript.gene,
"gene_id": transcript.gene,
}
str_attrs = ";".join(f'{k} "{v}"' for k, v in attributes.items())
def write_record(feature: str, start: int, stop: int) -> None:
record_buffer.append(
build_gtf_record(transcript, feature, start, stop, str_attrs))
write_record("transcript", transcript.tx[0], transcript.tx[1])
for exon in transcript.exons:
write_record("exon", exon.start, exon.stop)
if transcript.is_coding():
cds_regions = transcript.cds_regions()
for codon in collect_gtf_start_codon_regions(
transcript.strand, cds_regions):
write_record("start_codon", codon.start, codon.stop)
for cds in collect_gtf_cds_regions(
transcript.strand, cds_regions):
write_record("CDS", cds.start, cds.stop)
for codon in collect_gtf_stop_codon_regions(
transcript.strand, cds_regions):
write_record("stop_codon", codon.start, codon.stop)
for utr in transcript.utr3_regions() + transcript.utr5_regions():
write_record("UTR", utr.start, utr.stop)
return record_buffer
def _save_as_default_gene_models(
gene_models: GeneModels,
outfile: IO,
) -> None:
outfile.write(
"\t".join(
[
"chr",
"trID",
"trOrigId",
"gene",
"strand",
"tsBeg",
"txEnd",
"cdsStart",
"cdsEnd",
"exonStarts",
"exonEnds",
"exonFrames",
"atts",
],
),
)
outfile.write("\n")
for transcript_model in gene_models.transcript_models.values():
exon_starts = ",".join([
str(e.start) for e in transcript_model.exons])
exon_ends = ",".join([
str(e.stop) for e in transcript_model.exons])
exon_frames = ",".join([
str(e.frame) for e in transcript_model.exons])
add_atts = ";".join(
[
k + ":" + str(v).replace(":", "_")
for k, v in list(transcript_model.attributes.items())
],
)
columns = [
transcript_model.chrom,
transcript_model.tr_id,
transcript_model.tr_name,
transcript_model.gene,
transcript_model.strand,
transcript_model.tx[0],
transcript_model.tx[1],
transcript_model.cds[0],
transcript_model.cds[1],
exon_starts,
exon_ends,
exon_frames,
add_atts,
]
outfile.write("\t".join([str(x) if x else "" for x in columns]))
outfile.write("\n")
[docs]
def save_as_default_gene_models(
gene_models: GeneModels,
output_filename: str, *,
gzipped: bool = True,
) -> None:
"""Save gene models in a file in default file format."""
if gzipped:
if not output_filename.endswith(".gz"):
output_filename = f"{output_filename}.gz"
with gzip.open(output_filename, "wt") as outfile:
_save_as_default_gene_models(gene_models, outfile)
else:
with open(output_filename, "wt") as outfile:
_save_as_default_gene_models(gene_models, outfile)