import logging
from copy import copy
from typing import Any, Dict
from sqlalchemy import ( # type: ignore
Column,
Float,
Integer,
MetaData,
String,
Table,
create_engine,
inspect,
nullslast,
)
from sqlalchemy.sql import asc, desc, insert, select
from dae.gene_profile.statistic import GPStatistic
logger = logging.getLogger(__name__)
[docs]
class GeneProfileDB:
"""
Class for managing the gene profile database.
Uses SQLite for DB management and supports loading
and storing to filesystem.
Has to be supplied a configuration and a path to which to read/write
the SQLite DB.
"""
PAGE_SIZE = 50
def __init__(self, configuration, dbfile, clear=False):
self.dbfile = dbfile
self.engine = create_engine(f"sqlite:///{dbfile}")
self.metadata = MetaData()
self.configuration = \
GeneProfileDB.build_configuration(configuration)
self._create_gp_table()
self.gene_sets_categories = {}
if clear:
self._clear_gp_table()
if len(self.configuration.keys()):
for category in self.configuration["gene_sets"]:
category_name = category["category"]
for gene_set in category["sets"]:
collection_id = gene_set["collection_id"]
set_id = gene_set["set_id"]
full_gene_set_id = f"{collection_id}_{set_id}"
self.gene_sets_categories[full_gene_set_id] = category_name
[docs]
@classmethod
def build_configuration(cls, configuration):
"""
Perform a transformation on a given configuration.
The configuration is transformed to an internal version with more
specific information on order and ranks.
"""
if configuration is None:
return {}
order = configuration.get("order")
order_keys = []
for gene_set in configuration["gene_sets"]:
order_keys.append(gene_set["category"] + "_rank")
for genomic_score in configuration["genomic_scores"]:
order_keys.append(genomic_score["category"])
for dataset_id in configuration["datasets"].keys():
order_keys.append(dataset_id)
if order is None:
configuration["order"] = order_keys
else:
total_categories_count = \
len(configuration["gene_sets"]) + \
len(configuration["genomic_scores"]) + \
len(configuration["datasets"])
assert all(x in order_keys for x in order), "Given GP order " \
"has invalid entries"
assert len(order) == total_categories_count, "Given GP order " \
"is missing items"
return copy(configuration)
[docs]
def get_gp(self, gene_symbol) -> GPStatistic | None:
"""
Query a GP by gene_symbol and return the row as statistic.
Returns None if gene_symbol is not found within the DB.
"""
table = self.gp_table
query = table.select()
query = query.where(table.c.symbol_name.ilike(gene_symbol))
with self.engine.begin() as connection:
row = connection.execute(query).fetchone()
if not row:
return None
return self.gp_from_table_row_single_view(row)
# FIXME: Too many locals, refactor?
[docs]
def gp_from_table_row(self, row) -> dict:
# pylint: disable=too-many-locals
"""Build an GPStatistic from internal DB row."""
config = self.configuration
row = row._mapping # pylint: disable=protected-access
result = {}
result["geneSymbol"] = row["symbol_name"]
for gs_category in config["genomic_scores"]:
category_name = gs_category["category"]
for score in gs_category["scores"]:
score_name = score["score_name"]
value = row[f"{category_name}_{score_name}"]
result[".".join([category_name, score_name])] = value
for gs_category in config["gene_sets"]:
category_name = gs_category["category"]
for gene_set in gs_category["sets"]:
set_id = gene_set["set_id"]
collection_id = gene_set["collection_id"]
full_gs_id = f"{collection_id}_{set_id}"
result[".".join([f"{category_name}_rank", set_id])] = \
"\u2713" if row[full_gs_id] else None
for dataset_id, filters in config["datasets"].items():
for person_set in filters["person_sets"]:
set_name = person_set["set_name"]
for statistic in filters["statistics"]:
statistic_id = statistic["id"]
count = row[
f"{dataset_id}_{set_name}_{statistic_id}"
]
rate = row[
f"{dataset_id}_{set_name}_{statistic_id}_rate"
]
result[".".join([dataset_id, set_name, statistic_id])] = \
f"{count} ({round(rate, 2)})" if count else None
return result
# FIXME: Too many locals, refactor?
[docs]
def gp_from_table_row_single_view(self, row) -> GPStatistic:
"""Create an GPStatistic from single view row."""
# pylint: disable=too-many-locals
config = self.configuration
row = row._mapping # pylint: disable=protected-access
gene_symbol = row["symbol_name"]
genomic_scores: Dict[str, Dict] = {}
for gs_category in config["genomic_scores"]:
category_name = gs_category["category"]
genomic_scores[category_name] = {}
for score in gs_category["scores"]:
score_name = score["score_name"]
full_score_id = f"{category_name}_{score_name}"
genomic_scores[category_name][score_name] = {
"value": row[full_score_id],
"format": score["format"],
}
gene_sets_categories = config["gene_sets"]
gene_sets = []
for gs_category in gene_sets_categories:
category_name = gs_category["category"]
for gene_set in gs_category["sets"]:
set_id = gene_set["set_id"]
collection_id = gene_set["collection_id"]
full_gs_id = f"{collection_id}_{set_id}"
if row[full_gs_id] == 1:
gene_sets.append(full_gs_id)
variant_counts = {}
for dataset_id, filters in config["datasets"].items():
current_counts: Dict[str, Dict] = {}
for person_set in filters["person_sets"]:
set_name = person_set["set_name"]
for statistic in filters["statistics"]:
statistic_id = statistic["id"]
counts = current_counts.get(set_name)
if not counts:
current_counts[set_name] = {}
counts = current_counts[set_name]
count = row[
f"{dataset_id}_{set_name}_{statistic_id}"
]
rate = row[
f"{dataset_id}_{set_name}_{statistic_id}_rate"
]
counts[statistic_id] = {
"count": count,
"rate": rate,
}
variant_counts[dataset_id] = current_counts
return GPStatistic(
gene_symbol, gene_sets,
genomic_scores, variant_counts,
)
def _transform_sort_by(self, sort_by):
sort_by_tokens = sort_by.split(".")
if sort_by.startswith("gene_set_"):
sort_by = sort_by.replace("gene_set_", "", 1)
if "_rank" in sort_by_tokens[0]:
collection_id = ""
category = sort_by_tokens[0].replace("_rank", "")
if len(sort_by_tokens) == 2:
for gs_category in self.configuration["gene_sets"]:
if gs_category["category"] != category:
continue
for gene_set in gs_category["sets"]:
if gene_set["set_id"] == sort_by_tokens[1]:
collection_id = gene_set["collection_id"]
sort_by = ".".join((collection_id, sort_by_tokens[1]))
return sort_by.replace(".", "_")
[docs]
def query_gps(self, page, symbol_like=None, sort_by=None, order=None):
"""
Perform paginated query and return list of GPs.
Parameters:
page - Which page to fetch.
symbol_like - Which gene symbol to search for, supports
incomplete search
sort_by - Column to sort by
order - "asc" or "desc"
"""
table = self.gp_table
query = table.select()
if symbol_like:
query = query.where(table.c.symbol_name.like(f"%{symbol_like}%"))
if sort_by is not None:
if order is None:
order = "desc"
sort_by = self._transform_sort_by(sort_by)
query_order_func = desc if order == "desc" else asc
query = query.order_by(nullslast(query_order_func(sort_by)))
if page is not None:
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
with self.engine.begin() as connection:
return connection.execute(query).fetchall()
[docs]
def list_symbols(
self, page: int, symbol_like: str | None = None,
) -> list[str]:
"""
Perform paginated query and return list of gene symbols.
Parameters:
page - Which page to fetch.
symbol_like - Which gene symbol to search for, supports
incomplete search
"""
table = self.gp_table
query = select(table.c.symbol_name)
if symbol_like:
query = query.where(table.c.symbol_name.like(f"{symbol_like}%"))
query = query.order_by(table.c.symbol_name.asc())
if page is not None:
query = query.limit(self.PAGE_SIZE).offset(
self.PAGE_SIZE * (page - 1),
)
with self.engine.begin() as connection:
return [row[0] for row in connection.execute(query).fetchall()]
[docs]
def drop_gp_table(self):
with self.engine.begin() as connection:
connection.execute("DROP TABLE IF EXISTS gene_profile")
connection.commit()
[docs]
def gp_table_exists(self):
insp = inspect(self.engine)
with self.engine.begin() as connection:
has_gp_table = insp.dialect.has_table(
connection, "gene_profile",
)
return has_gp_table
# FIXME: Too many locals, refactor?
def _gp_table_columns(self): # pylint: disable=too-many-locals
columns = {}
columns["symbol_name"] = \
Column("symbol_name", String(64), primary_key=True)
if len(self.configuration) == 0:
return columns
for category in self.configuration["gene_sets"]:
category_name = category["category"]
rank_col = f"{category_name}_rank"
columns[rank_col] = Column(rank_col, Integer())
for gene_set in category["sets"]:
set_id = gene_set["set_id"]
collection_id = gene_set["collection_id"]
full_set_id = f"{collection_id}_{set_id}"
columns[full_set_id] = Column(full_set_id, Integer())
for category in self.configuration["genomic_scores"]:
category_name = category["category"]
for score in category["scores"]:
score_name = score["score_name"]
col = f"{category_name}_{score_name}"
columns[col] = Column(col, Float())
for dataset_id in self.configuration["datasets"].keys():
config_section = self.configuration["datasets"][dataset_id]
for person_set in config_section["person_sets"]:
set_name = person_set["set_name"]
for stat in config_section["statistics"]:
stat_id = stat["id"]
column_name = f"{dataset_id}_{set_name}_{stat_id}"
columns[column_name] = Column(column_name, Float())
rate_col_name = f"{column_name}_rate"
columns[rate_col_name] = Column(rate_col_name, Float())
return columns
def _create_gp_table(self):
columns = self._gp_table_columns().values()
self.gp_table = Table(
"gene_profile",
self.metadata,
*columns,
)
self.metadata.create_all(self.engine)
def _clear_gp_table(self, connection=None):
query = self.gp_table.delete()
if connection is not None:
connection.execute(query)
return
with self.engine.begin() as conn:
conn.execute(query)
conn.commit()
[docs]
def insert_gp(self, gp, connection=None):
"""Insert a GP into the DB."""
insert_map = self._create_insert_map(gp)
if connection is not None:
connection.execute(
insert(self.gp_table).values(**insert_map),
)
return
with self.engine.begin() as conn:
conn.execute(
insert(self.gp_table).values(**insert_map),
)
conn.commit()
# FIXME: Too many locals, refactor?
def _create_insert_map(self, gp): # pylint: disable=too-many-locals
insert_map = {
"symbol_name": gp.gene_symbol,
}
gs_categories_count = {
c["category"]: 0
for c in self.configuration["gene_sets"]
}
for gsc in self.configuration["gene_sets"]:
category = gsc["category"]
for gene_set in gsc["sets"]:
collection_id = gene_set["collection_id"]
gs_id = gene_set["set_id"]
set_col = f"{collection_id}_{gs_id}"
if set_col in gp.gene_sets:
insert_map[set_col] = 1
gs_categories_count[category] += 1
else:
insert_map[set_col] = 0
for category, count in gs_categories_count.items():
insert_map[f"{category}_rank"] = count
for category, scores in gp.genomic_scores.items():
for score_id, score in scores.items():
insert_map[f"{category}_{score_id}"] = score
for study_id, ps_counts in gp.variant_counts.items():
for person_set_id, eff_type_counts in ps_counts.items():
for eff_type, count in eff_type_counts.items():
count_col = f"{study_id}_{person_set_id}_{eff_type}"
insert_map[count_col] = 0
insert_map[f"{count_col}_rate"] = 0
return insert_map
[docs]
def insert_gps(self, gps):
"""Insert multiple GPStatistics into the DB."""
with self.engine.begin() as connection:
self._clear_gp_table(connection)
gp_count = len(gps)
for idx, gp in enumerate(gps, 1):
self.insert_gp(gp, connection)
if idx % 1000 == 0:
logger.info(
"Inserted %s/%s GPs into DB", idx, gp_count)
logger.info("Done!")
connection.commit()
[docs]
def update_gps_with_values(self, gs_values: dict[str, Any]) -> None:
with self.engine.begin() as connection:
for gs, values in gs_values.items():
update = self.gp_table.update().values(**values).where(
self.gp_table.c.symbol_name == gs,
)
connection.execute(update)
connection.commit()