from __future__ import annotations
import textwrap
from collections.abc import Iterator
from functools import reduce
from typing import Any
import duckdb
import sqlglot
from duckdb import (
ConstraintException,
)
from sqlglot import column, expressions, select
from sqlglot.expressions import (
Count,
Null,
Table,
alias_,
delete,
insert,
table_,
update,
values,
)
from dae.pheno.common import MeasureType
from dae.utils.sql_utils import to_duckdb_transpile
[docs]
class PhenoBrowser:
"""Class for handling saving and loading of phenotype browser data."""
PAGE_SIZE = 1001
def __init__(
self, dbfile: str, *, read_only: bool = True,
):
self.dbfile = dbfile
self.connection = duckdb.connect(
f"{dbfile}", read_only=read_only)
if not read_only:
PhenoBrowser.create_browser_tables(self.connection)
self.variable_browser = table_("variable_browser")
self.regressions = table_("regression")
self.regression_values = table_("regression_values")
self.instrument_descriptions = table_("instrument_descriptions")
self.measure_descriptions = table_("measure_descriptions")
self.is_legacy = self._is_browser_legacy()
self._closed = False
def _is_browser_legacy(self) -> bool:
"""Handle legacy databases."""
with self.connection.cursor() as cursor:
result = cursor.execute("SHOW TABLES").fetchall()
tables = [row[0] for row in result]
return bool(
"instrument_descriptions" not in tables
or "measure_descriptions" not in tables)
[docs]
def close(self) -> None:
"""Close the connection to the database."""
if self.connection is not None and not self._closed:
self.connection.close()
self._closed = True
[docs]
@staticmethod
def create_browser_tables(conn: duckdb.DuckDBPyConnection) -> None:
"""Create tables for the browser DB."""
create_variable_browser = sqlglot.parse(textwrap.dedent(
"""
CREATE TABLE IF NOT EXISTS variable_browser(
measure_id VARCHAR NOT NULL UNIQUE PRIMARY KEY,
instrument_name VARCHAR NOT NULL,
measure_name VARCHAR NOT NULL,
measure_type INT NOT NULL,
values_domain VARCHAR,
figure_distribution_small VARCHAR,
figure_distribution VARCHAR
);
CREATE UNIQUE INDEX IF NOT EXISTS variable_browser_measure_id_idx
ON variable_browser (measure_id);
CREATE INDEX IF NOT EXISTS variable_browser_instrument_name_idx
ON variable_browser (instrument_name);
CREATE INDEX IF NOT EXISTS variable_browser_measure_name_idx
ON variable_browser (measure_name);
""",
))
create_regression = sqlglot.parse(textwrap.dedent(
"""
CREATE TABLE IF NOT EXISTS regression(
regression_id VARCHAR NOT NULL UNIQUE PRIMARY KEY,
instrument_name VARCHAR,
measure_name VARCHAR NOT NULL,
display_name VARCHAR,
);
CREATE UNIQUE INDEX IF NOT EXISTS regression_regression_id_idx
ON regression (regression_id);
""",
))
create_regression_values = sqlglot.parse(textwrap.dedent(
"""
CREATE TABLE IF NOT EXISTS regression_values(
regression_id VARCHAR NOT NULL,
measure_id VARCHAR NOT NULL,
figure_regression VARCHAR,
figure_regression_small VARCHAR,
pvalue_regression_male DOUBLE,
pvalue_regression_female DOUBLE,
PRIMARY KEY (regression_id, measure_id)
);
CREATE INDEX IF NOT EXISTS regression_values_regression_id_idx
ON regression_values (regression_id);
CREATE INDEX IF NOT EXISTS regression_values_measure_id_idx
ON regression_values (measure_id);
""",
))
create_instrument_descriptions = sqlglot.parse(textwrap.dedent(
"""
CREATE TABLE IF NOT EXISTS instrument_descriptions(
instrument_name VARCHAR NOT NULL UNIQUE PRIMARY KEY,
description VARCHAR,
);
CREATE UNIQUE INDEX IF NOT EXISTS
instrument_descriptions_instrument_name_idx
ON instrument_descriptions (instrument_name);
""",
))
create_measure_descriptions = sqlglot.parse(textwrap.dedent(
"""
CREATE TABLE IF NOT EXISTS measure_descriptions(
measure_id VARCHAR NOT NULL UNIQUE PRIMARY KEY,
description VARCHAR,
);
CREATE UNIQUE INDEX IF NOT EXISTS
measure_descriptions_measure_id_idx
ON measure_descriptions (measure_id);
""",
))
queries = [
*create_variable_browser,
*create_regression,
*create_regression_values,
*create_instrument_descriptions,
*create_measure_descriptions,
]
with conn.cursor() as cursor:
for query in queries:
cursor.execute(to_duckdb_transpile(query))
[docs]
def save(self, v: dict[str, Any]) -> None:
"""Save measure values into the database."""
with self.connection.cursor() as cursor:
if not self.is_legacy:
instrument_desc_query = to_duckdb_transpile(insert(
values([(
v["instrument_name"],
v.get("instrument_description", ""),
)]),
self.instrument_descriptions,
columns=["instrument_name", "description"],
))
v.pop("instrument_description", None)
try:
cursor.execute(instrument_desc_query)
except ConstraintException:
delete_instrument_desc_query = to_duckdb_transpile(delete(
self.instrument_descriptions,
).where("instrument_name").eq(v["instrument_name"]))
cursor.execute(delete_instrument_desc_query)
cursor.execute(instrument_desc_query)
measure_desc_query = to_duckdb_transpile(insert(
values([(
v["measure_id"],
v.get("description", ""),
)]),
self.measure_descriptions,
columns=["measure_id", "description"],
))
v.pop("description", None)
try:
cursor.execute(measure_desc_query)
except ConstraintException:
delete_measure_desc_query = to_duckdb_transpile(delete(
self.measure_descriptions,
).where("measure_id").eq(v["measure_id"]))
cursor.execute(delete_measure_desc_query)
cursor.execute(measure_desc_query)
measure_query = to_duckdb_transpile(insert(
values([(*v.values(),)]),
self.variable_browser,
columns=[*v.keys()],
))
try:
cursor.execute(measure_query)
except ConstraintException:
delete_measure_query = to_duckdb_transpile(delete(
self.variable_browser,
).where("measure_id").eq(v["measure_id"]))
cursor.execute(delete_measure_query)
cursor.execute(measure_query)
[docs]
def save_regression(self, reg: dict[str, str]) -> None:
"""Save regressions into the database."""
query = to_duckdb_transpile(insert(
values([(*reg.values(),)]),
self.regressions,
columns=[*reg.keys()],
))
try:
with self.connection.cursor() as cursor:
cursor.execute(query)
except ConstraintException: # pylint: disable=broad-except
regression_id = reg["regression_id"]
del reg["regression_id"]
update_query = to_duckdb_transpile(update(
self.regressions, reg,
).where("regression_id").eq(regression_id))
with self.connection.cursor() as cursor:
cursor.execute(update_query)
[docs]
def save_regression_values(self, reg: dict[str, str]) -> None:
"""Save regression values into the databases."""
query = insert(
values([(*reg.values(),)]),
self.regression_values,
columns=[*reg.keys()],
)
try:
with self.connection.cursor() as cursor:
cursor.execute(to_duckdb_transpile(query))
except ConstraintException: # pylint: disable=broad-except
regression_id = reg["regression_id"]
measure_id = reg["measure_id"]
del reg["regression_id"]
del reg["measure_id"]
update_query = to_duckdb_transpile(update(
self.regression_values, reg,
where=(
f"regression_id = '{regression_id}' AND "
f"measure_id = '{measure_id}'"
),
))
with self.connection.cursor() as cursor:
cursor.execute(update_query)
def _build_ilike(
self,
keyword: str, col: expressions.Column) -> expressions.Escape:
return expressions.Escape(this=col.ilike(keyword), expression="'/'")
def _build_measures_query(
self,
instrument_name: str | None = None,
keyword: str | None = None,
sort_by: str | None = None,
order_by: str | None = None,
) -> tuple[expressions.Select, list[expressions.Alias]]:
"""Find measures by keyword search."""
joined_tables = {}
regression_ids = self.regression_ids
reg_cols = []
columns = [
column("measure_id", self.variable_browser.alias_or_name),
column("instrument_name", self.variable_browser.alias_or_name),
column("measure_name", self.variable_browser.alias_or_name),
column("measure_type", self.variable_browser.alias_or_name),
column("values_domain", self.variable_browser.alias_or_name),
column("figure_distribution_small",
self.variable_browser.alias_or_name),
column("figure_distribution", self.variable_browser.alias_or_name),
]
query = select(*columns).from_(self.variable_browser)
variable_browser_instrument_name_col = column(
"instrument_name", self.variable_browser.alias_or_name)
variable_browser_measure_id_col = column(
"measure_id", self.variable_browser.alias_or_name)
if not self.is_legacy:
instrument_descriptions_instrument_name_col = column(
"instrument_name", self.instrument_descriptions.alias_or_name)
instrument_descriptions_description_col = column(
"description", self.instrument_descriptions.alias_or_name)
query = query.join(
self.instrument_descriptions,
on=sqlglot.condition(
variable_browser_instrument_name_col.eq(
instrument_descriptions_instrument_name_col,
),
),
join_type="LEFT OUTER",
)
query = query.select(
instrument_descriptions_description_col.as_(
"instrument_description",
))
measure_descriptions_measure_id_col = column(
"measure_id", self.measure_descriptions.alias_or_name)
measure_descriptions_description_col = column(
"description", self.measure_descriptions.alias_or_name)
query = query.join(
self.measure_descriptions,
on=sqlglot.condition(
variable_browser_measure_id_col.eq(
measure_descriptions_measure_id_col,
),
),
join_type="LEFT OUTER",
)
query = query.select(measure_descriptions_description_col)
else:
query = query.select(alias_(Null(), "instrument_description"))
query = query.select(alias_(Null(), "description"))
for regression_id in regression_ids:
reg_table = self.regression_values.as_(regression_id)
reg_m_id = column("measure_id", reg_table.alias)
reg_id_col = column("regression_id", reg_table.alias)
query = query.join(
reg_table,
on=sqlglot.condition(
variable_browser_measure_id_col.eq(reg_m_id)
.and_(reg_id_col.eq(regression_id)),
),
join_type="LEFT OUTER",
)
joined_tables[regression_id] = reg_table
cols = [
column(
"figure_regression", table=reg_table.alias_or_name,
).as_(f"{regression_id}_figure_regression"),
column(
"figure_regression_small", table=reg_table.alias_or_name,
).as_(f"{regression_id}_figure_regression_small"),
column(
"pvalue_regression_male", table=reg_table.alias_or_name,
).as_(f"{regression_id}_pvalue_regression_male"),
column(
"pvalue_regression_female", table=reg_table.alias_or_name,
).as_(f"{regression_id}_pvalue_regression_female"),
]
reg_cols.extend(cols)
query = query.select(*cols)
query = query.distinct()
if keyword:
query = self._measures_query_by_keyword(
query, keyword, instrument_name,
)
if instrument_name:
query = query.where(
f"variable_browser.instrument_name = '{instrument_name}'",
)
if sort_by:
column_to_sort: Any
match sort_by:
case "instrument":
column_to_sort = column(
"measure_id", self.variable_browser.alias_or_name)
case "measure":
column_to_sort = column(
"measure_name", self.variable_browser.alias_or_name)
case "measure_type":
column_to_sort = column(
"measure_type", self.variable_browser.alias_or_name)
case _:
regression = sort_by.split(".")
if len(regression) != 2:
raise ValueError(
f"{sort_by} is an invalid sort column",
)
regression_id, sex = regression
if sex not in ("male", "female"):
raise ValueError(
f"{sort_by} is an invalid sort column",
)
column_to_sort = column(
f"{regression_id}_pvalue_regression_{sex}",
)
if order_by == "desc":
query = query.order_by(f"{column_to_sort} DESC")
else:
query = query.order_by(f"{column_to_sort} ASC")
else:
query = query.order_by(
"variable_browser.measure_id ASC",
)
return query, reg_cols
def _build_measures_count_query(
self,
instrument_name: str | None = None,
keyword: str | None = None,
) -> expressions.Select:
"""Count measures by keyword search."""
count = Count(this="*")
query = select(count).from_(self.variable_browser)
query = query.distinct()
if keyword:
query = self._measures_query_by_keyword(
query, keyword, instrument_name,
)
if instrument_name:
query = query.where(
f"variable_browser.instrument_name = '{instrument_name}'",
)
return query
def _measures_query_by_keyword(
self,
query: expressions.Select,
keyword: str,
instrument_name: str | None = None,
) -> expressions.Select:
column_filters = []
keyword = keyword.replace("/", "//")\
.replace("%", r"/%").replace("_", r"/_")
keyword = f"%{keyword}%"
if not instrument_name:
column_filters.append(
self._build_ilike(
keyword,
column("instrument_name", table="variable_browser"),
),
)
column_filters.extend((
self._build_ilike(
keyword,
column("measure_id", table="variable_browser"),
),
self._build_ilike(
keyword,
column("measure_name", table="variable_browser"),
),
))
return query.where(reduce(
lambda left, right: left.or_(right), # type: ignore
column_filters,
))
[docs]
def search_measures(
self,
instrument_name: str | None = None,
keyword: str | None = None,
page: int | None = None,
sort_by: str | None = None,
order_by: str | None = None,
) -> Iterator[dict[str, Any]]:
"""Find measures by keyword search."""
query, reg_cols = self._build_measures_query(
instrument_name,
keyword,
sort_by,
order_by,
)
reg_col_names = [reg_col.alias for reg_col in reg_cols]
if page is None:
page = 1
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
query_str = to_duckdb_transpile(query)
with self.connection.cursor() as cursor:
rows = cursor.execute(query_str).fetchall()
for row in rows:
yield {
"measure_id": row[0],
"instrument_name": row[1],
"measure_name": row[2],
"measure_type": MeasureType(row[3]),
"values_domain": row[4],
"figure_distribution_small": row[5],
"figure_distribution": row[6],
"instrument_description": row[7],
"description": row[8],
**dict(zip(reg_col_names, row[9:], strict=True)),
}
[docs]
def count_measures(
self,
instrument_name: str | None = None,
keyword: str | None = None,
page: int | None = None,
) -> int:
"""Find measures by keyword search."""
query = self._build_measures_count_query(
instrument_name,
keyword,
)
if page is None:
page = 1
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
query_str = to_duckdb_transpile(query)
with self.connection.cursor() as cursor:
rows = cursor.execute(query_str).fetchall()
return int(rows[0][0]) if rows else 0
[docs]
def save_descriptions(
self,
table: Table,
descriptions: dict[str, str],
) -> None:
"""Save instrument or measure descriptions."""
descriptions_table = table.alias_or_name
with self.connection.cursor() as cursor:
delete_rows = delete(table.alias_or_name)
query = insert(
values([tuple(i) for i in descriptions.items()]),
descriptions_table,
)
cursor.execute(to_duckdb_transpile(delete_rows))
cursor.execute(to_duckdb_transpile(query))
@property
def regression_ids(self) -> list[str]:
query = to_duckdb_transpile(select(
column("regression_id", self.regressions.alias_or_name),
).from_(self.regressions))
with self.connection.cursor() as cursor:
return [
x[0] for x in cursor.execute(query).fetchall()
]
@property
def regression_display_names(self) -> dict[str, str]:
"""Return regressions display name."""
res = {}
query = to_duckdb_transpile(select(
column("regression_id", self.regressions.alias_or_name),
column("display_name", self.regressions.alias_or_name),
).from_(self.regressions))
with self.connection.cursor() as cursor:
for row in cursor.execute(query).fetchall():
res[row[0]] = row[1]
return res
@property
def regression_display_names_with_ids(self) -> dict[str, Any]:
"""Return regression display names with measure IDs."""
res = {}
query = to_duckdb_transpile(select(
column("regression_id", self.regressions.alias_or_name),
column("display_name", self.regressions.alias_or_name),
column("instrument_name", self.regressions.alias_or_name),
column("measure_name", self.regressions.alias_or_name),
).from_(self.regressions))
with self.connection.cursor() as cursor:
for row in cursor.execute(query).fetchall():
res[row[0]] = {
"display_name": row[1],
"instrument_name": row[2],
"measure_name": row[3],
}
return res
@property
def has_instrument_descriptions(self) -> bool:
"""Check if the database has instrument description data."""
if self.is_legacy:
return False
query = to_duckdb_transpile(select("COUNT(*)").from_(
self.instrument_descriptions,
).where("description IS NOT NULL"))
with self.connection.cursor() as cursor:
row = cursor.execute(query).fetchone()
if row is None:
return False
return bool(row[0])
@property
def has_measure_descriptions(self) -> bool:
"""Check if the database has measure description data."""
if self.is_legacy:
return False
query = to_duckdb_transpile(select("COUNT(*)").from_(
self.measure_descriptions,
).where("description IS NOT NULL"))
with self.connection.cursor() as cursor:
row = cursor.execute(query).fetchone()
if row is None:
return False
return bool(row[0])