from __future__ import annotations
import abc
import logging
import traceback
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager
from types import TracebackType
from typing import Any
from gain.utils.regions import Region
logger = logging.getLogger(__name__)
[docs]
class Filter(AbstractContextManager):
"""Base class for all processing pipeline filters."""
[docs]
@abc.abstractmethod
def filter(self, data: Any) -> Any:
...
[docs]
class Source(AbstractContextManager):
"""Base class for all processing pipeline sources."""
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
return exc_type is None
[docs]
@abc.abstractmethod
def fetch(self, region: Region | None = None) -> Iterable[Any]:
...
[docs]
class PipelineProcessor(AbstractContextManager):
"""A processor that can be used to process variants in a pipeline."""
def __init__(self, source: Source, filters: Sequence[Filter]) -> None:
self.source = source
self.filters = filters
def __enter__(self) -> PipelineProcessor:
self.source.__enter__()
for variant_filter in self.filters:
variant_filter.__enter__()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
if exc_type is not None:
logger.error(
"exception during annotation: %s, %s, %s",
exc_type, exc_value, traceback.format_tb(exc_tb))
self.source.__exit__(exc_type, exc_value, exc_tb)
for variant_filter in self.filters:
variant_filter.__exit__(exc_type, exc_value, exc_tb)
return exc_type is None
[docs]
def process_region(self, region: Region | None = None) -> None:
for data in self.source.fetch(region):
for _filter in self.filters:
data = _filter.filter(data)
[docs]
def process(self, regions: Iterable[Region] | None = None) -> None:
"""Process a pipeline in batches for the given regions."""
if regions is None:
self.process_region(None)
return
for region in regions:
self.process_region(region)