Source code for dae.parquet.schema2.parquet_io

import functools
import logging
import os
import time
from collections.abc import Iterator
from typing import Any, cast

import fsspec
import pyarrow as pa
import pyarrow.parquet as pq

from dae.annotation.annotation_pipeline import AttributeInfo
from dae.parquet.helpers import url_to_pyarrow_fs
from dae.parquet.partition_descriptor import PartitionDescriptor
from dae.parquet.schema2.serializers import AlleleParquetSerializer
from dae.parquet.schema2.variant_serializers import (
    VariantsDataSerializer,
)
from dae.utils import fs_utils
from dae.utils.variant_utils import (
    is_all_reference_genotype,
    is_unknown_genotype,
)
from dae.variants.attributes import Inheritance
from dae.variants.family_variant import FamilyAllele, FamilyVariant
from dae.variants.variant import (
    SummaryAllele,
    SummaryVariant,
)

logger = logging.getLogger(__name__)


[docs] class ContinuousParquetFileWriter: """A continous parquet writer. Class that automatically writes to a given parquet file when supplied enough data. Automatically dumps leftover data when closing into the file """ BATCH_ROWS = 1_000 DEFAULT_COMPRESSION = "SNAPPY" def __init__( self, filepath: str, annotation_schema: list[AttributeInfo], filesystem: fsspec.AbstractFileSystem | None = None, row_group_size: int = 10_000, schema: str = "schema", blob_column: str | None = None, ) -> None: self.filepath = filepath self.annotation_schema = annotation_schema self.serializer = AlleleParquetSerializer( self.annotation_schema, ) self.schema = getattr(self.serializer, schema) dirname = os.path.dirname(filepath) if dirname and not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True) self.dirname = dirname filesystem, filepath = url_to_pyarrow_fs(filepath, filesystem) compression: str | dict[str, str] = self.DEFAULT_COMPRESSION if blob_column is not None: compression = {} for name in self.schema.names: compression[name] = self.DEFAULT_COMPRESSION compression[blob_column] = "ZSTD" self._writer = pq.ParquetWriter( filepath, self.schema, compression=compression, filesystem=filesystem, use_compliant_nested_type=True, write_page_index=True, ) self.row_group_size = row_group_size self._batches: list[pa.RecordBatch] = [] self._data: dict[str, Any] | None = None self.data_reset()
[docs] def data_reset(self) -> None: self._data = {name: [] for name in self.schema.names}
[docs] def size(self) -> int: assert self._data is not None return len(self._data["bucket_index"])
[docs] def build_table(self) -> pa.Table: logger.info( "writing %s rows to parquet %s", sum(len(b) for b in self._batches), self.filepath) return pa.Table.from_batches(self._batches, self.schema)
[docs] def build_batch(self) -> pa.RecordBatch: return pa.RecordBatch.from_pydict(self._data, self.schema)
def _write_batch(self) -> None: if self.size() == 0: return batch = self.build_batch() self._batches.append(batch) self.data_reset() if len(self._batches) >= self.row_group_size // self.BATCH_ROWS: self._flush_batches() def _flush_batches(self) -> None: if len(self._batches) == 0: return logger.debug( "flushing %s batches", len(self._batches)) self._writer.write_table(self.build_table()) self._batches = []
[docs] def append_summary_allele( self, allele: SummaryAllele, json_data: bytes) -> None: """Append the data for an entire variant to the correct file.""" assert self._data is not None data = self.serializer.build_summary_allele_batch_dict( allele, json_data, ) for k, v in self._data.items(): v.append(data[k]) if self.size() >= self.BATCH_ROWS: logger.debug( "parquet writer %s create summary batch at len %s", self.filepath, self.size()) self._write_batch()
[docs] def append_family_allele( self, allele: FamilyAllele, json_data: bytes) -> None: """Append the data for an entire variant to the correct file.""" assert self._data is not None data = self.serializer.build_family_allele_batch_dict( allele, json_data, ) for k, v in self._data.items(): v.extend(data[k]) if self.size() >= self.BATCH_ROWS: logger.debug( "parquet writer %s create family batch at len %s", self.filepath, self.size()) self._write_batch()
[docs] def close(self) -> None: """Close the parquet writer and write any remaining data.""" logger.debug( "closing parquet writer %s with %d rows", self.filepath, self.size()) self._write_batch() self._flush_batches() self._writer.close()
[docs] class VariantsParquetWriter: """Provide functions for storing variants into parquet dataset.""" def __init__( self, out_dir: str, annotation_schema: list[AttributeInfo], partition_descriptor: PartitionDescriptor, *, serializer: VariantsDataSerializer | None = None, bucket_index: int = 1, row_group_size: int = 10_000, include_reference: bool = False, filesystem: fsspec.AbstractFileSystem | None = None, ) -> None: self.out_dir = out_dir if serializer is None: serializer = VariantsDataSerializer.build_serializer() self.serializer = serializer self.bucket_index = bucket_index assert self.bucket_index < 1_000_000, "bad bucket index" self.row_group_size = row_group_size self.filesystem = filesystem self.include_reference = include_reference self.start = time.time() self.data_writers: dict[str, ContinuousParquetFileWriter] = {} assert isinstance(partition_descriptor, PartitionDescriptor) self.partition_descriptor = partition_descriptor self.annotation_schema = annotation_schema def _build_family_filename( self, allele: FamilyAllele, *, seen_as_denovo: bool, ) -> str: partition = self.partition_descriptor.family_partition( allele, seen_as_denovo=seen_as_denovo) partition_directory = self.partition_descriptor.partition_directory( fs_utils.join(self.out_dir, "family"), partition) partition_filename = self.partition_descriptor.partition_filename( "family", partition, self.bucket_index) return fs_utils.join(partition_directory, partition_filename) def _build_summary_filename( self, allele: SummaryAllele, *, seen_as_denovo: bool, ) -> str: partition = self.partition_descriptor.summary_partition( allele, seen_as_denovo=seen_as_denovo) partition_directory = self.partition_descriptor.partition_directory( fs_utils.join(self.out_dir, "summary"), partition) partition_filename = self.partition_descriptor.partition_filename( "summary", partition, self.bucket_index) return fs_utils.join(partition_directory, partition_filename) def _get_bin_writer_family( self, allele: FamilyAllele, *, seen_as_denovo: bool, ) -> ContinuousParquetFileWriter: filename = self._build_family_filename( allele, seen_as_denovo=seen_as_denovo) if filename not in self.data_writers: self.data_writers[filename] = ContinuousParquetFileWriter( filename, self.annotation_schema, filesystem=self.filesystem, row_group_size=self.row_group_size, schema="schema_family", blob_column="family_variant_data", ) return self.data_writers[filename] def _get_bin_writer_summary( self, allele: SummaryAllele, *, seen_as_denovo: bool, ) -> ContinuousParquetFileWriter: filename = self._build_summary_filename( allele, seen_as_denovo=seen_as_denovo) if filename not in self.data_writers: self.data_writers[filename] = ContinuousParquetFileWriter( filename, self.annotation_schema, filesystem=self.filesystem, row_group_size=self.row_group_size, schema="schema_summary", blob_column="summary_variant_data", ) return self.data_writers[filename] def _calc_sj_index(self, summary_index: int, allele_index: int) -> int: assert allele_index < 10_000, "too many alleles" return self._calc_sj_base_index(summary_index) + allele_index def _calc_sj_base_index(self, summary_index: int) -> int: return ( self.bucket_index * 1_000_000_000 + summary_index) * 10_000
[docs] def write_dataset( self, full_variants_iterator: Iterator[ tuple[SummaryVariant, list[FamilyVariant]]], ) -> list[str]: """Write variant to partitioned parquet dataset.""" # pylint: disable=too-many-locals,too-many-branches family_index = 0 summary_index = 0 for summary_index, ( summary_variant, family_variants, ) in enumerate(full_variants_iterator): assert summary_index < 1_000_000_000, \ "too many summary variants" num_fam_alleles_written = 0 seen_in_status = summary_variant.allele_count * [0] seen_as_denovo = summary_variant.allele_count * [False] family_variants_count = summary_variant.allele_count * [0] sj_base_index = self._calc_sj_base_index(summary_index) for fv in family_variants: family_index += 1 assert fv.gt is not None if is_all_reference_genotype(fv.gt) and \ not self.include_reference: continue fv.summary_index = summary_index fv.family_index = family_index allele_indexes = set() for fa in fv.alleles: assert fa.allele_index not in allele_indexes allele_indexes.add(fa.allele_index) extra_atts = { "bucket_index": self.bucket_index, "family_index": family_index, "sj_index": sj_base_index + fa.allele_index, } fa.update_attributes(extra_atts) family_variant_data_json = self.serializer.serialize_family(fv) family_alleles = [] if is_unknown_genotype(fv.gt) or \ is_all_reference_genotype(fv.gt): assert fv.ref_allele.allele_index == 0 family_alleles.append(fv.ref_allele) num_fam_alleles_written += 1 elif self.include_reference: family_alleles.append(fv.ref_allele) family_alleles.extend(fv.alt_alleles) for aa in family_alleles: fa = cast(FamilyAllele, aa) seen_in_status[fa.allele_index] = functools.reduce( lambda t, s: t | s.value, filter(None, fa.allele_in_statuses), seen_in_status[fa.allele_index]) inheritance = list( filter( lambda v: v not in { None, Inheritance.unknown, Inheritance.missing}, fa.inheritance_in_members)) sad = any( i == Inheritance.denovo for i in inheritance) seen_as_denovo[fa.allele_index] = \ sad or seen_as_denovo[fa.allele_index] family_bin_writer = self._get_bin_writer_family( fa, seen_as_denovo=sad) family_bin_writer.append_family_allele( fa, family_variant_data_json, ) family_variants_count[fa.allele_index] += 1 num_fam_alleles_written += 1 # don't store summary alleles withouth family ones if num_fam_alleles_written > 0: summary_variant.summary_index = summary_index summary_variant.ref_allele.update_attributes( {"bucket_index": self.bucket_index}) summary_variant.update_attributes({ "seen_in_status": seen_in_status[1:], "seen_as_denovo": seen_as_denovo[1:], "family_variants_count": family_variants_count[1:], "family_alleles_count": family_variants_count[1:], "bucket_index": [self.bucket_index], }) self.write_summary_variant( summary_variant, sj_base_index=sj_base_index, ) if summary_index % 1000 == 0 and summary_index > 0: elapsed = time.time() - self.start logger.info( "progress bucked %s; " "summary variants: %s; family variants: %s; " "elapsed time: %0.2f sec", self.bucket_index, summary_index, family_index, elapsed) filenames = list(self.data_writers.keys()) self.close() elapsed = time.time() - self.start logger.info( "finished bucked %s; summary variants: %s; family variants: %s; " "elapsed time: %0.2f sec", self.bucket_index, summary_index, family_index, elapsed) return filenames
[docs] def close(self) -> None: for bin_writer in self.data_writers.values(): bin_writer.close()
[docs] def write_summary_variant( self, summary_variant: SummaryVariant, attributes: dict[str, Any] | None = None, sj_base_index: int | None = None, ) -> None: """Write a single summary variant to the correct parquet file.""" if attributes is not None: summary_variant.update_attributes(attributes) if sj_base_index is not None: for summary_allele in summary_variant.alleles: sj_index = sj_base_index + summary_allele.allele_index extra_atts = { "sj_index": sj_index, } summary_allele.update_attributes(extra_atts) summary_blobs_json = self.serializer.serialize_summary(summary_variant) if self.include_reference: stored_alleles = summary_variant.alleles else: stored_alleles = summary_variant.alt_alleles for summary_allele in stored_alleles: seen_as_denovo = summary_allele.get_attribute("seen_as_denovo") summary_writer = self._get_bin_writer_summary( summary_allele, seen_as_denovo=seen_as_denovo) summary_writer.append_summary_allele( summary_allele, summary_blobs_json)