"""Provides base class for annotators."""
from __future__ import annotations
import abc
import os
from collections.abc import Sequence
from pathlib import Path
from typing import Any
from gain.annotation.annotatable import Annotatable
from gain.annotation.annotation_config import (
AnnotatorInfo,
Attribute,
AttributeConfig,
ParamsUsageMonitor,
)
from gain.annotation.annotation_pipeline import (
AnnotationPipeline,
Annotator,
AttributeSpec,
)
from gain.genomic_resources.aggregators import (
Aggregator,
validate_aggregator,
)
[docs]
class AnnotatorBase(Annotator):
"""Base implementation of the `Annotator` class."""
def __init__(
self, pipeline: AnnotationPipeline | None,
info: AnnotatorInfo,
):
self.attribute_specs: dict[str, AttributeSpec] = {}
for source, spec in self.get_attribute_specs().items():
if isinstance(spec, AttributeSpec):
self.attribute_specs[source] = spec
else:
raise TypeError(
f"Invalid attribute spec for source '{source}'"
f" in annotator {info.type}")
if not info.attributes:
for source, spec in self.attribute_specs.items():
if spec.is_default:
defaults = self.get_attribute_defaults(spec)
info.attributes.append(AttributeConfig(
name=source,
source=source,
internal=None,
aggregator=defaults.get("aggregator"),
parameters={
k: v for k, v in defaults.items()
if k != "aggregator"
},
))
self._attributes: list[Attribute] = []
for attr_config in info.attributes:
if attr_config.source not in self.attribute_specs:
raise ValueError(
f"The attribute source '{attr_config.source}'"
" is not supported for the annotator"
f" {info.type}")
spec = self.attribute_specs[attr_config.source]
internal = (
attr_config.internal
if attr_config.internal is not None
else spec.internal_default
)
defaults = self.get_attribute_defaults(spec)
default_aggregator = defaults.get("aggregator")
parameters = ParamsUsageMonitor({
**{k: v for k, v in defaults.items() if k != "aggregator"},
**attr_config.parameters,
})
aggregator = (
attr_config.aggregator
if attr_config.aggregator is not None
else default_aggregator
)
attr = Attribute(
name=attr_config.name,
source=attr_config.source,
internal=internal,
aggregator=aggregator,
spec=spec,
parameters=parameters,
)
if aggregator is not None:
if spec is not None and not spec.supports_aggregation:
raise ValueError(
f"Attribute '{attr.source}' in annotator"
f" {info.type} does not support aggregation.")
validate_aggregator(
aggregator,
self._aggregator_value_type(attr),
)
attr.aggregator_instance = Aggregator.build(aggregator)
self._attributes.append(attr)
work_dir = info.parameters.get("work_dir")
if work_dir is None:
raise ValueError(
f"Missing a 'work_dir' parameter in annotator {info}.")
self.work_dir: Path = Path(work_dir)
super().__init__(pipeline, info)
@property
def attributes(self) -> list[Attribute]:
return self._attributes
def _aggregator_value_type(self, attr: Attribute) -> str | None:
return attr.spec.value_type if attr.spec else None
[docs]
def get_attribute_defaults(
self, spec: AttributeSpec, # noqa: ARG002
) -> dict[str, Any]:
return {}
[docs]
def open(self) -> Annotator:
super().open()
os.makedirs(self.work_dir, exist_ok=True)
return self
@abc.abstractmethod
def _do_annotate(self, annotatable: Annotatable, context: dict[str, Any]) \
-> dict[str, Any]:
"""Annotate the annotatable.
Internal abstract method used for annotation. It should produce
a source-keyed dict, one entry per configured attribute.
"""
def _apply_aggregators(
self, values: dict[str, Any],
) -> dict[str, Any]:
result = {}
for attr in self._attributes:
value = values.get(attr.source)
if attr.aggregator_instance is not None and isinstance(value, list):
result[attr.name] = attr.aggregator_instance.aggregate(value)
else:
result[attr.name] = value
return result
[docs]
def annotate(
self, annotatable: Annotatable | None, context: dict[str, Any],
) -> dict[str, Any]:
if annotatable is None:
values = self._empty_result()
else:
values = self._do_annotate(annotatable, context)
return self._apply_aggregators(values)
def _do_batch_annotate(
self,
annotatables: Sequence[Annotatable | None],
contexts: list[dict[str, Any]],
batch_work_dir: str | None = None, # noqa: ARG002
) -> list[dict[str, Any]]:
"""Annotate a batch of annotatables."""
return [
self._empty_result() if annotatable is None
else self._do_annotate(annotatable, context)
for annotatable, context in zip(annotatables, contexts, strict=True)
]
[docs]
def batch_annotate(
self,
annotatables: Sequence[Annotatable | None],
contexts: list[dict[str, Any]],
batch_work_dir: str | None = None,
) -> list[dict[str, Any]]:
inner_output = self._do_batch_annotate(
annotatables, contexts, batch_work_dir=batch_work_dir,
)
return [self._apply_aggregators(result) for result in inner_output]