Source code for gain.annotation.annotate_vcf

from __future__ import annotations

import argparse
import gc
import itertools
import logging
import os
import sys
import traceback
from collections.abc import Iterable, Sequence
from contextlib import chdir, closing
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
from typing import Any, cast

from pysam import (
    TabixFile,
    VariantFile,
    VariantHeader,
    VariantRecord,
    tabix_compress,
    tabix_index,
)

from gain import __version__
from gain.annotation.annotatable import VCFAllele
from gain.annotation.annotate_utils import (
    add_common_annotation_arguments,
    add_input_files_to_task_graph,
    build_cli_genomic_context,
    cache_pipeline_resources,
    check_resource_locality,
    emit_annotation_plan,
    get_grr_from_context,
    get_pipeline_from_context,
    handle_default_args,
    maybe_remove_work_dir,
    maybe_wrap_reannotation,
    produce_partfile_paths,
    produce_regions,
    stringify,
)
from gain.annotation.annotation_config import (
    Attribute,
    RawAnnotatorsConfig,
    RawPipelineConfig,
)
from gain.annotation.annotation_factory import (
    build_annotation_pipeline,
)
from gain.annotation.annotation_pipeline import (
    AnnotationPipeline,
    ReannotationPipeline,
)
from gain.annotation.processing_pipeline import (
    Annotation,
    AnnotationPipelineAnnotatablesBatchFilter,
    AnnotationPipelineAnnotatablesFilter,
    AnnotationsWithSource,
)
from gain.genomic_resources.repository_factory import (
    build_genomic_resource_repository,
)
from gain.task_graph.cli_tools import TaskGraphCli
from gain.task_graph.graph import TaskGraph
from gain.utils.fs_utils import (
    is_compressed_filename,
    strip_compression_suffix,
    tabix_index_filename,
)
from gain.utils.processing_pipeline import Filter, PipelineProcessor, Source
from gain.utils.regions import Region
from gain.utils.verbosity_configuration import VerbosityConfiguration

logger = logging.getLogger("annotate_vcf")


@dataclass
class _InfoField:
    name: str
    number: str | None
    type: str | None
    description: str | None


class _VCFSource(Source):
    """Source for reading from VCF files."""

    def __init__(self, path: str):
        self.path = path
        self.vcf: VariantFile
        with VariantFile(self.path, "r") as infile:
            self.header = infile.header
        self.info = {
            k: _InfoField(
                name=v.name,
                number=v.number,
                type=v.type,
                description=v.description,
            )
            for k, v in self.header.info.items()
        }

    def __enter__(self) -> _VCFSource:
        self.vcf = VariantFile(self.path, "r")
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during annotation: %s, %s, %s",
                exc_type, exc_value, traceback.format_tb(exc_tb))

        self.vcf.close()

        return exc_type is None

    def _convert_info(self, alt_idx: int, variant: VariantRecord) -> dict:
        result = {}
        for k in variant.info:
            if self.info[k].number == "A":
                result[k] = variant.info[k][alt_idx]
            elif self.info[k].number == "1":
                result[k] = variant.info[k]
            elif self.info[k].number == ".":
                result[k] = None
            else:
                result[k] = variant.info[k]
        return result

    def _convert(
        self, variant: VariantRecord,
    ) -> AnnotationsWithSource:
        annotations = [
            Annotation(
                VCFAllele(variant.chrom,
                          variant.pos,
                          variant.ref,  # type: ignore
                          alt),
                self._convert_info(idx, variant),
            )
            for idx, alt in enumerate(variant.alts)  # type: ignore
        ]
        return AnnotationsWithSource(variant, annotations)

    def fetch(
        self, region: Region | None = None,
    ) -> Iterable[AnnotationsWithSource]:
        if region is None:
            in_file_iter = self.vcf.fetch()
            reg_start = 1
        else:
            assert region.start is not None
            in_file_iter = self.vcf.fetch(region.chrom,
                                          region.start - 1,
                                          region.stop)
            reg_start = region.start

        for vcf_var in in_file_iter:
            if vcf_var.pos < reg_start:
                continue
            if vcf_var.ref is None:
                logger.warning(
                    "vcf variant without reference: %s %s",
                    vcf_var.chrom, vcf_var.pos,
                )
                continue

            if vcf_var.alts is None:
                logger.info(
                    "vcf variant without alternatives: %s %s",
                    vcf_var.chrom, vcf_var.pos,
                )
                continue

            yield self._convert(vcf_var)


