Source code for dae.task_graph.executor

from __future__ import annotations

import logging
import sys
import time
import traceback
from abc import abstractmethod
from collections import deque
from collections.abc import Generator, Iterator
from copy import copy
from types import TracebackType
from typing import Any, Callable, cast

from dask.distributed import Client, Future

from dae.task_graph.cache import CacheRecordType, NoTaskCache, TaskCache
from dae.task_graph.graph import Task, TaskGraph
from dae.task_graph.logging import (
    configure_task_logging,
    ensure_log_dir,
    safe_task_id,
)
from dae.utils.verbosity_configuration import VerbosityConfiguration

logger = logging.getLogger(__name__)


[docs] class TaskGraphExecutor: """Class that executes a task graph."""
[docs] @abstractmethod def execute(self, task_graph: TaskGraph) -> Iterator[tuple[Task, Any]]: """Start executing the graph. Return an iterator that yields the task in the graph after they are executed. This is not nessessarily in DFS or BFS order. This is not even the order in which these tasks are executed. The only garantee is that when a task is returned its executions is already finished. """
def __enter__(self) -> TaskGraphExecutor: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.close()
[docs] @abstractmethod def close(self) -> None: """Clean-up any resources used by the executor."""
NO_TASK_CACHE = NoTaskCache()
[docs] class AbstractTaskGraphExecutor(TaskGraphExecutor): """Executor that walks the graph in order that satisfies dependancies.""" def __init__(self, task_cache: TaskCache = NO_TASK_CACHE): super().__init__() self._task_cache = task_cache self._executing = False @staticmethod def _exec( task_func: Callable, args: list, _deps: list, params: dict[str, Any], ) -> Any: verbose = params.get("verbose") if verbose is None: # Dont use .get default in case of a Box verbose = 0 VerbosityConfiguration.set_verbosity(verbose) task_id = params["task_id"] log_dir = params.get("log_dir", ".") root_logger = logging.getLogger() handler = configure_task_logging(log_dir, task_id, verbose) root_logger.addHandler(handler) task_logger = logging.getLogger("task_executor") task_logger.info("task <%s> started", task_id) start = time.time() result = task_func(*args) elapsed = time.time() - start task_logger.info("task <%s> finished in %0.2fsec", task_id, elapsed) root_logger.removeHandler(handler) handler.close() return result
[docs] def execute(self, task_graph: TaskGraph) -> Iterator[tuple[Task, Any]]: assert not self._executing, \ "Cannot execute a new graph while an old one is still running." self._check_for_cyclic_deps(task_graph) self._executing = True already_computed_tasks = {} for task_node, record in self._task_cache.load(task_graph): if record.type == CacheRecordType.COMPUTED: already_computed_tasks[task_node] = record.result for task_node in self._in_exec_order(task_graph): if task_node in already_computed_tasks: task_result = already_computed_tasks[task_node] self._set_task_result(task_node, task_result) else: self._queue_task(task_node) return self._yield_task_results(already_computed_tasks)
def _yield_task_results( self, already_computed_tasks: dict[Task, Any], ) -> Iterator[tuple[Task, Any]]: for task_node, result in already_computed_tasks.items(): yield task_node, result for task_node, result in self._await_tasks(): is_error = isinstance(result, BaseException) self._task_cache.cache( task_node, is_error=is_error, result=result, ) yield task_node, result self._executing = False @abstractmethod def _queue_task(self, task_node: Task) -> None: """Put the task on the execution queue.""" @abstractmethod def _await_tasks(self) -> Iterator[tuple[Task, Any]]: """Yield enqueued tasks as soon as they finish.""" @abstractmethod def _set_task_result(self, task: Task, result: Any) -> None: """Set a precomputed result for a task.""" @staticmethod def _in_exec_order( task_graph: TaskGraph, ) -> Generator[Task, None, None]: visited: set[Task] = set() for node in task_graph.tasks: yield from AbstractTaskGraphExecutor._node_in_exec_order( node, visited) @staticmethod def _node_in_exec_order( node: Task, visited: set[Task], ) -> Generator[Task, None, None]: if node in visited: return visited.add(node) for dep in node.deps: yield from AbstractTaskGraphExecutor._node_in_exec_order( dep, visited) yield node @staticmethod def _check_for_cyclic_deps(task_graph: TaskGraph) -> None: visited: set[Task] = set() stack: list[Task] = [] for node in task_graph.tasks: if node not in visited: cycle = AbstractTaskGraphExecutor._find_cycle( node, visited, stack) if cycle is not None: raise ValueError(f"Cyclic dependancy {cycle}") @staticmethod def _find_cycle( node: Task, visited: set[Task], stack: list[Task], ) -> list[Task] | None: visited.add(node) stack.append(node) for dep in node.deps: if dep not in visited: return AbstractTaskGraphExecutor._find_cycle( dep, visited, stack) if dep in stack: return copy(stack) stack.pop() return None
[docs] def close(self) -> None: pass
[docs] class SequentialExecutor(AbstractTaskGraphExecutor): """A Task Graph Executor that executes task in sequential order.""" def __init__(self, task_cache: TaskCache = NO_TASK_CACHE, **kwargs: Any): super().__init__(task_cache) self._task_queue: list[Task] = [] self._task2result: dict[Task, Any] = {} log_dir = ensure_log_dir(**kwargs) self._params = copy(kwargs) self._params["log_dir"] = log_dir def _queue_task(self, task_node: Task) -> None: self._task_queue.append(task_node) def _await_tasks(self) -> Generator[tuple[Task, Any], None, None]: finished_tasks = 0 initial_task_count = len(self._task_queue) for task_node in self._task_queue: all_deps_satisfied = all( d in self._task2result for d in task_node.deps ) if not all_deps_satisfied: # some of the dependancies were errors and didn't run logger.info( "Skipping execution of task(id=%s) because one or more of " "its dependancies failed with an error", task_node.task_id) continue # handle tasks that use the output of other tasks args = [ self._task2result[arg] if isinstance(arg, Task) else arg for arg in task_node.args ] is_error = False params = copy(self._params) params["task_id"] = safe_task_id(task_node.task_id) try: result = self._exec(task_node.func, args, [], params) except Exception as exp: # noqa: BLE001 pylint: disable=broad-except result = exp is_error = True finished_tasks += 1 logger.debug("clean up task %s", task_node) logger.info( "finished %s/%s", finished_tasks, initial_task_count) if not is_error: self._task2result[task_node] = result yield task_node, result # all tasks have already executed. Let's clean the state. self._task_queue = [] self._task2result = {} def _set_task_result(self, task: Task, result: Any) -> None: self._task2result[task] = result
[docs] class DaskExecutor(AbstractTaskGraphExecutor): """Execute tasks in parallel using Dask to do the heavy lifting.""" def __init__( self, client: Client, task_cache: TaskCache = NO_TASK_CACHE, **kwargs: Any, ): super().__init__(task_cache) self._client = client self._task2future: dict[Task, Future] = {} self._future_key2task: dict[str, Task] = {} self._task2result: dict[Task, Any] = {} self._task_queue: deque[Task] = deque() log_dir = ensure_log_dir(**kwargs) self._params = copy(kwargs) self._params["log_dir"] = log_dir def _queue_task(self, task_node: Task) -> None: self._task_queue.append(task_node) def _submit_task(self, task_node: Task) -> Future: deps = [] for dep in task_node.deps: future = self._task2future.get(dep) if future: deps.append(future) else: assert dep in self._task2result # handle tasks that use the output of other tasks args = [] for arg in task_node.args: if isinstance(arg, Task): value = self._get_future_or_result(arg) else: value = arg args.append(value) params = copy(self._params) params["task_id"] = safe_task_id(task_node.task_id) future = self._client.submit( self._exec, task_node.func, args, deps, params, pure=False) if future is None: raise ValueError( f"unexpected dask executor return None: {task_node}, {args}, " f"{deps}, {params}") self._task2future[task_node] = future self._future_key2task[future.key] = task_node return future def _get_future_or_result(self, task: Task) -> Any: future = self._task2future.get(task) return future or self._task2result[task] MIN_QUEUE_SIZE = 700 def _queue_size(self) -> int: n_workers = cast(int, sum(self._client.ncores().values())) return max(self.MIN_QUEUE_SIZE, 2 * n_workers) def _schedule_tasks(self, currently_running: set[Future]) -> set[Future]: while self._task_queue and len(currently_running) < self._queue_size(): future = self._submit_task(self._task_queue.popleft()) currently_running.add(future) return currently_running def _await_tasks(self) -> Generator[tuple[Task, Any], None, None]: # pylint: disable=import-outside-toplevel from dask.distributed import wait not_completed: set = set() completed = set() initial_task_count = len(self._task_queue) finished_tasks = 0 not_completed = self._schedule_tasks(not_completed) while not_completed: completed, not_completed = \ wait(not_completed, return_when="FIRST_COMPLETED") for future in completed: try: result = future.result() except Exception as exp: # noqa: BLE001 pylint: disable=broad-except result = exp task = self._future_key2task[future.key] self._task2result[task] = result yield task, result finished_tasks += 1 # del ref to future in order to make dask gc its resources logger.debug("clean up task %s", task) logger.info( "finished %s/%s", finished_tasks, initial_task_count) del self._task2future[task] not_completed = self._schedule_tasks(not_completed) # clean up if len(self._task2future) > 0: logger.error("[BUG] Dask Executor's future q is not empty.") if len(self._task_queue) > 0: logger.error("[BUG] Dask Executor's task q is not empty.") self._task2future = {} self._future_key2task = {} self._task_queue = deque() self._task2result = {} def _set_task_result(self, task: Task, result: Any) -> None: self._task2result[task] = result
[docs] def close(self) -> None: cluster = self._client.cluster self._client.shutdown() if cluster is not None: cluster.close()
[docs] def task_graph_status( task_graph: TaskGraph, task_cache: TaskCache, verbose: int | None) -> bool: """Show the status of each task from the task graph.""" id_col_len = max(len(t.task_id) for t in task_graph.tasks) id_col_len = min(120, max(50, id_col_len)) columns = ["TaskID", "Status"] print(f"{columns[0]:{id_col_len}s} {columns[1]}") task2record = dict(task_cache.load(task_graph)) for task in task_graph.tasks: record = task2record[task] status = record.type.name msg = f"{task.task_id:{id_col_len}s} {status}" is_error = record.type == CacheRecordType.ERROR if is_error and not verbose: msg += " (-v to see exception)" print(msg) if is_error and verbose: traceback.print_exception( None, value=record.error, tb=record.error.__traceback__, file=sys.stdout, ) return True
[docs] def task_graph_run( task_graph: TaskGraph, executor: TaskGraphExecutor | None = None, *, keep_going: bool = False, ) -> bool: """Execute (runs) the task_graph with the given executor.""" if executor is None: executor = SequentialExecutor() tasks_iter = executor.execute(task_graph) no_errors = True for task, result_or_error in tasks_iter: if isinstance(result_or_error, Exception): if keep_going: print(f"Task {task.task_id} failed with:", file=sys.stderr) traceback.print_exception( None, value=result_or_error, tb=result_or_error.__traceback__, file=sys.stdout, ) no_errors = False else: raise result_or_error return no_errors
[docs] def task_graph_run_with_results( task_graph: TaskGraph, executor: TaskGraphExecutor, ) -> Generator[Any, None, None]: """Run a task graph, yielding the results from each task.""" tasks_iter = executor.execute(task_graph) for _, result_or_error in tasks_iter: if isinstance(result_or_error, Exception): raise result_or_error yield result_or_error
[docs] def task_graph_all_done(task_graph: TaskGraph, task_cache: TaskCache) -> bool: """Check if the task graph is fully executed. When all tasks are already computed, the function returns True. If there are tasks, that need to run, the function returns False. """ # pylint: disable=protected-access AbstractTaskGraphExecutor._check_for_cyclic_deps(task_graph) # noqa: SLF001 already_computed_tasks = {} for task_node, record in task_cache.load(task_graph): if record.type == CacheRecordType.COMPUTED: already_computed_tasks[task_node] = record.result for task_node in AbstractTaskGraphExecutor._in_exec_order(task_graph): # noqa: SLF001 if task_node not in already_computed_tasks: return False return True