Source code for dae.query_variants.sql.schema2.sql_query_builder

from __future__ import annotations

import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, cast

import duckdb
from sqlglot import condition, exp, or_, parse_one
from sqlglot.expressions import (
    Condition,
    Select,
    replace_placeholders,
)
from sqlglot.schema import Schema, ensure_schema

from dae.genomic_resources.gene_models import (
    GeneModels,
    create_regions_from_genes,
)
from dae.genomic_resources.reference_genome import ReferenceGenome
from dae.parquet.partition_descriptor import PartitionDescriptor
from dae.pedigrees.families_data import FamiliesData
from dae.query_variants.attributes_query import (
    QueryTreeToSQLBitwiseTransformer,
    role_query,
    sex_query,
)
from dae.query_variants.attributes_query import (
    variant_type_query as VARIANT_TYPE_PARSER,
)
from dae.query_variants.attributes_query_inheritance import (
    InheritanceTransformer,
    inheritance_parser,
)
from dae.utils.regions import Region
from dae.variants.attributes import Inheritance, Role

logger = logging.getLogger(__name__)


# A type describing a schema as expected by the query builders
TableSchema = dict[str, str]
RealAttrFilterType = list[tuple[str, tuple[float | None, float | None]]]


# family_variant_table & summary_allele_table are mandatory
# - no reliance on a variants table as in impala
[docs] @dataclass(frozen=True) class Db2Layout: """Genotype data layout in the database.""" db: str | None study: str pedigree: str summary: str | None family: str | None meta: str
[docs] @dataclass(frozen=True) class QueryHeuristics: """Heuristics for a query.""" region_bins: list[str] coding_bins: list[str] frequency_bins: list[str] family_bins: list[str]
[docs] def is_empty(self) -> bool: """Check if all heuristics are empty.""" return len(self.region_bins) == 0 and len(self.coding_bins) == 0 and \ len(self.frequency_bins) == 0 and len(self.family_bins) == 0
[docs] class QueryBuilderBase: """Base class for building SQL queries.""" GENE_REGIONS_HEURISTIC_CUTOFF = 20 GENE_REGIONS_HEURISTIC_EXTEND = 20000 REGION_BINS_HEURISTIC_CUTOFF = 20 def __init__( self, schema: Schema, families: FamiliesData, partition_descriptor: PartitionDescriptor | None, gene_models: GeneModels, reference_genome: ReferenceGenome, ): if gene_models is None: raise ValueError("gene_models are required") self.gene_models = gene_models if reference_genome is None: raise ValueError("reference genome isrequired") self.reference_genome = reference_genome self.families = families self.schema = schema self.partition_descriptor = partition_descriptor
[docs] def build_gene_regions( self, genes: list[str], regions: list[Region] | None, ) -> list[Region] | None: """Build a list of regions based on genes.""" assert self.gene_models is not None return create_regions_from_genes( self.gene_models, genes, regions, self.GENE_REGIONS_HEURISTIC_CUTOFF, self.GENE_REGIONS_HEURISTIC_EXTEND, )
[docs] def calc_coding_bins( self, effect_types: Sequence[str] | None, ) -> list[str]: """Calculate applicable coding bins for a query.""" if self.partition_descriptor is None: return [] if effect_types is None: return [] if "coding_bin" not in self.schema.column_names("summary_table"): return [] assert "coding_bin" in self.schema.column_names("summary_table") assert "coding_bin" in self.schema.column_names("family_table") assert effect_types is not None query_effect_types = set(effect_types) intersection = query_effect_types & set( self.partition_descriptor.coding_effect_types, ) coding_bins = [] if intersection == query_effect_types: coding_bins.append("1") return coding_bins
[docs] def calc_region_bins( self, regions: list[Region] | None, ) -> list[str]: """Calculate applicable region bins for a query.""" if self.partition_descriptor is None: return [] if not regions or not self.partition_descriptor.has_region_bins(): return [] region_bins: set[str] = set() for region in regions: region_bins.update( self.partition_descriptor.region_to_region_bins( region, self.reference_genome.get_all_chrom_lengths(), ), ) assert len(region_bins) > 0 if len(region_bins) > self.REGION_BINS_HEURISTIC_CUTOFF: return [] if not self.partition_descriptor.integer_region_bins: return [ f"'{rb}'" for rb in region_bins ] return list(region_bins)
[docs] @staticmethod def build_roles_query(roles_query: str, attr: str) -> str: """Construct a roles query.""" parsed = role_query.transform_query_string_to_tree(roles_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_roles_query_value(roles_query: str, value: int) -> bool: """Check if value satisfies a given roles query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_roles_query( roles_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def build_inheritance_query( inheritance_query: Sequence[str], attr: str, ) -> str: """Construct an inheritance query.""" result = [] transformer = InheritanceTransformer(attr, use_bit_and_function=False) for query in inheritance_query: parsed = inheritance_parser.parse(query) result.append(str(transformer.transform(parsed))) if not result: return "" return " AND ".join(result)
[docs] @staticmethod def check_inheritance_query_value( inheritance_query: Sequence[str], value: int, ) -> bool: """Check if value satisfies a given inheritance query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_inheritance_query( inheritance_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def check_roles_denovo_only(roles_query: str) -> bool: """Check if roles query is de novo only.""" return QueryBuilderBase.check_roles_query_value( roles_query, Role.prb.value | Role.sib.value) and \ not QueryBuilderBase.check_roles_query_value( roles_query, Role.prb.value | Role.sib.value | Role.dad.value | Role.mom.value)
[docs] @staticmethod def check_inheritance_denovo_only( inheritance_query: Sequence[str], ) -> bool: """Check if inheritance query is de novo only.""" return not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.mendelian.value) \ and not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.possible_denovo.value) \ and not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.possible_omission.value) \ and not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.missing.value)
[docs] @staticmethod def build_sexes_query(sexes_query: str, attr: str) -> str: """Build sexes query.""" parsed = sex_query.transform_query_string_to_tree(sexes_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_sexes_query_value(sexes_query: str, value: int) -> bool: """Check if value matches a given sexes query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_sexes_query( sexes_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def build_variant_types_query( variant_types_query: str, attr: str, ) -> str: """Build a variant types query.""" parsed = VARIANT_TYPE_PARSER.transform_query_string_to_tree( variant_types_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_variant_types_value( variant_types_query: str, value: int, ) -> bool: """Check if value satisfies a given variant types query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_variant_types_query( variant_types_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] def calc_frequency_bins( self, *, inheritance: Sequence[str] | None = None, roles: str | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, ) -> list[str]: """Calculate applicable frequency bins for a query.""" if self.partition_descriptor is None: return [] if "frequency_bin" not in self.schema.column_names("summary_table"): return [] assert "frequency_bin" in self.schema.column_names("summary_table") assert "frequency_bin" in self.schema.column_names("family_table") if roles and self.check_roles_denovo_only(roles): return ["0"] if inheritance and self.check_inheritance_denovo_only(inheritance): return ["0"] if not ultra_rare and frequency_filter is None: return [] frequency_bins: set[int] = {0} # always search de Novo variants if ultra_rare is not None and ultra_rare: frequency_bins.add(1) if frequency_filter is not None: for freq, (_, right) in frequency_filter: if freq != "af_allele_freq": continue if right is None: return [] assert right is not None if right <= self.partition_descriptor.rare_boundary: frequency_bins.add(2) elif right > self.partition_descriptor.rare_boundary: return [] result: list[str] = [] if frequency_bins and len(frequency_bins) < 4: result = [ str(fb) for fb in range(max(frequency_bins) + 1) ] return result
[docs] def calc_family_bins( self, family_ids: Iterable[str] | None, person_ids: Iterable[str] | None, ) -> list[str]: """Calculate family bins for a query.""" if self.partition_descriptor is None: return [] if not self.partition_descriptor.has_family_bins(): return [] if "family_bin" not in self.schema.column_names("family_table"): return [] if family_ids is None and person_ids is None: return [] family_bins: set[str] = set() if family_ids is not None: assert family_ids is not None family_ids = set(family_ids) family_bins.update( str(self.partition_descriptor.make_family_bin(family_id)) for family_id in family_ids) if person_ids is not None: assert person_ids is not None person_ids = { pid for pid in person_ids if pid in self.families.persons_by_person_id } family_ids = { self.families.persons_by_person_id[person_id][0].family_id for person_id in person_ids } family_bins.update( str(self.partition_descriptor.make_family_bin(family_id)) for family_id in family_ids) if len(family_bins) >= self.partition_descriptor.family_bin_size // 2: return [] return list(family_bins)
[docs] def all_region_bins(self) -> list[str]: """Return all region bins.""" if self.partition_descriptor is None: return [] if not self.partition_descriptor.has_region_bins(): return [] chrom_lens = dict(self.reference_genome.get_all_chrom_lengths()) all_region_bins = self.partition_descriptor.make_all_region_bins( chrom_lens, ) if not self.partition_descriptor.integer_region_bins: all_region_bins = [ f"'{rb}'" for rb in all_region_bins ] return all_region_bins
[docs] def calc_heuristics( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, inheritance: Sequence[str] | None = None, roles: str | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, ) -> QueryHeuristics: """Calculate heuristic bins for a query.""" heuristics_region_bins = [] if genes is not None: regions = self.build_gene_regions(genes, regions) region_bins = self.calc_region_bins(regions) if region_bins: heuristics_region_bins = region_bins heuristics_coding_bins = [] coding_bins = self.calc_coding_bins(effect_types) if coding_bins: heuristics_coding_bins = coding_bins heuristics_frequency_bins = [] frequency_bins = self.calc_frequency_bins( inheritance=inheritance, roles=roles, ultra_rare=ultra_rare, frequency_filter=frequency_filter, ) if frequency_bins: heuristics_frequency_bins = frequency_bins heuristics_family_bins = [] family_bins = self.calc_family_bins(family_ids, person_ids) if family_bins: heuristics_family_bins = family_bins return QueryHeuristics( region_bins=heuristics_region_bins, coding_bins=heuristics_coding_bins, frequency_bins=heuristics_frequency_bins, family_bins=heuristics_family_bins, )
[docs] def calc_batched_heuristics( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, inheritance: Sequence[str] | None = None, roles: str | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, ) -> list[QueryHeuristics]: """Calculate heuristics baches for a query.""" heuristics = self.calc_heuristics( regions=regions, genes=genes, effect_types=effect_types, inheritance=inheritance, roles=roles, ultra_rare=ultra_rare, frequency_filter=frequency_filter, family_ids=family_ids, person_ids=person_ids, ) if heuristics.region_bins: # single batch if we have region bins in heuristics return [heuristics] if heuristics.frequency_bins: # single batch if we dont search for rare and common variants rare_and_common_bins = {"1", "2", "3"} if not rare_and_common_bins & set(heuristics.frequency_bins): return [heuristics] if heuristics.coding_bins and heuristics.frequency_bins: # single batch if we search for rare coding variants noncoding_bin = "0" common_bin = "3" if noncoding_bin not in heuristics.coding_bins and \ common_bin not in heuristics.frequency_bins: return [heuristics] if self.partition_descriptor and \ self.partition_descriptor.has_region_bins(): return [ QueryHeuristics( region_bins=[f"{rb}"], coding_bins=heuristics.coding_bins, frequency_bins=heuristics.frequency_bins, family_bins=heuristics.family_bins, ) for rb in self.all_region_bins() ] return [heuristics]
[docs] @staticmethod def build_schema( summary_schema: dict[str, str], family_schema: dict[str, str], pedigree_schema: dict[str, str], ) -> Schema: return ensure_schema( { "summary_table": summary_schema, "family_table": family_schema, "pedigree_table": pedigree_schema, }, )
[docs] class SqlQueryBuilder(QueryBuilderBase): """Build SQL queries using sqlglot.""" def __init__( self, db_layout: Db2Layout, *, schema: Schema, partition_descriptor: PartitionDescriptor | None, families: FamiliesData, gene_models: GeneModels, reference_genome: ReferenceGenome, ): super().__init__( schema=schema, families=families, partition_descriptor=partition_descriptor, gene_models=gene_models, reference_genome=reference_genome, ) self.db_layout = db_layout
[docs] @staticmethod def build( db_layout: Db2Layout, *, pedigree_schema: dict[str, str], summary_schema: dict[str, str], family_schema: dict[str, str], partition_descriptor: PartitionDescriptor | None, families: FamiliesData, gene_models: GeneModels, reference_genome: ReferenceGenome, ) -> SqlQueryBuilder: """Return a new instance of the builder.""" schema = ensure_schema( { "summary_table": summary_schema, "family_table": family_schema, "pedigree_table": pedigree_schema, }, ) return SqlQueryBuilder( db_layout=db_layout, schema=schema, partition_descriptor=partition_descriptor, families=families, gene_models=gene_models, reference_genome=reference_genome, )
[docs] @staticmethod def genes(genes: list[str]) -> Condition: """Create genes condition.""" if len(genes) == 0: return condition("eg.effect_gene_symbols IS NULL") if len(genes) == 1: return condition(f"eg.effect_gene_symbols = '{genes[0]}'") gene_set = ",".join(f"'{g}'" for g in genes) return condition(f"eg.effect_gene_symbols in ({gene_set})")
[docs] @staticmethod def effect_types(effect_types: list[str]) -> Condition: """Create effect types condition.""" effect_types = [et.replace("'", "''") for et in effect_types] if len(effect_types) == 0: return condition("eg.effect_types IS NULL") effect_set = ",".join(f"'{g}'" for g in effect_types) return condition(f"eg.effect_types in ({effect_set})")
[docs] @staticmethod def summary_base() -> Select: """Create summary base query.""" return exp.select("*").from_("summary_table as sa")
[docs] @staticmethod def family_base() -> Select: return exp.select("*").from_("family_table as fa")
[docs] def summary_variants( self, summary: Select, ) -> Select: """Construct summary variants query.""" return self._append_cte( target=Select(), source=summary, alias="summary", ).select( "sa.bucket_index", "sa.summary_index", "sa.allele_index", "sa.summary_variant_data", ).from_( "summary as sa", )
@staticmethod def _append_cte( target: Select, source: Select, alias: str, ) -> Select: if source.ctes: for cte in source.ctes: target = target.with_( cte.alias, as_=cte.this, ) else: target = target.with_( alias, as_=source, ) return target
[docs] def family_variants( self, summary: Select, family: Select, ) -> Select: """Construct family variants query.""" query = self._append_cte( Select(), summary, "summary", ) on_clause = ( "sa.bucket_index = fa.bucket_index " "and sa.summary_index = fa.summary_index " "and sa.allele_index = fa.allele_index" ) if "sj_index" in self.schema.column_names("family_table"): assert "sj_index" in self.schema.column_names("summary_table") on_clause = "sa.sj_index = fa.sj_index" return self._append_cte( query, family, "family", ).select( "fa.bucket_index", "fa.summary_index", "fa.family_index", "sa.allele_index", "sa.summary_variant_data", "fa.family_variant_data", ).from_( "summary as sa", ).join( "family as fa", on=on_clause, )
@staticmethod def _region_to_condition(reg: Region) -> Condition: if reg.start is None and reg.stop is None: return condition(f":chromosome = '{reg.chrom}'") if reg.start is None: assert reg.stop is not None return condition( f":chromosome = '{reg.chrom}'" f" AND NOT ( " f":position > {reg.stop} )", ) if reg.stop is None: assert reg.start is not None return condition( f":chromosome = '{reg.chrom}'" f" AND ( " f"COALESCE(:end_position, :position) > {reg.start} )", ) assert reg.stop is not None assert reg.start is not None return condition( f":chromosome = '{reg.chrom}'" f" AND NOT ( " f"COALESCE(:end_position, :position) < {reg.start} OR " f":position > {reg.stop} )", )
[docs] @staticmethod def regions(regions: list[Region]) -> Condition: """Create regions condition.""" assert len(regions) > 0 result = SqlQueryBuilder._region_to_condition(regions[0]) for reg in regions[1:]: result = or_( result, SqlQueryBuilder._region_to_condition(reg), ) return result
@staticmethod def _real_attr_filter( attr: str, value_range: tuple[float | None, float | None], *, is_frequency: bool = False, ) -> Condition: """Create real attribute condition.""" left, right = value_range if left is None and right is None: if is_frequency: return condition("1 = 1") return condition(f"sa.{attr} IS NOT NULL") if left is None: assert right is not None if is_frequency: return condition( f"sa.{attr} <= {right} OR sa.{attr} IS NULL", ) return condition(f"sa.{attr} <= {right}") if right is None: assert left is not None return condition(f"sa.{attr} >= {left}") return condition( f"sa.{attr} >= {left} AND sa.{attr} <= {right}", )
[docs] @staticmethod def frequency( real_attrs: RealAttrFilterType, ) -> Condition: """Build frequencies filter where condition.""" assert len(real_attrs) > 0 conditions = [ SqlQueryBuilder._real_attr_filter( attr, value_range, is_frequency=True, ) for attr, value_range in real_attrs ] if len(conditions) == 1: return conditions[0] result = conditions[0] for cond in conditions[1:]: result = result.and_(cond) return result
@staticmethod def _real_attr( attr: str, value_range: tuple[float | None, float | None], ) -> Condition: return SqlQueryBuilder._real_attr_filter( attr, value_range, is_frequency=False, )
[docs] @staticmethod def real_attr( real_attrs: RealAttrFilterType, ) -> Condition: """Build real attributes filter where condition.""" assert len(real_attrs) > 0 conditions = [ SqlQueryBuilder._real_attr_filter( attr, value_range, is_frequency=False, ) for attr, value_range in real_attrs ] if len(conditions) == 1: return conditions[0] result = conditions[0] for cond in conditions[1:]: result = result.and_(cond) return result
[docs] @staticmethod def ultra_rare() -> Condition: return SqlQueryBuilder._real_attr_filter( "af_allele_count", (None, 1), is_frequency=True)
[docs] def summary_query( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, ) -> Select: """Build a summary variant query.""" query = self.summary_base() if regions is not None: clause = self.regions(regions) query = query.where(replace_placeholders( clause, chromosome=exp.to_column("sa.chromosome"), position=exp.to_column("sa.position"), end_position=exp.to_column("sa.end_position"), )) if real_attr_filter: clause = self.real_attr(real_attr_filter) query = query.where(clause) if frequency_filter: clause = self.frequency(frequency_filter) query = query.where(clause) if ultra_rare is not None and ultra_rare: clause = self.ultra_rare() query = query.where(clause) if variant_type is not None: query = query.where( self.build_variant_types_query( variant_type, "sa.variant_type")) if not return_reference and not return_unknown: query = query.where("sa.allele_index > 0") if genes is not None or effect_types is not None: summary_effects = parse_one( "select *, unnest(sa.effect_gene) as eg " "from summary_base as sa", ) summary = exp.select( "*", ).from_( "summary_effects", ) if genes is not None: summary = summary.where( SqlQueryBuilder.genes(genes)) if effect_types is not None: summary = summary.where( SqlQueryBuilder.effect_types(effect_types)) query = Select().with_( "summary_base", as_=query, ).with_( "summary_effects", as_=summary_effects, ).with_( "summary", as_=summary, ).select("*").from_("summary") return query
[docs] @staticmethod def roles( roles_query: str, ) -> Condition: return condition( SqlQueryBuilder.build_roles_query( roles_query, "fa.allele_in_roles"))
[docs] @staticmethod def sexes( sexes_query: str, ) -> Condition: return condition( SqlQueryBuilder.build_sexes_query( sexes_query, "fa.allele_in_sexes"))
[docs] @staticmethod def inheritance( inheritance_query: str | Sequence[str], ) -> Condition: """Build inheritance filter.""" if isinstance(inheritance_query, str): return condition( SqlQueryBuilder.build_inheritance_query( [inheritance_query], "fa.inheritance_in_members")) return condition( SqlQueryBuilder.build_inheritance_query( inheritance_query, "fa.inheritance_in_members"))
[docs] @staticmethod def family_ids( family_ids: Sequence[str], ) -> Condition: """Create family IDs filter.""" if not family_ids: return condition("fa.family_id IS NULL") if len(family_ids) == 1: return condition(f"fa.family_id = '{next(iter(family_ids))}'") fids = [f"'{fid}'" for fid in family_ids] return condition(f"fa.family_id IN ({', '.join(fids)})")
[docs] @staticmethod def person_ids( person_ids: Sequence[str], ) -> Condition: """Create person IDs filter.""" if not person_ids: return condition("fa.aim IS NULL") if len(person_ids) == 1: return condition(f"fa.aim = '{next(iter(person_ids))}'") pids = [f"'{pid}'" for pid in person_ids] return condition(f"fa.aim IN ({', '.join(pids)})")
[docs] def family_query( self, family_ids: Sequence[str] | None = None, person_ids: Sequence[str] | None = None, inheritance: str | Sequence[str] | None = None, roles: str | None = None, sexes: str | None = None, ) -> Select: """Build a family subclause query.""" query = self.family_base() if roles is not None: clause = self.roles(roles) query = query.where(clause) if inheritance is not None: clause = self.inheritance(inheritance) query = query.where(clause) if sexes is not None: clause = self.sexes(sexes) query = query.where(clause) if family_ids is not None or person_ids is not None: if person_ids is not None: person_ids = [ pid for pid in person_ids if pid in self.families.persons_by_person_id ] fids = { self.families.persons_by_person_id[pid][0].family_id for pid in person_ids } if family_ids is not None: fids &= set(family_ids) family_ids = list(fids) assert family_ids is not None clause = self.family_ids(family_ids) query = query.where(clause) if person_ids is not None: family_members = parse_one( "select *, unnest(fa.allele_in_members) as aim " "from family_base as fa", ) family_query = exp.select( "*", ).from_( "family_members as fa", ).where( SqlQueryBuilder.person_ids(person_ids), ) query = Select().with_( "family_base", as_=query, ).with_( "family_members", as_=family_members, ).with_( "family", as_=family_query, ).select("*").from_("family") return query
@staticmethod def _heuristic_bins( table: str, heuristic: str, bins: list[str], ) -> Condition: assert len(bins) > 0 if len(bins) == 1: return condition(f"{table}.{heuristic} = {bins[0]}") return condition( f"{table}.{heuristic} IN ({', '.join(bins)})")
[docs] @staticmethod def region_bins(table: str, region_bins: list[str]) -> Condition: """Create region bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "region_bin", region_bins)
[docs] @staticmethod def frequency_bins(table: str, frequency_bins: list[str]) -> Condition: """Create frequency bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "frequency_bin", frequency_bins)
[docs] @staticmethod def coding_bins(table: str, coding_bins: list[str]) -> Condition: """Create coding bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "coding_bin", coding_bins)
[docs] @staticmethod def family_bins(table: str, family_bins: list[str]) -> Condition: """Create family bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "family_bin", family_bins)
[docs] def apply_summary_heuristics( self, query: Select, heuristics: QueryHeuristics | None, table: str = "sa", ) -> Select: """Apply heuristics to the summary query.""" if heuristics is None or heuristics.is_empty(): return query base_query = cast( Select, query.ctes[0].this if query.ctes else query, ) if heuristics.region_bins: base_query = base_query.where( self.region_bins(table, heuristics.region_bins)) if heuristics.frequency_bins: base_query = base_query.where( self.frequency_bins(table, heuristics.frequency_bins)) if heuristics.coding_bins: base_query = base_query.where( self.coding_bins(table, heuristics.coding_bins)) if not query.ctes: return base_query result = cast(Select, query.copy()) result.ctes[0].args["this"] = base_query return result
[docs] def apply_family_heuristics( self, query: Select, heuristics: QueryHeuristics | None, ) -> Select: """Apply heuristics to the family query.""" if heuristics is None or heuristics.is_empty(): return query query = self.apply_summary_heuristics(query, heuristics, table="fa") if heuristics.family_bins: query = query.where( self.family_bins("fa", heuristics.family_bins)) return query
[docs] def build_summary_variants_query( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, limit: int | None = None, ) -> list[str]: """Build a query for summary variants.""" squery = self.summary_query( regions=regions, genes=genes, effect_types=effect_types, variant_type=variant_type, real_attr_filter=real_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, ) batched_heuristics = self.calc_batched_heuristics( regions=regions, genes=genes, effect_types=effect_types, ultra_rare=ultra_rare, frequency_filter=frequency_filter, ) result = [] for heuristics in batched_heuristics: query = self.summary_variants( summary=self.apply_summary_heuristics(squery, heuristics), ) if limit is not None: query = query.limit(limit) query = self.replace_tables(query) result.append(query.sql()) return result
[docs] def replace_tables(self, query: Select) -> Select: """Replace table names in the query.""" if self.db_layout.summary is None: assert self.db_layout.family is None return exp.replace_tables( query, { "pedigree_table": self.db_layout.pedigree, }, ) assert self.db_layout.summary is not None assert self.db_layout.family is not None if self.db_layout.db is None: return exp.replace_tables( query, { "summary_table": self.db_layout.summary, "family_table": self.db_layout.family, "pedigree_table": self.db_layout.pedigree, }, ) db = self.db_layout.db assert db is not None return exp.replace_tables( query, { "summary_table": f"{db}.{self.db_layout.summary}", "family_table": f"{db}.{self.db_layout.family}", "pedigree_table": f"{db}.{self.db_layout.pedigree}", }, )
[docs] def build_family_variants_query( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, family_ids: Sequence[str] | None = None, person_ids: Sequence[str] | None = None, inheritance: Sequence[str] | None = None, roles: str | None = None, sexes: str | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, limit: int | None = None, **_kwargs: Any, ) -> list[str]: """Build a query for family variants.""" squery = self.summary_query( regions=regions, genes=genes, effect_types=effect_types, variant_type=variant_type, real_attr_filter=real_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, ) fquery = self.family_query( family_ids=family_ids, person_ids=person_ids, inheritance=inheritance, roles=roles, sexes=sexes, ) batched_heuristics = self.calc_batched_heuristics( regions=regions, genes=genes, effect_types=effect_types, inheritance=inheritance, roles=roles, ultra_rare=ultra_rare, frequency_filter=frequency_filter, family_ids=family_ids, ) result = [] for heuristics in batched_heuristics: query = self.family_variants( summary=self.apply_summary_heuristics(squery, heuristics), family=self.apply_family_heuristics(fquery, heuristics), ) if limit is not None: query = query.limit(limit) query = self.replace_tables(query) result.append(query.sql()) return result