Source code for gain.annotation.simple_effect_annotator

import logging
import textwrap
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any

from gain.annotation.annotatable import Annotatable
from gain.annotation.annotation_config import (
    AnnotatorInfo,
)
from gain.annotation.annotation_pipeline import (
    AnnotationPipeline,
    Annotator,
    AttributeSpec,
)
from gain.annotation.annotator_base import AnnotatorBase
from gain.genomic_resources.gene_models import (
    GeneModels,
    TranscriptModel,
    build_gene_models_from_resource,
)
from gain.genomic_resources.genomic_context import get_genomic_context
from gain.utils.regions import Region

logger = logging.getLogger(__name__)


[docs] def build_simple_effect_annotator( pipeline: AnnotationPipeline, info: AnnotatorInfo, ) -> Annotator: return SimpleEffectAnnotator(pipeline, info)
[docs] @dataclass(eq=True, frozen=True, repr=True, unsafe_hash=True) class SimpleEffect: effect_type: str transcript_id: str gene: str
[docs] class SimpleEffectAnnotator(AnnotatorBase): """Simple effect annotator class."""
[docs] @staticmethod def effect_types() -> list[str]: return [ "coding", "inter-coding_intronic", "peripheral", "noncoding", "intergenic", ]
[docs] def get_attribute_specs(self) -> dict[str, AttributeSpec]: gene_lists: dict[str, AttributeSpec] = {} for effect in SimpleEffectAnnotator.effect_types()[:-1]: source_gl = f"{effect}_gene_list" source_ge = f"{effect}_genes" gene_lists[source_gl] = AttributeSpec( source=source_gl, value_type="object", description=f"list of genes with {effect} effect.", internal_default=False, is_default=False, attribute_type="gene_list", ) gene_lists[source_ge] = AttributeSpec( source=source_ge, value_type="str", description=( f"comma separated list of genes with {effect} effect."), internal_default=False, is_default=False, supports_aggregation=False, ) return { "worst_effect": AttributeSpec( source="worst_effect", value_type="str", description="The worst effect.", internal_default=False, is_default=True, supports_aggregation=False, ), "worst_effect_genes": AttributeSpec( source="worst_effect_genes", value_type="str", description="comma separated list of genes with worst effect.", internal_default=False, is_default=True, supports_aggregation=False, ), "worst_effect_gene_list": AttributeSpec( source="worst_effect_gene_list", value_type="object", description="list of genes with worst effect.", internal_default=False, is_default=False, attribute_type="gene_list", ), "gene_list": AttributeSpec( source="gene_list", value_type="object", description="List of all affected genes.", internal_default=True, is_default=True, attribute_type="gene_list", ), "genes": AttributeSpec( source="genes", value_type="str", description="Comma separated list of all affected genes.", internal_default=False, is_default=False, supports_aggregation=False, ), "gene_effects": AttributeSpec( source="gene_effects", value_type="str", description="list of gene:effect pairs.", internal_default=False, is_default=False, supports_aggregation=False, ), "effect_details": AttributeSpec( source="effect_details", value_type="str", description="list of transcript:gene:effect tuples.", internal_default=False, is_default=False, supports_aggregation=False, ), **gene_lists, }
[docs] def get_attribute_defaults( self, spec: AttributeSpec, # noqa: ARG002 ) -> dict[str, Any]: return {}
def __init__(self, pipeline: AnnotationPipeline, info: AnnotatorInfo): gene_models_resrouce_id = info.parameters.get("gene_models") if gene_models_resrouce_id is None: gene_models = get_genomic_context().get_gene_models() if gene_models is None: raise ValueError( f"Can't create {info.type}: " "gene model resource are missing in config " "and context") else: resource = pipeline.repository.get_resource( gene_models_resrouce_id) gene_models = build_gene_models_from_resource(resource) assert isinstance(gene_models, GeneModels) info.documentation += textwrap.dedent(f""" Simple effect annotator. <a href="{self.BASE_DOC_URL}#simple-effect-annotator" target="_blank">More info</a> """) # noqa info.resources.append(gene_models.resource) super().__init__(pipeline, info) self.gene_models = gene_models
[docs] def open(self) -> Annotator: self.gene_models.load() return super().open()
def _do_annotate( self, annotatable: Annotatable, context: dict[str, Any], # noqa: ARG002 ) -> dict[str, Any]: assert annotatable is not None annotation = self.run_annotate( annotatable.chrom, annotatable.position, annotatable.end_position) result: dict[str, Any] = {} gene_list: set[str] = set() gene_effects: list[tuple[str, str]] = [] worst_effect = None worst_effect_gene_list: list[str] = [] details: list[str] = [] for effect_type in self.effect_types(): simple_effects = annotation.get(effect_type) if simple_effects is not None: genes = {se.gene for se in simple_effects} if worst_effect is None: worst_effect = effect_type worst_effect_gene_list = sorted(genes) result[effect_type + "_gene_list"] = sorted(genes) result[effect_type + "_genes"] = ",".join(sorted(genes)) gene_list = gene_list | genes gene_effects.extend( (gene, effect_type) for gene in sorted(genes)) details.extend( f"{se.transcript_id}:{se.gene}:{se.effect_type}" for se in sorted( simple_effects, key=lambda x: (x.transcript_id))) result["gene_list"] = sorted(gene_list) result["genes"] = ",".join(sorted(gene_list)) result["worst_effect"] = worst_effect result["worst_effect_genes"] = ",".join(worst_effect_gene_list) result["worst_effect_gene_list"] = worst_effect_gene_list result["gene_effects"] = "|".join( f"{gene}:{effect}" for gene, effect in gene_effects) result["effect_details"] = "|".join(details) return result
[docs] def cds_intron_regions( self, transcript: TranscriptModel, ) -> list[Region]: """Return whether region is CDS intron.""" regions: list[Region] = [] if not transcript.is_coding(): return regions for index in range(len(transcript.exons) - 1): beg = transcript.exons[index].stop + 1 end = transcript.exons[index + 1].start - 1 if beg > transcript.cds[0] and end < transcript.cds[1]: regions.append(Region(transcript.chrom, beg, end)) return regions
[docs] def cds_regions(self, transcript: TranscriptModel) -> Sequence[Region]: """Return whether the region is classified as coding.""" return transcript.cds_regions()
[docs] def peripheral_regions(self, transcript: TranscriptModel) -> list[Region]: """Return whether the region is peripheral.""" region: list[Region] = [] if not transcript.is_coding(): return region regions = transcript.utr5_regions() regions.extend(transcript.utr3_regions()) if transcript.cds[0] > transcript.tx[0]: region.append( Region( transcript.chrom, transcript.tx[0], transcript.cds[0] - 1)) if transcript.cds[1] < transcript.tx[1]: region.append( Region( transcript.chrom, transcript.cds[1] + 1, transcript.tx[1])) return region
[docs] def noncoding_regions(self, transcript: TranscriptModel) -> list[Region]: """Return whether the region is noncoding.""" if transcript.is_coding(): return [] return [ Region( transcript.chrom, transcript.tx[0], transcript.tx[1])]
[docs] def call_region( self, chrom: str, beg: int, end: int, tx: TranscriptModel, *, func_name: str, classification: str, ) -> SimpleEffect | None: """Call a region with a specific classification.""" regions = getattr(self, func_name)(tx) for region in regions: assert region.chrom == chrom if region.stop >= beg and region.start <= end: return SimpleEffect( effect_type=classification, transcript_id=tx.tr_id, gene=tx.gene) return None
@staticmethod def _classification() -> list[tuple[str, str]]: return [ ("coding", "cds_regions"), ("inter-coding_intronic", "cds_intron_regions"), ("peripheral", "peripheral_regions"), ("noncoding", "noncoding_regions"), ]
[docs] def run_annotate( self, chrom: str, beg: int, end: int, ) -> dict[str, set[SimpleEffect]]: """Return classification with a set of affected genes.""" assert self.gene_models.is_loaded() tms = self.gene_models.gene_models_by_location(chrom, beg, end) if len(tms) == 0: return {"intergenic": set()} assert all((beg <= t.tx[1] and end >= t.tx[0]) for t in tms) result: dict[str, set[SimpleEffect]] = defaultdict(set) for tx in tms: for effect_type, func_name in self._classification(): effect = self.call_region( chrom, beg, end, tx, func_name=func_name, classification=effect_type, ) if effect: result[effect.effect_type].add(effect) break return dict(result)