Source code for gain.utils.processing_pipeline

from __future__ import annotations

import abc
import logging
import traceback
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager
from types import TracebackType
from typing import Any

from gain.utils.regions import Region

logger = logging.getLogger(__name__)


[docs] class Filter(AbstractContextManager): """Base class for all processing pipeline filters."""
[docs] @abc.abstractmethod def filter(self, data: Any) -> Any: ...
[docs] class Source(AbstractContextManager): """Base class for all processing pipeline sources.""" def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: return exc_type is None
[docs] @abc.abstractmethod def fetch(self, region: Region | None = None) -> Iterable[Any]: ...
[docs] class PipelineProcessor(AbstractContextManager): """A processor that can be used to process variants in a pipeline.""" def __init__(self, source: Source, filters: Sequence[Filter]) -> None: self.source = source self.filters = filters def __enter__(self) -> PipelineProcessor: self.source.__enter__() for variant_filter in self.filters: variant_filter.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: if exc_type is not None: logger.error( "exception during annotation: %s, %s, %s", exc_type, exc_value, traceback.format_tb(exc_tb)) self.source.__exit__(exc_type, exc_value, exc_tb) for variant_filter in self.filters: variant_filter.__exit__(exc_type, exc_value, exc_tb) return exc_type is None
[docs] def process_region(self, region: Region | None = None) -> None: for data in self.source.fetch(region): for _filter in self.filters: data = _filter.filter(data)
[docs] def process(self, regions: Iterable[Region] | None = None) -> None: """Process a pipeline in batches for the given regions.""" if regions is None: self.process_region(None) return for region in regions: self.process_region(region)