from __future__ import annotations
import json
import logging
import os
from collections import defaultdict
from collections.abc import Collection, Iterable, Sequence
from itertools import product
from typing import Any
from dae.gene_sets.denovo_gene_sets_config import (
DenovoGeneSetsConfig,
DGSCQuery,
RecurrencyCriteria,
parse_denovo_gene_sets_study_config,
parse_dgsc_query,
)
from dae.gene_sets.gene_sets_db import GeneSet
from dae.pedigrees.family import Person
from dae.person_sets import (
PersonSetCollection,
)
from dae.studies.study import GenotypeData
from dae.variants.attributes import Inheritance
logger = logging.getLogger(__name__)
[docs]
class DenovoGeneSetCollection:
"""Class representing a study's denovo gene sets."""
def __init__(
self,
study_id: str,
study_name: str,
dgsc_config: DenovoGeneSetsConfig,
pscs: dict[str, PersonSetCollection],
) -> None:
self.study_id = study_id
self.study_name = study_name
self.config: DenovoGeneSetsConfig = dgsc_config
self.recurrency_criteria = \
self.config.recurrency
self.gene_sets_ids = \
self.config.gene_sets_ids
self.pscs = pscs
self.cache: dict[str, Any] = self._build_empty_cache()
self._gene_sets_types_legend: \
dict[str, Sequence[Collection[str]]] | None = None
[docs]
def add_gene(
self,
gene_effects: list[tuple[str, str]],
persons: list[Person],
) -> None:
"""Add a gene to the cache."""
for psc_id, person_set_collection in self.pscs.items():
if psc_id not in self.cache:
self.cache[psc_id] = {}
for ps_id, person_set in person_set_collection.person_sets.items():
for person in persons:
if person.fpid not in person_set.persons:
continue
for gene_symbol, effect in gene_effects:
self._cache_update(
psc_id, ps_id, gene_symbol, effect, person,
)
def _build_empty_cache(self) -> dict[str, Any]:
cache: dict[str, Any] = {}
for psc_id in self.config.selected_person_set_collections:
cache[psc_id] = {}
for ps_id in self.pscs[psc_id].person_sets:
cache[psc_id][ps_id] = {}
ps_cache = cache[psc_id][ps_id]
for effect_criteria, sex_critera in product(
self.config.effect_types.values(),
self.config.sexes.values()):
if effect_criteria.name not in ps_cache:
ps_cache[effect_criteria.name] = {}
effect_cache = ps_cache[effect_criteria.name]
if sex_critera.name not in effect_cache:
effect_cache[sex_critera.name] = {}
return cache
def _cache_update(
self,
psc_id: str,
ps_id: str,
gene_symbol: str,
effect: str,
person: Person,
) -> None:
"""Update the cache with a gene."""
ps_cache = self.cache[psc_id][ps_id]
for effect_criteria, sex_critera in product(
self.config.effect_types.values(),
self.config.sexes.values()):
if effect not in effect_criteria.effects:
continue
if person.sex not in sex_critera.sexes:
continue
effect_cache = ps_cache[effect_criteria.name]
sex_cache = effect_cache[sex_critera.name]
if gene_symbol not in sex_cache:
sex_cache[gene_symbol] = set()
sex_cache[gene_symbol].add(person.family_id)
[docs]
@staticmethod
def create_empty_collection(
study: GenotypeData,
) -> DenovoGeneSetCollection | None:
"""Create an empty denovo gene set collection for a genotype data."""
config = study.config
assert config is not None, study.study_id
dgsc_config = parse_denovo_gene_sets_study_config(
study.config)
if dgsc_config is None:
logger.info(
"No denovo gene sets defined %s", study.study_id)
return None
person_set_collections = {
psc_id: psc
for psc_id, psc in study.person_set_collections.items()
if psc_id in dgsc_config.selected_person_set_collections
}
return DenovoGeneSetCollection(
study.study_id,
study.name,
dgsc_config,
person_set_collections,
)
[docs]
@staticmethod
def build_collection(
genotype_data: GenotypeData,
) -> DenovoGeneSetCollection | None:
"""Generate a denovo gene set collection for a study."""
dgsc = DenovoGeneSetCollection.create_empty_collection(genotype_data)
if dgsc is None:
return None
assert dgsc is not None
effect_types = [
e
for etc in dgsc.config.effect_types.values()
for e in etc.effects
]
variants = genotype_data.query_variants(
effect_types=effect_types,
inheritance=["denovo"])
for fv in variants:
for fa in fv.family_alt_alleles:
persons = []
for index, person_id in enumerate(fa.variant_in_members):
if person_id is None:
continue
inheritance = fa.inheritance_in_members[index]
if inheritance != Inheritance.denovo:
continue
person = fa.family.persons[person_id]
persons.append(person)
if not persons:
continue
effect = fa.effects
if effect is None:
continue
gene_effects = [
(gene.symbol, gene.effect)
for gene in effect.genes
if gene.symbol is not None and gene.effect is not None
]
assert all(
ge[0] is not None and ge[1] is not None
for ge in gene_effects)
assert all(p is not None for p in persons)
dgsc.add_gene(gene_effects, persons)
return dgsc
@staticmethod
def _cache_file(
psc_id: str,
cache_dir: str,
) -> str:
"""Return the path to the cache file for a person set collection."""
return os.path.join(
cache_dir,
f"denovo-cache-{psc_id}.json",
)
@classmethod
def _convert_cache_innermost_types(
cls, cache: Any,
from_type: type,
to_type: type, *,
sort_values: bool = False,
) -> Any:
"""
Coerce the types of all values in a dictionary matching a given type.
This is done recursively.
"""
if isinstance(cache, from_type):
if sort_values is True:
return sorted(to_type(cache))
return to_type(cache)
assert isinstance(
cache, dict,
), f"expected type 'dict', got '{type(cache)}'"
res = {}
for key, value in cache.items():
res[key] = cls._convert_cache_innermost_types(
value, from_type, to_type, sort_values=sort_values,
)
return res
[docs]
def save(self, cache_dir: str) -> None:
"""Save the denovo gene set collection to a cache files."""
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
for psc_id in self.config.selected_person_set_collections:
cache_file = self._cache_file(psc_id, cache_dir)
content = self.cache[psc_id]
content = self._convert_cache_innermost_types(content, set, list)
with open(cache_file, "w") as outfile:
json.dump(content, outfile)
[docs]
def load(self, cache_dir: str) -> None:
"""Load cached denovo gene set collection from a cache files."""
for psc_id in self.config.selected_person_set_collections:
cache_file = self._cache_file(psc_id, cache_dir)
if not os.path.exists(cache_file):
continue
with open(cache_file, "r") as infile:
cache = json.load(infile)
self.cache[psc_id] = self._convert_cache_innermost_types(
cache, list, set,
)
[docs]
def get_gene_sets_types_legend(self) -> dict[str, Any]:
"""Return dict with legends for each collection."""
if self._gene_sets_types_legend is None:
name = self.study_name or self.study_id
person_set_collections = [
{
"personSetCollectionId": collection_id,
"personSetCollectionName": person_set_collection.name,
"personSetCollectionLegend":
self.get_person_set_collection_legend(collection_id),
}
for collection_id, person_set_collection
in self.pscs.items()
]
self._gene_sets_types_legend = {
"datasetId": self.study_id,
"datasetName": name,
"personSetCollections": person_set_collections,
}
return self._gene_sets_types_legend
[docs]
def get_person_set_collection_legend(
self, psc_id: str,
) -> list[dict[str, Any]]:
"""Return the domain (used as a legend) of a person set collection."""
# This could probably be removed, it just takes each domain
# and returns a dict with a subset of the original keys
person_set_collection = self.pscs.get(psc_id)
if person_set_collection is not None:
return person_set_collection.legend_json()
return []
[docs]
def get_gene_set(
self,
dgsc_query: str | DGSCQuery,
) -> GeneSet | None:
"""Return a gene set from the collection."""
if isinstance(dgsc_query, str):
dgsc_query = parse_dgsc_query(dgsc_query, self.config)
assert isinstance(dgsc_query, DGSCQuery)
if dgsc_query.gene_set_id not in self.gene_sets_ids:
raise ValueError(
f"Invalid gene set id: {dgsc_query.gene_set_id}")
if dgsc_query.psc_id not in self.config.selected_person_set_collections:
raise ValueError(
f"Invalid person set collection id: {dgsc_query.psc_id}")
result: dict[str, set[str]] = defaultdict(set)
psc_cache = self.cache[dgsc_query.psc_id]
for keys in product(
dgsc_query.selected_person_sets,
dgsc_query.effects,
dgsc_query.sex,
):
innermost_cache = psc_cache[keys[0]][keys[1].name][keys[2].name]
for gene, families in innermost_cache.items():
result[gene].update(families)
if dgsc_query.recurrency is not None:
result = self._apply_recurrency(result, dgsc_query.recurrency)
return GeneSet(
name=str(dgsc_query),
desc=str(dgsc_query),
syms=list(result.keys()),
)
[docs]
@classmethod
def get_all_gene_sets(
cls, denovo_gene_sets: list[DenovoGeneSetCollection],
denovo_gene_set_spec: dict[str, dict[str, list[str]]],
) -> list[dict[str, Any]]:
"""Return all gene sets from provided denovo gene set collections."""
sets = [
cls.get_gene_set_from_collections(
name,
denovo_gene_sets,
denovo_gene_set_spec)
for name in cls._get_gene_sets_names(denovo_gene_sets)
]
return list(filter(None, sets))
[docs]
@classmethod
def get_gene_set_from_collections(
cls, gene_set_id: str,
denovo_gene_set_collections: list[DenovoGeneSetCollection],
denovo_gene_set_spec: dict[str, dict[str, list[str]]],
) -> dict[str, Any] | None:
"""Return a single set from provided denovo gene set collections."""
syms = cls._get_gene_set_syms(
gene_set_id,
denovo_gene_set_collections,
denovo_gene_set_spec)
if not syms:
return None
return {
"name": gene_set_id,
"count": len(syms),
"syms": syms,
"desc": f"{gene_set_id} "
f"({cls._format_description(denovo_gene_set_spec)})",
}
@classmethod
def _get_gene_set_syms(
cls, gene_set_id: str,
denovo_gene_set_collections: list[DenovoGeneSetCollection],
denovo_gene_set_spec: dict[str, dict[str, list[str]]],
) -> set[str]:
"""
Return symbols of all genes in a given gene set.
Collect the symbols of all genes belonging to a
given gene set from a number of denovo gene set collections,
while filtering by the supplied spec.
"""
criteria = set(gene_set_id.split("."))
recurrency_criteria = cls._get_common_recurrency_criteria(
denovo_gene_set_collections,
)
recurrency_criteria_names = criteria & set(recurrency_criteria.keys())
standard_criteria = criteria - recurrency_criteria_names
genes_families: dict[str, set[str]] = {}
for dataset_id, person_set_collection in denovo_gene_set_spec.items():
for (
person_set_collection_id,
person_set_collection_values,
) in person_set_collection.items():
for value in person_set_collection_values:
all_criteria = standard_criteria.union(
(dataset_id, person_set_collection_id, value),
)
genes_to_families = cls._get_genes_to_families(
cls._get_cache(denovo_gene_set_collections),
all_criteria,
)
for gene, families in genes_to_families.items():
genes_families.setdefault(gene, set()).update(families)
matching_genes = genes_families
if recurrency_criteria_names:
assert len(recurrency_criteria_names) == 1, gene_set_id
recurrency_criterion = recurrency_criteria[
recurrency_criteria_names.pop()
]
matching_genes = cls._apply_recurrency(
matching_genes, recurrency_criterion,
)
return set(matching_genes.keys())
@classmethod
def _get_genes_to_families(
cls, gs_cache: dict[str, Any],
criteria: Iterable[str],
) -> dict[str, set[str]]:
"""
Recursively collect genes and families by given criteria.
Collects all genes and their families which
correspond to the set of given criteria.
The input gs_cache must be nested dictionaries with
leaf nodes of type 'set'.
"""
if len(gs_cache) == 0:
return {}
gs_cache_keys = list(gs_cache.keys())
criteria = set(criteria)
next_keys = criteria.intersection(gs_cache_keys)
if len(next_keys) == 0:
result: dict[str, set[str]] = {}
if not isinstance(gs_cache[gs_cache_keys[0]], set):
# still not the end of the tree
for key in gs_cache_keys:
for gene, families in cls._get_genes_to_families(
gs_cache[key], criteria,
).items():
result.setdefault(gene, set()).update(families)
elif len(criteria) == 0:
# end of tree with satisfied criteria
result.update(gs_cache)
return result
next_key = next_keys.pop()
return cls._get_genes_to_families(
gs_cache[next_key], criteria - {next_key},
)
@staticmethod
def _get_common_recurrency_criteria(
denovo_gene_set_collections: list[DenovoGeneSetCollection],
) -> dict[str, RecurrencyCriteria]:
if len(denovo_gene_set_collections) == 0:
return {}
recurrency_criteria = \
denovo_gene_set_collections[0].config.recurrency
for collection in denovo_gene_set_collections:
common_elements = frozenset(
recurrency_criteria.keys(),
).intersection(collection.config.recurrency.keys())
new_recurrency_criteria = {}
for element in common_elements:
new_recurrency_criteria[element] = recurrency_criteria[element]
recurrency_criteria = new_recurrency_criteria
return recurrency_criteria
@staticmethod
def _apply_recurrency(
genes_to_families: dict[str, set[str]],
recurrency: RecurrencyCriteria,
) -> dict[str, set[str]]:
"""Apply a recurrency criterion to a dictionary of genes."""
if recurrency.end < 0:
def filter_lambda(item: set[str]) -> bool:
return len(item) >= recurrency.start
else:
def filter_lambda(item: set[str]) -> bool:
return recurrency.start <= len(item) < recurrency.end
return {
k: v for k, v in genes_to_families.items() if filter_lambda(v)
}
@staticmethod
def _get_cache(
denovo_gene_set_collections: list[DenovoGeneSetCollection],
) -> dict[str, Any]:
gs_cache = {}
for collection in denovo_gene_set_collections:
gs_cache[collection.study_id] = collection.cache
return gs_cache
@staticmethod
def _format_description(
denovo_gene_set_spec: dict[str, dict[str, list[str]]],
) -> str:
return ";".join(
[
f"{genotype_data}:{group_id}:{','.join(values)}"
for genotype_data, person_set_collection
in denovo_gene_set_spec.items()
for group_id, values in person_set_collection.items()
],
)
@staticmethod
def _get_gene_sets_names(
denovo_gene_set_collections: list[DenovoGeneSetCollection],
) -> list[str]:
if len(denovo_gene_set_collections) == 0:
return []
gene_sets_ids = frozenset(
denovo_gene_set_collections[0].gene_sets_ids,
)
for collection in denovo_gene_set_collections:
gene_sets_ids = gene_sets_ids.intersection(
collection.gene_sets_ids,
)
return list(gene_sets_ids)