Source code for dae.parquet.schema2.serializers

from __future__ import annotations

import abc
import functools
import json
import logging
import operator
from itertools import starmap
from typing import Any, cast

import pyarrow as pa

from dae.annotation.annotation_pipeline import AttributeInfo
from dae.variants.attributes import (
    Inheritance,
    Role,
    Sex,
    Status,
    TransmissionType,
)
from dae.variants.core import Allele
from dae.variants.family_variant import FamilyAllele, FamilyVariant
from dae.variants.variant import SummaryAllele, SummaryVariant

logger = logging.getLogger(__name__)

SUMMARY_ALLELE_BASE_SCHEMA: dict[str, Any] = {
    "bucket_index": pa.int32(),
    "summary_index": pa.int32(),
    "allele_index": pa.int32(),
    "sj_index": pa.int64(),
    "chromosome": pa.string(),
    "position": pa.int32(),
    "end_position": pa.int32(),
    "effect_gene": pa.list_(
        pa.struct([
            pa.field("effect_gene_symbols", pa.string()),
            pa.field("effect_types", pa.string()),
        ]),
    ),
    "variant_type": pa.int8(),
    "transmission_type": pa.int8(),
    "reference": pa.string(),
    "af_allele_count": pa.int32(),
    "af_allele_freq": pa.float32(),
    "af_parents_called_count": pa.int32(),
    "af_parents_called_percent": pa.float32(),
    "seen_as_denovo": pa.bool_(),
    "seen_in_status": pa.int8(),
    "family_variants_count": pa.int32(),
    "family_alleles_count": pa.int32(),
}

FAMILY_ALLELE_BASE_SCHEMA: dict[str, Any] = {
    "bucket_index": pa.int32(),
    "summary_index": pa.int32(),
    "allele_index": pa.int32(),
    "sj_index": pa.int64(),
    "family_index": pa.int32(),
    "family_id": pa.string(),
    "is_denovo": pa.int8(),
    "allele_in_sexes": pa.int8(),
    "allele_in_statuses": pa.int8(),
    "allele_in_roles": pa.int32(),
    "inheritance_in_members": pa.int16(),
    "zygosity_in_status": pa.int8(),
    "zygosity_in_roles": pa.int64(),
    "zygosity_in_sexes": pa.int16(),
    "allele_in_members": pa.list_(pa.string()),
}

ENUM_PROPERTIES: dict[str, Any] = {
    "variant_type": Allele.Type,
    "transmission_type": TransmissionType,
    "allele_in_sexes": Sex,
    "allele_in_roles": Role,
    "allele_in_statuses": Status,
    "inheritance_in_members": Inheritance,
}


