Source code for dae.annotation.annotate_vcf

from __future__ import annotations

import argparse
import itertools
import logging
import os
import sys
from collections.abc import Iterable, Sequence
from contextlib import closing
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
from typing import Any

from pysam import (
    TabixFile,
    VariantFile,
    VariantHeader,
    VariantRecord,
    tabix_index,
)

from dae.annotation.annotatable import VCFAllele
from dae.annotation.annotate_utils import (
    add_input_files_to_task_graph,
    build_output_path,
    cache_pipeline_resources,
    get_stuff_from_context,
    produce_partfile_paths,
    produce_regions,
    stringify,
)
from dae.annotation.annotation_config import (
    AttributeInfo,
    RawAnnotatorsConfig,
    RawPipelineConfig,
)
from dae.annotation.annotation_factory import (
    build_annotation_pipeline,
    load_pipeline_from_file,
)
from dae.annotation.annotation_pipeline import (
    ReannotationPipeline,
)
from dae.annotation.genomic_context import CLIAnnotationContextProvider
from dae.annotation.processing_pipeline import (
    Annotation,
    AnnotationPipelineAnnotatablesBatchFilter,
    AnnotationPipelineAnnotatablesFilter,
    AnnotationsWithSource,
)
from dae.genomic_resources.repository_factory import (
    build_genomic_resource_repository,
)
from dae.task_graph import TaskGraphCli
from dae.task_graph.graph import TaskGraph
from dae.utils.fs_utils import tabix_index_filename
from dae.utils.processing_pipeline import Filter, PipelineProcessor, Source
from dae.utils.regions import Region
from dae.utils.verbosity_configuration import VerbosityConfiguration

logger = logging.getLogger("annotate_vcf")


@dataclass
class _ProcessingArgs:
    input: str
    reannotate: str | None
    work_dir: str
    batch_size: int
    region_size: int
    allow_repeated_attributes: bool
    full_reannotation: bool


class _VCFSource(Source):
    """Source for reading from VCF files."""

    def __init__(self, path: str):
        self.path = path
        self.vcf: VariantFile
        with VariantFile(self.path, "r") as infile:
            self.header = infile.header

    def __enter__(self) -> _VCFSource:
        self.vcf = VariantFile(self.path, "r")
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during annotation: %s, %s, %s",
                exc_type, exc_value, exc_tb)

        self.vcf.close()

        return exc_type is None

    @staticmethod
    def _convert(variant: VariantRecord) -> AnnotationsWithSource:
        annotations = [
            Annotation(
                VCFAllele(variant.chrom,
                          variant.pos,
                          variant.ref,  # type: ignore
                          alt),
                {k: v[idx] for k, v in variant.info.items()},
            )
            for idx, alt in enumerate(variant.alts)  # type: ignore
        ]
        return AnnotationsWithSource(variant, annotations)

    def fetch(
        self, region: Region | None = None,
    ) -> Iterable[AnnotationsWithSource]:
        if region is None:
            in_file_iter = self.vcf.fetch()
        else:
            in_file_iter = self.vcf.fetch(region.chrom,
                                          region.start,
                                          region.stop)

        for vcf_var in in_file_iter:
            if vcf_var.ref is None:
                logger.warning(
                    "vcf variant without reference: %s %s",
                    vcf_var.chrom, vcf_var.pos,
                )
                continue

            if vcf_var.alts is None:
                logger.info(
                    "vcf variant without alternatives: %s %s",
                    vcf_var.chrom, vcf_var.pos,
                )
                continue

            yield _VCFSource._convert(vcf_var)


class _VCFBatchSource(Source):
    """Source for reading from VCF files in batches."""

    def __init__(
        self,
        path: str,
        batch_size: int = 500,
    ):
        self.source = _VCFSource(path)
        self.batch_size = batch_size

    def __enter__(self) -> _VCFBatchSource:
        self.source.__enter__()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during annotation: %s, %s, %s",
                exc_type, exc_value, exc_tb)

        self.source.__exit__(exc_type, exc_value, exc_tb)

        return exc_type is None

    def fetch(
        self, region: Region | None = None,
    ) -> Iterable[Sequence[AnnotationsWithSource]]:
        records = self.source.fetch(region)
        while batch := tuple(itertools.islice(records, self.batch_size)):
            yield batch


