from __future__ import annotations
from collections.abc import Generator, Iterator
from functools import reduce
from pathlib import Path
from typing import Any, cast
import duckdb
import pandas as pd
import sqlglot
from duckdb import ConstraintException
from sqlglot import column, expressions, select, table
from sqlglot.expressions import Count, delete, insert, update, values
from dae.pheno.common import MeasureType
from dae.utils.sql_utils import glot_and, to_duckdb_transpile
from dae.variants.attributes import Role, Sex, Status
[docs]
class PhenoDb: # pylint: disable=too-many-instance-attributes
"""Class that manages access to phenotype databases."""
PAGE_SIZE = 1001
def __init__(
self, dbfile: str, *, read_only: bool = True,
) -> None:
self.dbfile = dbfile
self.connection = duckdb.connect(
f"{dbfile}", read_only=read_only)
self.variable_browser = table("variable_browser")
self.regressions = table("regression")
self.regression_values = table("regression_values")
self.family = table("family")
self.person = table("person")
self.measure = table("measure")
self.instrument = table("instrument")
self.instrument_values_tables = self.find_instrument_values_tables()
[docs]
@staticmethod
def verify_pheno_folder(folder: Path) -> None:
"""Verify integrity of a pheno db folder."""
parquet_folder = folder / "parquet"
assert parquet_folder.exists()
family_file = parquet_folder / "family.parquet"
assert family_file.exists()
person_file = parquet_folder / "person.parquet"
assert person_file.exists()
instrument_file = parquet_folder / "instrument.parquet"
assert instrument_file.exists()
measure_file = parquet_folder / "measure.parquet"
assert measure_file.exists()
instruments_dir = parquet_folder / "instruments"
assert instruments_dir.exists()
assert instruments_dir.is_dir()
assert len(list(instruments_dir.glob("*"))) > 0
[docs]
def find_instrument_values_tables(self) -> dict[str, expressions.Table]:
"""
Create instrument values tables.
Each row is basically a list of every measure value in the instrument
for a certain person.
"""
query = to_duckdb_transpile(select(
"instrument_name",
"table_name",
).from_(self.instrument))
with self.connection.cursor() as cursor:
results = cursor.execute(query).fetchall()
return {i_name: table(t_name) for i_name, t_name in results}
def _split_measures_into_groups(
self, measure_ids: list[str], group_size: int = 60,
) -> list[list[str]]:
groups_count = int(len(measure_ids) / group_size) + 1
if (groups_count) == 1:
return [measure_ids]
measure_groups = []
for i in range(groups_count):
begin = i * group_size
end = (i + 1) * group_size
group = measure_ids[begin:end]
if len(group) > 0:
measure_groups.append(group)
return measure_groups
[docs]
def save(self, v: dict[str, Any]) -> None:
"""Save measure values into the database."""
query = to_duckdb_transpile(insert(
values([(*v.values(),)]),
self.variable_browser,
columns=[*v.keys()],
))
try:
with self.connection.cursor() as cursor:
cursor.execute(query)
except ConstraintException: # pylint: disable=broad-except
measure_id = v["measure_id"]
delete_query = to_duckdb_transpile(delete(
self.variable_browser,
).where("measure_id").eq(measure_id))
with self.connection.cursor() as cursor:
cursor.execute(delete_query)
cursor.execute(query)
[docs]
def save_regression(self, reg: dict[str, str]) -> None:
"""Save regressions into the database."""
query = to_duckdb_transpile(insert(
values([(*reg.values(),)]),
self.regressions,
columns=[*reg.keys()],
))
try:
with self.connection.cursor() as cursor:
cursor.execute(query)
except ConstraintException: # pylint: disable=broad-except
regression_id = reg["regression_id"]
del reg["regression_id"]
update_query = update(
self.regressions, reg,
where=f"regression_id = {regression_id}",
)
with self.connection.cursor() as cursor:
cursor.execute(update_query)
[docs]
def save_regression_values(self, reg: dict[str, str]) -> None:
"""Save regression values into the databases."""
query = insert(
values([(*reg.values(),)]),
self.regression_values,
columns=[*reg.keys()],
)
try:
with self.connection.cursor() as cursor:
cursor.execute(to_duckdb_transpile(query))
except ConstraintException: # pylint: disable=broad-except
regression_id = reg["regression_id"]
measure_id = reg["measure_id"]
del reg["regression_id"]
del reg["measure_id"]
update_query = update(
self.regression_values, reg,
where=(
f"regression_id = {regression_id} AND "
f"measure_id = {measure_id}"
),
)
with self.connection.cursor() as cursor:
cursor.execute(update_query)
[docs]
def get_browser_measure(self, measure_id: str) -> dict | None:
"""Get measure description from phenotype browser database."""
query = to_duckdb_transpile(select("variable_browser.*",
).from_(self.variable_browser).where("measure_id").eq(measure_id))
with self.connection.cursor() as cursor:
rows = cursor.execute(query).df()
if not rows.empty:
return rows.to_dict("records")[0]
return None
def _build_ilike(
self,
keyword: str, col: expressions.Column) -> expressions.Escape:
return expressions.Escape(this=col.ilike(keyword), expression="'/'")
[docs]
def build_measures_query(
self,
instrument_name: str | None = None,
keyword: str | None = None,
sort_by: str | None = None,
order_by: str | None = None,
) -> tuple[expressions.Select, list[expressions.Alias]]:
"""Find measures by keyword search."""
joined_tables = {}
regression_ids = self.regression_ids
reg_cols = []
query = select(
f"{self.variable_browser.alias_or_name}.*",
).from_(self.variable_browser)
for regression_id in regression_ids:
reg_table = self.regression_values.as_(regression_id)
measure_id_col = column(
"measure_id", self.variable_browser.alias_or_name)
reg_m_id = column("measure_id", reg_table.alias)
reg_id_col = column("regression_id", reg_table.alias)
query = query.join(
reg_table,
on=sqlglot.condition(
measure_id_col.eq(reg_m_id)
.and_(reg_id_col.eq(regression_id)),
),
join_type="LEFT OUTER",
)
joined_tables[regression_id] = reg_table
cols = [
column(
"figure_regression", table=reg_table.alias_or_name,
).as_(f"{regression_id}_figure_regression"),
column(
"figure_regression_small", table=reg_table.alias_or_name,
).as_(f"{regression_id}_figure_regression_small"),
column(
"pvalue_regression_male", table=reg_table.alias_or_name,
).as_(f"{regression_id}_pvalue_regression_male"),
column(
"pvalue_regression_female", table=reg_table.alias_or_name,
).as_(f"{regression_id}_pvalue_regression_female"),
]
reg_cols.extend(cols)
query = query.select(*cols)
query = query.distinct()
if keyword:
query = self._measures_query_by_keyword(
query, keyword, instrument_name,
)
if instrument_name:
query = query.where(
f"variable_browser.instrument_name = '{instrument_name}'",
)
if sort_by:
column_to_sort: Any
match sort_by:
case "instrument":
column_to_sort = column(
"measure_id", self.variable_browser.alias_or_name)
case "measure":
column_to_sort = column(
"measure_name", self.variable_browser.alias_or_name)
case "measure_type":
column_to_sort = column(
"measure_type", self.variable_browser.alias_or_name)
case "description":
column_to_sort = column(
"description", self.variable_browser.alias_or_name)
case _:
regression = sort_by.split(".")
if len(regression) != 2:
raise ValueError(
f"{sort_by} is an invalid sort column",
)
regression_id, sex = regression
reg_table = joined_tables[regression_id]
if sex == "male":
col_name = f"{regression_id}_pvalue_regression_male"
else:
if sex != "female":
raise ValueError(
f"{sort_by} is an invalid sort column",
)
col_name = f"{regression_id}_pvalue_regression_female"
column_to_sort = column(col_name)
if order_by == "desc":
query = query.order_by(f"{column_to_sort} DESC")
else:
query = query.order_by(f"{column_to_sort} ASC")
else:
query = query.order_by(
"variable_browser.measure_id ASC",
)
return query, reg_cols
[docs]
def build_measures_count_query(
self,
instrument_name: str | None = None,
keyword: str | None = None,
) -> expressions.Select:
"""Count measures by keyword search."""
count = Count(this="*")
query = select(count).from_(self.variable_browser)
query = query.distinct()
if keyword:
query = self._measures_query_by_keyword(
query, keyword, instrument_name,
)
if instrument_name:
query = query.where(
f"variable_browser.instrument_name = '{instrument_name}'",
)
return query
def _measures_query_by_keyword(
self,
query: expressions.Select,
keyword: str,
instrument_name: str | None = None,
) -> expressions.Select:
column_filters = []
keyword = keyword.replace("/", "//")\
.replace("%", r"/%").replace("_", r"/_")
keyword = f"%{keyword}%"
if not instrument_name:
column_filters.append(
self._build_ilike(
keyword,
column("instrument_name", table="variable_browser"),
),
)
column_filters.extend((
self._build_ilike(
keyword,
column("measure_id", table="variable_browser"),
),
self._build_ilike(
keyword,
column("measure_name", table="variable_browser"),
),
self._build_ilike(
keyword,
column("description", table="variable_browser"),
),
))
return query.where(reduce(
lambda left, right: left.or_(right), # type: ignore
column_filters,
))
[docs]
def get_measures_df(
self,
instrument: str | None = None,
measure_type: MeasureType | None = None,
) -> pd.DataFrame:
"""
Return data frame containing measures information.
`instrument` -- an instrument name which measures should be
returned. If not specified all type of measures are returned.
`measure_type` -- a type ('continuous', 'ordinal' or 'categorical')
of measures that should be returned. If not specified all
type of measures are returned.
Each row in the returned data frame represents given measure.
Columns in the returned data frame are: `measure_id`, `measure_name`,
`instrument_name`, `description`, `stats`, `min_value`, `max_value`,
`value_domain`, `has_probands`, `has_siblings`, `has_parents`,
`default_filter`.
"""
measure_table = self.measure
columns = [
column("measure_id", measure_table.alias_or_name),
column("instrument_name", measure_table.alias_or_name),
column("measure_name", measure_table.alias_or_name),
column("description", measure_table.alias_or_name),
column("measure_type", measure_table.alias_or_name),
column("individuals", measure_table.alias_or_name),
column("default_filter", measure_table.alias_or_name),
column("values_domain", measure_table.alias_or_name),
column("min_value", measure_table.alias_or_name),
column("max_value", measure_table.alias_or_name),
]
query: Any = select(*columns).from_(
measure_table,
).where(f"{columns[4].sql()} IS NOT NULL")
if instrument is not None:
query = query.where(columns[1]).eq(instrument)
if measure_type is not None:
query = query.where(columns[4]).eq(measure_type.value)
with self.connection.cursor() as cursor:
df = cursor.execute(to_duckdb_transpile(query)).df()
df_columns = [
"measure_id",
"measure_name",
"instrument_name",
"description",
"individuals",
"measure_type",
"default_filter",
"values_domain",
"min_value",
"max_value",
]
return df[df_columns]
[docs]
def search_measures(
self,
instrument_name: str | None = None,
keyword: str | None = None,
page: int | None = None,
sort_by: str | None = None,
order_by: str | None = None,
) -> Iterator[dict[str, Any]]:
"""Find measures by keyword search."""
query, reg_cols = self.build_measures_query(
instrument_name,
keyword,
sort_by,
order_by,
)
reg_col_names = [reg_col.alias for reg_col in reg_cols]
if page is None:
page = 1
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
query_str = to_duckdb_transpile(query)
with self.connection.cursor() as cursor:
rows = cursor.execute(query_str).fetchall()
for row in rows:
yield {
"measure_id": row[0],
"instrument_name": row[1],
"measure_name": row[2],
"measure_type": MeasureType(row[3]),
"description": row[4],
"values_domain": row[5],
"figure_distribution_small": row[6],
"figure_distribution": row[7],
**dict(zip(reg_col_names, row[8:], strict=True)),
}
[docs]
def count_measures(
self,
instrument_name: str | None = None,
keyword: str | None = None,
page: int | None = None,
) -> int:
"""Find measures by keyword search."""
query = self.build_measures_count_query(
instrument_name,
keyword,
)
if page is None:
page = 1
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
query_str = to_duckdb_transpile(query)
with self.connection.cursor() as cursor:
rows = cursor.execute(query_str).fetchall()
return int(rows[0][0])
[docs]
def search_measures_df(
self, instrument_name: str | None = None,
keyword: str | None = None,
) -> pd.DataFrame:
"""Find measures and return a dataframe with values."""
query = to_duckdb_transpile(
self.build_measures_query(instrument_name, keyword)[0])
# execute query and .df()
return self.connection.execute(query).df()
@property
def regression_ids(self) -> list[str]:
query = to_duckdb_transpile(select(
column("regression_id", self.regressions.alias_or_name),
).from_(self.regressions))
with self.connection.cursor() as cursor:
return [
x[0] for x in cursor.execute(query).fetchall()
]
@property
def regression_display_names(self) -> dict[str, str]:
"""Return regressions display name."""
res = {}
query = to_duckdb_transpile(select(
column("regression_id", self.regressions.alias_or_name),
column("display_name", self.regressions.alias_or_name),
).from_(self.regressions))
with self.connection.cursor() as cursor:
for row in cursor.execute(query).fetchall():
res[row[0]] = row[1]
return res
@property
def regression_display_names_with_ids(self) -> dict[str, Any]:
"""Return regression display names with measure IDs."""
res = {}
query = to_duckdb_transpile(select(
column("regression_id", self.regressions.alias_or_name),
column("display_name", self.regressions.alias_or_name),
column("instrument_name", self.regressions.alias_or_name),
column("measure_name", self.regressions.alias_or_name),
).from_(self.regressions))
with self.connection.cursor() as cursor:
for row in cursor.execute(query).fetchall():
res[row[0]] = {
"display_name": row[1],
"instrument_name": row[2],
"measure_name": row[3],
}
return res
@property
def has_descriptions(self) -> bool:
"""Check if the database has a description data."""
query = to_duckdb_transpile(select("COUNT(*)").from_(
self.variable_browser,
).where("description IS NOT NULL"))
with self.connection.cursor() as cursor:
row = cursor.execute(query).fetchone()
if row is None:
return False
return bool(row[0])
def _get_measure_values_query(
self,
measure_ids: list[str],
person_ids: list[str] | None = None,
family_ids: list[str] | None = None,
roles: list[Role] | None = None,
) -> tuple[str, list[expressions.Column]]:
assert isinstance(measure_ids, list)
assert len(measure_ids) >= 1
assert len(self.instrument_values_tables) > 0
instrument_tables = {}
for measure_id in measure_ids:
instrument, _ = measure_id.split(".", maxsplit=1)
instrument_table = table(generate_instrument_table_name(instrument))
instrument_tables[instrument] = instrument_table
union_queries = [
select(
column("person_id", table.alias_or_name),
column("family_id", table.alias_or_name),
column("role", table.alias_or_name),
column("status", table.alias_or_name),
column("sex", table.alias_or_name),
).from_(table)
for table in instrument_tables.values()
]
instrument_people = reduce(
lambda left, right: left.union(right), # type: ignore
union_queries,
).subquery(alias="instrument_people")
person_id_col = column("person_id", instrument_people.alias_or_name)
output_cols = [
person_id_col,
column("family_id", instrument_people.alias_or_name),
column("role", instrument_people.alias_or_name),
column("status", instrument_people.alias_or_name),
column("sex", instrument_people.alias_or_name),
]
query = select(*output_cols).from_(instrument_people)
joined = set()
for measure_id in measure_ids:
instrument, measure = measure_id.split(".", maxsplit=1)
instrument_table = instrument_tables[instrument]
if instrument not in joined:
left_col = person_id_col.sql()
right_col = column(
"person_id", instrument_table.alias_or_name,
).sql()
measure_col = column(
safe_db_name(measure),
instrument_table.alias_or_name,
).as_(measure_id)
query = query.select(
measure_col,
).join(
instrument_table,
on=f"{left_col} = {right_col}",
join_type="FULL OUTER",
)
joined.add(instrument)
output_cols.append(cast(expressions.Column, measure_col))
else:
assert query is not None
measure_col = column(
safe_db_name(measure),
instrument_table.alias_or_name,
).as_(measure_id)
query = query.select(
measure_col,
)
output_cols.append(cast(expressions.Column, measure_col))
assert query is not None
empty_result = False
cols_in = []
if person_ids is not None:
if len(person_ids) == 0:
empty_result = True
else:
col = person_id_col
cols_in.append(col.isin(*person_ids))
if family_ids is not None:
if len(family_ids) == 0:
empty_result = True
else:
col = column(
"family_id",
instrument_people.alias_or_name,
)
cols_in.append(col.isin(*family_ids))
if roles is not None:
if len(roles) == 0:
empty_result = True
else:
col = column(
"role",
instrument_people.alias_or_name,
)
cols_in.append(col.isin(*[r.value for r in roles]))
query = query.order_by(person_id_col)
if cols_in:
query = query.where(reduce(glot_and, cols_in))
if empty_result:
query = query.where("1=2")
return (
to_duckdb_transpile(query),
output_cols,
)
[docs]
def get_people_measure_values(
self,
measure_ids: list[str],
person_ids: list[str] | None = None,
family_ids: list[str] | None = None,
roles: list[Role] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Yield lines from measure values tables."""
query, output_cols = self._get_measure_values_query(
measure_ids, person_ids, family_ids, roles,
)
with self.connection.cursor() as cursor:
result = cursor.execute(query)
for row in result.fetchall():
output = {
col.alias_or_name: row[idx]
for idx, col in enumerate(output_cols)
}
output["role"] = Role.to_name(output["role"])
output["status"] = Status.to_name(output["status"])
output["sex"] = Sex.to_name(output["sex"])
yield output
[docs]
def get_people_measure_values_df(
self,
measure_ids: list[str],
person_ids: list[str] | None = None,
family_ids: list[str] | None = None,
roles: list[Role] | None = None,
) -> pd.DataFrame:
"""Return dataframe from measure values tables."""
query, _ = self._get_measure_values_query(
measure_ids,
person_ids=person_ids,
family_ids=family_ids,
roles=roles,
)
with self.connection.cursor() as cursor:
result = cursor.execute(query)
df = result.df()
df["sex"] = df["sex"].transform(Sex.from_value)
df["status"] = df["status"].transform(Status.from_value)
df["role"] = df["role"].transform(Role.from_value)
return df
[docs]
def safe_db_name(name: str) -> str:
name = name.replace(".", "_").replace("-", "_").replace(" ", "_").lower()
name = name.replace("/", "_")
if name[0].isdigit():
name = f"_{name}"
return name
[docs]
def generate_instrument_table_name(instrument_name: str) -> str:
instrument_name = safe_db_name(instrument_name)
return f"{instrument_name}_measure_values"