"""Provides annotation pipeline class."""
from __future__ import annotations
import abc
import itertools
import logging
from collections.abc import Callable, Iterable, Sequence
from types import TracebackType
from typing import Any
from dae.annotation.annotatable import Annotatable
from dae.annotation.annotation_config import (
AnnotationPreamble,
AnnotatorInfo,
AttributeInfo,
RawPipelineConfig,
)
from dae.genomic_resources.genomic_context import (
GenomicContext,
PriorityGenomicContext,
SimpleGenomicContext,
get_genomic_context,
)
from dae.genomic_resources.reference_genome import (
build_reference_genome_from_resource,
)
from dae.genomic_resources.repository import (
GenomicResource,
GenomicResourceRepo,
)
logger = logging.getLogger(__name__)
GC_GRR_KEY = "genomic_resources_repository"
GC_REFERENCE_GENOME_KEY = "reference_genome"
GC_GENE_MODELS_KEY = "gene_models"
_AnnotationDependencyGraph = dict[
AnnotatorInfo, list[tuple[AnnotatorInfo, AttributeInfo]],
]
def _build_dependency_graph(
pipeline: AnnotationPipeline,
) -> _AnnotationDependencyGraph:
"""Make dependency graph for an annotation pipeline."""
graph: _AnnotationDependencyGraph = {}
for annotator in pipeline.annotators:
annotator_info = annotator.get_info()
graph[annotator_info] = _get_dependencies_for(annotator, pipeline)
return graph
def _get_dependencies_for(
annotator: Annotator,
pipeline: AnnotationPipeline,
) -> list[tuple[AnnotatorInfo, AttributeInfo]]:
"""Get all dependencies for a given annotator."""
result: list[tuple[AnnotatorInfo, AttributeInfo]] = []
used_attrs = annotator.used_context_attributes
for attr in used_attrs:
attr_info = pipeline.get_attribute_info(attr)
assert attr_info is not None
upstream_annotator = \
pipeline.get_annotator_by_attribute_info(attr_info)
assert upstream_annotator is not None
result.append((upstream_annotator.get_info(), attr_info))
if upstream_annotator.used_context_attributes:
result.extend(_get_dependencies_for(upstream_annotator, pipeline))
return result
def _get_rerun_annotators(
pipeline: AnnotationPipeline,
annotators_new: Iterable[AnnotatorInfo],
) -> set[AnnotatorInfo]:
"""Get all annotators that must be re-run for reannotation."""
result: set[AnnotatorInfo] = set()
dependency_graph = _build_dependency_graph(pipeline)
for dependent, dependencies in dependency_graph.items():
if dependent in annotators_new:
for dependency, dep_attr in dependencies:
if dep_attr.internal:
result.add(dependency)
else:
for dependency, _ in dependencies:
if dependency in annotators_new:
result.add(dependent)
break
return result
def _get_deleted_attributes(
pipeline_current: AnnotationPipeline,
pipeline_previous: AnnotationPipeline,
*,
full_reannotation: bool = False,
) -> list[str]:
"""Get a list of attributes that are deleted in the new annotation."""
infos_new = pipeline_current.get_info()
infos_old = pipeline_previous.get_info()
if full_reannotation is True:
return [attr.name for info in infos_old for attr in info.attributes]
result: list[str] = []
for deleted_info in [i for i in infos_old if i not in infos_new]:
result.extend([attr.name for attr in deleted_info.attributes
if not attr.internal])
return result
[docs]
class Annotator(abc.ABC):
"""Annotator provides a set of attrubutes for a given Annotatable."""
def __init__(self, pipeline: AnnotationPipeline | None,
info: AnnotatorInfo):
self.pipeline = pipeline
self._info = info
self._is_open = False
[docs]
def get_info(self) -> AnnotatorInfo:
return self._info
[docs]
@abc.abstractmethod
def annotate(
self, annotatable: Annotatable | None, context: dict[str, Any],
) -> dict[str, Any]:
"""Produce annotation attributes for an annotatable."""
[docs]
def batch_annotate(
self, annotatables: Sequence[Annotatable | None],
contexts: list[dict[str, Any]],
batch_work_dir: str | None = None, # noqa: ARG002
) -> Iterable[dict[str, Any]]:
return itertools.starmap(
self.annotate, zip(annotatables, contexts, strict=True),
)
[docs]
def close(self) -> None:
self._is_open = False
[docs]
def open(self) -> Annotator:
self._is_open = True
return self
[docs]
def is_open(self) -> bool:
return self._is_open
@property
def resources(self) -> list[GenomicResource]:
return self._info.resources
@property
def resource_ids(self) -> set[str]:
return {resource.get_id() for resource in self._info.resources}
@property
def attributes(self) -> list[AttributeInfo]:
return self._info.attributes
@property
def used_context_attributes(self) -> tuple[str, ...]:
return ()
def _empty_result(self) -> dict[str, Any]:
return {attribute_info.name: None
for attribute_info in self._info.attributes}
[docs]
class AnnotationPipeline:
"""Provides annotation pipeline abstraction."""
def __init__(self, repository: GenomicResourceRepo):
self.repository: GenomicResourceRepo = repository
self.annotators: list[Annotator] = []
self.preamble: AnnotationPreamble | None = None
self.raw: RawPipelineConfig = []
self._is_open = False
[docs]
def build_pipeline_genomic_context(self) -> GenomicContext:
"""Create a genomic context from the pipeline parameters."""
registered_context = get_genomic_context()
genome = None
if self.preamble is not None:
genome_res = self.preamble.input_reference_genome_res
if genome_res is not None:
genome = build_reference_genome_from_resource(genome_res)
pipeline_context = SimpleGenomicContext({
GC_GRR_KEY: self.repository,
GC_REFERENCE_GENOME_KEY: genome,
}, ("pipeline_context",))
return PriorityGenomicContext([pipeline_context, registered_context])
[docs]
def get_info(self) -> list[AnnotatorInfo]:
return [annotator.get_info() for annotator in self.annotators]
[docs]
def get_attributes(self) -> list[AttributeInfo]:
return [attribute_info for annotator in self.annotators for
attribute_info in annotator.attributes]
[docs]
def get_attribute_info(
self, attribute_name: str) -> AttributeInfo | None:
for annotator in self.annotators:
for attribute_info in annotator.get_info().attributes:
if attribute_info.name == attribute_name:
return attribute_info
return None
[docs]
def get_resource_ids(self) -> set[str]:
return {r_id for annotator in self.annotators
for r_id in annotator.resource_ids}
[docs]
def get_annotator_by_attribute_info(
self, attribute_info: AttributeInfo,
) -> Annotator | None:
for annotator in self.annotators:
if attribute_info in annotator.attributes:
return annotator
return None
[docs]
def add_annotator(self, annotator: Annotator) -> None:
assert isinstance(annotator, Annotator)
self.annotators.append(annotator)
[docs]
def annotate(
self, annotatable: Annotatable | None,
context: dict | None = None,
) -> dict:
"""Apply all annotators to an annotatable."""
if not self._is_open:
self.open()
if context is None:
context = {}
for annotator in self.annotators:
attributes = annotator.annotate(annotatable, context)
context.update(attributes)
return context
[docs]
def batch_annotate(
self, annotatables: Sequence[Annotatable | None],
contexts: list[dict] | None = None,
batch_work_dir: str | None = None,
) -> list[dict]:
"""Apply all annotators to a list of annotatables."""
if not self._is_open:
self.open()
if contexts is None:
contexts = [{} for _ in annotatables]
for annotator in self.annotators:
attributes_list = annotator.batch_annotate(
annotatables, contexts,
batch_work_dir=batch_work_dir,
)
for context, attributes in zip(
contexts, attributes_list, strict=True,
):
context.update(attributes)
return contexts
[docs]
def open(self) -> AnnotationPipeline:
"""Open all annotators in the pipeline and mark it as open."""
if self._is_open:
logger.warning("annotation pipeline is already open")
return self
assert not self._is_open
for annotator in self.annotators:
annotator.open()
self._is_open = True
return self
[docs]
def close(self) -> None:
"""Close the annotation pipeline."""
logger.info("closing annotation pipeline")
for annotator in self.annotators:
try:
annotator.close()
except Exception: # pylint: disable=broad-except
logger.exception(
"exception while closing annotator %s",
annotator.get_info())
self.repository = None # type: ignore
self._is_open = False
[docs]
def print(self) -> None:
"""Print the annotation pipeline."""
print("NEW ATTRIBUTES -")
for anno in self.annotators:
for attr in anno.attributes:
print(" +", attr.name)
def __enter__(self) -> AnnotationPipeline:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_tb: TracebackType | None) -> bool:
if exc_type is not None:
logger.error(
"exception during annotation: %s, %s, %s",
exc_type, exc_value, exc_tb)
self.close()
return exc_type is None
[docs]
class ReannotationPipeline(AnnotationPipeline):
"""Provides functionality for reannotation."""
def __init__(
self,
pipeline_new: AnnotationPipeline,
pipeline_previous: AnnotationPipeline,
*,
full_reannotation: bool = False,
):
super().__init__(pipeline_new.repository)
self.pipeline_new = pipeline_new
self.annotators: list[Annotator] = []
infos_current = pipeline_new.get_info()
infos_previous = pipeline_previous.get_info()
infos_new: set[AnnotatorInfo] = {
i for i in infos_current
if i not in infos_previous
}
infos_rerun = _get_rerun_annotators(pipeline_new, infos_new)
for annotator in pipeline_new.annotators:
info = annotator.get_info()
if info in infos_new or info in infos_rerun:
self.annotators.append(annotator)
self.deleted_attributes = _get_deleted_attributes(
pipeline_new, pipeline_previous,
full_reannotation=full_reannotation)
[docs]
def get_attributes(self) -> list[AttributeInfo]:
return self.pipeline_new.get_attributes()
[docs]
class AnnotatorDecorator(Annotator):
"""Defines annotator decorator base class."""
def __init__(self, child: Annotator):
super().__init__(child.pipeline, child.get_info())
self.child = child
[docs]
def close(self) -> None:
self.child.close()
[docs]
def open(self) -> Annotator:
return self.child.open()
[docs]
def is_open(self) -> bool:
return self.child.is_open()
def __getattr__(self, name: str) -> Any:
return getattr(self.child, name)