Source code for dae.parquet.schema2.processing_pipeline

from __future__ import annotations

import itertools
import logging
from collections.abc import Iterable, Sequence
from types import TracebackType
from typing import cast

from dae.annotation.annotation_pipeline import (
    AnnotationPipeline,
)
from dae.annotation.processing_pipeline import (
    Annotation,
    AnnotationPipelineAnnotatablesBatchFilter,
    AnnotationPipelineAnnotatablesFilter,
    AnnotationsWithSource,
)
from dae.effect_annotation.effect import AlleleEffects
from dae.parquet.schema2.loader import ParquetLoader
from dae.utils.processing_pipeline import Filter, Source
from dae.utils.regions import Region
from dae.variants.variant import SummaryAllele, SummaryVariant
from dae.variants_loaders.raw.loader import (
    FullVariant,
    VariantsGenotypesLoader,
)

logger = logging.getLogger(__name__)


[docs] class AnnotationPipelineVariantsFilterMixin: """Mixin for annotation pipeline filters.""" # pylint: disable=too-few-public-methods def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: self.annotation_pipeline = annotation_pipeline self._annotation_internal_attributes = { attribute.name for attribute in self.annotation_pipeline.get_attributes() if attribute.internal } def _apply_annotation_to_allele( self, summary_allele: SummaryAllele, annotation: Annotation, ) -> None: if "allele_effects" in annotation.context: allele_effects = annotation.context["allele_effects"] assert isinstance(allele_effects, AlleleEffects) # pylint: disable=protected-access summary_allele._effects = allele_effects # noqa: SLF001 del annotation.context["allele_effects"] public_attributes = { key: value for key, value in annotation.context.items() if key not in self._annotation_internal_attributes } summary_allele.update_attributes(public_attributes)
[docs] class DeleteAttributesFromVariantFilter(Filter): """Filter to remove items from variants. Works in-place.""" def __init__(self, attributes_to_remove: Sequence[str]) -> None: self.to_remove = set(attributes_to_remove)
[docs] def filter( self, data: FullVariant, ) -> FullVariant: """Remove specified attributes from a variant.""" for allele in data.summary_variant.alt_alleles: for attr in self.to_remove: if attr in allele.attributes: del allele.attributes[attr] return data
[docs] class DeleteAttributesFromVariantsBatchFilter(Filter): """Filter to remove items from variant batches. Works in-place.""" def __init__(self, attributes_to_remove: Sequence[str]) -> None: self._delete_filter = DeleteAttributesFromVariantFilter( attributes_to_remove)
[docs] def filter( self, data: Sequence[FullVariant], ) -> Sequence[FullVariant]: """Remove specified attributes from batches of variants.""" for variant in data: self._delete_filter.filter(variant) return data
[docs] class AnnotationPipelineVariantsFilter( Filter, AnnotationPipelineVariantsFilterMixin, ): """Annotation pipeline batched variants filter.""" def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: super().__init__(annotation_pipeline) self.annotatables_filter = AnnotationPipelineAnnotatablesFilter( annotation_pipeline) def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: self.annotation_pipeline.close() if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return exc_type is None
[docs] def filter(self, data: FullVariant) -> FullVariant: variant = AnnotationsWithSource( source=data, annotations=[ Annotation(sa.get_annotatable(), dict(sa.attributes)) for sa in data.summary_variant.alt_alleles ], ) aws = self.annotatables_filter.filter(variant) full_variant = cast(FullVariant, aws.source) assert isinstance(full_variant.summary_variant, SummaryVariant) assert len(aws.annotations) == \ len(full_variant.summary_variant.alt_alleles) for summary_allele, annotation in zip( full_variant.summary_variant.alt_alleles, aws.annotations, strict=True): assert isinstance(summary_allele, SummaryAllele) self._apply_annotation_to_allele(summary_allele, annotation) return full_variant
[docs] class AnnotationPipelineVariantsBatchFilter( Filter, AnnotationPipelineVariantsFilterMixin, ): """Annotation pipeline batched variants filter.""" def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: super().__init__(annotation_pipeline) self.annotatables_filter = AnnotationPipelineAnnotatablesBatchFilter( annotation_pipeline) def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: self.annotation_pipeline.close() if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return exc_type is None
[docs] def filter( self, data: Sequence[FullVariant], ) -> Sequence[FullVariant]: """Filter variants in batches.""" variants = [ AnnotationsWithSource( source=v, annotations=[ Annotation(sa.get_annotatable(), dict(sa.attributes)) for sa in v.summary_variant.alt_alleles ], ) for v in data ] aws_batch = \ self.annotatables_filter.filter(variants) result: list[FullVariant] = [] for aws in aws_batch: full_variant = cast(FullVariant, aws.source) assert isinstance(full_variant.summary_variant, SummaryVariant) assert len(aws.annotations) == \ len(full_variant.summary_variant.alt_alleles) for summary_allele, annotation in zip( full_variant.summary_variant.alt_alleles, aws.annotations, strict=True): assert isinstance(summary_allele, SummaryAllele) self._apply_annotation_to_allele(summary_allele, annotation) result.append(full_variant) return result
[docs] class VariantsLoaderSource(Source): """A source that can fetch variants from a loader.""" def __init__(self, loader: VariantsGenotypesLoader) -> None: self.loader = loader
[docs] def fetch(self, region: Region | None = None) -> Iterable[FullVariant]: """Fetch full variants from a variant loader.""" yield from self.loader.fetch(region)
[docs] class VariantsLoaderBatchSource(Source): """A source that can fetch variants in batches from a loader.""" def __init__( self, loader: VariantsGenotypesLoader, batch_size: int = 500, ) -> None: self.loader = loader self.batch_size = batch_size
[docs] def fetch( self, region: Region | None = None, ) -> Iterable[Sequence[FullVariant]]: """Fetch full variants from a variant loader in batches.""" variants = self.loader.fetch(region) while batch := tuple(itertools.islice(variants, self.batch_size)): yield batch
[docs] class Schema2SummaryVariantsSource(Source): """Producer for summary variants from a Parquet dataset.""" def __init__(self, loader: ParquetLoader): self.loader = loader
[docs] def fetch( self, region: Region | None = None, ) -> Iterable[FullVariant]: assert self.loader is not None for sv in self.loader.fetch_summary_variants(region=region): yield FullVariant(sv, ())
[docs] class Schema2SummaryVariantsBatchSource(Source): """Producer for summary variants from a Parquet dataset.""" def __init__( self, loader: ParquetLoader, batch_size: int = 500, ) -> None: self.loader = loader self.batch_size = batch_size
[docs] def fetch( self, region: Region | None = None, ) -> Iterable[Sequence[FullVariant]]: """Fetch full variants from a variant loader in batches.""" variants = self.loader.fetch_summary_variants(region=region) while batch := tuple(itertools.islice(variants, self.batch_size)): yield [FullVariant(sv, ()) for sv in batch]