class _VCFWriter(Filter):
    """A filter that writes variants to a VCF file."""

    def __init__(
        self,
        path: str,
        header: VariantHeader,
        annotation_attributes: Sequence[AttributeInfo],
        attributes_to_delete: Sequence[str],
    ):
        self.path = path
        self.output_file: VariantFile
        self.header = self._update_header(
            header, annotation_attributes, attributes_to_delete)
        self.annotation_attributes = annotation_attributes
        self.attributes_to_delete = attributes_to_delete

    @staticmethod
    def _update_header(
        header: VariantHeader,
        annotation_attributes: Sequence[AttributeInfo],
        attributes_to_delete: Sequence[str],
    ) -> VariantHeader:
        """Update a variant file's header with annotation."""
        header.add_meta("pipeline_annotation_tool", "GPF variant annotation.")

        annotation_attr_names = [attr.name for attr in annotation_attributes]

        for info_key in header.info:
            if info_key in attributes_to_delete \
               and info_key not in annotation_attr_names:
                header.info.remove_header(info_key)

        attributes = [
            attr for attr in annotation_attributes
            if attr.name not in header.info
        ]

        for attribute in attributes:
            description = attribute.description
            description = description.replace("\n", " ")
            description = description.replace('"', '\\"')
            header.info.add(attribute.name, "A", "String", description)

        return header

    @staticmethod
    def _convert_to_string(attr: Any) -> str:
        if isinstance(attr, list):
            attr = ";".join(stringify(a, vcf=True) for a in attr)
        elif isinstance(attr, dict):
            attr = ";".join(
                f"{k}:{v}"
                for k, v in attr.items()
            )
        return stringify(attr, vcf=True) \
            .replace(";", "|") \
            .replace(",", "|") \
            .replace(" ", "_")

    @staticmethod
    def _update_variant(
        vcf_var: VariantRecord,
        allele_annotations: list[dict],
        attributes: Sequence[AttributeInfo],
        attributes_to_delete: Sequence[str],
    ) -> None:
        buffers: list[list] = [[] for _ in attributes]

        for col in attributes_to_delete:
            del vcf_var.info[col]

        for annotation in allele_annotations:
            for buff, attribute in zip(buffers, attributes, strict=True):
                value = annotation.get(attribute.name)
                if vcf_var.header.info[attribute.name].type == "String":
                    value = _VCFWriter._convert_to_string(value)
                buff.append(value)
        # If the all values for a given attribute are
        # empty (i.e. - "."), then that attribute has no
        # values to be written and will be skipped in the output
        has_value = {
            attr.name: len(list(filter(lambda x: x != ".", buffers[idx])))
            for idx, attr in enumerate(attributes)
        }
        for buff, attribute in zip(buffers, attributes, strict=True):
            if not has_value[attribute.name]:
                continue
            vcf_var.info[attribute.name] = buff

    def __enter__(self) -> _VCFWriter:
        self.output_file = VariantFile(self.path, "w", header=self.header)
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during writing vcf: %s, %s, %s",
                exc_type, exc_value, exc_tb)

        self.output_file.close()

        return exc_type is None

    def filter(self, data: AnnotationsWithSource) -> None:
        data.source.translate(self.header)
        _VCFWriter._update_variant(
            data.source,
            [annotation.context for annotation in data.annotations],
            self.annotation_attributes,
            self.attributes_to_delete,
        )
        self.output_file.write(data.source)


class _VCFBatchWriter(Filter):
    """A filter that writes batches of variants to a VCF file."""

    def __init__(
        self,
        path: str,
        header: VariantHeader,
        annotation_attributes: Sequence[AttributeInfo],
        attributes_to_delete: Sequence[str],
    ):
        self.writer = _VCFWriter(
            path, header, annotation_attributes, attributes_to_delete)

    def __enter__(self) -> _VCFBatchWriter:
        self.writer.__enter__()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during writing vcf: %s, %s, %s",
                exc_type, exc_value, exc_tb)

        self.writer.__exit__(exc_type, exc_value, exc_tb)

        return exc_type is None

    def filter(self, data: Sequence[AnnotationsWithSource]) -> None:
        for variant in data:
            self.writer.filter(variant)


def _annotate_vcf(
    output_path: str,
    pipeline_config: RawAnnotatorsConfig,
    grr_definition: dict,
    region: Region | None,
    args: _ProcessingArgs,
) -> None:
    """Annotate a VCF file using a processing pipeline."""

    grr = build_genomic_resource_repository(definition=grr_definition)

    pipeline_previous = None
    if args.reannotate:
        pipeline_previous = load_pipeline_from_file(args.reannotate, grr)

    pipeline = build_annotation_pipeline(
        pipeline_config, grr,
        allow_repeated_attributes=args.allow_repeated_attributes,
        work_dir=Path(args.work_dir),
    )

    attributes_to_delete = []

    if pipeline_previous:
        pipeline = ReannotationPipeline(
            pipeline, pipeline_previous,
            full_reannotation=args.full_reannotation)
        attributes_to_delete = pipeline.deleted_attributes

    annotation_attributes = [
        attr for attr in pipeline.get_attributes()
        if not attr.internal
    ]

    source: Source
    filters: list[Filter] = []

    if args.batch_size <= 0:
        source = _VCFSource(args.input)
        header = source.header.copy()
        filters.extend([
            AnnotationPipelineAnnotatablesFilter(pipeline),
            _VCFWriter(output_path,
                       header,
                       annotation_attributes,
                       attributes_to_delete),
        ])
    else:
        source = _VCFBatchSource(args.input)
        header = source.source.header.copy()
        filters.extend([
            AnnotationPipelineAnnotatablesBatchFilter(pipeline),
            _VCFBatchWriter(output_path,
                            header,
                            annotation_attributes,
                            attributes_to_delete),
        ])

    with PipelineProcessor(source, filters) as processor:
        processor.process_region(region)


