Source code for dae.pheno.common

from __future__ import annotations

import enum
import time
from typing import Any

import duckdb
import sqlglot
from pydantic import BaseModel, ConfigDict
from sqlglot import insert, select
from sqlglot.expressions import Table, table_

from dae.utils.sql_utils import to_duckdb_transpile

IMPORT_METADATA_TABLE = table_("import_metadata")


[docs] class RankRange(BaseModel): model_config = ConfigDict(extra="forbid") min_rank: int | None = None max_rank: int | None = None
[docs] class InferenceConfig(BaseModel): """Classification inference configuration class.""" model_config = ConfigDict(extra="forbid") min_individuals: int = 1 non_numeric_cutoff: float = 0.06 value_max_len: int = 32 continuous: RankRange = RankRange(min_rank=10) ordinal: RankRange = RankRange(min_rank=1) categorical: RankRange = RankRange(min_rank=1, max_rank=15) skip: bool = False value_type: str | None = None histogram_type: str | None = None
[docs] class MeasureHistogramConfigs(BaseModel): """Classification histogram configuration class.""" model_config = ConfigDict(extra="forbid") number_config: dict = {} categorical_config: dict = {}
[docs] class DataDictionaryConfig(BaseModel): """Pydantic model for data dictionary config entries.""" model_config = ConfigDict(extra="forbid") path: str instrument: str | None = None delimiter: str = "\t" instrument_column: str = "instrumentName" measure_column: str = "measureName" description_column: str = "description"
[docs] class MeasureDescriptionsConfig(BaseModel): model_config = ConfigDict(extra="forbid") files: list[DataDictionaryConfig] | None = None # {Instrument -> {Measure -> Description}} dictionary: dict[str, dict[str, str]] | None = None
[docs] class RegressionMeasure(BaseModel): model_config = ConfigDict(extra="forbid") instrument_name: str measure_names: list[str] jitter: float display_name: str
[docs] class StudyConfig(BaseModel): regressions: str | dict[str, RegressionMeasure] | None = None common_report: dict[str, Any] | None = None person_set_collections: dict[str, Any] | None = None
[docs] class GPFInstanceConfig(BaseModel): model_config = ConfigDict(extra="forbid") path: str
[docs] class DestinationConfig(BaseModel): model_config = ConfigDict(extra="forbid") storage_id: str | None = None storage_dir: str | None = None
[docs] class InstrumentConfig(BaseModel): model_config = ConfigDict(extra="forbid") path: str instrument: str | None = None delimiter: str | None = None person_column: str | None = None
[docs] class PhenoImportConfig(BaseModel): """Pheno import config.""" model_config = ConfigDict(extra="forbid") id: str input_dir: str work_dir: str instrument_files: list[str | InstrumentConfig] pedigree: str person_column: str delimiter: str = "," destination: DestinationConfig | None = None gpf_instance: GPFInstanceConfig | None = None skip_pedigree_measures: bool = False inference_config: str | dict[str, InferenceConfig] | None = None histogram_configs: MeasureHistogramConfigs | None = None data_dictionary: MeasureDescriptionsConfig | None = None study_config: StudyConfig | None = None
[docs] class MeasureType(enum.Enum): """Definition of measure types.""" # pylint: disable=invalid-name continuous = 1 ordinal = 2 categorical = 3 text = 4 raw = 5 other = 100 skipped = 1000
[docs] @staticmethod def from_str(measure_type: str) -> MeasureType: if measure_type in MeasureType.__members__: return MeasureType[measure_type] raise ValueError("unexpected measure type", measure_type)
[docs] @staticmethod def is_numeric(measure_type: MeasureType) -> bool: return measure_type in {MeasureType.continuous, MeasureType.ordinal}
[docs] @staticmethod def is_text(measure_type: MeasureType) -> bool: return not MeasureType.is_numeric(measure_type)
[docs] class ImportManifest(BaseModel): """Import manifest for checking cache validity.""" unix_timestamp: float import_config: PhenoImportConfig
[docs] def is_older_than(self, other: ImportManifest) -> bool: if self.unix_timestamp < other.unix_timestamp: return True return self.import_config != other.import_config
[docs] @staticmethod def from_row(row: tuple[str, Any, str]) -> ImportManifest: timestamp = float(row[0]) import_config = PhenoImportConfig.model_validate_json(row[1]) return ImportManifest( unix_timestamp=timestamp, import_config=import_config, )
[docs] @staticmethod def from_table( connection: duckdb.DuckDBPyConnection, table: Table, ) -> list[ImportManifest]: """Read manifests from given table.""" with connection.cursor() as cursor: table_row = cursor.execute(sqlglot.parse_one( "SELECT * FROM information_schema.tables" # noqa: S608 f" WHERE table_name = '{table.alias_or_name}'", ).sql()).fetchone() if table_row is None: return [] rows = cursor.execute(select("*").from_(table).sql()).fetchall() return [ImportManifest.from_row(row) for row in rows]
[docs] @staticmethod def create_table(connection: duckdb.DuckDBPyConnection, table: Table): """Create table for recording import manifests.""" drop = sqlglot.parse_one( f"DROP TABLE IF EXISTS {table.alias_or_name}").sql() create = sqlglot.parse_one( f"CREATE TABLE {table.alias_or_name} " "(unix_timestamp DOUBLE, import_config VARCHAR)", ).sql() with connection.cursor() as cursor: cursor.execute(drop) cursor.execute(create)
[docs] @staticmethod def write_to_db( connection: duckdb.DuckDBPyConnection, table: Table, import_config: PhenoImportConfig, ): """Write manifest into DB on given table.""" config_json = import_config.model_dump_json() timestamp = time.time() query = insert( f"VALUES ({timestamp}, '{config_json}')", table, ) with connection.cursor() as cursor: cursor.execute(to_duckdb_transpile(query)).fetchall()