Source code for dae.variants_loaders.parquet.loader

import glob
import os
import pathlib
from collections.abc import Generator, Iterable
from typing import ClassVar

import numpy as np
import yaml
from pyarrow import compute as pc
from pyarrow import dataset as ds
from pyarrow import parquet as pq

from dae.parquet.partition_descriptor import PartitionDescriptor
from dae.parquet.schema2.variant_serializers import VariantsDataSerializer
from dae.pedigrees.families_data import FamiliesData
from dae.pedigrees.loader import FamiliesLoader
from dae.schema2_storage.schema2_layout import (
    Schema2DatasetLayout,
    load_schema2_dataset_layout,
)
from dae.utils.regions import Region
from dae.variants.attributes import Inheritance, Role, Sex, Status
from dae.variants.family_variant import FamilyVariant
from dae.variants.variant import SummaryVariant, SummaryVariantFactory


[docs] class ParquetLoaderException(Exception): pass
[docs] class Reader: """ Helper class to incrementally fetch variants. This class assumes variants are ordered by their bucket and summary index! """ BATCH_SIZE = 5000 def __init__(self, path: str, columns: Iterable[str]): if "summary_index" not in columns or "bucket_index" not in columns: raise ValueError self.pq_file = pq.ParquetFile(path) self.iterator = self.pq_file.iter_batches( columns=columns, batch_size=Reader.BATCH_SIZE) self.batch: list[dict] = [] self.exhausted = False def __del__(self) -> None: self.close() def __iter__(self) -> "Reader": return self def __next__(self) -> list[dict]: """Return next batch of variants with matching indices.""" if self.exhausted: raise StopIteration result: list[dict] = [] initial_idx = self.current_idx while self.current_idx == initial_idx: result.append(self._pop()) return result @property def current_idx(self) -> tuple[int, int]: top = self._peek() if top is None: return (-1, -1) return int(top["bucket_index"]), int(top["summary_index"]) def _advance(self) -> None: if self.exhausted: return try: self.batch = next(self.iterator).to_pylist() except StopIteration: self.exhausted = True def _peek(self) -> dict | None: if not self.batch: self._advance() if self.exhausted: return None return self.batch[0] def _pop(self) -> dict: if self._peek() is None: raise IndexError return self.batch.pop(0)
[docs] def close(self) -> None: self.pq_file.close()
[docs] class MultiReader: """ Incrementally fetch variants from multiple files. This class assumes variants are ordered by their bucket and summary index! """ def __init__(self, dirs: Iterable[str], columns: Iterable[str]): self.readers = tuple(Reader(path, columns) for path in dirs) def __del__(self) -> None: self.close() def __iter__(self) -> "MultiReader": return self def __next__(self) -> list[dict]: if self._exhausted: raise StopIteration result = [] iteration_idx = self.current_idx for reader in self.readers: if not reader.exhausted: while reader.current_idx == iteration_idx: result.extend(next(reader)) return result @property def _exhausted(self) -> bool: return all(reader.exhausted for reader in self.readers) @property def current_idx(self) -> tuple[int, int]: if self._exhausted: return (-1, -1) return min(reader.current_idx for reader in self.readers if not reader.exhausted)
[docs] def close(self) -> None: for reader in self.readers: reader.close()
[docs] class ParquetLoader: """Variants loader implementation for the Parquet format.""" SUMMARY_COLUMNS: ClassVar[list[str]] = [ "bucket_index", "summary_index", "allele_index", "summary_variant_data", "chromosome", "position", "end_position", ] FAMILY_COLUMNS: ClassVar[list[str]] = [ "bucket_index", "summary_index", "family_id", "family_variant_data", ] def __init__(self, layout: Schema2DatasetLayout): self.layout = layout if not os.path.exists(self.layout.pedigree): raise ParquetLoaderException( f"No pedigree file exists in {self.layout.study}!") self.families: FamiliesData = self._load_families(self.layout.pedigree) meta_file = pq.ParquetFile(self.layout.meta) self.meta = {row["key"]: row["value"] for row in meta_file.read().to_pylist()} meta_file.close() self.has_annotation = bool( yaml.safe_load(self.meta.get("annotation_pipeline", "").strip())) self.partitioned: bool = \ self.meta.get("partition_description", "").strip() self.partition_descriptor = PartitionDescriptor.parse_string( self.meta.get("partition_description", "").strip(), ) variants_data_meta = None if "variants_data_schema" in self.meta: variants_data_meta = yaml.safe_load( self.meta["variants_data_schema"]) self.serializer = VariantsDataSerializer.build_serializer( variants_data_meta, ) self.files_per_region = self._scan_region_bins() self.contigs: dict[str, int] = {} if self.meta.get("contigs"): self.contigs = { contig[0]: int(contig[1]) for contig in [r.split("=") for r in self.meta["contigs"].split(",")] }
[docs] @staticmethod def load_from_dir(input_dir: str) -> "ParquetLoader": return ParquetLoader(load_schema2_dataset_layout(input_dir))
def _scan_region_bins(self) -> dict[tuple[str, str], list[str]]: if not self.layout.summary: return {} if not self.partitioned \ or not self.partition_descriptor.has_region_bins(): return {} result: dict[tuple[str, str], list[str]] = {} for path in ds.dataset(f"{self.layout.summary}").files: summary = str(pathlib.Path(path).relative_to(self.layout.summary)) partitions = self.partition_descriptor.path_to_partitions(summary) region_partition = None for partition in partitions: if partition[0] == "region_bin": region_partition = partition if region_partition is None: raise ValueError result.setdefault(region_partition, []).append(path) return result @staticmethod def _extract_region_bin(path: str) -> tuple[str, int]: # (...)/region_bin=chr1_0/(...) # ^~~~~^ start = path.find("region_bin=") + 11 end = path.find("/", start) rbin = path[start:end].split("_") return rbin[0], int(rbin[1]) @staticmethod def _load_families(path: str) -> FamiliesData: parquet_file = pq.ParquetFile(path) ped_df = parquet_file.read().to_pandas() parquet_file.close() ped_df.role = ped_df.role.apply(Role.from_value) ped_df.sex = ped_df.sex.apply(Sex.from_value) ped_df.status = ped_df.status.apply(Status.from_value) ped_df.loc[ped_df.layout.isna(), "layout"] = None return FamiliesLoader.build_families_data_from_pedigree(ped_df) def _pq_file_in_region(self, path: str, region: Region) -> bool: if not self.partition_descriptor.has_region_bins(): raise ParquetLoaderException( f"No region bins exist in {self.layout.study}!") normalized_region = Region( (region.chrom if region.chrom in self.partition_descriptor.chromosomes else "other"), region.start, region.stop, ) rbin = ParquetLoader._extract_region_bin(path) bin_region = Region( rbin[0], (rbin[1] * self.partition_descriptor.region_length) + 1, (rbin[1] + 1) * self.partition_descriptor.region_length, ) return bin_region.intersects(normalized_region)
[docs] def get_summary_pq_filepaths( self, region: Region | None = None, ) -> Generator[list[str], None, None]: """ Produce paths to available Parquet files grouped by region. Can filter by region if region bins are configured. """ if not self.layout.summary: return if not self.partitioned \ or not self.partition_descriptor.has_region_bins(): yield list(ds.dataset(f"{self.layout.summary}").files) return if region is None: region_bins = list(self.files_per_region.keys()) else: region_bins = [ ("region_bin", r) for r in self.partition_descriptor.region_to_region_bins( region, self.contigs) ] for r_bin in region_bins: if r_bin in self.files_per_region: # check with if since some region bins may not exist # if no variants were written there yield self.files_per_region[r_bin]
[docs] def get_family_pq_filepaths(self, summary_path: str) -> list[str]: """Get all family parquet files for given summary parquet file.""" if not self.layout.summary or not self.layout.family: return [] if not os.path.exists(summary_path): raise ParquetLoaderException( f"Non-existent summary path given - {summary_path}") if not summary_path.startswith(self.layout.summary): raise ParquetLoaderException( f"Invalid summary path given - {summary_path}") bins = os.path.relpath( os.path.dirname(summary_path), self.layout.summary, ) glob_dir = os.path.join( self.layout.family, bins, "**", "*.parquet", ) return glob.glob(glob_dir, recursive=True)
def _deserialize_summary_variant(self, record: bytes) -> SummaryVariant: return SummaryVariantFactory.summary_variant_from_records( self.serializer.deserialize_summary_record(record), ) def _deserialize_family_variant( self, record: bytes, summary_variant: SummaryVariant, ) -> FamilyVariant: fv_record = self.serializer.deserialize_family_record(record) inheritance_in_members = { int(k): [Inheritance.from_value(inh) for inh in v] for k, v in fv_record["inheritance_in_members"].items() } return FamilyVariant( summary_variant, self.families[fv_record["family_id"]], np.array(fv_record["genotype"]), np.array(fv_record["best_state"]), inheritance_in_members=inheritance_in_members, )
[docs] def fetch_summary_variants( self, region: Region | None = None, ) -> Generator[SummaryVariant, None, None]: """Iterate over summary variants.""" region_filter = None if region is not None: region_filter = pc.field("chromosome") == region.chrom if region.start is not None: region_filter = ( region_filter & (pc.field("end_position") >= region.start) ) if region.stop is not None: region_filter = ( region_filter & (pc.field("position") <= region.stop) ) for summary_paths in self.get_summary_pq_filepaths(region): if not summary_paths: continue seen = set() for s_path in summary_paths: table = pq.read_table( s_path, columns=self.SUMMARY_COLUMNS, filters=region_filter) for rec in table.to_pylist(): v_id = (rec["bucket_index"], rec["summary_index"]) if v_id not in seen: seen.add(v_id) yield self._deserialize_summary_variant( rec["summary_variant_data"], )
[docs] def fetch_variants( self, region: Region | None = None, ) -> Generator[tuple[SummaryVariant, list[FamilyVariant]], None, None]: """Iterate over summary and family variants.""" for summary_paths in self.get_summary_pq_filepaths(region): if not summary_paths: continue family_paths: list[str] = [] for path in summary_paths: family_paths.extend(self.get_family_pq_filepaths(path)) summary_reader = MultiReader(summary_paths, self.SUMMARY_COLUMNS) family_reader = MultiReader(family_paths, self.FAMILY_COLUMNS) for alleles in summary_reader: rec = alleles[0] if region is not None \ and not region.intersects(Region(rec["chromosome"], rec["position"], rec["end_position"])): continue sv_idx = (rec["bucket_index"], rec["summary_index"]) sv = self._deserialize_summary_variant( rec["summary_variant_data"]) fvs: list[dict] = [] try: while sv_idx > family_reader.current_idx: next(family_reader) fvs = next(family_reader) except StopIteration: pass seen = set() to_yield = [] for fv in fvs: fv_id = (fv["summary_index"], fv["family_id"]) if fv_id not in seen: seen.add(fv_id) to_yield.append(self._deserialize_family_variant( fv["family_variant_data"], sv)) yield (sv, to_yield) summary_reader.close() family_reader.close()
[docs] def fetch_family_variants( self, region: Region | None = None, ) -> Generator[FamilyVariant, None, None]: """Iterate over family variants.""" for _, fvs in self.fetch_variants(region): yield from fvs