import textwrap
from collections.abc import Callable
from typing import Any
from lark import Lark, Token, Tree
from gain.annotation.annotatable import Annotatable
from gain.annotation.annotation_config import (
AnnotationConfigurationError,
AnnotatorInfo,
)
from gain.annotation.annotation_pipeline import (
AnnotationPipeline,
Annotator,
AttributeSpec,
)
from gain.annotation.annotator_base import AnnotatorBase
from gain.genomic_resources.genomic_scores import CNV, CnvCollection
[docs]
def build_cnv_collection_annotator(pipeline: AnnotationPipeline,
info: AnnotatorInfo) -> Annotator:
return CnvCollectionAnnotator(pipeline, info)
[docs]
class CnvCollectionAnnotator(AnnotatorBase):
"""CNV collection annotator class."""
CNV_FILTER_GRAMMAR = textwrap.dedent("""
?start: filter | and_ | or
and_: filter "and" filter
or: filter "or" filter
?filter: subject operator subject | or | and_
?subject: variable | value
value: "\\"" word "\\"" | number
variable: word
operator: equals | greater_than | less_than | in
equals: "=="
greater_than: ">"
less_than: "<"
in: "in"
word: /[a-zA-Z!@#$%^&*()_+]+/
number: /[0-9\\.]+/
%ignore " "
""")
def __init__(self, pipeline: AnnotationPipeline, info: AnnotatorInfo):
cnv_collection_resrouce_id = info.parameters.get("resource_id")
if cnv_collection_resrouce_id is None:
raise ValueError(f"Can't create {info.type}: "
"no resrouce_id parameter.")
resource = pipeline.repository.get_resource(cnv_collection_resrouce_id)
self.cnv_collection = CnvCollection(resource)
info.resources.append(resource)
self.filter_parser = Lark(self.CNV_FILTER_GRAMMAR)
self.cnv_filter = None
cnv_filter_str = info.parameters.get("cnv_filter")
if cnv_filter_str is not None:
assert isinstance(cnv_filter_str, str)
cnv_filter_str = cnv_filter_str.replace(
"\n", " ").replace("\t", " ").strip()
try:
self.cnv_filter = self._build_cnv_filter_func(
self.filter_parser.parse(cnv_filter_str))
except Exception as e:
raise AnnotationConfigurationError(
f"Error parsing cnv_filter: {e}") from e
super().__init__(pipeline, info)
for attr in self._attributes:
spec = self.attribute_specs[attr.source]
score_def = self.cnv_collection\
.get_score_definition(attr.source)
if score_def is not None:
attr._documentation = f"""
{spec.description}
small values: {score_def.small_values_desc},
large_values: {score_def.large_values_desc}
aggregator: {attr.aggregator}
""" # noqa: SLF001
[docs]
def get_attribute_specs(self) -> dict[str, AttributeSpec]:
attributes: dict[str, AttributeSpec] = {
"count": AttributeSpec(
source="count",
value_type="int",
description="The number of CNVs overlapping with the "
"annotatable.",
),
}
for score_id, score_def in \
self.cnv_collection.score_definitions.items():
attributes[score_id] = AttributeSpec(
source=score_id,
value_type=score_def.value_type,
description=score_def.desc,
is_default=False,
)
return attributes
[docs]
def get_attribute_defaults(
self, spec: AttributeSpec,
) -> dict[str, Any]:
score_def = self.cnv_collection.get_score_definition(spec.source)
if score_def is not None:
return {"aggregator": score_def.allele_aggregator}
return {}
@classmethod
def _build_cnv_filter_func(
cls, tree: Tree,
) -> Callable[[CNV], bool]:
if tree.data == "and_":
assert isinstance(tree.children[0], Tree)
assert isinstance(tree.children[1], Tree)
left_func = cls._build_cnv_filter_func(tree.children[0])
right_func = cls._build_cnv_filter_func(tree.children[1])
return lambda cnv: left_func(cnv) and right_func(cnv)
if tree.data == "or":
left_func = cls._build_cnv_filter_func(tree.children[0])
right_func = cls._build_cnv_filter_func(tree.children[1])
return lambda cnv: left_func(cnv) or right_func(cnv)
left = tree.children[0]
assert isinstance(left, Tree)
assert isinstance(left.data, Token)
left_type = left.data.value
if left_type == "variable":
assert isinstance(left.children[0], Tree)
assert isinstance(left.children[0].data, Token)
assert left.children[0].data.value == "word"
assert isinstance(left.children[0].children[0], Token)
left_value = left.children[0].children[0].value
def left_accessor(_cnv: CNV) -> Any:
return _cnv.attributes.get(left_value)
else:
assert isinstance(left.children[0], Tree)
assert isinstance(left.children[0].data, Token)
is_number = left.children[0].data.value == "number"
assert isinstance(left.children[0].children[0], Token)
left_value = left.children[0].children[0].value
if is_number:
left_value = float(left_value)
def left_accessor(
_cnv: CNV) -> Any: # pylint: disable=unused-argument
return left_value
assert isinstance(tree.children[1], Tree)
assert isinstance(tree.children[1].children[0], Tree)
assert isinstance(tree.children[1].children[0].data, Token)
operator = tree.children[1].children[0].data.value
right = tree.children[2]
assert isinstance(right, Tree)
assert isinstance(right.data, Token)
right_type = right.data.value
if right_type == "variable":
assert isinstance(right.children[0], Tree)
assert isinstance(right.children[0].data, Token)
assert right.children[0].data.value == "word"
assert isinstance(right.children[0].children[0], Token)
right_value = right.children[0].children[0].value
def right_accessor(_cnv: CNV) -> Any:
return _cnv.attributes.get(right_value)
else:
assert isinstance(right.children[0], Tree)
assert isinstance(right.children[0].data, Token)
is_number = right.children[0].data.value == "number"
assert isinstance(right.children[0].children[0], Token)
right_value = right.children[0].children[0].value
if is_number:
right_value = float(right_value)
def right_accessor(
_cnv: CNV) -> Any: # pylint: disable=unused-argument
return right_value
if operator == "equals":
return lambda cnv: left_accessor(cnv) == right_accessor(cnv)
if operator == "greater_than":
return lambda cnv: left_accessor(cnv) > right_accessor(cnv)
if operator == "less_than":
return lambda cnv: left_accessor(cnv) < right_accessor(cnv)
if operator == "in":
return lambda cnv: left_accessor(cnv) in right_accessor(cnv)
raise ValueError(f"Unsupported operator {operator.data}")
[docs]
def open(self) -> Annotator:
self.cnv_collection.open()
super().open()
return self
[docs]
def close(self) -> None:
self.cnv_collection.close()
super().close()
def _do_annotate(
self, annotatable: Annotatable,
context: dict[str, Any], # noqa: ARG002
) -> dict[str, Any]:
cnvs = self.cnv_collection.fetch_cnvs(
annotatable.chrom, annotatable.pos, annotatable.pos_end)
if self.cnv_filter:
cnvs = [cnv for cnv in cnvs if self.cnv_filter(cnv)]
raw: dict[str, list] = {
attr.source: []
for attr in self._attributes
if attr.aggregator is not None
}
for cnv in cnvs:
for source in raw:
raw[source].append(cnv.attributes[source])
result: dict[str, Any] = {}
for attr in self._attributes:
if attr.source in raw:
result[attr.source] = raw[attr.source]
else:
result[attr.source] = len(cnvs)
return result