Source code for dae.query_variants.sql.schema2.sql_query_builder

from __future__ import annotations

import itertools
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, cast

import duckdb
from pydantic import BaseModel
from sqlglot import Expression, column, condition, exp, or_, parse_one
from sqlglot.expressions import (
    Condition,
    Select,
    Table,
    replace_placeholders,
    table_,
)
from sqlglot.schema import Schema, ensure_schema

from dae.genomic_resources.gene_models import (
    GeneModels,
    create_regions_from_genes,
)
from dae.genomic_resources.reference_genome import ReferenceGenome
from dae.parquet.partition_descriptor import PartitionDescriptor
from dae.pedigrees.families_data import FamiliesData
from dae.query_variants.attributes_query import (
    QueryTreeToSQLBitwiseTransformer,
    affected_status_query,
    role_query,
    sex_query,
)
from dae.query_variants.attributes_query import (
    variant_type_query as VARIANT_TYPE_PARSER,
)
from dae.query_variants.attributes_query_inheritance import (
    InheritanceTransformer,
    inheritance_parser,
)
from dae.utils.regions import Region
from dae.variants.attributes import Inheritance, Role, Status, Zygosity

logger = logging.getLogger(__name__)


# A type describing a schema as expected by the query builders
TableSchema = dict[str, str]
RealAttrFilterType = list[tuple[str, tuple[float | None, float | None]]]
CategoricalAttrFilterType = list[tuple[str, list[str] | list[int] | None]]


