Source code for dae.gene_profile.db

import logging
from copy import copy
from typing import Any

import duckdb
import numpy as np
from box import Box
from sqlglot import column, select
from sqlglot.expressions import (
    ColumnConstraint,
    ColumnDef,
    Create,
    DataType,
    PrimaryKeyColumnConstraint,
    Schema,
    delete,
    insert,
    table_,
    update,
    values,
)

from dae.gene_profile.statistic import GPStatistic
from dae.utils.sql_utils import to_duckdb_transpile

logger = logging.getLogger(__name__)


[docs] class GeneProfileDB: """ Class for managing the gene profile database. Uses SQLite for DB management and supports loading and storing to filesystem. Has to be supplied a configuration and a path to which to read/write the SQLite DB. """ PAGE_SIZE = 50 def __init__( self, configuration: Box | dict | None, dbfile: str, ): # Support legacy sqlite gpdb duckdb.execute("INSTALL sqlite;") duckdb.execute("LOAD sqlite;") self.dbfile = dbfile self.configuration = \ GeneProfileDBWriter.build_configuration(configuration) self.table = table_("gene_profile") self.gene_sets_categories = {} if len(self.configuration.keys()): for category in self.configuration["gene_sets"]: category_name = category["category"] for gene_set in category["sets"]: collection_id = gene_set["collection_id"] set_id = gene_set["set_id"] full_gene_set_id = f"{collection_id}_{set_id}" self.gene_sets_categories[full_gene_set_id] = category_name
[docs] def get_gp(self, gene_symbol: str) -> GPStatistic | None: """ Query a GP by gene_symbol and return the row as statistic. Returns None if gene_symbol is not found within the DB. """ query = select("*").from_(self.table) query = query.where( column( "symbol_name", table=self.table.alias_or_name, ).ilike(f"%{gene_symbol}%"), ) query = query.limit(self.PAGE_SIZE) with duckdb.connect(f"{self.dbfile}", read_only=True) as connection: rows = connection.execute( to_duckdb_transpile(query), ).df().replace([np.nan], [None]).to_dict("records") if len(rows) == 0: return None return self.gp_from_table_row_single_view(rows[0])
[docs] def gp_from_table_row(self, row: dict) -> dict: # pylint: disable=too-many-locals """Build an GPStatistic from internal DB row.""" config = self.configuration result = {} result["geneSymbol"] = row["symbol_name"] for gs_category in config["genomic_scores"]: category_name = gs_category["category"] for score in gs_category["scores"]: score_name = score["score_name"] value = row[f"{category_name}_{score_name}"] result[".".join([category_name, score_name])] = \ value or None for gs_category in config["gene_sets"]: category_name = gs_category["category"] for gene_set in gs_category["sets"]: set_id = gene_set["set_id"] collection_id = gene_set["collection_id"] full_gs_id = f"{collection_id}_{set_id}" result[".".join([f"{category_name}_rank", set_id])] = \ "\u2713" if row[full_gs_id] else None for dataset_id, filters in config["datasets"].items(): for person_set in filters["person_sets"]: set_name = person_set["set_name"] for statistic in filters["statistics"]: statistic_id = statistic["id"] count = row[ f"{dataset_id}_{set_name}_{statistic_id}" ] rate = row[ f"{dataset_id}_{set_name}_{statistic_id}_rate" ] result[".".join([dataset_id, set_name, statistic_id])] = \ f"{count} ({round(rate, 2)})" if count else None return result
[docs] def gp_from_table_row_single_view(self, row: dict) -> GPStatistic: """Create an GPStatistic from single view row.""" # pylint: disable=too-many-locals config = self.configuration gene_symbol = row["symbol_name"] genomic_scores: dict[str, dict] = {} for gs_category in config["genomic_scores"]: category_name = gs_category["category"] genomic_scores[category_name] = {} for score in gs_category["scores"]: score_name = score["score_name"] full_score_id = f"{category_name}_{score_name}" genomic_scores[category_name][score_name] = { "value": row[full_score_id], "format": score["format"], } gene_sets_categories = config["gene_sets"] gene_sets = [] for gs_category in gene_sets_categories: category_name = gs_category["category"] for gene_set in gs_category["sets"]: set_id = gene_set["set_id"] collection_id = gene_set["collection_id"] full_gs_id = f"{collection_id}_{set_id}" if row[full_gs_id] == 1: gene_sets.append(full_gs_id) variant_counts = {} for dataset_id, filters in config["datasets"].items(): current_counts: dict[str, dict] = {} for person_set in filters["person_sets"]: set_name = person_set["set_name"] for statistic in filters["statistics"]: statistic_id = statistic["id"] counts = current_counts.get(set_name) if not counts: current_counts[set_name] = {} counts = current_counts[set_name] count = row[ f"{dataset_id}_{set_name}_{statistic_id}" ] rate = row[ f"{dataset_id}_{set_name}_{statistic_id}_rate" ] counts[statistic_id] = { "count": count, "rate": rate, } variant_counts[dataset_id] = current_counts return GPStatistic( gene_symbol, gene_sets, genomic_scores, variant_counts, )
def _transform_sort_by(self, sort_by: str) -> str: sort_by_tokens = sort_by.split(".") if sort_by.startswith("gene_set_"): sort_by = sort_by.replace("gene_set_", "", 1) if "_rank" in sort_by_tokens[0]: collection_id = "" category = sort_by_tokens[0].replace("_rank", "") if len(sort_by_tokens) == 2: for gs_category in self.configuration["gene_sets"]: if gs_category["category"] != category: continue for gene_set in gs_category["sets"]: if gene_set["set_id"] == sort_by_tokens[1]: collection_id = gene_set["collection_id"] sort_by = ".".join((collection_id, sort_by_tokens[1])) return sort_by.replace(".", "_")
[docs] def query_gps( self, page: int, symbol_like: str | None = None, sort_by: str | None = None, order: str | None = None, ) -> list: """ Perform paginated query and return list of GPs. Parameters: page - Which page to fetch. symbol_like - Which gene symbol to search for, supports incomplete search sort_by - Column to sort by order - "asc" or "desc" """ query = select("*").from_(self.table) if symbol_like: query = query.where( column( "symbol_name", table=self.table.alias_or_name, ).ilike(f"%{symbol_like}%"), ) if sort_by is not None: if order is None: order = "desc" sort_by = self._transform_sort_by(sort_by) order = "DESC" if order == "desc" else "ASC" query = query.order_by(f'"{sort_by}" {order}') if page is not None: query = query.limit(self.PAGE_SIZE).offset( self.PAGE_SIZE * (page - 1), ) # Can't have multiple connections with sqlite db alive when # one of those does a 'write' action. That's why connect here: with duckdb.connect(f"{self.dbfile}", read_only=True) as connection: return connection.execute( to_duckdb_transpile(query), ).df().replace([np.nan], [None]).to_dict("records")
[docs] def list_symbols( self, page: int, symbol_like: str | None = None, ) -> list[str]: """ Perform paginated query and return list of gene symbols. Parameters: page - Which page to fetch. symbol_like - Which gene symbol to search for, supports incomplete search """ query = select( column( "symbol_name", table=self.table.alias_or_name, ), ).from_(self.table) if symbol_like: query = query.where( column( "symbol_name", table=self.table.alias_or_name, ).ilike(f"{symbol_like}%"), ) query = query.order_by("symbol_name ASC") if page is not None: query = query.limit(self.PAGE_SIZE).offset( self.PAGE_SIZE * (page - 1), ) with duckdb.connect(f"{self.dbfile}", read_only=True) as connection: return [ row["symbol_name"] for row in connection.execute( to_duckdb_transpile(query), ).df().replace([np.nan], [None]).to_dict("records") ]
[docs] class GeneProfileDBWriter: """ Class for managing the gene profile database. Uses SQLite for DB management and supports loading and storing to filesystem. Has to be supplied a configuration and a path to which to read/write the SQLite DB. """ def __init__( self, configuration: Box | dict | None, dbfile: str, ): # Support legacy sqlite gpdb duckdb.execute("INSTALL sqlite;") duckdb.execute("LOAD sqlite;") self.dbfile = dbfile self.configuration = \ GeneProfileDBWriter.build_configuration(configuration) self._create_gp_table() self.gene_sets_categories = {} self._clear_gp_table() if len(self.configuration.keys()): for category in self.configuration["gene_sets"]: category_name = category["category"] for gene_set in category["sets"]: collection_id = gene_set["collection_id"] set_id = gene_set["set_id"] full_gene_set_id = f"{collection_id}_{set_id}" self.gene_sets_categories[full_gene_set_id] = category_name
[docs] @classmethod def build_configuration(cls, configuration: Box | dict | None) -> dict: """ Perform a transformation on a given configuration. The configuration is transformed to an internal version with more specific information on order and ranks. """ if configuration is None: return {} order = configuration.get("order") order_keys = [ gene_set["category"] + "_rank" for gene_set in configuration["gene_sets"] ] order_keys.extend( genomic_score["category"] for genomic_score in configuration["genomic_scores"] ) order_keys.extend(configuration["datasets"].keys()) if order is None: configuration["order"] = order_keys else: total_categories_count = \ len(configuration["gene_sets"]) + \ len(configuration["genomic_scores"]) + \ len(configuration["datasets"]) assert all(x in order_keys for x in order), \ "Given GP order has invalid entries" assert len(order) == total_categories_count, \ "Given GP order is missing items" return copy(configuration)
[docs] def drop_gp_table(self) -> None: with duckdb.connect(f"{self.dbfile}") as connection: connection.execute("DROP TABLE IF EXISTS gene_profile") connection.commit()
[docs] def gp_table_exists(self) -> bool: """Checks if gp table exists""" duckdb_tables = "duckdb_tables" query = select( column( "table_name", table=duckdb_tables, ), ).from_(duckdb_tables).where( column( "table_name", table=duckdb_tables, ).like("gene_profile"), ).limit(1) with duckdb.connect(f"{self.dbfile}", read_only=True) as connection: rows = connection.execute( to_duckdb_transpile(query), ).df().to_dict("records") return len(rows) == 1
def _gp_table_columns( self, *, with_types: bool = True, ) -> list[ColumnDef]: # pylint: disable=too-many-locals columns = [] constraints = [ColumnConstraint(kind=PrimaryKeyColumnConstraint())] \ if with_types else [] columns.append( ColumnDef( this="symbol_name", kind=DataType( this=DataType.Type.VARCHAR, ) if with_types else None, constraints=constraints, ), ) if len(self.configuration) == 0: return columns for category in self.configuration["gene_sets"]: category_name = category["category"] rank_col = f"{category_name}_rank" columns.append( ColumnDef( this=f'"{rank_col}"', kind=DataType( this=DataType.Type.INT, ) if with_types else None, ), ) for gene_set in category["sets"]: set_id = gene_set["set_id"] collection_id = gene_set["collection_id"] full_set_id = f"{collection_id}_{set_id}" columns.append( ColumnDef( this=f'"{full_set_id}"', kind=DataType( this=DataType.Type.INT, ) if with_types else None, ), ) for category in self.configuration["genomic_scores"]: category_name = category["category"] for score in category["scores"]: score_name = score["score_name"] col = f"{category_name}_{score_name}" columns.append( ColumnDef( this=f'"{col}"', kind=DataType( this=DataType.Type.FLOAT, ) if with_types else None, ), ) for dataset_id in self.configuration["datasets"]: config_section = self.configuration["datasets"][dataset_id] for person_set in config_section["person_sets"]: set_name = person_set["set_name"] for stat in config_section["statistics"]: stat_id = stat["id"] column_name = f"{dataset_id}_{set_name}_{stat_id}" columns.append( ColumnDef( this=f'"{column_name}"', kind=DataType( this=DataType.Type.FLOAT, ) if with_types else None, ), ) rate_col_name = f"{column_name}_rate" columns.append( ColumnDef( this=f'"{rate_col_name}"', kind=DataType( this=DataType.Type.FLOAT, ) if with_types else None, ), ) return columns def _create_gp_table(self) -> None: self.table = table_("gene_profile") self.schema = Schema( this=self.table, expressions=self._gp_table_columns(), ) query = Create(this=self.schema, kind="TABLE", exists=True) with duckdb.connect(f"{self.dbfile}") as connection: connection.execute(to_duckdb_transpile(query)) def _clear_gp_table( self, connection: duckdb.DuckDBPyConnection | None = None, ) -> None: query = delete(self.table) if connection is not None: connection.execute(to_duckdb_transpile(query)) return with duckdb.connect(f"{self.dbfile}") as connection: connection.execute(to_duckdb_transpile(query))
[docs] def insert_gp( self, gp: GPStatistic, connection: duckdb.DuckDBPyConnection | None = None, ) -> None: """Insert a GP into the DB.""" insert_map = self._create_insert_map(gp) query = insert( values([tuple(insert_map.values())]), self.table, columns=list(insert_map.keys()), ) if connection is not None: connection.execute(to_duckdb_transpile(query)) return with duckdb.connect(f"{self.dbfile}") as connection: connection.execute(to_duckdb_transpile(query))
def _create_insert_map( self, gp: GPStatistic, ) -> dict: # pylint: disable=too-many-locals insert_map: dict[str, str | int] = { "symbol_name": gp.gene_symbol, } gs_categories_count = { c["category"]: 0 for c in self.configuration["gene_sets"] } for gsc in self.configuration["gene_sets"]: category = gsc["category"] for gene_set in gsc["sets"]: collection_id = gene_set["collection_id"] gs_id = gene_set["set_id"] set_col = f"{collection_id}_{gs_id}" if set_col in gp.gene_sets: insert_map[set_col] = 1 gs_categories_count[category] += 1 else: insert_map[set_col] = 0 for category, count in gs_categories_count.items(): insert_map[f"{category}_rank"] = count for category, scores in gp.genomic_scores.items(): for score_id, score in scores.items(): insert_map[f"{category}_{score_id}"] = score insert_map.update(dict(gp.variant_counts.items())) return insert_map
[docs] def insert_gps( self, gps: dict, ) -> None: """Insert multiple GPStatistics into the DB.""" with duckdb.connect(f"{self.dbfile}") as connection: self._clear_gp_table(connection) gp_count = len(gps) for idx, gp in enumerate(gps, 1): self.insert_gp(gp, connection) if idx % 1000 == 0: logger.info( "Inserted %s/%s GPs into DB", idx, gp_count) logger.info("Done!") connection.commit()
[docs] def update_gps_with_values(self, gs_values: dict[str, Any]) -> None: """Update gp statistic with values""" with duckdb.connect(f"{self.dbfile}") as connection: for gs, vals in gs_values.items(): query = update( self.table, vals, where=column( "symbol_name", table=self.table.alias_or_name, ).eq(gs), ) connection.execute(to_duckdb_transpile(query)) connection.commit()