"""Provides annotation pipeline class."""
from __future__ import annotations
import abc
import itertools
import logging
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any
from dae.annotation.annotatable import Annotatable
from dae.annotation.annotation_config import (
AnnotationPreamble,
AnnotatorInfo,
AttributeInfo,
RawPipelineConfig,
)
from dae.genomic_resources.repository import (
GenomicResource,
GenomicResourceRepo,
)
logger = logging.getLogger(__name__)
_AnnotationDependencyGraph = dict[
AnnotatorInfo, list[tuple[AnnotatorInfo, AttributeInfo]],
]
def _build_dependency_graph(
pipeline: AnnotationPipeline,
) -> _AnnotationDependencyGraph:
"""Make dependency graph for an annotation pipeline."""
graph: _AnnotationDependencyGraph = {}
for annotator in pipeline.annotators:
annotator_info = annotator.get_info()
graph[annotator_info] = _get_dependencies_for(annotator, pipeline)
return graph
def _get_dependencies_for(
annotator: Annotator,
pipeline: AnnotationPipeline,
) -> list[tuple[AnnotatorInfo, AttributeInfo]]:
"""Get all dependencies for a given annotator."""
result: list[tuple[AnnotatorInfo, AttributeInfo]] = []
used_attrs = annotator.used_context_attributes
for attr in used_attrs:
attr_info = pipeline.get_attribute_info(attr)
assert attr_info is not None
upstream_annotator = \
pipeline.get_annotator_by_attribute_info(attr_info)
assert upstream_annotator is not None
result.append((upstream_annotator.get_info(), attr_info))
if upstream_annotator.used_context_attributes:
result.extend(_get_dependencies_for(upstream_annotator, pipeline))
return result
def _get_rerun_annotators(
pipeline: AnnotationPipeline,
annotators_new: Iterable[AnnotatorInfo],
) -> set[AnnotatorInfo]:
"""Get all annotators that must be re-run for reannotation."""
result: set[AnnotatorInfo] = set()
dependency_graph = _build_dependency_graph(pipeline)
for dependent, dependencies in dependency_graph.items():
if dependent in annotators_new:
for dependency, dep_attr in dependencies:
if dep_attr.internal:
result.add(dependency)
else:
for dependency, _ in dependencies:
if dependency in annotators_new:
result.add(dependent)
break
return result
def _get_deleted_attributes(
pipeline_current: AnnotationPipeline,
pipeline_previous: AnnotationPipeline,
*,
full_reannotation: bool = False,
) -> list[str]:
"""Get a list of attributes that are deleted in the new annotation."""
infos_new = pipeline_current.get_info()
infos_old = pipeline_previous.get_info()
if full_reannotation is True:
return [attr.name for info in infos_old for attr in info.attributes]
result: list[str] = []
for deleted_info in [i for i in infos_old if i not in infos_new]:
result.extend([attr.name for attr in deleted_info.attributes
if not attr.internal])
return result
[docs]
@dataclass
class AttributeDesc:
"""Holds default attribute configuration for annotators."""
name: str
type: str
description: str
default: bool = True
internal: bool = False
params: dict[str, Any] = field(default_factory=dict)
[docs]
class Annotator(abc.ABC):
"""Annotator provides a set of attrubutes for a given Annotatable."""
def __init__(self, pipeline: AnnotationPipeline | None,
info: AnnotatorInfo):
self.pipeline = pipeline
self._info = info
self._is_open = False
[docs]
def get_info(self) -> AnnotatorInfo:
return self._info
[docs]
@abc.abstractmethod
def annotate(
self, annotatable: Annotatable | None, context: dict[str, Any],
) -> dict[str, Any]:
"""Produce annotation attributes for an annotatable."""
[docs]
def batch_annotate(
self, annotatables: Sequence[Annotatable | None],
contexts: list[dict[str, Any]],
batch_work_dir: str | None = None, # noqa: ARG002
) -> Iterable[dict[str, Any]]:
return itertools.starmap(
self.annotate, zip(annotatables, contexts, strict=True),
)
[docs]
def close(self) -> None:
self._is_open = False
[docs]
def open(self) -> Annotator:
self._is_open = True
return self
[docs]
def is_open(self) -> bool:
return self._is_open
@property
def resources(self) -> list[GenomicResource]:
return self._info.resources
@property
def resource_ids(self) -> set[str]:
return {resource.get_id() for resource in self._info.resources}
@property
def attributes(self) -> list[AttributeInfo]:
return self._info.attributes
@property
def used_context_attributes(self) -> tuple[str, ...]:
return ()
def _empty_result(self) -> dict[str, Any]:
return {attribute_info.name: None
for attribute_info in self._info.attributes}
[docs]
@abc.abstractmethod
def get_all_attribute_descriptions(self) -> dict[str, AttributeDesc]:
"""Get descriptions of all attributes provided by the annotator."""
[docs]
class AnnotationPipeline:
"""Provides annotation pipeline abstraction."""
def __init__(self, repository: GenomicResourceRepo):
self.repository: GenomicResourceRepo = repository
self.annotators: list[Annotator] = []
self.preamble: AnnotationPreamble | None = None
self.raw: RawPipelineConfig = []
self._is_open = False
[docs]
def get_info(self) -> list[AnnotatorInfo]:
return [annotator.get_info() for annotator in self.annotators]
[docs]
def get_attributes(self) -> list[AttributeInfo]:
return [attribute_info for annotator in self.annotators for
attribute_info in annotator.attributes]
[docs]
def get_attribute_info(
self, attribute_name: str) -> AttributeInfo | None:
for annotator in self.annotators:
for attribute_info in annotator.get_info().attributes:
if attribute_info.name == attribute_name:
return attribute_info
return None
[docs]
def get_resource_ids(self) -> set[str]:
return {r_id for annotator in self.annotators
for r_id in annotator.resource_ids}
[docs]
def get_annotator_by_attribute_info(
self, attribute_info: AttributeInfo,
) -> Annotator | None:
for annotator in self.annotators:
if attribute_info in annotator.attributes:
return annotator
return None
[docs]
def add_annotator(self, annotator: Annotator) -> None:
assert isinstance(annotator, Annotator)
self.annotators.append(annotator)
[docs]
def annotate(
self, annotatable: Annotatable | None,
context: dict | None = None,
) -> dict:
"""Apply all annotators to an annotatable."""
if not self._is_open:
self.open()
if context is None:
context = {}
for annotator in self.annotators:
attributes = annotator.annotate(annotatable, context)
context.update(attributes)
return context
[docs]
def batch_annotate(
self, annotatables: Sequence[Annotatable | None],
contexts: list[dict] | None = None,
batch_work_dir: str | None = None,
) -> list[dict]:
"""Apply all annotators to a list of annotatables."""
if not self._is_open:
self.open()
if contexts is None:
contexts = [{} for _ in annotatables]
for annotator in self.annotators:
attributes_list = annotator.batch_annotate(
annotatables, contexts,
batch_work_dir=batch_work_dir,
)
for context, attributes in zip(
contexts, attributes_list, strict=True,
):
context.update(attributes)
return contexts
[docs]
def open(self) -> AnnotationPipeline:
"""Open all annotators in the pipeline and mark it as open."""
if self._is_open:
logger.warning("annotation pipeline is already open")
return self
assert not self._is_open
for annotator in self.annotators:
annotator.open()
self._is_open = True
return self
[docs]
def close(self) -> None:
"""Close the annotation pipeline."""
logger.info("closing annotation pipeline")
for annotator in self.annotators:
try:
annotator.close()
except Exception: # pylint: disable=broad-except
logger.exception(
"exception while closing annotator %s",
annotator.get_info())
self._is_open = False
[docs]
def print(self) -> None:
"""Print the annotation pipeline."""
print("NEW ATTRIBUTES -")
for anno in self.annotators:
for attr in anno.attributes:
print(" +", attr.name)
def __enter__(self) -> AnnotationPipeline:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_tb: TracebackType | None) -> bool:
if exc_type is not None:
logger.error(
"exception during annotation: %s, %s, %s",
exc_type, exc_value, exc_tb)
self.close()
return exc_type is None
[docs]
class ReannotationPipeline(AnnotationPipeline):
"""Provides functionality for reannotation."""
def __init__(
self,
pipeline_new: AnnotationPipeline,
pipeline_previous: AnnotationPipeline,
*,
full_reannotation: bool = False,
):
super().__init__(pipeline_new.repository)
self.pipeline_new = pipeline_new
self.annotators: list[Annotator] = []
infos_current = pipeline_new.get_info()
infos_previous = pipeline_previous.get_info()
infos_new: set[AnnotatorInfo] = {
i for i in infos_current
if i not in infos_previous
}
infos_rerun = _get_rerun_annotators(pipeline_new, infos_new)
for annotator in pipeline_new.annotators:
info = annotator.get_info()
if info in infos_new or info in infos_rerun:
self.annotators.append(annotator)
self.deleted_attributes = _get_deleted_attributes(
pipeline_new, pipeline_previous,
full_reannotation=full_reannotation)
[docs]
def get_attributes(self) -> list[AttributeInfo]:
return self.pipeline_new.get_attributes()
[docs]
class AnnotatorDecorator(Annotator):
"""Defines annotator decorator base class."""
def __init__(self, child: Annotator):
super().__init__(child.pipeline, child.get_info())
self.child = child
[docs]
def get_all_attribute_descriptions(self) -> dict[str, AttributeDesc]:
return self.child.get_all_attribute_descriptions()
[docs]
def close(self) -> None:
self.child.close()
[docs]
def open(self) -> Annotator:
return self.child.open()
[docs]
def is_open(self) -> bool:
return self.child.is_open()
def __getattr__(self, name: str) -> Any:
return getattr(self.child, name)