class _VCFBatchSource(Source):
    """Source for reading from VCF files in batches."""

    def __init__(
        self,
        path: str,
        batch_size: int = 500,
    ):
        self.source = _VCFSource(path)
        self.batch_size = batch_size

    def __enter__(self) -> _VCFBatchSource:
        self.source.__enter__()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during annotation: %s, %s, %s",
                exc_type, exc_value, traceback.format_tb(exc_tb))

        self.source.__exit__(exc_type, exc_value, exc_tb)

        return exc_type is None

    def fetch(
        self, region: Region | None = None,
    ) -> Iterable[Sequence[AnnotationsWithSource]]:
        records = self.source.fetch(region)
        while batch := tuple(itertools.islice(records, self.batch_size)):
            yield batch


class _VCFWriter(Filter):
    """A filter that writes variants to a VCF file."""

    def __init__(
        self,
        path: str,
        header: VariantHeader,
        annotation_attributes: Sequence[Attribute],
        attributes_to_delete: Sequence[str],
    ):
        self.path = path
        self.output_file: VariantFile
        self.header = self._update_header(
            header, annotation_attributes, attributes_to_delete)
        self.annotation_attributes = annotation_attributes
        self.attributes_to_delete = attributes_to_delete

    @staticmethod
    def _update_header(
        header: VariantHeader,
        annotation_attributes: Sequence[Attribute],
        attributes_to_delete: Sequence[str],
    ) -> VariantHeader:
        """Update a variant file's header with annotation."""
        header.add_meta("pipeline_annotation_tool", "GPF variant annotation.")

        annotation_attr_names = [attr.name for attr in annotation_attributes]

        for info_key in header.info:
            if info_key in attributes_to_delete \
               and info_key not in annotation_attr_names:
                header.info.remove_header(info_key)

        attributes = [
            attr for attr in annotation_attributes
            if attr.name not in header.info
        ]

        for attribute in attributes:
            description = attribute.spec.description \
                if attribute.spec else ""
            description = description.replace("\n", " ")
            description = description.replace('"', '\\"')
            header.info.add(attribute.name, "A", "String", description)

        return header

    @staticmethod
    def _convert_to_string(attr: Any) -> str:
        if isinstance(attr, list):
            attr = ";".join(stringify(a, vcf=True) for a in attr)
        elif isinstance(attr, dict):
            attr = ";".join(
                f"{k}:{v}"
                for k, v in attr.items()
            )
        return stringify(attr, vcf=True) \
            .replace(";", "|") \
            .replace(",", "|") \
            .replace(" ", "_")

    @staticmethod
    def _update_variant(
        vcf_var: VariantRecord,
        allele_annotations: list[dict],
        attributes: Sequence[Attribute],
        attributes_to_delete: Sequence[str],
    ) -> None:
        buffers: list[list] = [[] for _ in attributes]

        for col in attributes_to_delete:
            if col in vcf_var.info:
                del vcf_var.info[col]

        for annotation in allele_annotations:
            for buff, attribute in zip(buffers, attributes, strict=True):
                value = annotation.get(attribute.name)
                if vcf_var.header.info[attribute.name].type == "String":
                    value = _VCFWriter._convert_to_string(value)
                buff.append(value)
        # If the all values for a given attribute are
        # empty (i.e. - "."), then that attribute has no
        # values to be written and will be skipped in the output
        has_value = {
            attr.name: len(list(filter(lambda x: x != ".", buffers[idx])))
            for idx, attr in enumerate(attributes)
        }
        for buff, attribute in zip(buffers, attributes, strict=True):
            if not has_value[attribute.name]:
                continue
            vcf_var.info[attribute.name] = buff

    def __enter__(self) -> _VCFWriter:
        self.output_file = VariantFile(self.path, "w", header=self.header)
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during writing vcf: %s, %s, %s",
                exc_type, exc_value, exc_tb)

        self.output_file.close()

        return exc_type is None

    def filter(self, data: AnnotationsWithSource) -> None:
        data.source.translate(self.header)
        _VCFWriter._update_variant(
            data.source,
            [annotation.context for annotation in data.annotations],
            self.annotation_attributes,
            self.attributes_to_delete,
        )
        self.output_file.write(data.source)


class _VCFBatchWriter(Filter):
    """A filter that writes batches of variants to a VCF file."""

    def __init__(
        self,
        path: str,
        header: VariantHeader,
        annotation_attributes: Sequence[Attribute],
        attributes_to_delete: Sequence[str],
    ):
        self.writer = _VCFWriter(
            path, header, annotation_attributes, attributes_to_delete)

    def __enter__(self) -> _VCFBatchWriter:
        self.writer.__enter__()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool:
        if exc_type is not None:
            logger.error(
                "exception during writing vcf: %s, %s, %s",
                exc_type, exc_value, traceback.format_tb(exc_tb))

        self.writer.__exit__(exc_type, exc_value, exc_tb)

        return exc_type is None

    def filter(self, data: Sequence[AnnotationsWithSource]) -> None:
        for variant in data:
            self.writer.filter(variant)


