import logging
import time
from collections.abc import Iterable
from copy import copy
from typing import Any
import duckdb
import numpy as np
from box import Box
from sqlglot import column, select
from sqlglot.expressions import (
ColumnConstraint,
ColumnDef,
Create,
DataType,
PrimaryKeyColumnConstraint,
Schema,
delete,
insert,
table_,
update,
values,
)
from dae.gene_profile.statistic import GPStatistic
from dae.utils.sql_utils import to_duckdb_transpile
logger = logging.getLogger(__name__)
[docs]
class GeneProfileDB:
"""
Class for managing the gene profile database.
Uses SQLite for DB management and supports loading
and storing to filesystem.
Has to be supplied a configuration and a path to which to read/write
the SQLite DB.
"""
PAGE_SIZE = 50
def __init__(
self,
configuration: Box | dict | None,
dbfile: str,
):
# Support legacy sqlite gpdb
duckdb.execute("INSTALL sqlite;")
duckdb.execute("LOAD sqlite;")
self.dbfile = dbfile
self.configuration = \
GeneProfileDBWriter.build_configuration(configuration)
self.table = table_("gene_profile")
self.gene_sets_categories = {}
if len(self.configuration.keys()):
for category in self.configuration["gene_sets"]:
category_name = category["category"]
for gene_set in category["sets"]:
collection_id = gene_set["collection_id"]
set_id = gene_set["set_id"]
full_gene_set_id = f"{collection_id}_{set_id}"
self.gene_sets_categories[full_gene_set_id] = category_name
[docs]
def get_gp(self, gene_symbol: str) -> GPStatistic | None:
"""
Query a GP by gene_symbol and return the row as statistic.
Returns None if gene_symbol is not found within the DB.
"""
query = select("*") \
.from_(self.table) \
.where(f"symbol_name = '{gene_symbol}'")
with duckdb.connect(f"{self.dbfile}", read_only=True) as connection:
rows = connection.execute(
to_duckdb_transpile(query),
).df().replace([np.nan], [None]).to_dict("records")
if len(rows) == 0:
return None
if len(rows) > 1:
raise ValueError(
f"More than one gene profile with symbol name {gene_symbol}",
)
return self.gp_from_table_row_single_view(rows[0])
[docs]
def gp_from_table_row(self, row: dict) -> dict:
# pylint: disable=too-many-locals
"""Build an GPStatistic from internal DB row."""
config = self.configuration
result = {}
result["geneSymbol"] = row["symbol_name"]
for gs_category in config["gene_scores"]:
category_name = gs_category["category"]
for score in gs_category["scores"]:
score_name = score["score_name"]
value = row[f"{category_name}_{score_name}"]
result[".".join([category_name, score_name])] = \
value or None
for gs_category in config["gene_sets"]:
category_name = gs_category["category"]
for gene_set in gs_category["sets"]:
set_id = gene_set["set_id"]
collection_id = gene_set["collection_id"]
full_gs_id = f"{collection_id}_{set_id}"
result[".".join([f"{category_name}_rank", set_id])] = \
"\u2713" if row[full_gs_id] else None
for dataset_id, filters in config["datasets"].items():
for person_set in filters["person_sets"]:
set_name = person_set["set_name"]
for statistic in filters["statistics"]:
statistic_id = statistic["id"]
count = row[
f"{dataset_id}_{set_name}_{statistic_id}"
]
rate = row[
f"{dataset_id}_{set_name}_{statistic_id}_rate"
]
result[".".join([dataset_id, set_name, statistic_id])] = \
f"{count} ({round(rate, 2)})" if count else None
return result
[docs]
def gp_from_table_row_single_view(self, row: dict) -> GPStatistic:
"""Create an GPStatistic from single view row."""
# pylint: disable=too-many-locals
config = self.configuration
gene_symbol = row["symbol_name"]
gene_scores: dict[str, dict] = {}
for gs_category in config["gene_scores"]:
category_name = gs_category["category"]
gene_scores[category_name] = {}
for score in gs_category["scores"]:
score_name = score["score_name"]
full_score_id = f"{category_name}_{score_name}"
gene_scores[category_name][score_name] = {
"value": row[full_score_id],
"format": score["format"],
}
gene_sets_categories = config["gene_sets"]
gene_sets = []
for gs_category in gene_sets_categories:
category_name = gs_category["category"]
for gene_set in gs_category["sets"]:
set_id = gene_set["set_id"]
collection_id = gene_set["collection_id"]
full_gs_id = f"{collection_id}_{set_id}"
if row[full_gs_id] == 1:
gene_sets.append(full_gs_id)
variant_counts = {}
for dataset_id, filters in config["datasets"].items():
current_counts: dict[str, dict] = {}
for person_set in filters["person_sets"]:
set_name = person_set["set_name"]
for statistic in filters["statistics"]:
statistic_id = statistic["id"]
counts = current_counts.get(set_name)
if not counts:
current_counts[set_name] = {}
counts = current_counts[set_name]
count = row[
f"{dataset_id}_{set_name}_{statistic_id}"
]
rate = row[
f"{dataset_id}_{set_name}_{statistic_id}_rate"
]
counts[statistic_id] = {
"count": count,
"rate": rate,
}
variant_counts[dataset_id] = current_counts
return GPStatistic(
gene_symbol, gene_sets,
gene_scores, variant_counts,
)
def _transform_sort_by(self, sort_by: str) -> str:
sort_by_tokens = sort_by.split(".")
if sort_by.startswith("gene_set_"):
sort_by = sort_by.replace("gene_set_", "", 1)
if "_rank" in sort_by_tokens[0]:
collection_id = ""
category = sort_by_tokens[0].replace("_rank", "")
if len(sort_by_tokens) == 2:
for gs_category in self.configuration["gene_sets"]:
if gs_category["category"] != category:
continue
for gene_set in gs_category["sets"]:
if gene_set["set_id"] == sort_by_tokens[1]:
collection_id = gene_set["collection_id"]
sort_by = ".".join((collection_id, sort_by_tokens[1]))
return sort_by.replace(".", "_")
[docs]
def query_gps(
self,
page: int,
symbol_like: str | None = None,
sort_by: str | None = None,
order: str | None = None,
) -> list:
"""
Perform paginated query and return list of GPs.
Parameters:
page - Which page to fetch.
symbol_like - Which gene symbol to search for, supports
incomplete search
sort_by - Column to sort by
order - "asc" or "desc"
"""
query = select("*").from_(self.table)
if symbol_like:
query = query.where(
column(
"symbol_name",
table=self.table.alias_or_name,
).ilike(f"%{symbol_like}%"),
)
if sort_by is not None:
if order is None:
order = "desc"
sort_by = self._transform_sort_by(sort_by)
order = "DESC" if order == "desc" else "ASC"
query = query.order_by(f'"{sort_by}" {order}')
if page is not None:
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
# Can't have multiple connections with sqlite db alive when
# one of those does a 'write' action. That's why connect here:
with duckdb.connect(f"{self.dbfile}", read_only=True) as connection:
return connection.execute(
to_duckdb_transpile(query),
).df().replace([np.nan], [None]).to_dict("records")
[docs]
def list_symbols(
self, page: int, symbol_like: str | None = None,
) -> list[str]:
"""
Perform paginated query and return list of gene symbols.
Parameters:
page - Which page to fetch.
symbol_like - Which gene symbol to search for, supports
incomplete search
"""
query = select(
column(
"symbol_name",
table=self.table.alias_or_name,
),
).from_(self.table)
if symbol_like:
query = query.where(
column(
"symbol_name",
table=self.table.alias_or_name,
).ilike(f"{symbol_like}%"),
)
query = query.order_by("symbol_name ASC")
if page is not None:
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
with duckdb.connect(f"{self.dbfile}", read_only=True) as connection:
return [
row["symbol_name"] for row in connection.execute(
to_duckdb_transpile(query),
).df().replace([np.nan], [None]).to_dict("records")
]
[docs]
class GeneProfileDBWriter:
"""
Class for managing the gene profile database.
Uses SQLite for DB management and supports loading
and storing to filesystem.
Has to be supplied a configuration and a path to which to read/write
the SQLite DB.
"""
def __init__(
self,
configuration: Box | dict | None,
dbfile: str,
):
# Support legacy sqlite gpdb
duckdb.execute("INSTALL sqlite;")
duckdb.execute("LOAD sqlite;")
self.dbfile = dbfile
self.configuration = \
GeneProfileDBWriter.build_configuration(configuration)
self._create_gp_table()
self.gene_sets_categories = {}
self._clear_gp_table()
if len(self.configuration.keys()):
for category in self.configuration["gene_sets"]:
category_name = category["category"]
for gene_set in category["sets"]:
collection_id = gene_set["collection_id"]
set_id = gene_set["set_id"]
full_gene_set_id = f"{collection_id}_{set_id}"
self.gene_sets_categories[full_gene_set_id] = category_name
[docs]
@classmethod
def build_configuration(cls, configuration: Box | dict | None) -> dict:
"""
Perform a transformation on a given configuration.
The configuration is transformed to an internal version with more
specific information on order and ranks.
"""
if configuration is None:
return {}
order = configuration.get("order")
order_keys = [
gene_set["category"] + "_rank"
for gene_set in configuration["gene_sets"]
]
order_keys.extend(
gene_score["category"]
for gene_score in configuration["gene_scores"]
)
order_keys.extend(configuration["datasets"].keys())
if order is None:
configuration["order"] = order_keys
else:
total_categories_count = \
len(configuration["gene_sets"]) + \
len(configuration["gene_scores"]) + \
len(configuration["datasets"])
assert all(x in order_keys for x in order), \
"Given GP order has invalid entries"
assert len(order) == total_categories_count, \
"Given GP order is missing items"
return copy(configuration)
[docs]
def drop_gp_table(self) -> None:
with duckdb.connect(f"{self.dbfile}") as connection:
connection.execute("DROP TABLE IF EXISTS gene_profile")
connection.commit()
[docs]
def gp_table_exists(self) -> bool:
"""Checks if gp table exists"""
duckdb_tables = "duckdb_tables"
query = select(
column(
"table_name",
table=duckdb_tables,
),
).from_(duckdb_tables).where(
column(
"table_name",
table=duckdb_tables,
).like("gene_profile"),
).limit(1)
with duckdb.connect(f"{self.dbfile}", read_only=True) as connection:
rows = connection.execute(
to_duckdb_transpile(query),
).df().to_dict("records")
return len(rows) == 1
def _gp_table_columns(
self, *,
with_types: bool = True,
) -> list[ColumnDef]: # pylint: disable=too-many-locals
columns = []
constraints = [ColumnConstraint(kind=PrimaryKeyColumnConstraint())] \
if with_types else []
columns.append(
ColumnDef(
this="symbol_name",
kind=DataType(
this=DataType.Type.VARCHAR,
) if with_types else None,
constraints=constraints,
),
)
if len(self.configuration) == 0:
return columns
for category in self.configuration["gene_sets"]:
category_name = category["category"]
rank_col = f"{category_name}_rank"
columns.append(
ColumnDef(
this=f'"{rank_col}"',
kind=DataType(
this=DataType.Type.INT,
) if with_types else None,
),
)
for gene_set in category["sets"]:
set_id = gene_set["set_id"]
collection_id = gene_set["collection_id"]
full_set_id = f"{collection_id}_{set_id}"
columns.append(
ColumnDef(
this=f'"{full_set_id}"',
kind=DataType(
this=DataType.Type.INT,
) if with_types else None,
),
)
for category in self.configuration["gene_scores"]:
category_name = category["category"]
for score in category["scores"]:
score_name = score["score_name"]
col = f"{category_name}_{score_name}"
columns.append(
ColumnDef(
this=f'"{col}"',
kind=DataType(
this=DataType.Type.FLOAT,
) if with_types else None,
),
)
for dataset_id in self.configuration["datasets"]:
config_section = self.configuration["datasets"][dataset_id]
for person_set in config_section["person_sets"]:
set_name = person_set["set_name"]
for stat in config_section["statistics"]:
stat_id = stat["id"]
column_name = f"{dataset_id}_{set_name}_{stat_id}"
columns.append(
ColumnDef(
this=f'"{column_name}"',
kind=DataType(
this=DataType.Type.FLOAT,
) if with_types else None,
),
)
rate_col_name = f"{column_name}_rate"
columns.append(
ColumnDef(
this=f'"{rate_col_name}"',
kind=DataType(
this=DataType.Type.FLOAT,
) if with_types else None,
),
)
return columns
def _create_gp_table(self) -> None:
self.table = table_("gene_profile")
self.schema = Schema(
this=self.table,
expressions=self._gp_table_columns(),
)
query = Create(this=self.schema, kind="TABLE", exists=True)
with duckdb.connect(f"{self.dbfile}") as connection:
connection.execute(to_duckdb_transpile(query))
def _clear_gp_table(
self,
connection: duckdb.DuckDBPyConnection | None = None,
) -> None:
query = delete(self.table)
if connection is not None:
connection.execute(to_duckdb_transpile(query))
return
with duckdb.connect(f"{self.dbfile}") as conn:
conn.execute(to_duckdb_transpile(query))
[docs]
def insert_gp(
self,
gp: GPStatistic,
connection: duckdb.DuckDBPyConnection | None = None,
) -> None:
"""Insert a GP into the DB."""
insert_map = self._create_insert_map(gp)
query = insert(
values([tuple(insert_map.values())]),
self.table,
columns=list(insert_map.keys()),
)
if connection is not None:
connection.execute(to_duckdb_transpile(query))
return
with duckdb.connect(f"{self.dbfile}") as conn:
conn.execute(to_duckdb_transpile(query))
def _create_insert_map(
self,
gp: GPStatistic,
) -> dict: # pylint: disable=too-many-locals
insert_map: dict[str, str | int] = {
"symbol_name": gp.gene_symbol,
}
gs_categories_count = {
c["category"]: 0
for c in self.configuration["gene_sets"]
}
for gsc in self.configuration["gene_sets"]:
category = gsc["category"]
for gene_set in gsc["sets"]:
collection_id = gene_set["collection_id"]
gs_id = gene_set["set_id"]
set_col = f"{collection_id}_{gs_id}"
if set_col in gp.gene_sets:
insert_map[set_col] = 1
gs_categories_count[category] += 1
else:
insert_map[set_col] = 0
for category, count in gs_categories_count.items():
insert_map[f"{category}_rank"] = count
for category, scores in gp.gene_scores.items():
for score_id, score in scores.items():
insert_map[f"{category}_{score_id}"] = score
insert_map.update(gp.variant_counts.items())
return insert_map
[docs]
def insert_gps(
self,
gps: Iterable[GPStatistic],
) -> None:
"""Insert multiple GPStatistics into the DB."""
with duckdb.connect(f"{self.dbfile}") as connection:
self._clear_gp_table(connection)
cols = None
vals = []
for gp in gps:
insert_map = self._create_insert_map(gp)
if cols is None:
cols = list(insert_map.keys())
vals.append(tuple(insert_map.values()))
query = insert(
values(vals),
self.table,
columns=cols,
)
connection.execute(to_duckdb_transpile(query))
logger.info("Done!")
connection.commit()
[docs]
def update_gps_with_values(self, gs_values: dict[str, Any]) -> None:
"""Update gp statistic with values"""
with duckdb.connect(f"{self.dbfile}") as connection:
started = time.time()
for idx, (gs, vals) in enumerate(gs_values.items(), 1):
query = update(
self.table,
vals,
where=column(
"symbol_name",
table=self.table.alias_or_name,
).eq(gs),
)
connection.execute(to_duckdb_transpile(query))
if idx % 1000 == 0:
elapsed = time.time() - started
logger.info(
"Updated %s/%s GP statistics in %.2f seconds",
idx, len(gs_values), elapsed)
connection.commit()