def _concat(partfile_paths: list[str], output_path: str) -> None:
    """Concatenate multiple VCF files into a single VCF file *in order*."""
    # Get any header from the partfiles, they should all be equal
    # and usable as a final output header
    header_donor = VariantFile(partfile_paths[0], "r")
    output_file = VariantFile(
        output_path, "w",
        header=header_donor.header.copy(),
    )
    for path in partfile_paths:
        partfile = VariantFile(path, "r")
        for variant in partfile.fetch():
            output_file.write(variant)
        partfile.close()
    output_file.close()
    header_donor.close()

    for partfile_path in partfile_paths:
        os.remove(partfile_path)


def _build_argument_parser() -> argparse.ArgumentParser:
    """Construct and configure argument parser."""
    parser = argparse.ArgumentParser(
        description="Annotate VCF",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "input", default="-", nargs="?",
        help="the input vcf file")
    parser.add_argument(
        "-r", "--region-size", default=300_000_000,
        type=int, help="region size to parallelize by")
    parser.add_argument(
        "-w", "--work-dir",
        help="Directory to store intermediate output files",
        default="annotate_vcf_output")
    parser.add_argument(
        "-o", "--output",
        help="Filename of the output VCF result",
        default=None)
    parser.add_argument(
        "--reannotate", default=None,
        help="Old pipeline config to reannotate over")
    parser.add_argument(
        "-i", "--full-reannotation",
        help="Ignore any previous annotation and run "
        " a full reannotation.",
        action="store_true",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=0,  # 0 = annotate iteratively, no batches
        help="Annotate in batches of",
    )

    CLIAnnotationContextProvider.add_argparser_arguments(parser)
    TaskGraphCli.add_arguments(parser)
    VerbosityConfiguration.set_arguments(parser)
    return parser


def _make_vcf_tabix(filepath: str) -> None:
    tabix_index(filepath, preset="vcf")


def _add_tasks_plaintext(
    args: _ProcessingArgs,
    task_graph: TaskGraph,
    output_path: str,
    pipeline_config: RawPipelineConfig,
    grr_definition: dict[str, Any],
) -> None:
    task_graph.create_task(
        "all_variants_annotate",
        _annotate_vcf,
        args=[
            output_path,
            pipeline_config,
            grr_definition,
            None,
            args,
        ],
        deps=[],
    )


def _add_tasks_tabixed(
    args: _ProcessingArgs,
    task_graph: TaskGraph,
    output_path: str,
    pipeline_config: RawPipelineConfig,
    grr_definition: dict[str, Any],
) -> None:
    with closing(TabixFile(args.input)) as pysam_file:
        regions = produce_regions(pysam_file, args.region_size)
    file_paths = produce_partfile_paths(
        args.input, regions, args.work_dir)

    annotation_tasks = []
    for (region, file_path) in zip(regions, file_paths, strict=True):
        annotation_tasks.append(task_graph.create_task(
            f"part-{str(region).replace(':', '-')}",
            _annotate_vcf,
            args=[
                file_path,
                pipeline_config,
                grr_definition,
                region,
                args,
            ],
            deps=[],
        ))

    annotation_sync = task_graph.create_task(
        "sync_vcf_annotate", lambda: None,
        args=[], deps=annotation_tasks,
    )

    concat_task = task_graph.create_task(
        "concat",
        _concat,
        args=[file_paths, output_path],
        deps=[annotation_sync],
    )
    task_graph.create_task(
        "compress_and_tabix",
        _make_vcf_tabix,
        args=[output_path],
        deps=[concat_task])


[docs] def cli(raw_args: list[str] | None = None) -> None: """Entry point for running the VCF annotation tool.""" if not raw_args: raw_args = sys.argv[1:] arg_parser = _build_argument_parser() args = vars(arg_parser.parse_args(raw_args)) if not os.path.exists(args["input"]): raise ValueError(f"{args['input']} does not exist!") if not os.path.exists(args["work_dir"]): os.mkdir(args["work_dir"]) args["task_status_dir"] = os.path.join(args["work_dir"], ".task-status") args["task_log_dir"] = os.path.join(args["work_dir"], ".task-log") pipeline, _, grr = get_stuff_from_context(args) assert grr.definition is not None cache_pipeline_resources(grr, pipeline) processing_args = _ProcessingArgs( args["input"], args["reannotate"], args["work_dir"], args["batch_size"], args["region_size"], args["allow_repeated_attributes"], args["full_reannotation"], ) output_path = build_output_path(args["input"], args["output"]) task_graph = TaskGraph() if tabix_index_filename(args["input"]): _add_tasks_tabixed( processing_args, task_graph, output_path, pipeline.raw, grr.definition, ) else: _add_tasks_plaintext( processing_args, task_graph, output_path, pipeline.raw, grr.definition, ) add_input_files_to_task_graph(args, task_graph) TaskGraphCli.process_graph(task_graph, **args)