"""Factory for creation of annotation pipeline."""
import logging
from collections import Counter
from collections.abc import Callable
from pathlib import Path
import yaml
from dae.annotation.annotation_config import (
AnnotationConfigParser,
AnnotationConfigurationError,
AnnotatorInfo,
RawPipelineConfig,
)
from dae.annotation.annotation_pipeline import (
AnnotationPipeline,
Annotator,
InputAnnotableAnnotatorDecorator,
ReannotationPipeline,
ValueTransformAnnotatorDecorator,
)
from dae.genomic_resources.repository import (
GenomicResourceRepo,
)
logger = logging.getLogger(__name__)
_ANNOTATOR_FACTORY_REGISTRY: dict[
str, Callable[[AnnotationPipeline, AnnotatorInfo], Annotator]] = {}
_EXTENTIONS_LOADED = False
def _load_annotator_factory_plugins() -> None:
# pylint: disable=global-statement
global _EXTENTIONS_LOADED
if _EXTENTIONS_LOADED:
return
# pylint: disable=import-outside-toplevel
from importlib_metadata import entry_points
discovered_entries = entry_points(group="dae.annotation.annotators")
for entry in discovered_entries:
annotator_type = entry.name
factory = entry.load()
if annotator_type in _ANNOTATOR_FACTORY_REGISTRY:
logger.warning(
"overwriting annotator type: %s", annotator_type)
_ANNOTATOR_FACTORY_REGISTRY[annotator_type] = factory
_EXTENTIONS_LOADED = True
[docs]
def get_annotator_factory(
annotator_type: str,
) -> Callable[[AnnotationPipeline, AnnotatorInfo], Annotator]:
"""Find and return a factory function for creation of an annotator type.
If the specified annotator type is not found, this function raises
`ValueError` exception.
:return: the annotator factory for the specified annotator type.
:raises ValueError: when can't find an annotator factory for the
specified annotator type.
"""
_load_annotator_factory_plugins()
if annotator_type not in _ANNOTATOR_FACTORY_REGISTRY:
raise ValueError(f"unsupported annotator type: {annotator_type}")
return _ANNOTATOR_FACTORY_REGISTRY[annotator_type]
[docs]
def get_available_annotator_types() -> list[str]:
"""Return the list of all registered annotator factory types."""
_load_annotator_factory_plugins()
return list(_ANNOTATOR_FACTORY_REGISTRY.keys())
[docs]
def register_annotator_factory(
annotator_type: str,
factory: Callable[[AnnotationPipeline, AnnotatorInfo], Annotator],
) -> None:
"""Register additional annotator factory.
By default all genotype storage factories should be registered at
`[dae.genotype_storage.factories]` extenstion point. All registered
factories are loaded automatically. This function should be used if you
want to bypass extension point mechanism and register addition genotype
storage factory programatically.
"""
_load_annotator_factory_plugins()
if annotator_type in _ANNOTATOR_FACTORY_REGISTRY:
logger.warning("overwriting annotator type: %s", annotator_type)
_ANNOTATOR_FACTORY_REGISTRY[annotator_type] = factory
[docs]
def load_pipeline_from_file(
raw_path: str, grr: GenomicResourceRepo, *,
allow_repeated_attributes: bool = False,
work_dir: Path | None = None,
) -> AnnotationPipeline:
"""Load an annotation pipeline from a configuration file."""
path = Path(raw_path)
if not path.exists():
raise OSError(f"{raw_path} does not exist!")
if path.suffix == ".yaml":
return load_pipeline_from_yaml(
path.read_text(), grr,
allow_repeated_attributes=allow_repeated_attributes,
work_dir=work_dir,
)
raise ValueError(f"Unsupported annotation config format {path.suffix}")
[docs]
def load_pipeline_from_yaml(
raw: str, grr: GenomicResourceRepo, *,
allow_repeated_attributes: bool = False,
work_dir: Path | None = None,
) -> AnnotationPipeline:
"""Load an annotation pipeline from a YAML-formatted string."""
config = yaml.safe_load(raw)
return build_annotation_pipeline(
config, grr,
allow_repeated_attributes=allow_repeated_attributes,
work_dir=work_dir,
)
[docs]
def build_annotation_pipeline(
config: RawPipelineConfig, grr: GenomicResourceRepo, *,
allow_repeated_attributes: bool = False,
work_dir: Path | None = None,
) -> AnnotationPipeline:
"""Build an annotation pipeline."""
preamble, pipeline_config = AnnotationConfigParser.parse_raw(
config, grr=grr,
)
pipeline = AnnotationPipeline(grr)
pipeline.preamble = preamble
pipeline.raw = config
try:
for idx, annotator_config in enumerate(pipeline_config):
params = annotator_config.parameters
if "work_dir" not in params:
if work_dir is not None:
params._data["work_dir"] = ( # noqa: SLF001
work_dir / f"A{idx}_{annotator_config.type}"
)
else:
params._data["work_dir"] = Path("./work") # noqa: SLF001
params._used_keys.add("work_dir") # noqa: SLF001
builder = get_annotator_factory(annotator_config.type)
annotator = builder(pipeline, annotator_config)
annotator = InputAnnotableAnnotatorDecorator.decorate(annotator)
annotator = ValueTransformAnnotatorDecorator.decorate(annotator)
check_for_unused_parameters(annotator_config)
check_for_repeated_attributes_in_annotator(annotator_config)
pipeline.add_annotator(annotator)
except ValueError as value_error:
raise AnnotationConfigurationError(
f"The {annotator_config.annotator_id} annotator"
f" configuration is incorrect: ",
value_error) from value_error
check_for_repeated_attributes_in_pipeline(
pipeline, allow_repeated_attributes=allow_repeated_attributes,
)
return pipeline
[docs]
def copy_annotation_pipeline(
pipeline: AnnotationPipeline,
) -> AnnotationPipeline:
"""Copy an annotation pipeline instance."""
return build_annotation_pipeline(pipeline.raw, pipeline.repository)
[docs]
def copy_reannotation_pipeline(
pipeline: ReannotationPipeline,
) -> ReannotationPipeline:
"""Copy a reannotation pipeline instance."""
return ReannotationPipeline(
copy_annotation_pipeline(pipeline.pipeline_new),
copy_annotation_pipeline(pipeline.pipeline_old),
)
[docs]
def check_for_repeated_attributes_in_annotator(
annotator_config: AnnotatorInfo,
) -> None:
"""Check for repeated attributes in annotator configuration."""
annotator_names_list = [att.name for att in annotator_config.attributes]
annotator_names_set = set(annotator_names_list)
if len(annotator_names_set) < len(annotator_names_list):
repeated_annotator_names = ",".join(sorted(
[att for att, cnt in Counter(annotator_names_list).items()
if cnt > 1]))
raise ValueError("The annotator has repeated attributes: "
f"{repeated_annotator_names}")
[docs]
def check_for_repeated_attributes_in_pipeline(
pipeline: AnnotationPipeline, *, allow_repeated_attributes: bool = False,
) -> None:
"""Check for repeated attributes in pipeline configuration."""
pipeline_names_set = Counter(att.name for att in pipeline.get_attributes())
repeated_attributes = {
att for att, cnt in Counter(pipeline_names_set).items() if cnt > 1
}
if not repeated_attributes:
return
if allow_repeated_attributes:
resolve_repeated_attributes(pipeline, repeated_attributes)
return
overlaps: dict[str, list[str]] = {}
# reversed so that it follows the order of the pipeline config
for annotator in reversed(pipeline.annotators):
annotator_id = annotator.get_info().annotator_id
for attr in annotator.attributes:
if attr.name in repeated_attributes:
overlaps.setdefault(attr.name, []).append(annotator_id)
raise AnnotationConfigurationError(
f"Repeated attributes in pipeline were found - {overlaps}",
)
[docs]
def resolve_repeated_attributes(
pipeline: AnnotationPipeline, repeated_attributes: set[str],
) -> None:
"""Resolve repeated attributes in pipeline configuration via renaming."""
for rep in repeated_attributes:
for annotator in pipeline.annotators:
for attribute in annotator.attributes:
if attribute.name == rep:
attribute.name = \
f"{attribute.name}_{annotator.get_info().annotator_id}"
[docs]
def check_for_unused_parameters(info: AnnotatorInfo) -> None:
"""Check annotator configuration for unused parameters."""
unused_annotator_parameters = info.parameters.get_unused_keys()
if unused_annotator_parameters:
raise ValueError("The are unused annotator parameters: "
f"{unused_annotator_parameters}")
for att in info.attributes:
unused_params = att.parameters.get_unused_keys()
if unused_params:
raise ValueError("There are unused annotator attribute "
f"parameters: {','.join(sorted(unused_params))}")