Source code for dae.annotation.processing_pipeline

from __future__ import annotations

import abc
import itertools
import logging
from collections.abc import Sequence
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any

from dae.annotation.annotatable import Annotatable
from dae.annotation.annotation_pipeline import (
    AnnotationPipeline,
)
from dae.utils.processing_pipeline import Filter

logger = logging.getLogger(__name__)


[docs] @dataclass(repr=True) class Annotation: """ A pair of an annotatable and its relevant context. The context can hold any key/value pair relevant to the annotatable and is typically used to store the results of annotators. """ annotatable: Annotatable | None = field() context: dict[str, Any] = field(default_factory=dict)
[docs] @dataclass class AnnotationsWithSource: """ A pair of a list of Annotation instances and their source. The source is typically a variant read from some format, with the 'annotations' attribute corresponding to its alleles. """ source: Any annotations: list[Annotation]
[docs] class AnnotationsWithSourceFilter(Filter): """Base class for filters that work on AnnotationsWithSource objects.""" @abc.abstractmethod def _filter_annotation( self, annotation: Annotation, ) -> Annotation: ...
[docs] def filter( self, data: AnnotationsWithSource, ) -> AnnotationsWithSource: """Filter a single AnnotationsWithSource object.""" new_annotations = [self._filter_annotation(annotation) for annotation in data.annotations] return AnnotationsWithSource( annotations=new_annotations, source=data.source, )
[docs] class AnnotationsWithSourceBatchFilter(Filter): """Base class for filters that work on AnnotationsWithSource batches.""" @abc.abstractmethod def _filter_annotation_batch( self, batch: Sequence[Annotation], ) -> Sequence[Annotation]: ...
[docs] def filter( self, data: Sequence[AnnotationsWithSource], ) -> Sequence[AnnotationsWithSource]: """Filter a batch of AnnotationsWithSource objects.""" annotations_batch = list(itertools.chain.from_iterable( aws.annotations for aws in data )) new_annotations = self._filter_annotation_batch(annotations_batch) assert len(new_annotations) == len(annotations_batch) annotations_iter = iter(new_annotations) result: list[AnnotationsWithSource] = [] for aws in data: # pylint: disable=stop-iteration-return annos: list[Annotation] = [ next(annotations_iter) for _ in aws.annotations ] result.append( AnnotationsWithSource(annotations=annos, source=aws.source)) return result
[docs] class AnnotationPipelineContextManager(AbstractContextManager): """A context manager for annotation pipelines.""" def __init__(self, annotation_pipeline: AnnotationPipeline) -> None: self.annotation_pipeline = annotation_pipeline def __enter__(self) -> AnnotationPipelineContextManager: """Enter the context manager.""" self.annotation_pipeline.open() return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: self.annotation_pipeline.close() if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, exc_tb) return exc_type is None
[docs] class AnnotationPipelineAnnotatablesFilter( AnnotationsWithSourceFilter, AnnotationPipelineContextManager, ): """ Filter that annotates an AnnotationWithSource object using a pipeline. """ def _filter_annotation( self, annotation: Annotation, ) -> Annotation: result = self.annotation_pipeline.annotate( annotation.annotatable, context=annotation.context) return Annotation(annotatable=annotation.annotatable, context=result)
[docs] class AnnotationPipelineAnnotatablesBatchFilter( AnnotationsWithSourceBatchFilter, AnnotationPipelineContextManager, ): """ Filter that annotates an AnnotationWithSource batch using a pipeline. """ def _filter_annotation_batch( self, batch: Sequence[Annotation], ) -> Sequence[Annotation]: annotatable_batch, context_batch = [], [] for annotation in batch: annotatable_batch.append(annotation.annotatable) context_batch.append(annotation.context) annotations = self.annotation_pipeline.batch_annotate( annotatable_batch, contexts=context_batch) return [ Annotation(annotatable=annotatable, context=annotation) for annotatable, annotation in zip(annotatable_batch, annotations, strict=True) ]
[docs] class DeleteAttributesFromAWSFilter(Filter): """Filter to remove items from AWSs. Works in-place.""" def __init__(self, attributes_to_remove: Sequence[str]) -> None: self.to_remove = set(attributes_to_remove)
[docs] def filter(self, data: AnnotationsWithSource) -> AnnotationsWithSource: for attr in self.to_remove: del data.source[attr] return data
[docs] class DeleteAttributesFromAWSBatchFilter(Filter): """Filter to remove items from AWS batches. Works in-place.""" def __init__(self, attributes_to_remove: Sequence[str]) -> None: self._delete_filter = DeleteAttributesFromAWSFilter( attributes_to_remove)
[docs] def filter( self, data: Sequence[AnnotationsWithSource], ) -> Sequence[AnnotationsWithSource]: for aws in data: self._delete_filter.filter(aws) return data