Source code for dae.genomic_resources.gene_models.gene_models

from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Generator
from threading import Lock
from typing import Any, cast

from intervaltree import (  # type: ignore
    Interval,
    IntervalTree,
)

from dae.genomic_resources.repository import (
    GenomicResource,
)
from dae.genomic_resources.resource_implementation import (
    ResourceConfigValidationMixin,
    get_base_resource_schema,
)
from dae.utils.regions import (
    Region,
    collapse,
)

from .parsers import load_transcript_models
from .transcript_models import TranscriptModel

logger = logging.getLogger(__name__)


[docs] class GeneModels( ResourceConfigValidationMixin, ): """Manage and query gene model data from genomic resources. This class provides access to gene models loaded from various file formats (GTF, refFlat, refSeq, CCDS, etc.) and offers efficient querying by gene name or genomic location. The class maintains three internal data structures: - transcript_models: Dict mapping transcript IDs to TranscriptModel objects - gene_models: Dict mapping gene names to lists of TranscriptModel objects - _tx_index: IntervalTree index for fast location-based queries Attributes: resource (GenomicResource): The genomic resource containing gene models. config (dict): Validated configuration from the resource. reference_genome_id (str | None): ID of the reference genome. gene_models (dict[str, list[TranscriptModel]]): Gene name to transcript models mapping. transcript_models (dict[str, TranscriptModel]): Transcript ID to transcript model mapping. Example: >>> from dae.genomic_resources.gene_models.gene_models_factory import \\ ... build_gene_models_from_file >>> gene_models = build_gene_models_from_file("genes.gtf") >>> gene_models.load() >>> # Query by gene name >>> tp53_transcripts = gene_models.gene_models_by_gene_name("TP53") >>> # Query by location >>> transcripts = gene_models.gene_models_by_location("chr17", 7676592) Note: The gene models must be loaded using the load() method before queries can be performed. The class is thread-safe for concurrent access. """ def __init__(self, resource: GenomicResource): self._is_loaded = False if resource.get_type() != "gene_models": raise ValueError( f"wrong type of resource passed: {resource.get_type()}") self.resource = resource self.config = self.validate_and_normalize_schema( resource.get_config(), resource, ) self.reference_genome_id: str | None = \ self.config["meta"]["labels"].get("reference_genome") \ if (self.config.get("meta") is not None and self.config["meta"].get("labels") is not None) \ else None self.gene_models: dict[str, list[TranscriptModel]] = defaultdict(list) self._tx_index: dict[str, IntervalTree] = defaultdict(IntervalTree) self.transcript_models: dict[str, Any] = {} self._reset() self.__lock = Lock() @property def resource_id(self) -> str: return self.resource.resource_id
[docs] def close(self) -> None: pass
def _reset(self) -> None: """Reset gene models.""" self._is_loaded = False self.transcript_models = {} self.gene_models = defaultdict(list) self._tx_index = defaultdict(IntervalTree) def _add_to_utr_index(self, tm: TranscriptModel) -> None: self._tx_index[tm.chrom].add(Interval(tm.tx[0], tm.tx[1] + 1, tm))
[docs] def chrom_gene_models(self) -> Generator[ tuple[tuple[str, str], list[TranscriptModel]], None, None]: """Generate chromosome and gene name keys with transcript models.""" for chrom, interval_tree in self._tx_index.items(): gene_models: dict[ tuple[str, str], list[TranscriptModel]] = defaultdict(list) for interval in interval_tree: tm = cast(TranscriptModel, interval.data) assert chrom == tm.chrom gene_models[tm.chrom, tm.gene].append(tm) yield from gene_models.items()
def _update_indexes(self) -> None: """Update internal indexes.""" self.gene_models = defaultdict(list) self._tx_index = defaultdict(IntervalTree) for transcript in self.transcript_models.values(): self.gene_models[transcript.gene].append(transcript) self._add_to_utr_index(transcript)
[docs] def gene_names(self) -> list[str]: """Get list of all gene names in the loaded gene models. Returns: list[str]: List of gene names (symbols). Example: >>> gene_models.load() >>> genes = gene_models.gene_names() >>> print(f"Loaded {len(genes)} genes") """ if self.gene_models is None: logger.warning( "gene models %s are empty", self.resource.resource_id) return [] return list(self.gene_models.keys())
[docs] def gene_models_by_gene_name( self, name: str, ) -> list[TranscriptModel] | None: """Retrieve all transcript models for a specific gene. Args: name (str): The gene name/symbol to search for. Returns: list[TranscriptModel] | None: List of transcript models for the gene, or None if the gene is not found. Example: >>> transcripts = gene_models.gene_models_by_gene_name("BRCA1") >>> if transcripts: ... print(f"BRCA1 has {len(transcripts)} transcript variants") """ return self.gene_models.get(name, None)
[docs] def has_chromosome(self, chrom: str) -> bool: """Check if a chromosome has any gene models. Args: chrom (str): The chromosome name to check. Returns: bool: True if the chromosome has gene models, False otherwise. Example: >>> if gene_models.has_chromosome("chr1"): ... print("Chromosome 1 has gene annotations") """ return chrom in self._tx_index
[docs] def gene_models_by_location( self, chrom: str, pos_begin: int, pos_end: int | None = None, ) -> list[TranscriptModel]: """Retrieve transcripts overlapping a genomic position or region. This method uses an interval tree index for efficient querying of transcripts by genomic coordinates. Args: chrom (str): The chromosome name (e.g., "chr1", "17"). pos_begin (int): The start position (1-based, inclusive). pos_end (int | None): The end position (1-based, inclusive). If None, queries a single position. Returns: list[TranscriptModel]: List of TranscriptModel objects whose transcript regions overlap the query position/region. Returns empty list if no overlaps found. Example: >>> # Query single position >>> models = gene_models.gene_models_by_location("chr17", 7676592) >>> # Query region >>> models = gene_models.gene_models_by_location( ... "chr17", 7661779, 7687550 ... ) >>> for tm in models: ... print(f"{tm.gene}: {tm.tr_id}") Note: Positions are swapped automatically if pos_end < pos_begin. """ if chrom not in self._tx_index: return [] if pos_end is None: pos_end = pos_begin if pos_end < pos_begin: pos_begin, pos_end = pos_end, pos_begin tms_interval = self._tx_index[chrom] result = tms_interval.overlap(pos_begin, pos_end + 1) return [r.data for r in result]
[docs] def relabel_chromosomes( self, relabel: dict[str, str] | None = None, map_file: str | None = None, ) -> None: """Relabel chromosome names in all transcript models. This method is useful for converting between different chromosome naming conventions (e.g., "chr1" <-> "1"). Args: relabel (dict[str, str] | None): Mapping from old to new chromosome names. Either this or map_file must be provided. map_file (str | None): Path to file with chromosome mappings, one mapping per line (old_name new_name). Example: >>> # Using dict >>> gene_models.relabel_chromosomes({"1": "chr1", "2": "chr2"}) >>> # Using file >>> gene_models.relabel_chromosomes(map_file="chrom_map.txt") Note: Transcripts on chromosomes not in the mapping are removed. Internal indexes are rebuilt after relabeling. """ assert relabel or map_file if not relabel: assert map_file is not None with open(map_file) as infile: relabel = dict( line.strip("\n\r").split()[:2]for line in infile ) self.transcript_models = { tid: tm for tid, tm in self.transcript_models.items() if tm.chrom in relabel } for transcript_model in self.transcript_models.values(): transcript_model.chrom = relabel[transcript_model.chrom] self._update_indexes()
[docs] @staticmethod def get_schema() -> dict[str, Any]: return { **get_base_resource_schema(), "filename": {"type": "string"}, "format": {"type": "string"}, "gene_mapping": {"type": "string"}, }
[docs] def load(self) -> GeneModels: """Load gene models from the genomic resource. This method parses the gene model file and builds internal indexes for efficient querying. It is thread-safe and will only load once. Returns: GeneModels: Self, for method chaining. Example: >>> gene_models = build_gene_models_from_file("genes.gtf") >>> gene_models.load() >>> num_transcripts = len(gene_models.transcript_models) >>> print(f"Loaded {num_transcripts} transcripts") Note: Calling load() multiple times is safe - subsequent calls return immediately if already loaded. """ with self.__lock: if self._is_loaded: return self self._reset() self.transcript_models = load_transcript_models(self.resource) self._update_indexes() self._is_loaded = True return self
[docs] def is_loaded(self) -> bool: """Check if gene models have been loaded. Returns: bool: True if load() has been called and completed, False otherwise. Example: >>> if not gene_models.is_loaded(): ... gene_models.load() """ with self.__lock: return self._is_loaded
[docs] @staticmethod def join_gene_models(*gene_models: GeneModels) -> GeneModels: """Merge multiple gene models into a single GeneModels object. This combines transcript models from multiple sources into one unified gene models object. Args: *gene_models (GeneModels): Two or more GeneModels objects to merge. Returns: GeneModels: New GeneModels object containing all transcripts. Raises: ValueError: If fewer than 2 gene models provided. Example: >>> gm1 = build_gene_models_from_file("genes1.gtf") >>> gm2 = build_gene_models_from_file("genes2.gtf") >>> merged = GeneModels.join_gene_models(gm1, gm2) Note: Transcript IDs should be unique across all input gene models. """ if len(gene_models) < 2: raise ValueError("The function needs at least 2 arguments!") gm = GeneModels(gene_models[0].resource) gm._reset() gm.transcript_models = gene_models[0].transcript_models.copy() for i in gene_models[1:]: gm.transcript_models.update(i.transcript_models) gm._update_indexes() return gm
[docs] def create_regions_from_genes( gene_models: GeneModels, genes: list[str], regions: list[Region] | None, gene_regions_heuristic_cutoff: int = 20, gene_regions_heuristic_extend: int = 20000, ) -> list[Region] | None: """Produce a list of regions from given gene symbols. If given a list of regions, will merge the newly-created regions from the genes with the provided ones. """ assert genes is not None assert gene_models is not None if len(genes) == 0 or len(genes) > gene_regions_heuristic_cutoff: return regions gene_regions = [] for gene_name in genes: gene_model = gene_models.gene_models_by_gene_name(gene_name) if gene_model is None: logger.warning("gene model for %s not found", gene_name) continue for gm in gene_model: gene_regions.append( # noqa: PERF401 Region( gm.chrom, max(1, gm.tx[0] - 1 - gene_regions_heuristic_extend), gm.tx[1] + 1 + gene_regions_heuristic_extend, ), ) gene_regions = collapse(gene_regions) if not regions: regions = gene_regions or None else: result = [] for gene_region in gene_regions: for region in regions: intersection = gene_region.intersection(region) if intersection: result.append(intersection) result = collapse(result) logger.info("original regions: %s; result: %s", regions, result) regions = result return regions