Source code for dae.query_variants.sql.schema2.base_query_builder

import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, cast

from dae.genomic_resources.gene_models import (
    GeneModels,
    create_regions_from_genes,
)
from dae.genomic_resources.reference_genome import ReferenceGenome
from dae.parquet.partition_descriptor import PartitionDescriptor
from dae.pedigrees.families_data import FamiliesData
from dae.query_variants.attributes_query import (
    LeafNode,
    QueryTransformerMatcher,
    QueryTreeToSQLBitwiseTransformer,
    TreeNode,
    inheritance_query,
    role_query,
    sex_query,
    variant_type_query,
)
from dae.query_variants.attributes_query_inheritance import (
    InheritanceTransformer,
    inheritance_parser,
)
from dae.utils.regions import Region
from dae.variants.attributes import Inheritance

logger = logging.getLogger(__name__)
RealAttrFilterType = list[tuple[str, tuple[float | None, float | None]]]


[docs] class Dialect: """Caries info about a SQL dialect.""" def __init__(self, namespace: str | None = None): # namespace, self.namespace = namespace
[docs] @staticmethod def use_bit_and_function() -> bool: return True
[docs] @staticmethod def add_unnest_in_join() -> bool: return False
[docs] @staticmethod def float_type() -> str: return "float"
[docs] @staticmethod def array_item_suffix() -> str: return ".item"
[docs] @staticmethod def int_type() -> str: return "int"
[docs] @staticmethod def escape_char() -> str: return "`"
[docs] @staticmethod def escape_quote_char() -> str: return "\\"
[docs] def build_table_name(self, table: str, db: str | None) -> str: return f"`{self.namespace}`.{db}.{table}" if self.namespace else \ f"{db}.{table}"
[docs] def build_array_join(self, column: str, allias: str) -> str: return f"\n JOIN\n {column} AS {allias}"
# A type describing a schema as expected by the query builders TableSchema = dict[str, str] # family_variant_table & summary_allele_table are mandatory # - no reliance on a variants table as in impala
[docs] class BaseQueryBuilder(ABC): """Class that abstracts away the process of building a query.""" # pylint: disable=too-many-instance-attributes QUOTE = "'" WHERE = """ WHERE {where} """ GENE_REGIONS_HEURISTIC_CUTOFF = 20 GENE_REGIONS_HEURISTIC_EXTEND = 20000 MAX_CHILD_NUMBER = 9999 def __init__( self, dialect: Dialect, db: str | None, family_variant_table: str | None, summary_allele_table: str, pedigree_table: str, family_variant_schema: TableSchema | None, summary_allele_schema: TableSchema, partition_config: dict[str, Any] | None, pedigree_schema: TableSchema, families: FamiliesData, gene_models: GeneModels | None = None, reference_genome: ReferenceGenome | None = None, ): # pylint: disable=too-many-arguments assert summary_allele_table is not None self.dialect = dialect self.db = db self.family_variant_table = family_variant_table self.summary_allele_table = summary_allele_table self.pedigree_table = pedigree_table self.partition_config = partition_config or {} self.partition_descriptor = PartitionDescriptor.parse_dict( self.partition_config) if not family_variant_schema: family_variant_schema = {} self.family_columns = family_variant_schema.keys() self.summary_columns = summary_allele_schema.keys() self.combined_columns = { **family_variant_schema, **summary_allele_schema, } self.pedigree_columns = pedigree_schema self.families = families self.has_extra_attributes = "extra_attributes" in self.combined_columns self._product = "" self.gene_models = gene_models self.reference_genome = reference_genome self.query_columns = self._query_columns() self.where_accessors = self._where_accessors() def _where_accessors(self) -> dict[str, str]: cols = list(self.family_columns) + list(self.summary_columns) accessors = dict(zip(cols, cols, strict=True)) family_keys = set(self.family_columns) summary_keys = set(self.summary_columns) for key, value in accessors.items(): if value in summary_keys: accessors[key] = f"sa.{value}" elif value in family_keys: accessors[key] = f"fa.{value}" return accessors
[docs] def build_query( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, inheritance: str | list[str] | None = None, roles: str | None = None, sexes: str | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, limit: int | None = None, pedigree_fields: tuple | None = None, ) -> str: # pylint: disable=too-many-arguments,too-many-locals,unused-argument """Build an SQL query in the correct order.""" self._product = "" self._build_select() self._build_from() self._build_join(genes=genes, effect_types=effect_types) self._build_where( regions=regions, genes=genes, effect_types=effect_types, family_ids=family_ids, person_ids=person_ids, inheritance=inheritance, roles=roles, sexes=sexes, variant_type=variant_type, real_attr_filter=real_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, pedigree_fields=pedigree_fields, ) self._build_group_by() self._build_having() self._build_limit(limit) return self._product
def _build_select(self) -> None: columns = ", ".join(self.query_columns) select_clause = f"SELECT {columns}" self._add_to_product(select_clause) @abstractmethod def _build_from(self) -> None: """Build from clause.""" @abstractmethod def _build_join( self, genes: list[str] | None, effect_types: list[str] | None, ) -> None: """Build join clause.""" def _build_where_pedigree_fields( self, pedigree_fields: tuple | None, # noqa: ARG002 ) -> str: # pylint: disable=unused-argument return "" def _build_where( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, inheritance: str | list[str] | None = None, roles: str | None = None, sexes: str | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, pedigree_fields: tuple | None = None, **kwargs: Any, # noqa: ARG002 ) -> None: # pylint: disable=too-many-arguments,too-many-locals,unused-argument where_clause = self._build_where_string( regions=regions, genes=genes, effect_types=effect_types, family_ids=family_ids, person_ids=person_ids, inheritance=inheritance, roles=roles, sexes=sexes, variant_type=variant_type, real_attr_filter=real_attr_filter, ultra_rare=ultra_rare, frequency_filter=frequency_filter, return_reference=return_reference, return_unknown=return_unknown, pedigree_fields=pedigree_fields, ) self._add_to_product(where_clause) def _build_where_string( self, *, regions: list[Region] | None = None, genes: list[str] | None = None, effect_types: list[str] | None = None, family_ids: Iterable[str] | None = None, person_ids: Iterable[str] | None = None, inheritance: str | list[str] | None = None, roles: str | None = None, sexes: str | None = None, variant_type: str | None = None, real_attr_filter: RealAttrFilterType | None = None, ultra_rare: bool | None = None, frequency_filter: RealAttrFilterType | None = None, return_reference: bool | None = None, return_unknown: bool | None = None, pedigree_fields: tuple | None = None, **kwargs: Any, # noqa: ARG002 ) -> str: # pylint: disable=too-many-arguments,too-many-branches,unused-argument where = [] if genes is not None and effect_types is not None: regions = self._build_gene_regions_heuristic(genes, regions) clause = self._build_iterable_struct_string_attr_where( ["eg.effect_gene_symbols", "eg.effect_types"], [genes, effect_types], ) where.append(f"({clause})") if genes is not None and effect_types is None: regions = self._build_gene_regions_heuristic(genes, regions) # effect gene is a struct under affect gene in V2 schema where.append( self._build_iterable_struct_string_attr_where( ["eg.effect_gene_symbols"], [genes], ), ) if effect_types is not None and genes is None: # effect gene is a struct under affect gene in V2 schema where.append( self._build_iterable_struct_string_attr_where( ["eg.effect_types"], [effect_types], ), ) if regions is not None: where.append(self._build_regions_where(regions)) if family_ids is not None: where.append( self._build_iterable_string_attr_where( self.where_accessors["family_id"], family_ids, ), ) if person_ids is not None: person_ids = set(person_ids) where.append( self._build_iterable_string_attr_where( "pi" + self.dialect.array_item_suffix(), person_ids), ) if inheritance is not None: where.extend( self._build_inheritance_where( self.where_accessors["inheritance_in_members"], inheritance, use_bit_and_function=self.dialect.use_bit_and_function(), ), ) if roles is not None: where.append( self._build_bitwise_attr_where( self.where_accessors["allele_in_roles"], roles, role_query, ), ) if sexes is not None: where.append( self._build_bitwise_attr_where( self.where_accessors["allele_in_sexes"], sexes, sex_query, ), ) if variant_type is not None: where.append( self._build_bitwise_attr_where( self.where_accessors["variant_type"], variant_type, variant_type_query, ), ) if real_attr_filter is not None: where.append(self._build_real_attr_where(real_attr_filter)) if frequency_filter is not None: where.append( self._build_real_attr_where( frequency_filter, is_frequency=True, ), ) if ultra_rare: where.append(self._build_ultra_rare_where(ultra_rare=ultra_rare)) where.extend([ self._build_return_reference_and_return_unknown( return_reference=return_reference, return_unknown=return_unknown, ), self._build_frequency_bin_heuristic( inheritance=inheritance, ultra_rare=ultra_rare, frequency_filter=frequency_filter, ), self._build_family_bin_heuristic(family_ids, person_ids), self._build_coding_heuristic(effect_types), self._build_region_bin_heuristic(regions), self._build_where_pedigree_fields(pedigree_fields), ]) where = [w for w in where if w] where_clause = "" if where: where_clause = self.WHERE.format( where=(" AND \n" + " " * 4).join( [f"( {w} )" for w in where], ), ) return where_clause @abstractmethod def _build_group_by(self) -> None: pass def _build_limit(self, limit: int | None) -> None: if limit is not None: self._add_to_product(f"LIMIT {limit}") @abstractmethod def _build_having(self, **kwargs: Any) -> None: pass def _add_to_product(self, query_part: str | None) -> None: if query_part is None or query_part == "": return if self._product == "": self._product += query_part else: self._product += f" {query_part}" @abstractmethod def _query_columns(self) -> list[str]: pass def _build_real_attr_where( self, real_attr_filter: RealAttrFilterType, *, is_frequency: bool = False, ) -> str: query = [] for attr_name, attr_range in real_attr_filter: if attr_name not in self.combined_columns: query.append("false") continue assert attr_name in self.combined_columns assert ( self.combined_columns[attr_name] == self.dialect.float_type() or self.combined_columns[attr_name].startswith( self.dialect.int_type()) ), f"{attr_name} - {self.combined_columns}" left, right = attr_range attr_name = self.where_accessors[attr_name] if left is None and right is None: if not is_frequency: query.append(f"({attr_name} is not null)") elif left is None: assert right is not None if is_frequency: query.append( f"({attr_name} <= {right} or {attr_name} is null)", ) else: query.append( f"({attr_name} <= {right})", ) elif right is None: assert left is not None query.append(f"({attr_name} >= {left})") else: query.append( f"({attr_name} >= {left} AND {attr_name} <= {right})", ) return " AND ".join(query) def _build_ultra_rare_where(self, *, ultra_rare: bool) -> str: assert ultra_rare return self._build_real_attr_where( real_attr_filter=[("af_allele_count", (None, 1))], is_frequency=True, ) def _build_regions_where(self, regions: list[Region]) -> str: assert isinstance(regions, list), regions where = [] for region in regions: assert isinstance(region, Region) esc = self.dialect.escape_char() end_position = f"COALESCE(sa.{esc}end_position{esc}, -1)" query = "( sa.{esc}chromosome{esc} = {q}{chrom}{q}" if region.start is None and region.end is None: query += " )" query = query.format( q=self.QUOTE, chrom=region.chrom, esc=esc, ) else: region_start = region.start or 1 region_stop = region.stop or 3_000_000_000 query += ( " AND " "(" "(sa.{esc}position{esc} >= {start} AND " "sa.{esc}position{esc} <= {stop}) " "OR " "({end_position} >= {start} AND " "{end_position} <= {stop}) " "OR " "({start} >= sa.{esc}position{esc} AND " "{stop} <= {end_position})" "))" ) query = query.format( q=self.QUOTE, chrom=region.chrom, start=region_start, stop=region_stop, end_position=end_position, esc=esc, ) where.append(query) return " OR ".join(where) def _build_iterable_struct_string_attr_where( self, key_names: Iterable[str] | None = None, query_values: Iterable[Iterable[str]] | None = None, ) -> str: key_names = key_names or [] query_values = query_values or [] inner_clauses = [ self._build_iterable_string_attr_where(tup[0], tup[1]) for tup in zip(key_names, query_values, strict=True) ] return " AND ".join(inner_clauses) def _build_iterable_string_attr_where( self, column_name: str, query_values: Iterable[str], ) -> str: assert query_values is not None assert isinstance(query_values, (list, set)), type(query_values) if not query_values: return f" {column_name} IS NULL" values = [ " {q}{val}{q} ".format( q=self.QUOTE, val=val.replace("'", self.dialect.escape_quote_char() + "'"), ) for val in query_values ] where: list[str] = [] for i in range(0, len(values), self.MAX_CHILD_NUMBER): chunk_values = values[i: i + self.MAX_CHILD_NUMBER] in_expr = f" {column_name} in ( {','.join(chunk_values)} ) " where.append(in_expr) return " OR ".join([f"( {w} )" for w in where]) def _build_bitwise_attr_where( self, column_name: str, query_value: str, query_transformer: QueryTransformerMatcher, ) -> str: assert query_value is not None parsed: str | LeafNode | TreeNode = query_value if isinstance(query_value, str): parsed = query_transformer.transform_query_string_to_tree( query_value, ) transformer = QueryTreeToSQLBitwiseTransformer( column_name, self.dialect.use_bit_and_function(), ) return cast(str, transformer.transform(parsed)) @staticmethod def _build_inheritance_where( column_name: str, query_value: str | list[str], *, use_bit_and_function: bool, ) -> list[str]: trees = [] if isinstance(query_value, str): tree = inheritance_parser.parse(query_value) trees.append(tree) elif isinstance(query_value, list): for qval in query_value: tree = inheritance_parser.parse(qval) trees.append(tree) else: tree = query_value trees.append(tree) result = [] for tree in trees: transformer = InheritanceTransformer( column_name, use_bit_and_function=use_bit_and_function) res = transformer.transform(tree) result.append(res) return result def _build_gene_regions_heuristic( self, genes: list[str], regions: list[Region] | None, ) -> list[Region] | None: assert self.gene_models is not None return create_regions_from_genes( self.gene_models, genes, regions, self.GENE_REGIONS_HEURISTIC_CUTOFF, self.GENE_REGIONS_HEURISTIC_EXTEND, ) def _build_partition_bin_heuristic_where( self, bin_column: str, bins: list[str] | set[str], number_of_possible_bins: int | None = None, *, str_bins: bool = False, ) -> str: if len(bins) == 0: return "" if number_of_possible_bins is not None and \ len(bins) == number_of_possible_bins: return "" cols = [] if bin_column in self.family_columns: cols.append("fa." + bin_column) if bin_column in self.summary_columns: cols.append("sa." + bin_column) if str_bins: bins_str = ",".join([f"'{rb}'" for rb in bins]) else: bins_str = ",".join([f"{rb}" for rb in bins]) parts = [f"{col} IN ({bins_str})" for col in cols] return " AND ".join(parts) def _build_frequency_bin_heuristic_compute_bins( self, *, inheritance: str | list[str] | None, ultra_rare: bool | None, frequency_filter: RealAttrFilterType | None, rare_boundary: float, ) -> set[str]: frequency_bins: set[str] = set() matchers = [] if inheritance is not None: logger.debug( "frequence_bin_heuristic inheritance: %s (%s)", inheritance, type(inheritance), ) if isinstance(inheritance, str): inheritance = [inheritance] matchers = [ inheritance_query.transform_tree_to_matcher( inheritance_query.transform_query_string_to_tree(inh), ) for inh in inheritance ] if any(m.match([Inheritance.denovo]) for m in matchers): frequency_bins.add("0") has_frequency_filter = False if frequency_filter: for name, _ in frequency_filter: if name == "af_allele_freq": has_frequency_filter = True break if inheritance is None or any( m.match( [ Inheritance.mendelian, Inheritance.possible_denovo, Inheritance.possible_omission, ], ) for m in matchers ): if ultra_rare: frequency_bins |= {"0", "1"} elif has_frequency_filter: assert frequency_filter is not None for name, (begin, end) in frequency_filter: if name == "af_allele_freq": if end and end < rare_boundary: frequency_bins |= {"0", "1", "2"} elif (begin and begin >= rare_boundary) or \ (end is not None and end >= rare_boundary): frequency_bins |= {"0", "1", "2", "3"} elif inheritance is not None: frequency_bins |= {"0", "1", "2", "3"} return frequency_bins def _build_frequency_bin_heuristic( self, *, inheritance: None | str | list[str], ultra_rare: bool | None, frequency_filter: RealAttrFilterType | None, ) -> str: # pylint: disable=too-many-branches assert self.partition_config is not None if "frequency_bin" not in self.combined_columns: return "" rare_boundary = self.partition_config["rare_boundary"] frequency_bins = self._build_frequency_bin_heuristic_compute_bins( inheritance=inheritance, ultra_rare=ultra_rare, frequency_filter=frequency_filter, rare_boundary=rare_boundary) return self._build_partition_bin_heuristic_where( "frequency_bin", frequency_bins, 4) def _build_coding_heuristic( self, effect_types: None | set[str] | list[str], ) -> str: assert self.partition_config is not None if effect_types is None: return "" if "coding_bin" not in self.combined_columns: return "" effect_types = set(effect_types) intersection = effect_types & set( self.partition_config["coding_effect_types"], ) logger.debug( "coding bin heuristic: query effect types: %s; " "coding_effect_types: %s; => %s", effect_types, self.partition_config["coding_effect_types"], intersection == effect_types, ) coding_bins = set() if intersection == effect_types: coding_bins.add("1") elif not intersection: coding_bins.add("0") return self._build_partition_bin_heuristic_where( "coding_bin", coding_bins, 2) def _build_region_bin_heuristic( self, regions: list[Region] | None, ) -> str: if not regions or not self.partition_descriptor.has_region_bins(): return "" assert self.partition_descriptor.has_region_bins() assert self.reference_genome is not None chromsome_lengths = self.reference_genome.get_all_chrom_lengths() region_bins = set() for region in regions: region_bins.update( self.partition_descriptor.region_to_region_bins( region, chromsome_lengths, )) return self._build_partition_bin_heuristic_where( "region_bin", region_bins, str_bins=not self.partition_descriptor.integer_region_bins) def _build_family_bin_heuristic( self, family_ids: Iterable[str] | None, person_ids: Iterable[str] | None, ) -> str: assert self.partition_config is not None if "family_bin" not in self.combined_columns: return "" if "family_bin" not in self.pedigree_columns: return "" family_bins: set[str] = set() if family_ids: family_ids = set(family_ids) family_bins = family_bins.union( { p.family_bin # type: ignore for p in self.families.persons.values() if p.family_id in family_ids }, ) if person_ids: person_ids = set(person_ids) family_bins = family_bins.union( { p.family_bin # type: ignore for p in self.families.persons.values() if p.person_id in person_ids }, ) return self._build_partition_bin_heuristic_where( "family_bin", family_bins, self.partition_config["family_bin_size"]) def _build_return_reference_and_return_unknown( self, *, return_reference: bool | None = None, return_unknown: bool | None = None, # noqa: ARG002 ) -> str: # pylint: disable=unused-argument allele_index_col = self.where_accessors["allele_index"] if not return_reference: return f"{allele_index_col} > 0" # return_unknown basically means return all so no specific where # expression is required return ""