Source code for dae.gene_profile.generate_gene_profile

#!/usr/bin/env python
import argparse
import logging
import os
import time
from collections.abc import Iterable
from typing import Any, Dict, Set, cast

from box import Box

from dae.effect_annotation.effect import expand_effect_types
from dae.gene_sets.gene_sets_db import GeneSet
from dae.gene_profile.db import GeneProfileDB
from dae.gene_profile.statistic import GPStatistic
from dae.gpf_instance.gpf_instance import GPFInstance
from dae.person_sets import PSCQuery
from dae.utils.verbosity_configuration import VerbosityConfiguration
from dae.variants.attributes import Role
from dae.variants.family_variant import FamilyAllele, FamilyVariant
from dae.variants.variant import allele_type_from_name

logger = logging.getLogger(__file__)


[docs] def generate_gp( gpf_instance: GPFInstance, gene_symbol: str, collections_gene_sets: list[tuple[str, GeneSet]], ) -> tuple[str, GPStatistic]: """Generate GP.""" # pylint: disable=protected-access, invalid-name, too-many-locals gene_scores_db = gpf_instance.gene_scores_db config = gpf_instance._gene_profile_config assert config is not None scores: Dict[str, Any] = {} sets_in = [] for collection_id, gs in collections_gene_sets: if gene_symbol in gs["syms"]: gs_name = gs["name"] sets_in.append(f"{collection_id}_{gs_name}") for category in config.genomic_scores: category_name = category["category"] scores[category_name] = {} for score in category["scores"]: gene_score_name = score["score_name"] score_desc = gene_scores_db.get_score_desc(gene_score_name) gene_score = gene_scores_db.get_gene_score(score_desc.resource_id) if gene_symbol in gene_score.get_genes(gene_score_name): value = gene_score.get_gene_value(gene_score_name, gene_symbol) else: value = None scores[category_name][gene_score_name] = value variant_counts: dict[str, Any] = {} return gene_symbol, GPStatistic( gene_symbol, sets_in, scores, variant_counts, )
[docs] def add_variant_count( variant: FamilyVariant, variant_counts: dict[str, Any], person_set: str, statistic_id: str, effect_types: set[str] | None, ) -> None: """Increment count for specific variant.""" # pylint: disable=invalid-name for gs in variant.effect_gene_symbols: if gs not in variant_counts: continue skip = False if effect_types is not None: skip = True for allele in variant.alt_alleles: allele_gene_effects: dict[str, set[str]] = {} for eg in allele.effect_genes: if eg.symbol is None or eg.effect is None: continue if eg.symbol in allele_gene_effects: allele_gene_effects[eg.symbol].add(eg.effect) else: allele_gene_effects[eg.symbol] = {eg.effect} allele_effects = allele_gene_effects.get(gs) if allele_effects \ and allele_effects.intersection(effect_types): skip = False break if skip: continue vc = variant_counts[gs] vc[person_set][statistic_id] += 1
[docs] def calculate_table_values( instance: GPFInstance, variant_counts: dict[str, Any], dataset_id: str, filters: Box, ) -> dict[str, Any]: """Calculate GP variant counts and return a SQLite update mapping.""" # pylint: disable=invalid-name table_values: dict[str, Any] = {} for gs, counts in variant_counts.items(): table_values[gs] = {} genotype_data = instance.get_genotype_data(dataset_id) for ps in filters.person_sets: psc = genotype_data.get_person_set_collection( ps.collection_name, ) assert psc is not None set_name = ps.set_name person_set = psc.person_sets[set_name] children_count = person_set.get_children_count() for statistic in filters.statistics: stat_id = statistic["id"] count_col = f"{dataset_id}_{person_set.id}_{stat_id}" rate_col = f"{count_col}_rate" count = counts[set_name][stat_id] if children_count > 0: table_values[gs][count_col] = count table_values[gs][rate_col] = \ (count / children_count) * 1000 else: table_values[gs][count_col] = 0 table_values[gs][rate_col] = 0 return table_values
[docs] def count_variant( v: FamilyVariant, dataset_id: str, variant_counts: dict[str, Any], config: Box, person_ids: dict[str, Any], denovo_flag: bool, ) -> None: """Count variant.""" # pylint: disable=invalid-name, too-many-locals, too-many-branches filters = config.datasets[dataset_id] members = set() for aa in v.alt_alleles: for member in cast(FamilyAllele, aa).variant_in_members: if member is not None: members.add(member) for ps in filters.person_sets: pids = set(person_ids[ps.set_name]) for statistic in filters.statistics: if statistic.category == "denovo" and not denovo_flag: continue if statistic.category == "rare" and denovo_flag: continue stat_id = statistic.id in_members = len(pids.intersection(members)) > 0 if not in_members: continue if statistic.get("scores"): do_count = True for score in statistic.scores: score_name = score["name"] score_min = score.get("min") score_max = score.get("max") score_value = v.get_attribute(score_name)[0] if score_value is None: do_count = False if score_min: do_count = do_count and score_value >= score_min if score_max: do_count = do_count and score_value <= score_max if not do_count: continue if statistic.get("category") == "rare": match = False for aa in v.alt_alleles: freq = aa.get_attribute("af_allele_freq") if freq is not None and freq <= 1.0: match = True if not match: continue if statistic.get("variant_types"): variant_types = { allele_type_from_name(t) for t in statistic.variant_types } if not len(variant_types.intersection(v.variant_types)) > 0: continue if statistic.get("roles"): roles = { Role.from_name(r) for r in statistic.roles } v_roles = set( cast(FamilyAllele, v.alt_alleles[0]).variant_in_roles, ) if not len(v_roles.intersection(roles)) > 0: continue ets = None if statistic.get("effects"): ets = set(expand_effect_types(statistic.effects)) add_variant_count( v, variant_counts, ps.set_name, stat_id, ets, )
[docs] def collect_variant_counts( variant_counts: dict[str, Any], variants: Iterable[FamilyVariant], dataset_id: str, config: Box, person_ids: dict[str, Any], denovo_flag: bool, ) -> None: """Collect variant gene counts for a given dataset.""" for idx, v in enumerate(variants, 1): if idx % 1000 == 0: logger.info( "%s: Counted %s variants from %s", dataset_id, idx, dataset_id, ) count_variant( v, dataset_id, variant_counts, config, person_ids, denovo_flag, )
[docs] def main( gpf_instance: GPFInstance | None = None, argv: list[str] | None = None, ) -> None: """Entry point for the generate GP script.""" # flake8: noqa: C901 # pylint: disable=too-many-locals,too-many-branches,too-many-statements description = "Generate gene profile statistics tool" parser = argparse.ArgumentParser(description=description) VerbosityConfiguration.set_arguments(parser) default_dbfile = os.path.join(os.getenv("DAE_DB_DIR", "./"), "gpdb") parser.add_argument("--dbfile", default=default_dbfile) parser.add_argument( "--gene-sets-genes", action="store_true", help="Generate GPs only for genes contained in the config's gene sets", ) parser.add_argument( "--genes", help="Comma separated list of genes to generate statistics for", ) parser.add_argument("--drop", action="store_true") args = parser.parse_args(argv) VerbosityConfiguration.set(args) start = time.time() if gpf_instance is None: gpf_instance = GPFInstance.build() # pylint: disable=protected-access, invalid-name config = gpf_instance._gene_profile_config assert config is not None, "No GP configuration found." collections_gene_sets = [] gpdb = GeneProfileDB( config.to_dict(), args.dbfile, clear=True, ) for gs_category in config.gene_sets: for gs in gs_category.sets: gs_id = gs["set_id"] collection_id = gs["collection_id"] gene_set = gpf_instance.gene_sets_db.get_gene_set( collection_id, gs_id) if gene_set is None: logger.error("missing gene set: %s, %s", collection_id, gs_id) raise ValueError( f"missing gene set: {collection_id}: {gs_id}") collections_gene_sets.append((collection_id, gene_set)) logger.info("collected gene sets: %d", len(collections_gene_sets)) gene_symbols: Set[str] = set() if args.genes: gene_symbols = set(gs.strip() for gs in args.genes.split(",")) elif args.gene_sets_genes: for _, gs in collections_gene_sets: gene_symbols = gene_symbols.union(gs["syms"]) else: gene_models = gpf_instance.gene_models gene_symbols = set(gene_models.gene_names()) gs_count = len(gene_symbols) logger.info("Collected %d gene symbols", gs_count) has_denovo = False has_rare = False person_ids: Dict[str, Any] = {} for dataset_id, filters in config.datasets.items(): genotype_data = gpf_instance.get_genotype_data(dataset_id) assert genotype_data is not None genotype_data_children = { p.person_id for p in genotype_data.families.persons_with_parents() } assert genotype_data is not None, dataset_id person_ids[dataset_id] = {} for ps in filters.person_sets: psc_query = PSCQuery( ps.collection_name, {ps.set_name}, ) psc = genotype_data.get_person_set_collection(psc_query.psc_id) assert psc is not None person_set = psc.query_person_ids(psc_query) assert person_set is not None, psc_query children_person_set = set(person_set) & genotype_data_children person_ids[dataset_id][ps.set_name] = children_person_set for stat in filters.statistics: if stat.category == "denovo": has_denovo = True elif stat.category == "rare": has_rare = True gps = {} gene_symbols = set(gene_symbols) gs_count = len(gene_symbols) elapsed = time.time() - start logger.info("data collected: %.2f secs", elapsed) start = time.time() for idx, sym in enumerate(gene_symbols, 1): gs, gp = generate_gp( gpf_instance, sym, collections_gene_sets, ) gps[gs] = gp if idx % 1000 == 0: elapsed = time.time() - start logger.info( "Generated %d/%d GP statistics %.2f secs", idx, gs_count, elapsed) logger.info("Inserting statistics into DB") gpdb.insert_gps(gps.values()) generate_end = time.time() elapsed = generate_end - start logger.info("Took %.2f secs", elapsed) for dataset_id, filters in config.datasets.items(): variant_counts: dict[str, Any] = {} for gs in gene_symbols: variant_counts[gs] = {} for ps in filters.person_sets: ps_statistics: dict[str, Any] = {} for statistic in filters.statistics: ps_statistics[statistic.id] = 0 variant_counts[gs][ps.set_name] = ps_statistics if has_denovo: logger.info("Collecting denovo variants for %s", dataset_id) genotype_data = gpf_instance.get_genotype_data(dataset_id) assert genotype_data is not None, dataset_id if args.gene_sets_genes or args.genes: genes = list(gene_symbols) else: genes = None denovo_variants = \ genotype_data.query_variants( genes=genes, inheritance="denovo", unique_family_variants=True, ) logger.info("Done collecting denovo variants for %s", dataset_id) logger.info("Collecting denovo variant counts for %s", dataset_id) collect_variant_counts( variant_counts, denovo_variants, dataset_id, config, person_ids[dataset_id], True, ) logger.info( "Done collecting denovo variant counts for %s", dataset_id, ) if has_rare: logger.info("Counting rare variants for %s", dataset_id) genotype_data = gpf_instance.get_genotype_data(dataset_id) assert genotype_data is not None, dataset_id if args.gene_sets_genes or args.genes: genes = list(gene_symbols) else: genes = None for statistic in filters.statistics: if statistic.category == "denovo": continue kwargs: dict[str, Any] = {} kwargs["roles"] = "prb or sib or child" if statistic.effects is not None: kwargs["effect_types"] = \ expand_effect_types(statistic.effects) if statistic.variant_types: variant_types = [ # pylint: disable=no-member allele_type_from_name( statistic.variant_types).repr(), # type: ignore ] kwargs["variant_type"] = " or ".join(variant_types) if statistic.scores: scores = [] for score in statistic.scores: min_max = (score.min, score.max) score_filter = (score.name, min_max) scores.append(score_filter) kwargs["real_attr_filter"] = scores if statistic.roles: roles = [ repr(Role.from_name(statistic.roles)), ] kwargs["roles"] = " or ".join(roles) rare_variants = \ genotype_data.query_variants( genes=genes, inheritance=[ "not denovo and " "not possible_denovo and not possible_omission", "mendelian or missing", ], frequency_filter=[("af_allele_freq", (None, 1.0))], **kwargs) logger.info( "Counting rare variants for statistic %s", statistic.id, ) collect_variant_counts( variant_counts, rare_variants, dataset_id, config, person_ids[dataset_id], False, ) logger.info("Done counting rare variants") logger.info("Calculating rates for %s", dataset_id) table_values = calculate_table_values( gpf_instance, variant_counts, dataset_id, filters, ) logger.info("Done calculating rates for %s", dataset_id) logger.info("Updating GPs for %s", dataset_id) gpdb.update_gps_with_values(table_values) logger.info("Done updating GPs for %s", dataset_id) elapsed = time.time() - generate_end logger.info("Took %.2f secs", elapsed)