[docs] def build_summary_schema( annotation_schema: list[AttributeInfo], ) -> pa.Schema: """Build the schema for the summary alleles.""" fields = list( starmap(pa.field, SUMMARY_ALLELE_BASE_SCHEMA.items())) fields.append(pa.field("summary_variant_data", pa.binary())) annotation_type_to_pa_type = { "float": pa.float32(), "int": pa.int32(), "str": pa.string(), } if annotation_schema is not None: for attr in annotation_schema: if attr.internal: continue if attr.type in annotation_type_to_pa_type: fields.append( pa.field( attr.name, annotation_type_to_pa_type[attr.type], ), ) return pa.schema(fields)
[docs] def build_family_schema() -> pa.Schema: """Build the schema for the family alleles.""" fields = list(starmap(pa.field, FAMILY_ALLELE_BASE_SCHEMA.items())) fields.append(pa.field("family_variant_data", pa.binary())) return pa.schema(fields)
[docs] class AlleleParquetSerializer(abc.ABC): """Base class for serializing alleles to parquet format.""" def __init__( self, annotation_schema: list[AttributeInfo], extra_attributes: list[str] | None = None, ) -> None: self.annotation_schema = annotation_schema self._schema: pa.Schema | None = None self.extra_attributes = [] if extra_attributes: self.extra_attributes = extra_attributes[:] def _get_searchable_prop_value( self, allele: SummaryAllele | FamilyAllele, spr: str, ) -> Any: prop_value = getattr(allele, spr, None) if prop_value is None: prop_value = allele.get_attribute(spr) if prop_value and spr in ENUM_PROPERTIES: if isinstance(prop_value, list): prop_value = functools.reduce( operator.or_, [enum.value for enum in prop_value if enum is not None], 0, ) else: prop_value = prop_value.value return prop_value
[docs] @abc.abstractmethod def schema(self) -> pa.Schema: """Lazy construct and return the schema for the summary alleles."""
[docs] @abc.abstractmethod def blob_column(self) -> str: """Return the name of the column that contains the variant blob."""
[docs] @abc.abstractmethod def build_allele_record_dict( self, allele: SummaryAllele | FamilyAllele, variant_blob: bytes, ) -> dict[str, Any]: """Build a record from allele data in the form of a dict."""
[docs] class SummaryAlleleParquetSerializer(AlleleParquetSerializer): """Serialize summary alleles for parquet storage."""
[docs] def schema(self) -> pa.Schema: """Lazy construct and return the schema for the summary alleles.""" if self._schema is None: self._schema = self.build_schema( self.annotation_schema, ) return self._schema
[docs] def blob_column(self) -> str: return "summary_variant_data"
[docs] @classmethod def build_schema( cls, annotation_schema: list[AttributeInfo], ) -> pa.Schema: """Build the schema for the summary alleles.""" return build_summary_schema(annotation_schema)
[docs] @classmethod def build_blob_schema( cls, annotation_schema: list[AttributeInfo], ) -> dict[str, str]: schema_summary = cls.build_schema(annotation_schema) return { f.name: str(f.type) for f in schema_summary if f.name not in { "effect_gene", "summary_variant_data", "chromosome", } }
[docs] def build_allele_record_dict( self, allele: SummaryAllele, variant_blob: bytes, ) -> dict[str, Any]: """Build a record of summary allele data in the form of a dict.""" allele_data = {"summary_variant_data": variant_blob} for spr in SUMMARY_ALLELE_BASE_SCHEMA: if spr == "effect_gene": if allele.effect_types is None: assert allele.effect_gene_symbols is None prop_value = [ {"effect_types": None, "effect_gene_symbols": None}, ] else: prop_value = [ {"effect_types": e[0], "effect_gene_symbols": e[1]} for e in zip( allele.effect_types, allele.effect_gene_symbols, strict=True, ) ] else: prop_value = self._get_searchable_prop_value(allele, spr) allele_data[spr] = prop_value # type: ignore if self.annotation_schema is not None: for attr in self.annotation_schema: if attr.internal: continue allele_data[attr.name] = allele.get_attribute(attr.name) return allele_data
[docs] class FamilyAlleleParquetSerializer(AlleleParquetSerializer): """Serialize family alleles."""
[docs] def schema(self) -> pa.Schema: """Lazy construct and return the schema for the family alleles.""" if self._schema is None: self._schema = self.build_schema() return self._schema
[docs] def blob_column(self) -> str: return "family_variant_data"
[docs] @classmethod def build_schema(cls) -> pa.Schema: """Build the schema for the family alleles.""" return build_family_schema()
[docs] def build_allele_record_dict( self, allele: SummaryAllele | FamilyAllele, variant_blob: bytes, ) -> dict[str, Any]: """Build a batch of family allele data in the form of a dict.""" allele_data: dict[str, Any] = { "family_variant_data": variant_blob, } for spr in FAMILY_ALLELE_BASE_SCHEMA: prop_value = self._get_searchable_prop_value(allele, spr) allele_data[spr] = prop_value allele_in_member = allele_data["allele_in_members"] allele_data["allele_in_members"] = [ m for m in allele_in_member if m is not None ] return allele_data
[docs] class VariantsDataSerializer(abc.ABC): """Interface for serializing family and summary alleles."""
[docs] @abc.abstractmethod def serialize_family( self, variant: FamilyVariant, ) -> bytes: """Serialize a family variant part to a byte string."""
[docs] @abc.abstractmethod def serialize_summary( self, variant: SummaryVariant, ) -> bytes: """Serialize a summary allele to a byte string."""
[docs] @abc.abstractmethod def deserialize_family_record( self, data: bytes, ) -> dict[str, Any]: """Deserialize a family allele from a byte string."""
[docs] @abc.abstractmethod def deserialize_summary_record( self, data: bytes, ) -> list[dict[str, Any]]: """Deserialize a summary allele from a byte string."""
[docs] @staticmethod def build_serializer( ) -> VariantsDataSerializer: """Build a serializer based on the metadata.""" return JsonVariantsDataSerializer()
[docs] class JsonVariantsDataSerializer(VariantsDataSerializer): """Serialize family and summary alleles to json."""
[docs] def serialize_family( self, variant: FamilyVariant, ) -> bytes: return json.dumps(variant.to_record(), sort_keys=True).encode()
[docs] def serialize_summary( self, variant: SummaryVariant, ) -> bytes: return json.dumps(variant.to_record(), sort_keys=True).encode()
[docs] def deserialize_family_record( self, data: bytes, ) -> dict[str, Any]: return cast(dict[str, Any], json.loads(data))
[docs] def deserialize_summary_record( self, data: bytes, ) -> list[dict[str, Any]]: return cast(list[dict[str, Any]], json.loads(data))