def _annotate_vcf(
    output_path: str,
    pipeline_config: RawAnnotatorsConfig,
    grr_definition: dict[str, Any] | None,
    region: Region | None,
    args: dict[str, Any],
) -> None:
    """Annotate a VCF file using a processing pipeline."""
    build_cli_genomic_context(args)
    grr = build_genomic_resource_repository(definition=grr_definition)

    pipeline = build_annotation_pipeline(
        pipeline_config, grr,
        allow_repeated_attributes=args["allow_repeated_attributes"],
        work_dir=Path(args["work_dir"]),
    )

    pipeline = maybe_wrap_reannotation(pipeline, args, grr)
    attributes_to_delete = (
        pipeline.deleted_attributes
        if isinstance(pipeline, ReannotationPipeline) else [])

    _annotate_vcf_helper(
        args["input"],
        pipeline,
        output_path,
        args,
        region=region,
        attributes_to_delete=attributes_to_delete,
    )


def _concat(
    partfile_paths: list[str],
    output_path: str,
    keep_parts: bool,  # noqa: FBT001
) -> None:
    """Concatenate multiple VCF files into a single VCF file *in order*."""
    # Get any header from the partfiles, they should all be equal
    # and usable as a final output header
    header_donor = VariantFile(partfile_paths[0], "r")
    output_file = VariantFile(
        output_path, "w",
        header=header_donor.header.copy(),
    )
    for path in partfile_paths:
        partfile = VariantFile(path, "r")
        for variant in partfile.fetch():
            output_file.write(variant)
        partfile.close()
    output_file.close()
    header_donor.close()

    if not keep_parts:
        for partfile_path in partfile_paths:
            os.remove(partfile_path)


def _build_argument_parser() -> argparse.ArgumentParser:
    """Construct and configure argument parser."""
    parser = argparse.ArgumentParser(
        description="Annotate VCF",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-n", "--dry-run",
        help="Print the annotation/reannotation plan and exit without "
             "writing any output.",
        action="store_true",
        default=False,
    )
    add_common_annotation_arguments(parser)
    return parser


def _count_vcf_records(input_path: str, limit: int) -> int:
    """Count VCF records, short-circuiting at limit."""
    count = 0
    with VariantFile(input_path, "r") as vcf:
        for _ in vcf.fetch():
            count += 1
            if count >= limit:
                break
    return count


def _tabix_index(filepath: str) -> None:
    tabix_index(filepath, preset="vcf", force=True)


def _tabix_compress(filepath: str, output_path: str | None = None) -> None:
    if output_path is None:
        output_path = f"{filepath}.gz"
    tabix_compress(filepath, output_path, force=True)
    if os.path.exists(filepath):
        os.remove(filepath)


def _add_tasks_plaintext(
    args: dict[str, Any],
    task_graph: TaskGraph,
    output_path: str,
    pipeline_config: RawPipelineConfig,
    grr_definition: dict[str, Any],
) -> None:
    if is_compressed_filename(output_path):
        working_path = strip_compression_suffix(output_path)
        annotate_task = task_graph.create_task(
            "all_variants_annotate",
            _annotate_vcf,
            args=[
                working_path,
                pipeline_config,
                grr_definition,
                None,
                args,
            ],
            deps=[],
            intermediate_output_files=[working_path],
        )
        task_graph.create_task(
            "tabix_compress",
            _tabix_compress,
            args=[working_path, output_path],
            input_files=[working_path],
            output_files=[output_path],
            deps=[annotate_task],
        )
    else:
        task_graph.create_task(
            "all_variants_annotate",
            _annotate_vcf,
            args=[
                output_path,
                pipeline_config,
                grr_definition,
                None,
                args,
            ],
            deps=[],
            output_files=[output_path],
        )


