from __future__ import annotations
import logging
from collections.abc import Iterable
from copy import copy
from dataclasses import dataclass
from typing import Any, Callable
logger = logging.getLogger(__name__)
[docs]
@dataclass(eq=False, frozen=True)
class Task:
"""Represent one node in a TaskGraph together with its dependencies."""
task_id: str
func: Callable
args: list[Any]
deps: list[Task]
input_files: list[str]
def __repr__(self) -> str:
deps = ",".join(dep.task_id for dep in self.deps)
in_files = ",".join(self.input_files)
return f"Task(id={self.task_id}, deps={deps}, in_files={in_files})"
[docs]
class TaskGraph:
"""An object representing a graph of tasks."""
def __init__(self) -> None:
self.tasks: list[Task] = []
self.input_files: list[str] = []
self._task_ids: set[str] = set()
[docs]
def create_task(self, task_id: str, func: Callable[..., Any], args: list,
deps: list[Task],
input_files: list[str] | None = None) -> Task:
"""Create a new task and add it to the graph.
:param name: Name of the task (used for debugging purposes)
:param func: Function to execute
:param args: Arguments to that function
:param deps: List of TaskNodes on which the current task depends
:param input_files: Files that were used to build the graph itself
:return TaskNode: The newly created task node in the graph
"""
if len(task_id) > 200:
logger.warning("task id is too long %s: %s", len(task_id), task_id)
logger.warning("truncating task id to 200 symbols")
task_id = task_id[:200]
if task_id in self._task_ids:
raise ValueError(f"Task with id='{task_id}' already in graph")
# tasks that use the output of other tasks as input should
# have those other tasks as dependancies
deps = copy(deps)
for arg in args:
if isinstance(arg, Task):
deps.append(arg)
node = Task(task_id, func, args, deps, input_files or [])
self.tasks.append(node)
self._task_ids.add(task_id)
return node
[docs]
def prune(self, ids_to_keep: Iterable[str]) -> TaskGraph:
"""Prune tasks which are not in ids_to_keep or in their deps.
tasks ids which are in ids_to_keep but not in the graph are simply
assumed to have already been removed and no error is raised.
"""
ids_to_keep = set(ids_to_keep)
ids_not_found = ids_to_keep - self._task_ids
if ids_not_found:
raise KeyError(ids_not_found)
tasks_to_keep: set[str] = set()
for task in self.tasks:
if task.task_id in ids_to_keep:
tasks_to_keep.add(task.task_id)
self._add_task_deps(task, tasks_to_keep)
new_tasks = [t for t in self.tasks if t.task_id in tasks_to_keep]
res = TaskGraph()
res.tasks = new_tasks
res.input_files = self.input_files
res._task_ids |= tasks_to_keep # noqa: SLF001
return res
@staticmethod
def _add_task_deps(task: Task, task_set: set[str]) -> None:
for dep in task.deps:
if dep.task_id not in task_set:
task_set.add(dep.task_id)
TaskGraph._add_task_deps(dep, task_set)