Source code for dae.genomic_resources.gene_models.gene_models

from __future__ import annotations

import logging
from collections import defaultdict
from typing import Any, cast

from dae.genomic_resources import GenomicResource
from dae.genomic_resources.fsspec_protocol import build_local_resource
from dae.genomic_resources.repository import GenomicResourceRepo
from dae.genomic_resources.repository_factory import (
    build_genomic_resource_repository,
)
from dae.genomic_resources.resource_implementation import (
    ResourceConfigValidationMixin,
    get_base_resource_schema,
)
from dae.utils.regions import (
    BedRegion,
    Region,
    collapse,
)

logger = logging.getLogger(__name__)


[docs] class Exon: """Provides exon model.""" def __init__( self, start: int, stop: int, frame: int | None = None, ): """Initialize exon model. Args: start: The genomic start position of the exon (1-based). stop (int): The genomic stop position of the exon (1-based, closed). frame (Optional[int]): The frame of the exon. """ self.start = start self.stop = stop self.frame = frame # related to cds def __repr__(self) -> str: return f"Exon(start={self.start}; stop={self.stop})"
[docs] def contains(self, region: tuple[int, int]) -> bool: start, stop = region return self.start <= start and self.stop >= stop
[docs] class TranscriptModel: """Provides transcript model.""" def __init__( self, gene: str, tr_id: str, tr_name: str, chrom: str, strand: str, tx: tuple[int, int], # pylint: disable=invalid-name cds: tuple[int, int], exons: list[Exon] | None = None, attributes: dict[str, Any] | None = None, ): """Initialize transcript model. Args: gene (str): The gene name. tr_id (str): The transcript ID. tr_name (str): The transcript name. chrom (str): The chromosome name. strand (str): The strand of the transcript. tx (tuple[int, int]): The transcript start and end positions. (1-based, closed interval) cds (tuple[int, int]): The coding region start and end positions. The CDS region includes the start and stop codons. (1-based, closed interval) exons (Optional[list[Exon]]): The list of exons. Defaults to empty list. attributes (Optional[dict[str, Any]]): The additional attributes. Defaults to empty dictionary. """ self.gene = gene self.tr_id = tr_id self.tr_name = tr_name self.chrom = chrom self.strand = strand self.tx = tx # pylint: disable=invalid-name self.cds = cds self.exons: list[Exon] = exons if exons is not None else [] self.attributes = attributes if attributes is not None else {}
[docs] def is_coding(self) -> bool: return self.cds[0] < self.cds[1]
[docs] def cds_regions(self, ss_extend: int = 0) -> list[BedRegion]: """Compute CDS regions.""" if self.cds[0] >= self.cds[1]: return [] regions = [] k = 0 while self.exons[k].stop < self.cds[0]: k += 1 if self.cds[1] <= self.exons[k].stop: regions.append( BedRegion( chrom=self.chrom, start=self.cds[0], stop=self.cds[1]), ) return regions regions.append( BedRegion( chrom=self.chrom, start=self.cds[0], stop=self.exons[k].stop + ss_extend, ), ) k += 1 while k < len(self.exons) and self.exons[k].stop <= self.cds[1]: if self.exons[k].stop < self.cds[1]: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start - ss_extend, stop=self.exons[k].stop + ss_extend, ), ) k += 1 else: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start - ss_extend, stop=self.exons[k].stop, ), ) return regions if k < len(self.exons) and self.exons[k].start <= self.cds[1]: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start - ss_extend, stop=self.cds[1], ), ) return regions
[docs] def utr5_regions(self) -> list[BedRegion]: """Build list of UTR5 regions.""" if self.cds[0] >= self.cds[1]: return [] regions = [] k = 0 if self.strand == "+": while self.exons[k].stop < self.cds[0]: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start, stop=self.exons[k].stop, ), ) k += 1 if self.exons[k].start < self.cds[0]: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start, stop=self.cds[0] - 1, ), ) else: while self.exons[k].stop < self.cds[1]: k += 1 if self.exons[k].stop == self.cds[1]: k += 1 else: regions.append( BedRegion( chrom=self.chrom, start=self.cds[1] + 1, stop=self.exons[k].stop, ), ) k += 1 regions.extend([ BedRegion(chrom=self.chrom, start=exon.start, stop=exon.stop) for exon in self.exons[k:] ]) return regions
[docs] def utr3_regions(self) -> list[BedRegion]: """Build and return list of UTR3 regions.""" if self.cds[0] >= self.cds[1]: return [] regions = [] k = 0 if self.strand == "-": while self.exons[k].stop < self.cds[0]: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start, stop=self.exons[k].stop, ), ) k += 1 if self.exons[k].start < self.cds[0]: regions.append( BedRegion( chrom=self.chrom, start=self.exons[k].start, stop=self.cds[0] - 1, ), ) else: while self.exons[k].stop < self.cds[1]: k += 1 if self.exons[k].stop == self.cds[1]: k += 1 else: regions.append( BedRegion( chrom=self.chrom, start=self.cds[1] + 1, stop=self.exons[k].stop, ), ) k += 1 regions.extend([ BedRegion(chrom=self.chrom, start=exon.start, stop=exon.stop) for exon in self.exons[k:] ]) return regions
[docs] def all_regions( self, ss_extend: int = 0, prom: int = 0, ) -> list[BedRegion]: """Build and return list of regions.""" # pylint:disable=too-many-branches regions = [] if ss_extend == 0: regions.extend([ BedRegion(chrom=self.chrom, start=exon.start, stop=exon.stop) for exon in self.exons ]) else: for exon in self.exons: if exon.stop <= self.cds[0]: regions.append( BedRegion( chrom=self.chrom, start=exon.start, stop=exon.stop), ) elif exon.start <= self.cds[0]: if exon.stop >= self.cds[1]: regions.append( BedRegion( chrom=self.chrom, start=exon.start, stop=exon.stop), ) else: regions.append( BedRegion( chrom=self.chrom, start=exon.start, stop=exon.stop + ss_extend, ), ) elif exon.start > self.cds[1]: regions.append( BedRegion( chrom=self.chrom, start=exon.start, stop=exon.stop), ) else: if exon.stop >= self.cds[1]: regions.append( BedRegion( chrom=self.chrom, start=exon.start - ss_extend, stop=exon.stop, ), ) else: regions.append( BedRegion( chrom=self.chrom, start=exon.start - ss_extend, stop=exon.stop + ss_extend, ), ) if prom != 0: if self.strand == "+": regions[0] = BedRegion( chrom=regions[0].chrom, start=regions[0].start - prom, stop=regions[0].stop, ) else: regions[-1] = BedRegion( chrom=regions[-1].chrom, start=regions[-1].start, stop=regions[-1].stop + prom, ) return regions
[docs] def total_len(self) -> int: length = 0 for reg in self.exons: length += reg.stop - reg.start + 1 return length
[docs] def cds_len(self) -> int: regions = self.cds_regions() length = 0 for reg in regions: length += reg.stop - reg.start + 1 return length
[docs] def utr3_len(self) -> int: utr3 = self.utr3_regions() length = 0 for reg in utr3: length += reg.stop - reg.start + 1 return length
[docs] def utr5_len(self) -> int: utr5 = self.utr5_regions() length = 0 for reg in utr5: length += reg.stop - reg.start + 1 return length
[docs] def calc_frames(self) -> list[int]: """Calculate codon frames.""" length = len(self.exons) fms = [] if self.cds[0] > self.cds[1]: fms = [-1] * length elif self.strand == "+": k = 0 while self.exons[k].stop < self.cds[0]: fms.append(-1) k += 1 fms.append(0) if self.exons[k].stop < self.cds[1]: fms.append((self.exons[k].stop - self.cds[0] + 1) % 3) k += 1 while self.exons[k].stop < self.cds[1] and k < length: fms.append( (fms[k] + self.exons[k].stop - self.exons[k].start + 1) % 3, ) k += 1 fms += [-1] * (length - len(fms)) else: k = length - 1 while self.exons[k].start > self.cds[1]: fms.append(-1) k -= 1 fms.append(0) if self.cds[0] < self.exons[k].start: fms.append((self.cds[1] - self.exons[k].start + 1) % 3) k -= 1 while self.cds[0] < self.exons[k].start and k > -1: fms.append( (fms[-1] + self.exons[k].stop - self.exons[k].start + 1) % 3, ) k -= 1 fms += [-1] * (length - len(fms)) fms = fms[::-1] assert len(self.exons) == len(fms) return fms
[docs] def update_frames(self) -> None: """Update codon frames.""" frames = self.calc_frames() for exon, frame in zip(self.exons, frames, strict=True): exon.frame = frame
[docs] def test_frames(self) -> bool: frames = self.calc_frames() for exon, frame in zip(self.exons, frames, strict=True): if exon.frame != frame: return False return True
[docs] def get_exon_number_for(self, start: int, stop: int) -> int: for exon_number, exon in enumerate(self.exons): if not (start > exon.stop or stop < exon.start): return exon_number + 1 if self.strand == "+" \ else len(self.exons) - exon_number return 0
[docs] class GeneModels( ResourceConfigValidationMixin, ): """Provides class for gene models.""" def __init__(self, resource: GenomicResource): if resource.get_type() != "gene_models": raise ValueError( f"wrong type of resource passed: {resource.get_type()}") self.resource = resource self.config = self.validate_and_normalize_schema( resource.get_config(), resource, ) self.reference_genome_id: str | None = \ self.config["meta"]["labels"].get("reference_genome") \ if (self.config.get("meta") is not None and self.config["meta"].get("labels") is not None) \ else None self.gene_models: dict[str, list[TranscriptModel]] = defaultdict(list) self.utr_models: dict[ str, dict[tuple[int, int], list[TranscriptModel]]] = \ defaultdict(lambda: defaultdict(list)) self.transcript_models: dict[str, Any] = {} self.alternative_names: dict[str, Any] = {} self.reset() @property def resource_id(self) -> str: return self.resource.resource_id
[docs] def reset(self) -> None: self.alternative_names = {} self.utr_models = defaultdict(lambda: defaultdict(list)) self.transcript_models = {} self.gene_models = defaultdict(list)
[docs] def add_transcript_model(self, transcript_model: TranscriptModel) -> None: """Add a transcript model to the gene models.""" assert transcript_model.tr_id not in self.transcript_models self.transcript_models[transcript_model.tr_id] = transcript_model self.gene_models[transcript_model.gene].append(transcript_model) self.utr_models[transcript_model.chrom][transcript_model.tx]\ .append(transcript_model)
[docs] def update_indexes(self) -> None: self.gene_models = defaultdict(list) self.utr_models = defaultdict(lambda: defaultdict(list)) for transcript in self.transcript_models.values(): self.gene_models[transcript.gene].append(transcript) self.utr_models[transcript.chrom][transcript.tx].append(transcript)
[docs] def gene_names(self) -> list[str]: if self.gene_models is None: logger.warning( "gene models %s are empty", self.resource.resource_id) return [] return list(self.gene_models.keys())
[docs] def gene_models_by_gene_name( self, name: str, ) -> list[TranscriptModel] | None: return self.gene_models.get(name, None)
[docs] def gene_models_by_location( self, chrom: str, pos1: int, pos2: int | None = None, ) -> list[TranscriptModel]: """Retrieve TranscriptModel objects based on genomic position(s). Args: chrom (str): The chromosome name. pos1 (int): The starting genomic position. pos2 (Optional[int]): The ending genomic position. If not provided, only models that contain pos1 will be returned. Returns: list[TranscriptModel]: A list of TranscriptModel objects that match the given location criteria. """ result = [] if pos2 is None: key: tuple[int, int] for key in self.utr_models[chrom]: if key[0] <= pos1 <= key[1]: result.extend(self.utr_models[chrom][key]) else: if pos2 < pos1: pos1, pos2 = pos2, pos1 for key in self.utr_models[chrom]: if pos1 <= key[0] <= pos2 or key[0] <= pos1 <= key[1]: result.extend(self.utr_models[chrom][key]) return result
[docs] def relabel_chromosomes( self, relabel: dict[str, str] | None = None, map_file: str | None = None, ) -> None: """Relabel chromosomes in gene model.""" assert relabel or map_file if not relabel: assert map_file is not None with open(map_file) as infile: relabel = cast( dict[str, str], { line.strip("\n\r").split()[:2] for line in infile }, ) self.utr_models = { relabel[chrom]: v for chrom, v in self.utr_models.items() if chrom in relabel } self.transcript_models = { tid: tm for tid, tm in self.transcript_models.items() if tm.chrom in relabel } for transcript_model in self.transcript_models.values(): transcript_model.chrom = relabel[transcript_model.chrom]
[docs] @staticmethod def get_schema() -> dict[str, Any]: return { **get_base_resource_schema(), "filename": {"type": "string"}, "format": {"type": "string"}, "gene_mapping": {"type": "string"}, }
[docs] def load(self) -> GeneModels: """Load gene models.""" from .parsing import load_gene_models # pylint: disable=C0415 self.reset() return load_gene_models(self)
[docs] def is_loaded(self) -> bool: return len(self.transcript_models) > 0
[docs] def join_gene_models(*gene_models: GeneModels) -> GeneModels: """Join muliple gene models into a single gene models object.""" if len(gene_models) < 2: raise ValueError("The function needs at least 2 arguments!") gm = GeneModels(gene_models[0].resource) gm.utr_models = {} gm.gene_models = {} gm.transcript_models = gene_models[0].transcript_models.copy() for i in gene_models[1:]: gm.transcript_models.update(i.transcript_models) gm.update_indexes() return gm
[docs] def build_gene_models_from_file( file_name: str, file_format: str | None = None, gene_mapping_file_name: str | None = None, ) -> GeneModels: """Load gene models from local filesystem.""" config = { "type": "gene_models", "filename": file_name, } if file_format: config["format"] = file_format if gene_mapping_file_name: config["gene_mapping"] = gene_mapping_file_name res = build_local_resource(".", config) return build_gene_models_from_resource(res)
[docs] def build_gene_models_from_resource( resource: GenomicResource | None, ) -> GeneModels: """Load gene models from a genomic resource.""" if resource is None: raise ValueError(f"missing resource {resource}") if resource.get_type() != "gene_models": logger.error( "trying to open a resource %s of type " "%s as gene models", resource.resource_id, resource.get_type()) raise ValueError(f"wrong resource type: {resource.resource_id}") return GeneModels(resource)
[docs] def build_gene_models_from_resource_id( resource_id: str, grr: GenomicResourceRepo | None = None, ) -> GeneModels: if grr is None: grr = build_genomic_resource_repository() return build_gene_models_from_resource(grr.get_resource(resource_id))
[docs] def create_regions_from_genes( gene_models: GeneModels, genes: list[str], regions: list[Region] | None, gene_regions_heuristic_cutoff: int = 20, gene_regions_heuristic_extend: int = 20000, ) -> list[Region] | None: """Produce a list of regions from given gene symbols. If given a list of regions, will merge the newly-created regions from the genes with the provided ones. """ assert genes is not None assert gene_models is not None if len(genes) == 0 or len(genes) > gene_regions_heuristic_cutoff: return regions gene_regions = [] for gene_name in genes: gene_model = gene_models.gene_models_by_gene_name(gene_name) if gene_model is None: logger.warning("gene model for %s not found", gene_name) continue for gm in gene_model: gene_regions.append( # noqa: PERF401 Region( gm.chrom, max(1, gm.tx[0] - 1 - gene_regions_heuristic_extend), gm.tx[1] + 1 + gene_regions_heuristic_extend, ), ) gene_regions = collapse(gene_regions) if not regions: regions = gene_regions else: result = [] for gene_region in gene_regions: for region in regions: intersection = gene_region.intersection(region) if intersection: result.append(intersection) result = collapse(result) logger.info("original regions: %s; result: %s", regions, result) regions = result return regions