def _add_tasks_tabixed(
    args: dict[str, Any],
    task_graph: TaskGraph,
    output_path: str,
    pipeline_config: RawPipelineConfig,
    grr_definition: dict[str, Any],
) -> None:
    # output_path carries the final compression suffix (.gz/.bgz); annotate
    # into the uncompressed working file, then compress to the final name.
    # Without a suffix, working_path would equal output_path and the compress
    # task would tabix_compress(out, out, force=True), truncating it in place.
    assert is_compressed_filename(output_path), (
        f"_add_tasks_tabixed: output_path must carry a compression suffix, "
        f"got {output_path!r}")
    working_path = strip_compression_suffix(output_path)
    with closing(TabixFile(args["input"])) as pysam_file:
        regions = produce_regions(pysam_file, args["region_size"])
    file_paths = produce_partfile_paths(
        args["input"], regions, args["work_dir"])

    annotation_tasks = []
    for (region, file_path) in zip(regions, file_paths, strict=True):
        annotation_tasks.append(task_graph.create_task(
            f"part-{str(region).replace(':', '-')}",
            _annotate_vcf,
            args=[
                file_path,
                pipeline_config,
                grr_definition,
                region,
                args,
            ],
            deps=[],
            output_files=[file_path],
        ))

    concat_task = task_graph.create_task(
        "concat",
        _concat,
        args=[file_paths, working_path, args["keep_parts"]],
        input_files=file_paths,
        intermediate_output_files=[working_path],
        deps=annotation_tasks,
    )
    compress_task = task_graph.create_task(
        "tabix_compress",
        _tabix_compress,
        args=[working_path, output_path],
        input_files=[working_path],
        output_files=[output_path],
        deps=[concat_task])

    task_graph.create_task(
        "tabix_index",
        _tabix_index,
        args=[output_path],
        input_files=[output_path],
        output_files=[f"{output_path}.tbi"],
        deps=[compress_task])


[docs] def cli(argv: list[str] | None = None) -> None: """Entry point for running the VCF annotation tool.""" if not argv: argv = sys.argv[1:] arg_parser = _build_argument_parser() args = vars(arg_parser.parse_args(argv)) if args.get("version"): print(f"GAIn version: {__version__}") sys.exit(0) VerbosityConfiguration.set(args) args = handle_default_args(args) # Run inside work_dir so that intermediate files created by worker # processes (e.g. htslib downloading a remote tabix .tbi index over # http) land in work_dir instead of the launch directory. Workers # spawned by process_graph inherit this working directory. with chdir(args["work_dir"]): context = build_cli_genomic_context(args) pipeline = get_pipeline_from_context(context) grr = get_grr_from_context(context) assert grr.definition is not None check_resource_locality( pipeline, lambda limit: _count_vcf_records(args["input"], limit), allow_remote=args["allow_remote_resources"], ) cache_pipeline_resources(grr, pipeline) if args.get("reannotate") or args.get("dry_run"): emit_annotation_plan(args, pipeline, grr) if args.get("dry_run"): pipeline.close() maybe_remove_work_dir(args, result=True) return output_path = args["output"] region_size = args["region_size"] task_graph = TaskGraph() if tabix_index_filename(args["input"]) and region_size > 0: _add_tasks_tabixed( args, task_graph, output_path, pipeline.raw, grr.definition, ) else: logger.info( "input %s cannot be split into genomic regions; " "forcing sequential execution (-j 1)", args["input"]) args["jobs"] = 1 _add_tasks_plaintext( args, task_graph, output_path, pipeline.raw, grr.definition, ) add_input_files_to_task_graph(args, task_graph) result = TaskGraphCli.process_graph(task_graph, **args) pipeline.close() gc.collect() maybe_remove_work_dir(args, result=result)
def _annotate_vcf_helper( input_path: str, pipeline: AnnotationPipeline, output_path: str, args: dict[str, Any], *, region: Region | None = None, attributes_to_delete: Sequence[str] | None = None, ) -> None: """Annotate a columns file using a processing pipeline.""" annotation_attributes = [ attr for attr in pipeline.get_attributes() if not attr.internal ] attributes_to_delete = attributes_to_delete or [] batch_size = cast(int, args.get("batch_size", 0)) source: Source filters: list[Filter] = [] if batch_size <= 0: source = _VCFSource(input_path) header = source.header.copy() filters.extend([ AnnotationPipelineAnnotatablesFilter(pipeline), _VCFWriter(output_path, header, annotation_attributes, attributes_to_delete), ]) else: source = _VCFBatchSource( input_path, batch_size=batch_size) header = source.source.header.copy() filters.extend([ AnnotationPipelineAnnotatablesBatchFilter(pipeline), _VCFBatchWriter(output_path, header, annotation_attributes, attributes_to_delete), ]) with PipelineProcessor(source, filters) as processor: processor.process_region(region)
[docs] def annotate_vcf( input_path: str, pipeline: AnnotationPipeline, output_path: str, args: dict[str, Any], *, region: Region | None = None, attributes_to_delete: Sequence[str] | None = None, ) -> None: """Annotate a columns file using a processing pipeline.""" temp_output_path = output_path if is_compressed_filename(output_path): temp_output_path = strip_compression_suffix(output_path) _annotate_vcf_helper( input_path, pipeline, temp_output_path, args, region=region, attributes_to_delete=attributes_to_delete, ) if is_compressed_filename(output_path): # honor the explicit compression suffix (.gz/.bgz) _tabix_compress(temp_output_path, output_path) elif is_compressed_filename(input_path): # uncompressed output name + compressed input: default to .gz _tabix_compress(temp_output_path)