# family_variant_table & summary_allele_table are mandatory
# - no reliance on a variants table as in impala
[docs] @dataclass(frozen=True) class Db2Layout: """Genotype data layout in the database.""" db: str | None study: str pedigree: str summary: str | None family: str | None meta: str
[docs] @dataclass(frozen=True) class QueryHeuristics: """Heuristics for a query.""" region_bins: list[str] coding_bins: list[str] frequency_bins: list[str] family_bins: list[str]
[docs] def is_empty(self) -> bool: """Check if all heuristics are empty.""" return len(self.region_bins) == 0 and len(self.coding_bins) == 0 and \ len(self.frequency_bins) == 0 and len(self.family_bins) == 0
[docs] class TagsQuery(BaseModel): selected_family_tags: list[str] | None = None deselected_family_tags: list[str] | None = None tags_or_mode: bool = False
[docs] class ZygosityQuery(BaseModel): status_zygosity: str | None = None parents_zygosity: int | None = None children_zygosity: int | None = None sex_zygosity: int | None = None
[docs] class QueryBuilderBase: """Base class for building SQL queries.""" GENE_REGIONS_HEURISTIC_CUTOFF = 20 GENE_REGIONS_HEURISTIC_EXTEND = 20000 REGION_BINS_HEURISTIC_CUTOFF = 20 def __init__( self, schema: Schema, families: FamiliesData, partition_descriptor: PartitionDescriptor | None, gene_models: GeneModels, reference_genome: ReferenceGenome, ): if gene_models is None: raise ValueError("gene_models are required") self.gene_models = gene_models if reference_genome is None: raise ValueError("reference genome isrequired") self.reference_genome = reference_genome self.families = families self.schema = schema self.partition_descriptor = partition_descriptor
[docs] def build_gene_regions( self, genes: list[str], regions: list[Region] | None, ) -> list[Region] | None: """Build a list of regions based on genes.""" assert self.gene_models is not None return create_regions_from_genes( self.gene_models, genes, regions, self.GENE_REGIONS_HEURISTIC_CUTOFF, self.GENE_REGIONS_HEURISTIC_EXTEND, )
[docs] def calc_coding_bins( self, effect_types: Sequence[str] | None, ) -> list[str]: """Calculate applicable coding bins for a query.""" if self.partition_descriptor is None: return [] if effect_types is None: return [] if "coding_bin" not in self.schema.column_names("summary_table"): return [] assert "coding_bin" in self.schema.column_names("summary_table") assert "coding_bin" in self.schema.column_names("family_table") assert effect_types is not None query_effect_types = set(effect_types) intersection = query_effect_types & set( self.partition_descriptor.coding_effect_types, ) coding_bins = [] if intersection == query_effect_types: coding_bins.append("1") return coding_bins
[docs] def calc_region_bins( self, regions: list[Region] | None, ) -> list[str]: """Calculate applicable region bins for a query.""" if self.partition_descriptor is None: return [] if not regions or not self.partition_descriptor.has_region_bins(): return [] region_bins: set[str] = set() for region in regions: region_bins.update( self.partition_descriptor.region_to_region_bins( region, self.reference_genome.get_all_chrom_lengths(), ), ) assert len(region_bins) > 0 if len(region_bins) > self.REGION_BINS_HEURISTIC_CUTOFF: return [] if not self.partition_descriptor.integer_region_bins: return [ f"'{rb}'" for rb in region_bins ] return list(region_bins)
[docs] @staticmethod def build_roles_query(roles_query: str, attr: str) -> str: """Construct a roles query.""" parsed = role_query.transform_query_string_to_tree(roles_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_roles_query_value(roles_query: str, value: int) -> bool: """Check if value satisfies a given roles query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_roles_query( roles_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def build_inheritance_query( inheritance_query: Sequence[str], attr: str, ) -> str: """Construct an inheritance query.""" result = [] transformer = InheritanceTransformer(attr, use_bit_and_function=False) for query in inheritance_query: parsed = inheritance_parser.parse(query) result.append(str(transformer.transform(parsed))) if not result: return "" return " AND ".join(result)
[docs] @staticmethod def check_inheritance_query_value( inheritance_query: Sequence[str], value: int, ) -> bool: """Check if value satisfies a given inheritance query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_inheritance_query( inheritance_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def check_roles_denovo_only(roles_query: str) -> bool: """Check if roles query is de novo only.""" return QueryBuilderBase.check_roles_query_value( roles_query, Role.prb.value | Role.sib.value) and \ not QueryBuilderBase.check_roles_query_value( roles_query, Role.prb.value | Role.sib.value | Role.dad.value | Role.mom.value)
[docs] @staticmethod def check_inheritance_denovo_only( inheritance_query: Sequence[str], ) -> bool: """Check if inheritance query is de novo only.""" return not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.mendelian.value) \ and not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.possible_denovo.value) \ and not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.possible_omission.value) \ and not QueryBuilderBase.check_inheritance_query_value( inheritance_query, Inheritance.missing.value)
[docs] @staticmethod def build_sexes_query(sexes_query: str, attr: str) -> str: """Build sexes query.""" parsed = sex_query.transform_query_string_to_tree(sexes_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_sexes_query_value(sexes_query: str, value: int) -> bool: """Check if value matches a given sexes query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_sexes_query( sexes_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def build_statuses_query(statuses_query: str, attr: str) -> str: """Build affected status query.""" parsed = affected_status_query.transform_query_string_to_tree( statuses_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_statuses_query_value(statuses_query: str, value: int) -> bool: """Check if value matches a given affected statuses query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_statuses_query( statuses_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] @staticmethod def build_variant_types_query( variant_types_query: str, attr: str, ) -> str: """Build a variant types query.""" parsed = VARIANT_TYPE_PARSER.transform_query_string_to_tree( variant_types_query) transformer = QueryTreeToSQLBitwiseTransformer( attr, use_bit_and_function=False) return cast(str, transformer.transform(parsed))
[docs] @staticmethod def check_variant_types_value( variant_types_query: str, value: int, ) -> bool: """Check if value satisfies a given variant types query.""" with duckdb.connect(":memory:") as con: query = QueryBuilderBase.build_variant_types_query( variant_types_query, str(value)) res = con.execute(f"SELECT {query}").fetchall() assert len(res) == 1 assert len(res[0]) == 1 return cast(bool, res[0][0])
[docs] def calc_frequency_bins( self, *, inheritance: Sequence[str] | None = None, roles: str | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, ) -> list[str]: """Calculate applicable frequency bins for a query.""" if self.partition_descriptor is None: return [] if "frequency_bin" not in self.schema.column_names("summary_table"): return [] assert "frequency_bin" in self.schema.column_names("summary_table") assert "frequency_bin" in self.schema.column_names("family_table") if roles and self.check_roles_denovo_only(roles): return ["0"] if inheritance and self.check_inheritance_denovo_only(inheritance): return ["0"] if not ultra_rare and frequency_filter is None: return [] frequency_bins: set[int] = set() if ultra_rare is not None and ultra_rare: frequency_bins.add(1) if frequency_filter is not None: for freq, (_, right) in frequency_filter: if freq != "af_allele_freq": continue if right is None: return [] assert right is not None if right <= self.partition_descriptor.rare_boundary: frequency_bins.add(2) elif right > self.partition_descriptor.rare_boundary: return [] if len(frequency_bins) == 0: return [] frequency_bins.add(0) # always search de Novo variants result: list[str] = [] if frequency_bins and len(frequency_bins) < 4: result = [ str(fb) for fb in range(max(frequency_bins) + 1) ] return result
[docs] def calc_family_bins( self, family_ids: Iterable[str] | None, person_ids: Iterable[str] | None, ) -> list[str]: """Calculate family bins for a query.""" if self.partition_descriptor is None: return [] if not self.partition_descriptor.has_family_bins(): return [] if "family_bin" not in self.schema.column_names("family_table"): return [] if family_ids is None and person_ids is None: return [] family_bins: set[str] = set() if family_ids is not None: assert family_ids is not None family_ids = set(family_ids) family_bins.update( str(self.partition_descriptor.make_family_bin(family_id)) for family_id in family_ids) if person_ids is not None: assert person_ids is not None person_ids = { pid for pid in person_ids if pid in self.families.persons_by_person_id } family_ids = { self.families.persons_by_person_id[person_id][0].family_id for person_id in person_ids } family_bins.update( str(self.partition_descriptor.make_family_bin(family_id)) for family_id in family_ids) if len(family_bins) >= self.partition_descriptor.family_bin_size // 2: return [] return list(family_bins)
[docs] def all_region_bins(self) -> list[str]: """Return all region bins.""" if self.partition_descriptor is None: return [] if not self.partition_descriptor.has_region_bins(): return [] chrom_lens = dict(self.reference_genome.get_all_chrom_lengths()) all_region_bins = self.partition_descriptor.make_all_region_bins( chrom_lens, ) if not self.partition_descriptor.integer_region_bins: all_region_bins = [ f"'{rb}'" for rb in all_region_bins ] return all_region_bins
[docs] def calc_heuristics( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, inheritance: Sequence[str] | None = None, roles: str | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, ) -> QueryHeuristics: """Calculate heuristic bins for a query.""" heuristics_region_bins = [] if genes is not None: regions = self.build_gene_regions(genes, regions) region_bins = self.calc_region_bins(regions) if region_bins: heuristics_region_bins = region_bins heuristics_coding_bins = [] coding_bins = self.calc_coding_bins(effect_types) if coding_bins: heuristics_coding_bins = coding_bins heuristics_frequency_bins = [] frequency_bins = self.calc_frequency_bins( inheritance=inheritance, roles=roles, ultra_rare=ultra_rare, frequency_filter=frequency_filter, ) if frequency_bins: heuristics_frequency_bins = frequency_bins heuristics_family_bins = [] family_bins = self.calc_family_bins(family_ids, person_ids) if family_bins: heuristics_family_bins = family_bins return QueryHeuristics( region_bins=heuristics_region_bins, coding_bins=heuristics_coding_bins, frequency_bins=heuristics_frequency_bins, family_bins=heuristics_family_bins, )
[docs] def calc_batched_heuristics( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, inheritance: Sequence[str] | None = None, roles: str | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, ) -> list[QueryHeuristics]: """Calculate heuristics baches for a query.""" heuristics = self.calc_heuristics( regions=regions, genes=genes, effect_types=effect_types, inheritance=inheritance, roles=roles, ultra_rare=ultra_rare, frequency_filter=frequency_filter, family_ids=family_ids, person_ids=person_ids, ) if heuristics.region_bins: # single batch if we have region bins in heuristics return [heuristics] if heuristics.frequency_bins: # single batch if we dont search for rare and common variants rare_and_common_bins = {"1", "2", "3"} if not rare_and_common_bins & set(heuristics.frequency_bins): return [heuristics] if heuristics.coding_bins and heuristics.frequency_bins: # single batch if we search for rare coding variants noncoding_bin = "0" common_bin = "3" if noncoding_bin not in heuristics.coding_bins and \ common_bin not in heuristics.frequency_bins: return [heuristics] if self.partition_descriptor and \ self.partition_descriptor.has_region_bins(): return [ QueryHeuristics( region_bins=[f"{rb}"], coding_bins=heuristics.coding_bins, frequency_bins=heuristics.frequency_bins, family_bins=heuristics.family_bins, ) for rb in self.all_region_bins() ] return [heuristics]
[docs] @staticmethod def build_schema( summary_schema: dict[str, str], family_schema: dict[str, str], pedigree_schema: dict[str, str], ) -> Schema: return ensure_schema( { "summary_table": summary_schema, "family_table": family_schema, "pedigree_table": pedigree_schema, }, )
[docs] class SqlQueryBuilder(QueryBuilderBase): # pylint: disable=too-many-public-methods """Build SQL queries using sqlglot.""" def __init__( self, db_layout: Db2Layout, *, schema: Schema, partition_descriptor: PartitionDescriptor | None, families: FamiliesData, gene_models: GeneModels, reference_genome: ReferenceGenome, ): super().__init__( schema=schema, families=families, partition_descriptor=partition_descriptor, gene_models=gene_models, reference_genome=reference_genome, ) self.db_layout = db_layout
[docs] @staticmethod def build( db_layout: Db2Layout, *, pedigree_schema: dict[str, str], summary_schema: dict[str, str], family_schema: dict[str, str], partition_descriptor: PartitionDescriptor | None, families: FamiliesData, gene_models: GeneModels, reference_genome: ReferenceGenome, ) -> SqlQueryBuilder: """Return a new instance of the builder.""" schema = ensure_schema( { "summary_table": summary_schema, "family_table": family_schema, "pedigree_table": pedigree_schema, }, ) return SqlQueryBuilder( db_layout=db_layout, schema=schema, partition_descriptor=partition_descriptor, families=families, gene_models=gene_models, reference_genome=reference_genome, )
[docs] @staticmethod def genes(genes: list[str]) -> Condition: """Create genes condition.""" if len(genes) == 0: return condition("eg.effect_gene_symbols IS NULL") if len(genes) == 1: return condition(f"eg.effect_gene_symbols = '{genes[0]}'") gene_set = ",".join(f"'{g}'" for g in genes) return condition(f"eg.effect_gene_symbols in ({gene_set})")
[docs] @staticmethod def effect_types(effect_types: list[str]) -> Condition: """Create effect types condition.""" effect_types = [et.replace("'", "''") for et in effect_types] if len(effect_types) == 0: return condition("eg.effect_types IS NULL") effect_set = ",".join(f"'{g}'" for g in effect_types) return condition(f"eg.effect_types in ({effect_set})")
[docs] @staticmethod def summary_base() -> Select: """Create summary base query.""" return exp.select("*").from_("summary_table as sa")
[docs] @staticmethod def family_base() -> Select: return exp.select("*").from_("family_table as fa")
[docs] def summary_variants( self, summary: Select, ) -> Select: """Construct summary variants query.""" return self._append_cte( target=Select(), source=summary, alias="summary", ).select( "sa.bucket_index", "sa.summary_index", "sa.allele_index", "sa.summary_variant_data", ).from_( "summary as sa", )
@staticmethod def _append_cte( target: Select, source: Select, alias: str, ) -> Select: if source.ctes: for cte in source.ctes: target = target.with_( cte.alias, as_=cte.this, ) else: target = target.with_( alias, as_=source, ) return target
[docs] def family_variants( self, summary: Select, family: Select, ) -> Select: """Construct family variants query.""" query = self._append_cte( Select(), summary, "summary", ) on_clause = ( "sa.bucket_index = fa.bucket_index " "and sa.summary_index = fa.summary_index " "and sa.allele_index = fa.allele_index" ) if "sj_index" in self.schema.column_names("family_table"): assert "sj_index" in self.schema.column_names("summary_table") on_clause = "sa.sj_index = fa.sj_index" return self._append_cte( query, family, "family", ).select( "fa.bucket_index", "fa.summary_index", "fa.family_index", "sa.allele_index", "sa.summary_variant_data", "fa.family_variant_data", ).from_( "summary as sa", ).join( "family as fa", on=on_clause, )
@staticmethod def _region_to_condition(reg: Region) -> Condition: if reg.start is None and reg.stop is None: return condition(f":chromosome = '{reg.chrom}'") if reg.start is None: assert reg.stop is not None return condition( f":chromosome = '{reg.chrom}'" f" AND NOT ( " f":position > {reg.stop} )", ) if reg.stop is None: assert reg.start is not None return condition( f":chromosome = '{reg.chrom}'" f" AND ( " f"COALESCE(:end_position, :position) > {reg.start} )", ) assert reg.stop is not None assert reg.start is not None return condition( f":chromosome = '{reg.chrom}'" f" AND NOT ( " f"COALESCE(:end_position, :position) < {reg.start} OR " f":position > {reg.stop} )", )
[docs] @staticmethod def regions(regions: list[Region]) -> Condition: """Create regions condition.""" assert len(regions) > 0 result = SqlQueryBuilder._region_to_condition(regions[0]) for reg in regions[1:]: result = or_( result, SqlQueryBuilder._region_to_condition(reg), ) return result
@staticmethod def _real_attr_filter( attr: str, value_range: tuple[float | None, float | None], *, is_frequency: bool = False, ) -> Condition: """Create real attribute condition.""" left, right = value_range if left is None and right is None: if is_frequency: return condition("1 = 1") return condition(f"sa.{attr} IS NOT NULL") if left is None: assert right is not None if is_frequency: return condition( f"sa.{attr} <= {right} OR sa.{attr} IS NULL", ) return condition(f"sa.{attr} <= {right}") if right is None: assert left is not None return condition(f"sa.{attr} >= {left}") return condition( f"sa.{attr} >= {left} AND sa.{attr} <= {right}", )
[docs] @staticmethod def frequency( real_attrs: RealAttrFilterType, ) -> Condition: """Build frequencies filter where condition.""" assert len(real_attrs) > 0 conditions = [ SqlQueryBuilder._real_attr_filter( attr, value_range, is_frequency=True, ) for attr, value_range in real_attrs ] if len(conditions) == 1: return conditions[0] result = conditions[0] for cond in conditions[1:]: result = result.and_(cond) return result
[docs] @staticmethod def real_attr( real_attrs: RealAttrFilterType, ) -> Condition: """Build real attributes filter where condition.""" assert len(real_attrs) > 0 conditions = [ SqlQueryBuilder._real_attr_filter( attr, value_range, is_frequency=False, ) for attr, value_range in real_attrs ] if len(conditions) == 1: return conditions[0] result = conditions[0] for cond in conditions[1:]: result = result.and_(cond) return result
@staticmethod def _categorical_attr_filter( attr: str, values: list[str] | list[int] | None, ) -> Condition: """Create real attribute condition.""" if values is None: return condition(f"sa.{attr} IS NULL") if len(values) == 0: return condition(f"sa.{attr} IS NOT NULL") if all(isinstance(v, str) for v in values): return condition(" OR ".join(f"sa.{attr} = '{v}'" for v in values)) if all(isinstance(v, int) for v in values): return condition(" OR ".join(f"sa.{attr} = {v}" for v in values)) raise TypeError(f"values must be all str or all int: {values}")
[docs] @staticmethod def categorical_attr( categorical_attrs: CategoricalAttrFilterType, ) -> Condition: """Build real attributes filter where condition.""" assert len(categorical_attrs) > 0 conditions = list(itertools.starmap( SqlQueryBuilder._categorical_attr_filter, categorical_attrs, )) if len(conditions) == 1: return conditions[0] result = conditions[0] for cond in conditions[1:]: result = result.and_(cond) return result
[docs] @staticmethod def ultra_rare() -> Condition: return SqlQueryBuilder._real_attr_filter( "af_allele_count", (None, 1), is_frequency=True)
[docs] def summary_query( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, categorical_attr_filter: CategoricalAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, ) -> Select: """Build a summary variant query.""" query = self.summary_base() if genes is not None: regions = self.build_gene_regions(genes, regions) if regions is not None: clause = self.regions(regions) query = query.where(replace_placeholders( clause, chromosome=exp.to_column("sa.chromosome"), position=exp.to_column("sa.position"), end_position=exp.to_column("sa.end_position"), )) if real_attr_filter: clause = self.real_attr(real_attr_filter) query = query.where(clause) if categorical_attr_filter: clause = self.categorical_attr(categorical_attr_filter) query = query.where(clause) if frequency_filter: clause = self.frequency(frequency_filter) query = query.where(clause) if ultra_rare is not None and ultra_rare: clause = self.ultra_rare() query = query.where(clause) if variant_type is not None: query = query.where( self.build_variant_types_query( variant_type, "sa.variant_type")) if not return_reference and not return_unknown: query = query.where("sa.allele_index > 0") if genes is not None or effect_types is not None: summary_effects = parse_one( "select *, unnest(sa.effect_gene) as eg " "from summary_base as sa", ) summary = exp.select( "*", ).from_( "summary_effects", ) if genes is not None: summary = summary.where( SqlQueryBuilder.genes(genes)) if effect_types is not None: summary = summary.where( SqlQueryBuilder.effect_types(effect_types)) query = Select().with_( "summary_base", as_=query, ).with_( "summary_effects", as_=summary_effects, ).with_( "summary", as_=summary, ).select("*").from_("summary") return query
[docs] @staticmethod def roles( roles_query: str, zygosity: int | None, ) -> Condition: """ Construct a roles query condition. Can match for zygosity in roles with a precalculated mask. """ query = condition( SqlQueryBuilder.build_roles_query( roles_query, "fa.allele_in_roles")) if zygosity is not None: query = query.and_( parse_one(f"fa.zygosity_in_roles & {zygosity} = {zygosity}"), ) return query
[docs] @staticmethod def sexes( sexes_query: str, zygosity: int | None, ) -> Condition: """ Construct a sexes query condition. Can match for zygosity in sexes with a precalculated mask. """ query = condition( SqlQueryBuilder.build_sexes_query( sexes_query, "fa.allele_in_sexes")) if zygosity is not None: query = query.and_( parse_one(f"fa.zygosity_in_sexes & {zygosity} != 0"), ) return query
[docs] @staticmethod def statuses( statuses_query: str, ) -> Condition: return condition( SqlQueryBuilder.build_statuses_query( statuses_query, "fa.allele_in_statuses"))
[docs] @staticmethod def inheritance( inheritance_query: str | Sequence[str], ) -> Condition: """Build inheritance filter.""" if isinstance(inheritance_query, str): return condition( SqlQueryBuilder.build_inheritance_query( [inheritance_query], "fa.inheritance_in_members")) return condition( SqlQueryBuilder.build_inheritance_query( inheritance_query, "fa.inheritance_in_members"))
[docs] @staticmethod def family_ids( family_ids: Sequence[str], ) -> Condition: """Create family IDs filter.""" if not family_ids: return condition("fa.family_id IS NULL") if len(family_ids) == 1: return condition(f"fa.family_id = '{next(iter(family_ids))}'") fids = [f"'{fid}'" for fid in family_ids] return condition(f"fa.family_id IN ({', '.join(fids)})")
[docs] @staticmethod def person_ids( person_ids: Sequence[str], ) -> Condition: """Create person IDs filter.""" if not person_ids: return condition("fa.aim IS NULL") if len(person_ids) == 1: return condition(f"fa.aim = '{next(iter(person_ids))}'") pids = [f"'{pid}'" for pid in person_ids] return condition(f"fa.aim IN ({', '.join(pids)})")
[docs] @staticmethod def resolve_tags( tags_query: TagsQuery, pedigree_table: Table, ) -> Expression | None: """Resolve tags query to an expression to use as a condition.""" pedigree_tags = None if tags_query.selected_family_tags is not None: for tag in tags_query.selected_family_tags: comparison = column( tag, pedigree_table.alias_or_name).eq("True") if pedigree_tags is None: pedigree_tags = comparison else: if tags_query.tags_or_mode: pedigree_tags = pedigree_tags.or_(comparison) else: pedigree_tags = pedigree_tags.and_(comparison) if tags_query.deselected_family_tags is not None: for tag in tags_query.deselected_family_tags: comparison = column( tag, pedigree_table.alias_or_name).eq("False") if pedigree_tags is None: pedigree_tags = comparison else: if tags_query.tags_or_mode: pedigree_tags = pedigree_tags.or_(comparison) else: pedigree_tags = pedigree_tags.and_(comparison) return pedigree_tags
[docs] def family_query( # pylint: disable=too-many-branches self, *, family_ids: Sequence[str] | None = None, person_ids: Sequence[str] | None = None, inheritance: str | Sequence[str] | None = None, roles_in_parent: str | None = None, roles_in_child: str | None = None, sexes: str | None = None, affected_statuses: str | None = None, tags_query: TagsQuery | None = None, zygosity: ZygosityQuery | None = None, ) -> Select: """Build a family subclause query.""" if tags_query is None: tags_query = TagsQuery() if zygosity is None: zygosity = ZygosityQuery() query = self.family_base() if roles_in_parent is not None: clause = self.roles(roles_in_parent, zygosity.parents_zygosity) query = query.where(clause) if roles_in_child is not None: clause = self.roles(roles_in_child, zygosity.children_zygosity) query = query.where(clause) if inheritance is not None: clause = self.inheritance(inheritance) query = query.where(clause) if sexes is not None: clause = self.sexes(sexes, zygosity.sex_zygosity) query = query.where(clause) if affected_statuses is not None: clause = self.statuses(affected_statuses) query = query.where(clause) if zygosity.status_zygosity is not None: status_zygosity = self.calc_zygosity_status_value( affected_statuses, zygosity, ) if status_zygosity is not None: expr = parse_one( "(fa.zygosity_in_status & " f"{status_zygosity}) != 0") query = query.where(expr) if family_ids is not None or person_ids is not None: if person_ids is not None: person_ids = [ pid for pid in person_ids if pid in self.families.persons_by_person_id ] fids = { self.families.persons_by_person_id[pid][0].family_id for pid in person_ids } if family_ids is not None: fids &= set(family_ids) family_ids = list(fids) assert family_ids is not None clause = self.family_ids(family_ids) query = query.where(clause) pedigree_table = table_("pedigree_table", alias="ped") pedigree_tags = self.resolve_tags(tags_query, pedigree_table) base_table = "family_base" ctes = [["family_base", query]] if pedigree_tags is not None: tagged_families_query = exp.select("family_id").from_( pedigree_table, ).where( pedigree_tags, ) filtered_by_tags_query = exp.select( f"{base_table}.*", ).from_("family_base").join( "tagged", on=( column("family_id", "family_base").eq( column("family_id", "tagged"), ) ), ) ctes.extend([ ["tagged", tagged_families_query], ["filtered_by_tags", filtered_by_tags_query], ]) base_table = "filtered_by_tags" if person_ids is not None: family_members = parse_one( "select *, unnest(fa.allele_in_members) as aim " # noqa: S608 f"from {base_table} as fa", ) family_query = exp.select( "*", ).from_( "family_members as fa", ).where( SqlQueryBuilder.person_ids(person_ids), ) ctes.extend([ ["family_members", family_members], ["filtered_members", family_query], ]) base_table = "filtered_members" if len(ctes) > 1: ctes[-1][0] = "family" query = Select() for cte in ctes: query = self._append_cte(query, cte[1], cte[0]) query = query.select("*").from_("family") return query
@staticmethod def _heuristic_bins( table: str, heuristic: str, bins: list[str], ) -> Condition: assert len(bins) > 0 if len(bins) == 1: return condition(f"{table}.{heuristic} = {bins[0]}") return condition( f"{table}.{heuristic} IN ({', '.join(bins)})")
[docs] @staticmethod def region_bins(table: str, region_bins: list[str]) -> Condition: """Create region bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "region_bin", region_bins)
[docs] @staticmethod def frequency_bins(table: str, frequency_bins: list[str]) -> Condition: """Create frequency bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "frequency_bin", frequency_bins)
[docs] @staticmethod def coding_bins(table: str, coding_bins: list[str]) -> Condition: """Create coding bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "coding_bin", coding_bins)
[docs] @staticmethod def family_bins(table: str, family_bins: list[str]) -> Condition: """Create family bins condition.""" return SqlQueryBuilder._heuristic_bins( table, "family_bin", family_bins)
[docs] def apply_summary_heuristics( self, query: Select, heuristics: QueryHeuristics | None, table: str = "sa", ) -> Select: """Apply heuristics to the summary query.""" if heuristics is None or heuristics.is_empty(): return query base_query = cast( Select, query.ctes[0].this if query.ctes else query, ) if heuristics.region_bins: base_query = base_query.where( self.region_bins(table, heuristics.region_bins)) if heuristics.frequency_bins: base_query = base_query.where( self.frequency_bins(table, heuristics.frequency_bins)) if heuristics.coding_bins: base_query = base_query.where( self.coding_bins(table, heuristics.coding_bins)) if not query.ctes: return base_query result = query.copy() result.ctes[0].args["this"] = base_query return result
[docs] def apply_family_heuristics( self, query: Select, heuristics: QueryHeuristics | None, ) -> Select: """Apply heuristics to the family query.""" if heuristics is None or heuristics.is_empty(): return query query = self.apply_summary_heuristics(query, heuristics, table="fa") if heuristics.family_bins: query = query.where( self.family_bins("fa", heuristics.family_bins)) return query
[docs] def build_summary_variants_query( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, categorical_attr_filter: CategoricalAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, limit: int | None = None, ) -> list[str]: """Build a query for summary variants.""" squery = self.summary_query( regions=regions, genes=genes, effect_types=effect_types, variant_type=variant_type, real_attr_filter=real_attr_filter, categorical_attr_filter=categorical_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, ) batched_heuristics = self.calc_batched_heuristics( regions=regions, genes=genes, effect_types=effect_types, ultra_rare=ultra_rare, frequency_filter=frequency_filter, ) result = [] for heuristics in batched_heuristics: query = self.summary_variants( summary=self.apply_summary_heuristics(squery, heuristics), ) if limit is not None: query = query.limit(limit) query = self.replace_tables(query) result.append(query.sql()) return result
[docs] def replace_tables(self, query: Select) -> Select: """Replace table names in the query.""" if self.db_layout.summary is None: assert self.db_layout.family is None return exp.replace_tables( query, { "pedigree_table": self.db_layout.pedigree, }, ) assert self.db_layout.summary is not None assert self.db_layout.family is not None if self.db_layout.db is None: return exp.replace_tables( query, { "summary_table": self.db_layout.summary, "family_table": self.db_layout.family, "pedigree_table": self.db_layout.pedigree, }, ) db = self.db_layout.db assert db is not None return exp.replace_tables( query, { "summary_table": f"{db}.{self.db_layout.summary}", "family_table": f"{db}.{self.db_layout.family}", "pedigree_table": f"{db}.{self.db_layout.pedigree}", }, )
[docs] @staticmethod def calc_zygosity_status_value( affected_statuses: str | None, zygosity: ZygosityQuery, ) -> int | None: """Extract from a query an int for filtering zygosity by status.""" if zygosity.status_zygosity is None: return None affected = False unaffected = False if affected_statuses is None: affected = True unaffected = True else: if QueryBuilderBase.check_statuses_query_value( affected_statuses, Status.affected.value): affected = True if QueryBuilderBase.check_statuses_query_value( affected_statuses, Status.unaffected.value): unaffected = True status_zygosity = Zygosity.from_name(zygosity.status_zygosity) zygosity_value = 0 if unaffected: zygosity_value |= status_zygosity.value if affected: zygosity_value |= status_zygosity.value << 2 return zygosity_value
[docs] def build_family_variants_query( # pylint: disable=too-many-arguments self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, family_ids: Sequence[str] | None = None, person_ids: Sequence[str] | None = None, inheritance: Sequence[str] | None = None, roles_in_parent: str | None = None, roles_in_child: str | None = None, sexes: str | None = None, affected_statuses: str | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, categorical_attr_filter: CategoricalAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, limit: int | None = None, tags_query: TagsQuery | None = None, zygosity_query: ZygosityQuery | None = None, **_kwargs: Any, ) -> list[str]: """Build a query for family variants.""" squery = self.summary_query( regions=regions, genes=genes, effect_types=effect_types, variant_type=variant_type, real_attr_filter=real_attr_filter, categorical_attr_filter=categorical_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, ) fquery = self.family_query( family_ids=family_ids, person_ids=person_ids, inheritance=inheritance, roles_in_parent=roles_in_parent, roles_in_child=roles_in_child, sexes=sexes, affected_statuses=affected_statuses, tags_query=tags_query, zygosity=zygosity_query, ) if roles_in_parent and roles_in_child: roles = f"({roles_in_parent}) and ({roles_in_child})" elif roles_in_child or roles_in_parent: roles = roles_in_child or roles_in_parent else: roles = None batched_heuristics = self.calc_batched_heuristics( regions=regions, genes=genes, effect_types=effect_types, inheritance=inheritance, roles=roles, ultra_rare=ultra_rare, frequency_filter=frequency_filter, family_ids=family_ids, ) result = [] for heuristics in batched_heuristics: query = self.family_variants( summary=self.apply_summary_heuristics(squery, heuristics), family=self.apply_family_heuristics(fquery, heuristics), ) if limit is not None: query = query.limit(limit) query = self.replace_tables(query) result.append(query.sql()) return result