Source code for dae.parquet.schema2.processing_pipeline

from __future__ import annotations

import abc
import itertools
import logging
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager
from dataclasses import dataclass
from types import TracebackType
from typing import Any, cast

from dae.annotation.annotatable import Annotatable
from dae.annotation.annotation_pipeline import (
    AnnotationPipeline,
)
from dae.effect_annotation.effect import AlleleEffects
from dae.utils.regions import Region
from dae.variants.variant import SummaryAllele, SummaryVariant
from dae.variants_loaders.raw.loader import (
    FullVariant,
    VariantsGenotypesLoader,
)

logger = logging.getLogger(__name__)


[docs] class VariantsSource(AbstractContextManager):
[docs] @abc.abstractmethod def fetch(self, region: Region | None = None) -> Iterable[FullVariant]: """Fetch variants."""
[docs] class VariantsConsumer(AbstractContextManager): """A terminator for variant processing pipelines."""
[docs] @abc.abstractmethod def consume_one(self, full_variant: FullVariant) -> None: """Consume a single variant."""
[docs] def consume(self, variants: Iterable[FullVariant]) -> None: """Consume variants.""" for full_variant in variants: self.consume_one(full_variant)
[docs] class VariantsFilter(AbstractContextManager): """A filter that can be used to filter variants."""
[docs] def filter(self, variants: Iterable[FullVariant]) -> Iterable[FullVariant]: """Filter variants.""" for full_variant in variants: yield self.filter_one(full_variant)
[docs] @abc.abstractmethod def filter_one( self, full_variant: FullVariant, ) -> FullVariant: """Filter a single variant."""
[docs] class VariantsBatchSource(AbstractContextManager): """A source that can fetch variants in batches."""
[docs] @abc.abstractmethod def fetch_batches( self, region: Region | None = None, ) -> Iterable[Sequence[FullVariant]]: """Fetch variants in batches."""
[docs] class VariantsBatchConsumer(AbstractContextManager): """A sink that can write variants in batches."""
[docs] @abc.abstractmethod def consume_batch( self, batch: Sequence[FullVariant], ) -> None: """Consume a single batch of variants."""
[docs] def consume_batches( self, batches: Iterable[Sequence[FullVariant]], ) -> None: """Consume variants in batches.""" for batch in batches: self.consume_batch(batch)
[docs] class VariantsBatchFilter(AbstractContextManager): """A filter that can filter variants in batches."""
[docs] @abc.abstractmethod def filter_batch( self, batch: Sequence[FullVariant], ) -> Sequence[FullVariant]: """Filter variants in a single batch."""
[docs] def filter_batches( self, batches: Iterable[Sequence[FullVariant]], ) -> Iterable[Sequence[FullVariant]]: """Filter variants in batches.""" for batch in batches: yield self.filter_batch(batch)
[docs] @dataclass(repr=True) class Annotation: """An annotatable with annotations.""" annotatable: Annotatable | None annotations: dict[str, Any]
[docs] @dataclass(repr=True) class AnnotatablesWithContext: annotatables: list[Annotatable | None] context: Any
[docs] @dataclass class AnnotationsWithContext: annotations: list[Annotation] context: Any
[docs] class AnnotatablesFilter(AbstractContextManager): """A filter that can filter annotatables."""
[docs] @abc.abstractmethod def filter_one( self, annotatable: Annotatable | None, ) -> Annotation: """Filter a single annotatable."""
[docs] def filter( self, annotatables: Iterable[Annotatable | None], ) -> Iterable[Annotation]: """Filter annotatables.""" for annotatable in annotatables: yield self.filter_one(annotatable)
[docs] def filter_one_with_context( self, annotatables: AnnotatablesWithContext, ) -> AnnotationsWithContext: annotations = list(self.filter(annotatables.annotatables)) return AnnotationsWithContext( annotations=annotations, context=annotatables.context, )
[docs] def filter_with_context( self, annotatables_with_context: Iterable[AnnotatablesWithContext], ) -> Iterable[AnnotationsWithContext]: """Filter annotatables with context.""" for annotatables in annotatables_with_context: yield self.filter_one_with_context(annotatables)
[docs] class AnnotatablesBatchFilter(AbstractContextManager): """A filter that can filter annotatables in batches."""
[docs] @abc.abstractmethod def filter_batch( self, batch: Sequence[Annotatable | None], ) -> Sequence[Annotation]: """Filter annotatables in a single batch."""
[docs] def filter_batches( self, batches: Iterable[Sequence[Annotatable | None]], ) -> Iterable[Sequence[Annotation]]: """Filter annotatables in batches.""" for batch in batches: annotations = list(self.filter_batch(batch)) yield annotations
[docs] def filter_batch_with_context( self, batch_with_context: Sequence[AnnotatablesWithContext], ) -> Sequence[AnnotationsWithContext]: """Filter a single batch of annotatables with context.""" annotatables_batch = list(itertools.chain.from_iterable( awc.annotatables for awc in batch_with_context )) annotations = self.filter_batch(annotatables_batch) assert len(annotations) == len(annotatables_batch) annotations_iter = iter(annotations) result: list[AnnotationsWithContext] = [] for awc in batch_with_context: # pylint: disable=stop-iteration-return annos: list[Annotation] = [ next(annotations_iter) for _ in awc.annotatables ] result.append( AnnotationsWithContext( annotations=annos, context=awc.context)) return result
[docs] def filter_batches_with_context( self, batches_with_context: Iterable[Sequence[AnnotatablesWithContext]], ) -> Iterable[Sequence[AnnotationsWithContext]]: """Filter annotatables with context in batches.""" for batch_with_context in batches_with_context: yield self.filter_batch_with_context(batch_with_context)
[docs] class AnnotationPipelineContextManager(AbstractContextManager): """A context manager for annotation pipelines.""" def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: self.annotation_pipeline = annotation_pipeline def __enter__(self) -> AnnotationPipelineContextManager: """Enter the context manager.""" self.annotation_pipeline.open() return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: self.annotation_pipeline.close() if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return exc_type is not None
[docs] class AnnotationPipelineAnnotatablesFilter( AnnotatablesFilter, AnnotationPipelineContextManager): """A filter that can filter annotatables in batches."""
[docs] def filter_one( self, annotatable: Annotatable | None, ) -> Annotation: """Filter annotatables.""" annotations = self.annotation_pipeline.annotate(annotatable) return Annotation(annotatable=annotatable, annotations=annotations)
[docs] class AnnotationPipelineAnnotatablesBatchFilter( AnnotatablesBatchFilter, AnnotationPipelineContextManager): """A filter that can filter annotatables in batches."""
[docs] def filter_batch( self, batch: Sequence[Annotatable | None], ) -> Sequence[Annotation]: """Filter annotatables in a single batch.""" annotations = self.annotation_pipeline.batch_annotate(batch) return [ Annotation(annotatable=annotatable, annotations=annotation) for annotatable, annotation in zip(batch, annotations, strict=True) ]
[docs] class AnnotationPipelineVariantsFilterMixin: """Mixin for annotation pipeline filters.""" # pylint: disable=too-few-public-methods def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: self.annotation_pipeline = annotation_pipeline self._annotation_internal_attributes = { attribute.name for attribute in self.annotation_pipeline.get_attributes() if attribute.internal } def _apply_annotation_to_allele( self, summary_allele: SummaryAllele, annotation: Annotation, ) -> None: if "allele_effects" in annotation.annotations: allele_effects = annotation.annotations["allele_effects"] assert isinstance(allele_effects, AlleleEffects) # pylint: disable=protected-access summary_allele._effects = allele_effects # noqa: SLF001 del annotation.annotations["allele_effects"] public_attributes = { key: value for key, value in annotation.annotations.items() if key not in self._annotation_internal_attributes } summary_allele.update_attributes(public_attributes)
[docs] class AnnotationPipelineVariantsFilter( VariantsFilter, AnnotationPipelineVariantsFilterMixin): """Annotation pipeline batched variants filter.""" def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: super().__init__(annotation_pipeline) self.annotatables_filter = AnnotationPipelineAnnotatablesFilter( annotation_pipeline)
[docs] def filter_one( self, full_variant: FullVariant, ) -> FullVariant: annotatables = AnnotatablesWithContext( annotatables=[ sa.get_annotatable() for sa in full_variant.summary_variant.alt_alleles], context=full_variant, ) awc = self.annotatables_filter.filter_one_with_context( annotatables) full_variant = cast(FullVariant, awc.context) assert isinstance(full_variant.summary_variant, SummaryVariant) assert len(awc.annotations) == \ len(full_variant.summary_variant.alt_alleles) for summary_allele, annotation in zip( full_variant.summary_variant.alt_alleles, awc.annotations, strict=True): assert isinstance(summary_allele, SummaryAllele) self._apply_annotation_to_allele(summary_allele, annotation) return full_variant
def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: self.annotation_pipeline.close() if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return exc_type is not None
[docs] class AnnotationPipelineVariantsBatchFilter( VariantsBatchFilter, AnnotationPipelineVariantsFilterMixin): """Annotation pipeline batched variants filter.""" def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: super().__init__(annotation_pipeline) self.annotatables_filter = AnnotationPipelineAnnotatablesBatchFilter( annotation_pipeline)
[docs] def filter_batch( self, batch: Sequence[FullVariant], ) -> Sequence[FullVariant]: """Filter variants in batches.""" annotatables_with_context = [ AnnotatablesWithContext( annotatables=[ sa.get_annotatable() for sa in v.summary_variant.alt_alleles], context=v, ) for v in batch ] annotations_with_context = \ self.annotatables_filter.filter_batch_with_context( annotatables_with_context) result: list[FullVariant] = [] for awc in annotations_with_context: full_variant = cast(FullVariant, awc.context) assert isinstance(full_variant.summary_variant, SummaryVariant) assert len(awc.annotations) == \ len(full_variant.summary_variant.alt_alleles) for summary_allele, annotation in zip( full_variant.summary_variant.alt_alleles, awc.annotations, strict=True): assert isinstance(summary_allele, SummaryAllele) self._apply_annotation_to_allele(summary_allele, annotation) result.append(full_variant) return result
def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: self.annotation_pipeline.close() if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return True
[docs] class VariantsLoaderSource(VariantsSource): """A source that can fetch variants from a loader.""" def __init__(self, loader: VariantsGenotypesLoader) -> None: self.loader = loader
[docs] def fetch(self, region: Region | None = None) -> Iterable[FullVariant]: """Fetch full variants from a variant loader.""" yield from self.loader.fetch(region)
def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: return exc_type is not None
[docs] class VariantsLoaderBatchSource(VariantsBatchSource): """A source that can fetch variants in batches from a loader.""" def __init__( self, loader: VariantsGenotypesLoader, batch_size: int = 500, ) -> None: self.loader = loader self.batch_size = batch_size
[docs] def fetch_batches( self, region: Region | None = None, ) -> Iterable[Sequence[FullVariant]]: """Fetch full variants from a variant loader in batches.""" variants = self.loader.fetch(region) while batch := tuple( itertools.islice(variants, self.batch_size)): yield batch
def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return exc_type is not None
[docs] class VariantsPipelineProcessor(AbstractContextManager): """A processor that can be used to process variants in a pipeline.""" def __init__( self, source: VariantsSource, filters: Sequence[VariantsFilter], consumer: VariantsConsumer, ) -> None: self.source = source self.filters = filters self.consumer = consumer
[docs] def process_region(self, region: Region | None = None) -> None: for full_variant in self.source.fetch(region): for variant_filter in self.filters: full_variant = variant_filter.filter_one(full_variant) self.consumer.consume_one(full_variant)
[docs] def process(self, regions: Iterable[Region] | None = None) -> None: """Process variants in batches for the given regions.""" if regions is None: self.process_region(None) return for region in regions: self.process_region(region)
def __enter__(self) -> VariantsPipelineProcessor: """Enter the context manager.""" self.source.__enter__() for variant_filter in self.filters: variant_filter.__enter__() self.consumer.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) self.source.__exit__(exc_type, exc_value, exc_tb) for variant_filter in self.filters: variant_filter.__exit__(exc_type, exc_value, exc_tb) self.consumer.__exit__(exc_type, exc_value, exc_tb) return exc_type is not None
[docs] class VariantsBatchPipelineProcessor: """A processor that can be used to process variants in a pipeline.""" def __init__( self, source: VariantsBatchSource, filters: Sequence[VariantsBatchFilter], consumer: VariantsBatchConsumer, ) -> None: self.source = source self.filters = filters self.consumer = consumer
[docs] def process_region(self, region: Region | None = None) -> None: for batch in self.source.fetch_batches(region): for variant_filter in self.filters: batch = variant_filter.filter_batch(batch) self.consumer.consume_batch(batch)
[docs] def process(self, regions: Iterable[Region] | None = None) -> None: """Process variants in batches for the given regions.""" if regions is None: self.process_region(None) return for region in regions: self.process_region(region)
def __enter__(self) -> VariantsBatchPipelineProcessor: """Enter the context manager.""" self.source.__enter__() for variant_filter in self.filters: variant_filter.__enter__() self.consumer.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None) -> bool: if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) self.source.__exit__(exc_type, exc_value, exc_tb) for variant_filter in self.filters: variant_filter.__exit__(exc_type, exc_value, exc_tb) self.consumer.__exit__(exc_type, exc_value, exc_tb) return exc_type is not None