import argparse
import logging
import os
import pathlib
import sys
import time
from collections import defaultdict
from collections.abc import Iterable, Sequence
from itertools import batched
from typing import Any, cast
from box import Box
from dae.effect_annotation.effect import expand_effect_types
from dae.gene_profile.db import GeneProfileDBWriter
from dae.gene_profile.statistic import GPStatistic
from dae.gene_sets.gene_sets_db import GeneSet
from dae.genomic_resources.gene_models.gene_models import GeneModels
from dae.genomic_resources.reference_genome import ReferenceGenome
from dae.genomic_resources.repository_factory import (
build_genomic_resource_repository,
)
from dae.gpf_instance.gpf_instance import GPFInstance
from dae.person_sets import PSCQuery
from dae.task_graph.cli_tools import (
TaskGraphCli,
task_graph_run_with_results,
)
from dae.task_graph.graph import TaskGraph
from dae.utils.regions import Region
from dae.utils.verbosity_configuration import VerbosityConfiguration
from dae.variants.attributes import Role
from dae.variants.family_variant import FamilyAllele, FamilyVariant
from dae.variants.variant import allele_type_from_name
logger = logging.getLogger("generate_gene_profile")
[docs]
def generate_gp(
gpf_instance: GPFInstance,
gene_symbol: str,
collections_gene_sets: list[tuple[str, GeneSet]],
) -> tuple[str, GPStatistic]:
"""Generate GP."""
# pylint: disable=protected-access, invalid-name, too-many-locals
gene_scores_db = gpf_instance.gene_scores_db
config = gpf_instance._gene_profile_config # noqa: SLF001
assert config is not None
scores: dict[str, Any] = {}
sets_in = []
for collection_id, gs in collections_gene_sets:
if gene_symbol in gs["syms"]:
gs_name = gs["name"]
sets_in.append(f"{collection_id}_{gs_name}")
for category in config.gene_scores:
category_name = category["category"]
scores[category_name] = {}
for score in category["scores"]:
gene_score_name = score["score_name"]
score_desc = gene_scores_db.get_score_desc(gene_score_name)
gene_score = gene_scores_db.get_gene_score(score_desc.resource_id)
value = gene_score.get_gene_value(gene_score_name, gene_symbol)
scores[category_name][gene_score_name] = value
variant_counts: dict[str, Any] = {}
for dataset_id, value in config["datasets"].items():
statistics = value["statistics"]
person_sets = value["person_sets"]
for person_set in person_sets:
for statistic in statistics:
col = f'{dataset_id}_{person_set["set_name"]}_{statistic["id"]}'
col_rate = f"{col}_rate"
variant_counts[col] = 0
variant_counts[col_rate] = 0
return gene_symbol, GPStatistic(
gene_symbol, sets_in,
scores, variant_counts,
)
[docs]
def add_variant_count(
variant: FamilyVariant,
variant_counts: dict[str, Any],
person_set: str,
statistic_id: str,
statistic_effect_types: set[str] | None,
) -> None:
"""Increment count for specific variant."""
# pylint: disable=invalid-name
for gs in variant.effect_gene_symbols:
if gs not in variant_counts:
continue
skip = False
if statistic_effect_types is not None:
skip = True
for allele in variant.alt_alleles:
allele_gene_effects: dict[str, set[str]] = defaultdict(set)
for eg in allele.effect_genes:
if eg.symbol is None or eg.effect is None:
continue
allele_gene_effects[eg.symbol].add(eg.effect)
allele_effects = allele_gene_effects[gs]
if allele_effects.intersection(statistic_effect_types):
skip = False
break
if skip:
continue
vc = variant_counts[gs]
vc[person_set][statistic_id].add(variant.fvuid)
RARE_FREQUENCY_THRESHOLD = 1.0
[docs]
def build_rare_query(
statistic: Box,
) -> dict[str, Any]:
"""Build rare variant query."""
assert statistic.get("category") == "rare"
query: dict[str, Any] = {
"frequency_filter": [
("af_allele_freq", (None, RARE_FREQUENCY_THRESHOLD))],
"inheritance": [
("not denovo and "
"not possible_denovo and not possible_omission"),
"any([mendelian,unknown])",
],
}
if statistic.effects is not None:
query["effect_types"] = list(
expand_effect_types(statistic.effects))
if statistic.variant_types:
query["variant_type"] = " or ".join(
allele_type_from_name(
statistic.variant_types).repr() # type: ignore
for t in statistic.variant_types)
if statistic.roles:
query["roles"] = " or ".join(
repr(Role.from_name(r)) for r in statistic.roles)
else:
query["roles"] = "(prb or sib or child) and (mom or dad)"
if statistic.genomic_scores:
real_attr_query = []
for score in statistic.genomic_scores:
score_name = score["name"]
score_min = score.get("min")
score_max = score.get("max")
assert score_min is not None or score_max is not None
real_attr_query.append(
(score_name, (score_min, score_max)))
query["real_attr_filter"] = real_attr_query
return query
[docs]
def process_region(
regions: list[Region] | None,
gene_symbols: set[str],
person_ids: dict[str, Any],
*,
gpf_config: str | None = None,
grr_definition: dict[str, Any] | None = None,
) -> dict[str, dict[str, Any]]:
"""Process list of regions to collect variant counts."""
if grr_definition is not None:
grr = build_genomic_resource_repository(grr_definition)
else:
grr = None
gpf_instance = GPFInstance.build(gpf_config, grr=grr)
gene_profiles_config = gpf_instance._gene_profile_config # noqa: SLF001
assert gene_profiles_config is not None
query_genes = list(gene_symbols) if len(gene_symbols) <= 20 else None
variant_counts = _init_variant_counts(gene_profiles_config, gene_symbols)
for dataset_id, filters in gene_profiles_config.datasets.items():
has_denovo = any(
stats.category == "denovo" for stats in filters.statistics)
has_rare = any(
stats.category == "rare" for stats in filters.statistics)
if has_denovo:
logger.debug("collecting denovo variants for %s", dataset_id)
genotype_data = gpf_instance.get_genotype_data(dataset_id)
assert genotype_data is not None, dataset_id
denovo_variants = \
genotype_data.query_variants(
regions=regions,
genes=query_genes,
inheritance="denovo",
)
logger.debug("done collecting denovo variants for %s", dataset_id)
logger.debug("collecting denovo variant counts for %s", dataset_id)
collect_variant_counts(
variant_counts[dataset_id],
denovo_variants,
dataset_id,
gene_profiles_config,
person_ids[dataset_id],
denovo_flag=True,
)
logger.debug(
"done collecting denovo variant counts for %s", dataset_id,
)
if has_rare:
logger.debug("counting rare variants for %s", dataset_id)
genotype_data = gpf_instance.get_genotype_data(dataset_id)
assert genotype_data is not None, dataset_id
for statistic in filters.statistics:
if statistic.category != "rare":
continue
query_kwargs = build_rare_query(statistic)
rare_variants = \
genotype_data.query_variants(
regions=regions,
genes=query_genes,
**query_kwargs)
logger.debug(
"counting rare variants for dataset %s", dataset_id)
collect_variant_counts(
variant_counts[dataset_id],
rare_variants,
dataset_id,
gene_profiles_config,
person_ids[dataset_id],
denovo_flag=False,
)
logger.debug(
"done counting rare variants for dataset %s", dataset_id)
return variant_counts
[docs]
def count_variant(
v: FamilyVariant,
dataset_id: str,
variant_counts: dict[str, Any],
gene_profiles_config: Box,
person_ids: dict[str, Any], *,
denovo_flag: bool,
) -> None:
"""Count variant."""
# pylint: disable=invalid-name, too-many-locals, too-many-branches
filters = gene_profiles_config.datasets[dataset_id]
members = set()
for aa in v.alt_alleles:
for member in cast(FamilyAllele, aa).variant_in_members:
if member is not None:
members.add(member)
for ps in filters.person_sets:
pids = set(person_ids[ps.set_name])
for statistic in filters.statistics:
if statistic.category == "denovo" and not denovo_flag:
continue
if statistic.category == "rare" and denovo_flag:
continue
stat_id = statistic.id
in_members = pids.intersection(members)
if not in_members:
continue
if statistic.get("genomic_scores"):
do_count = _check_variant_genomic_scores(v, statistic)
if not do_count:
continue
if statistic.get("category") == "rare":
match = False
for aa in v.alt_alleles:
freq = aa.get_attribute("af_allele_freq")
if freq is not None and freq <= RARE_FREQUENCY_THRESHOLD:
match = True
if not match:
continue
if statistic.get("variant_types"):
variant_types = {
allele_type_from_name(t)
for t in statistic.variant_types
}
if not len(variant_types.intersection(v.variant_types)) > 0:
continue
if statistic.get("roles"):
roles = {
Role.from_name(r)
for r in statistic.roles
}
v_roles = set(
cast(FamilyAllele, v.alt_alleles[0]).variant_in_roles,
)
if not len(v_roles.intersection(roles)) > 0:
continue
statistic_effect_types = None
if statistic.get("effects"):
statistic_effect_types = set(
expand_effect_types(statistic.effects))
add_variant_count(
v, variant_counts, ps.set_name, stat_id, statistic_effect_types,
)
def _check_variant_genomic_scores(
v: FamilyVariant,
statistic: Box,
) -> bool:
do_count = True
for score in statistic.genomic_scores:
score_name = score["name"]
score_min = score.get("min")
score_max = score.get("max")
score_values: list[float] = list(
filter(None, v.get_attribute(score_name)))
if not score_values:
return False
if score_min is not None and score_max is not None:
if not any(score_min <= sv <= score_max for sv in score_values):
return False
elif score_min is not None:
if not any(sv >= score_min for sv in score_values):
return False
elif score_max is not None: # noqa: SIM102
if not any(sv <= score_max for sv in score_values):
return False
return do_count
[docs]
def collect_variant_counts(
variant_counts: dict[str, Any],
variants: Iterable[FamilyVariant],
dataset_id: str,
gene_profiles_config: Box,
person_ids: dict[str, Any], *,
denovo_flag: bool,
) -> None:
"""Collect variant gene counts for a given dataset."""
started = time.time()
for idx, v in enumerate(variants, 1):
if idx % 1000 == 0:
elapsed = time.time() - started
logger.debug(
"%s: counted %s variants from %s in %.2f seconds",
dataset_id, idx, dataset_id, elapsed,
)
count_variant(
v, dataset_id, variant_counts, gene_profiles_config, person_ids,
denovo_flag=denovo_flag,
)
[docs]
def build_partitions(
reference_genome: ReferenceGenome,
gene_models: GeneModels,
**kwargs: Any,
) -> Sequence[list[Region] | None]:
"""Build partitions for processing."""
split_by_chromosome = kwargs.get("split_by_chromosome", True)
if not split_by_chromosome:
return [None]
gene_symbols = kwargs.get("gene_symbols")
if gene_symbols is None:
all_chromosomes = set(reference_genome.chromosomes)
else:
all_chromosomes = set()
for gene in gene_symbols:
if gene not in gene_models.gene_models:
logger.warning(
"Gene symbol %s not found in gene models; skipping",
gene,
)
continue
gene_chromosomes = {
tr.chrom for tr in gene_models.gene_models[gene]}
all_chromosomes = all_chromosomes.union(gene_chromosomes)
logger.debug(
"collected chromosomes from gene symbols: %s", all_chromosomes,
)
autosomes = [f"chr{i}" for i in range(1, 23)]
autosomes_x = [*autosomes, "chrX", "X"]
autosomes_x = [chrom for chrom in autosomes_x if chrom in all_chromosomes]
partitions = [
[Region(chrom)]
for chrom in autosomes_x
if chrom in all_chromosomes]
if len(all_chromosomes - set(autosomes_x)) > 0:
remaining_chromsomes = [
chrom for chrom in (all_chromosomes - set(autosomes_x))
if gene_models.has_chromosome(chrom)
]
logger.debug(
"adding remaining %s chromosomes to partition %s; %s",
len(remaining_chromsomes), len(partitions), remaining_chromsomes)
remaining = [
Region(chrom) for chrom in remaining_chromsomes]
if remaining:
partitions.append(remaining)
return partitions
def _regions_id(regions: list[Region] | None) -> str:
"""Build regions id."""
if regions is None:
return "all_regions"
return "_".join(f"{r}" for r in regions)
[docs]
def main(
gpf_instance: GPFInstance | None = None,
argv: list[str] | None = None,
) -> None:
"""Entry point for the generate GP script."""
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
description = "Generate gene profile statistics tool"
parser = argparse.ArgumentParser(description=description)
VerbosityConfiguration.set_arguments(parser)
default_dbfile = os.path.join(os.getenv("DAE_DB_DIR", "./"), "gpdb.duckdb")
parser.add_argument("--dbfile", default=default_dbfile)
parser.add_argument(
"--gene-sets-genes",
action="store_true",
help="Generate GPs only for genes contained in the config's gene sets",
)
parser.add_argument(
"--genes",
help="Comma separated list of genes to generate statistics for",
)
parser.add_argument("--drop", action="store_true")
parser.add_argument(
"--split-by-chromosome", "--split", default=True,
action="store_true", dest="split_by_chromosome",
help="Split processing by chromosome "
"(default)")
parser.add_argument(
"--no-split-by-chromosome", "--no-split", default=True,
action="store_false", dest="split_by_chromosome",
help="Do not split processing by chromosome "
"(default)")
TaskGraphCli.add_arguments(
parser, use_commands=False, task_progress_mode=False,
)
args = parser.parse_args(argv)
VerbosityConfiguration.set(args)
if gpf_instance is None:
gpf_instance = GPFInstance.build()
# pylint: disable=protected-access, invalid-name
gene_profiles_config = gpf_instance._gene_profile_config # noqa: SLF001
assert gene_profiles_config is not None, "No GP configuration found."
if pathlib.Path(args.dbfile).exists() and not args.drop:
logger.error(
"gene profiles DB file %s already exists; "
"use --drop to drop and recreate the table",
args.dbfile,
)
sys.exit(1)
gpdb = GeneProfileDBWriter(
gene_profiles_config.to_dict(),
args.dbfile,
)
if args.drop:
gpdb.drop_gp_table()
gpdb.create_gp_table()
collections_gene_sets = _collect_gene_sets(
gpf_instance, gene_profiles_config)
gene_symbols = _collect_gene_symbols(
gpf_instance, collections_gene_sets, **vars(args))
gs_count = len(gene_symbols)
logger.debug("collected %d gene symbols", gs_count)
person_ids = _collect_person_set_collections(
gpf_instance, gene_profiles_config)
gene_profiles = _init_gene_profiles(
gpf_instance, collections_gene_sets, gene_symbols)
partitions = build_partitions(
gpf_instance.reference_genome,
gene_models=gpf_instance.gene_models,
gene_symbols=gene_symbols,
**vars(args))
variant_counts = _calculate_variant_counts(
gpf_instance,
gene_profiles_config,
gene_symbols,
person_ids,
partitions,
**vars(args),
)
_populate_gene_profile_statistics(
gpf_instance,
gene_profiles_config,
gene_profiles,
variant_counts,
)
_insert_gene_profiles_into_db(
gpdb,
gene_profiles,
)
_INSERT_GP_BATCH_SIZE = 1000
def _insert_gene_profiles_into_db(
gpdb: GeneProfileDBWriter,
gene_profiles: dict[str, GPStatistic],
) -> None:
started = time.time()
logger.debug("inserting statistics into DB")
batches = batched(gene_profiles.values(), _INSERT_GP_BATCH_SIZE)
total_batches = (
len(gene_profiles) + _INSERT_GP_BATCH_SIZE - 1) // _INSERT_GP_BATCH_SIZE
for idx, gene_profiles_batch in enumerate(batches, 1):
logger.debug("inserting batch %d/%d", idx, total_batches)
gpdb.insert_gps(gene_profiles_batch)
elapsed = time.time() - started
logger.debug(
"done inserting batch %d/%d, took %.2f seconds",
idx, total_batches, elapsed)
elapsed = time.time() - started
logger.debug("done inserting GPs; took %.2f secs", elapsed)
def _populate_gene_profile_statistics(
gpf_instance: GPFInstance,
gene_profiles_config: Box,
gene_profiles: dict[str, GPStatistic],
variant_counts: dict[str, Any],
) -> None:
for dataset_id in gene_profiles_config.datasets:
logger.debug("populating gene profile statistics for %s", dataset_id)
filters = gene_profiles_config.datasets[dataset_id]
for gs, counts in variant_counts[dataset_id].items():
gp_counts = gene_profiles[gs].variant_counts
genotype_data = gpf_instance.get_genotype_data(dataset_id)
for ps in filters.person_sets:
psc = genotype_data.get_person_set_collection(
ps.collection_name,
)
assert psc is not None
set_name = ps.set_name
person_set = psc.person_sets[set_name]
children_count = person_set.get_children_count()
for statistic in filters.statistics:
stat_id = statistic["id"]
count_col = f"{dataset_id}_{person_set.id}_{stat_id}"
rate_col = f"{count_col}_rate"
count = len(counts.get(set_name, {}).get(stat_id, set()))
if children_count > 0:
gp_counts[count_col] = count
gp_counts[rate_col] = \
(count / children_count) * 1000
else:
gp_counts[count_col] = 0
gp_counts[rate_col] = 0
logger.debug(
"done populating gene profile statistics for %s", dataset_id)
def _calculate_variant_counts(
gpf_instance: GPFInstance,
gene_profiles_config: Box,
gene_symbols: set[str],
person_ids: dict[str, Any],
partitions: Sequence[list[Region] | None],
**kwargs: Any,
) -> dict[str, Any]:
variant_counts: dict[str, Any] = _init_variant_counts(
gene_profiles_config, gene_symbols)
graph = TaskGraph()
for regions in partitions:
grr_definition = gpf_instance.grr.definition \
if gpf_instance.grr else None
graph.create_task(
f"generate_gene_profiles_{_regions_id(regions)}",
process_region,
args=(
regions,
gene_symbols,
person_ids,
),
kwargs={
"gpf_config": str(gpf_instance.dae_config_path),
"grr_definition": grr_definition,
},
)
with TaskGraphCli.create_executor(**kwargs) as executor:
for result_or_error in task_graph_run_with_results(graph, executor):
if isinstance(result_or_error, Exception):
raise result_or_error
region_variant_counts = result_or_error
variant_counts = _merge_variant_counts(
gene_profiles_config,
gene_symbols,
variant_counts,
region_variant_counts,
)
return variant_counts
def _init_variant_counts(
gene_profiles_config: Box,
gene_symbols: set[str],
) -> dict[str, Any]:
variant_counts: dict[str, Any] = {}
for dataset_id, filters in gene_profiles_config.datasets.items():
variant_counts[dataset_id] = {}
for gs in gene_symbols:
variant_counts[dataset_id][gs] = {}
for ps in filters.person_sets:
ps_statistics: dict[str, Any] = {}
for statistic in filters.statistics:
ps_statistics[statistic.id] = set()
variant_counts[dataset_id][gs][ps.set_name] = ps_statistics
return variant_counts
def _merge_variant_counts(
gene_profiles_config: Box,
gene_symbols: set[str],
variant_counts1: dict[str, Any],
variant_counts2: dict[str, Any],
) -> dict[str, Any]:
merged_counts: dict[str, Any] = {}
for dataset_id, filters in gene_profiles_config.datasets.items():
counts1 = variant_counts1[dataset_id]
counts2 = variant_counts2[dataset_id]
merged_counts[dataset_id] = {}
for gs in gene_symbols:
merged_counts[dataset_id][gs] = {}
gs_counts1 = counts1.get(gs, {})
gs_counts2 = counts2.get(gs, {})
for ps in filters.person_sets:
ps_statistics: dict[str, Any] = {}
for statistic in filters.statistics:
stats_count1 = gs_counts1.get(
ps.set_name, {}).get(statistic.id, set())
stats_count2 = gs_counts2.get(
ps.set_name, {}).get(statistic.id, set())
ps_statistics[statistic.id] = stats_count1 | stats_count2
merged_counts[dataset_id][gs][ps.set_name] = ps_statistics
return merged_counts
def _init_gene_profiles(
gpf_instance: GPFInstance,
collections_gene_sets: list[tuple[str, GeneSet]],
gene_symbols: set[str],
) -> dict[str, Any]:
start = time.time()
gene_profiles = {}
gene_symbols = set(gene_symbols)
gs_count = len(gene_symbols)
start = time.time()
for idx, sym in enumerate(gene_symbols, 1):
gs, gp = generate_gp(
gpf_instance, sym, collections_gene_sets,
)
gene_profiles[gs] = gp
if idx % 1000 == 0:
elapsed = time.time() - start
logger.debug(
"initializing %d/%d GP statistics %.2f secs",
idx, gs_count, elapsed)
return gene_profiles
def _collect_person_set_collections(
gpf_instance: GPFInstance,
gene_profiles_config: Box,
) -> dict[str, Any]:
person_ids: dict[str, Any] = {}
for dataset_id, filters in gene_profiles_config.datasets.items():
genotype_data = gpf_instance.get_genotype_data(dataset_id)
assert genotype_data is not None, dataset_id
assert genotype_data is not None, dataset_id
person_ids[dataset_id] = {}
for ps in filters.person_sets:
psc_query = PSCQuery(
ps.collection_name,
{ps.set_name},
)
psc = genotype_data.get_person_set_collection(psc_query.psc_id)
assert psc is not None
person_set = psc.query_person_ids(psc_query)
assert person_set is not None, psc_query
children_person_set = set(person_set)
person_ids[dataset_id][ps.set_name] = children_person_set
return person_ids
def _collect_gene_sets(
gpf_instance: GPFInstance,
gene_profiles_config: Box,
) -> list[tuple[str, GeneSet]]:
collections_gene_sets = []
for gs_category in gene_profiles_config.gene_sets:
for gs in gs_category.sets:
gs_id = gs["set_id"]
collection_id = gs["collection_id"]
gene_set = gpf_instance.gene_sets_db.get_gene_set(
collection_id, gs_id)
if gene_set is None:
logger.error("missing gene set: %s, %s", collection_id, gs_id)
raise ValueError(
f"missing gene set: {collection_id}: {gs_id}")
collections_gene_sets.append((collection_id, gene_set))
logger.debug("collected gene sets: %d", len(collections_gene_sets))
return collections_gene_sets
def _collect_gene_symbols(
gpf_instance: GPFInstance,
collections_gene_sets: list[tuple[str, GeneSet]],
**kwargs: Any,
) -> set[str]:
gene_symbols: set[str] = set()
if kwargs.get("genes"):
gene_symbols = {gs.strip() for gs in kwargs["genes"].split(",")}
elif kwargs.get("gene_sets_genes"):
for _, gs in collections_gene_sets:
gene_symbols = gene_symbols.union(gs["syms"])
else:
gene_models = gpf_instance.gene_models
gene_symbols = set(gene_models.gene_names())
return gene_symbols