from __future__ import annotations
from collections.abc import Generator
from functools import reduce
from typing import Any, cast
import duckdb
import pandas as pd
from sqlglot import column, expressions, select
from sqlglot.expressions import table_
from dae.pheno.common import MeasureType
from dae.utils.sql_utils import glot_and, to_duckdb_transpile
from dae.variants.attributes import Role, Sex, Status
[docs]
class PhenoDb: # pylint: disable=too-many-instance-attributes
"""Class that manages access to phenotype databases."""
def __init__(
self, dbfile: str, *, read_only: bool = True,
) -> None:
self.dbfile = dbfile
self.connection = duckdb.connect(
f"{dbfile}", read_only=read_only)
self.family = table_("family")
self.person = table_("person")
self.measure = table_("measure")
self.instrument = table_("instrument")
self.instrument_values_tables = self.find_instrument_values_tables()
[docs]
def find_instrument_values_tables(self) -> dict[str, expressions.Table]:
"""
Create instrument values tables.
Each row is basically a list of every measure value in the instrument
for a certain person.
"""
query = to_duckdb_transpile(select(
"instrument_name",
"table_name",
).from_(self.instrument))
with self.connection.cursor() as cursor:
results = cursor.execute(query).fetchall()
return {i_name: table_(t_name) for i_name, t_name in results}
[docs]
def get_measures_df(
self,
instrument: str | None = None,
measure_type: MeasureType | None = None,
) -> pd.DataFrame:
"""
Return data frame containing measures information.
`instrument` -- an instrument name which measures should be
returned. If not specified all type of measures are returned.
`measure_type` -- a type ('continuous', 'ordinal' or 'categorical')
of measures that should be returned. If not specified all
type of measures are returned.
Each row in the returned data frame represents given measure.
Columns in the returned data frame are: `measure_id`, `measure_name`,
`instrument_name`, `description`, `stats`, `min_value`, `max_value`,
`value_domain`, `has_probands`, `has_siblings`, `has_parents`,
`default_filter`.
"""
measure_table = self.measure
columns = [
column("measure_id", measure_table.alias_or_name),
column("instrument_name", measure_table.alias_or_name),
column("measure_name", measure_table.alias_or_name),
column("description", measure_table.alias_or_name),
column("measure_type", measure_table.alias_or_name),
column("individuals", measure_table.alias_or_name),
column("default_filter", measure_table.alias_or_name),
column("values_domain", measure_table.alias_or_name),
column("min_value", measure_table.alias_or_name),
column("max_value", measure_table.alias_or_name),
]
query: Any = select(*columns).from_(
measure_table,
).where(f"{columns[4].sql()} IS NOT NULL")
if instrument is not None:
query = query.where(columns[1]).eq(instrument)
if measure_type is not None:
query = query.where(columns[4]).eq(measure_type.value)
with self.connection.cursor() as cursor:
df = cursor.execute(to_duckdb_transpile(query)).df()
df_columns = [
"measure_id",
"measure_name",
"instrument_name",
"description",
"individuals",
"measure_type",
"default_filter",
"values_domain",
"min_value",
"max_value",
]
return df[df_columns]
def _get_measure_values_query(
self,
measure_ids: list[str],
person_ids: list[str] | None = None,
family_ids: list[str] | None = None,
roles: list[Role] | None = None,
) -> tuple[str, list[expressions.Column]]:
assert isinstance(measure_ids, list)
assert len(measure_ids) >= 1
assert len(self.instrument_values_tables) > 0
instrument_tables = {}
for measure_id in measure_ids:
instrument, _ = measure_id.split(".", maxsplit=1)
instrument_table = table_(
generate_instrument_table_name(instrument))
instrument_tables[instrument] = instrument_table
union_queries = [
select(
column("person_id", table.alias_or_name),
column("family_id", table.alias_or_name),
column("role", table.alias_or_name),
column("status", table.alias_or_name),
column("sex", table.alias_or_name),
).from_(table)
for table in instrument_tables.values()
]
instrument_people = reduce(
lambda left, right: left.union(right), # type: ignore
union_queries,
).subquery(alias="instrument_people")
person_id_col = column("person_id", instrument_people.alias_or_name)
output_cols = [
person_id_col,
column("family_id", instrument_people.alias_or_name),
column("role", instrument_people.alias_or_name),
column("status", instrument_people.alias_or_name),
column("sex", instrument_people.alias_or_name),
]
query = select(*output_cols).from_(instrument_people)
joined = set()
for measure_id in measure_ids:
instrument, measure = measure_id.split(".", maxsplit=1)
instrument_table = instrument_tables[instrument]
if instrument not in joined:
left_col = person_id_col.sql()
right_col = column(
"person_id", instrument_table.alias_or_name,
).sql()
measure_col = column(
safe_db_name(measure),
instrument_table.alias_or_name,
).as_(measure_id)
query = query.select(
measure_col,
).join(
instrument_table,
on=f"{left_col} = {right_col}",
join_type="FULL OUTER",
)
joined.add(instrument)
output_cols.append(cast(expressions.Column, measure_col))
else:
assert query is not None
measure_col = column(
safe_db_name(measure),
instrument_table.alias_or_name,
).as_(measure_id)
query = query.select(
measure_col,
)
output_cols.append(cast(expressions.Column, measure_col))
assert query is not None
empty_result = False
cols_in = []
if person_ids is not None:
if len(person_ids) == 0:
empty_result = True
else:
col = person_id_col
cols_in.append(col.isin(*person_ids))
if family_ids is not None:
if len(family_ids) == 0:
empty_result = True
else:
col = column(
"family_id",
instrument_people.alias_or_name,
)
cols_in.append(col.isin(*family_ids))
if roles is not None:
if len(roles) == 0:
empty_result = True
else:
col = column(
"role",
instrument_people.alias_or_name,
)
cols_in.append(col.isin(*[r.value for r in roles]))
query = query.order_by(person_id_col)
if cols_in:
query = query.where(reduce(glot_and, cols_in))
if empty_result:
query = query.where("1=2")
return (
to_duckdb_transpile(query),
output_cols,
)
[docs]
def get_people_measure_values(
self,
measure_ids: list[str],
person_ids: list[str] | None = None,
family_ids: list[str] | None = None,
roles: list[Role] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Yield lines from measure values tables."""
query, output_cols = self._get_measure_values_query(
measure_ids, person_ids, family_ids, roles,
)
with self.connection.cursor() as cursor:
result = cursor.execute(query)
for row in result.fetchall():
output = {
col.alias_or_name: row[idx]
for idx, col in enumerate(output_cols)
}
output["role"] = Role.to_name(output["role"])
output["status"] = Status.to_name(output["status"])
output["sex"] = Sex.to_name(output["sex"])
yield output
[docs]
def get_people_measure_values_df(
self,
measure_ids: list[str],
person_ids: list[str] | None = None,
family_ids: list[str] | None = None,
roles: list[Role] | None = None,
) -> pd.DataFrame:
"""Return dataframe from measure values tables."""
query, _ = self._get_measure_values_query(
measure_ids,
person_ids=person_ids,
family_ids=family_ids,
roles=roles,
)
with self.connection.cursor() as cursor:
result = cursor.execute(query)
df = result.df()
df["sex"] = df["sex"].transform(Sex.from_value)
df["status"] = df["status"].transform(Status.from_value)
df["role"] = df["role"].transform(Role.from_value)
return df
[docs]
def safe_db_name(name: str) -> str:
"""Convert a string to a db-friendly string."""
if name == "":
raise ValueError("The name cannot be empty")
name = name.replace(".", "_").replace("-", "_").replace(" ", "_").lower()
name = name.replace("/", "_").strip()
if name[0].isdigit():
name = f"_{name}"
return name
[docs]
def generate_instrument_table_name(instrument_name: str) -> str:
instrument_name = safe_db_name(instrument_name)
return f"{instrument_name}_measure_values"