Source code for dae.task_graph.executor

from __future__ import annotations

import logging
import multiprocessing as mp
import os
import pickle  # noqa: S403
import sys
import time
import traceback
from abc import abstractmethod
from collections import defaultdict, deque
from collections.abc import Callable, Generator, Iterator, Sequence
from copy import copy
from types import TracebackType
from typing import Any, cast

import fsspec
import networkx
import psutil
from dask.distributed import Client, Future

from dae.task_graph.cache import CacheRecordType, NoTaskCache, TaskCache
from dae.task_graph.graph import Task, TaskGraph
from dae.task_graph.logging import (
    configure_task_logging,
    ensure_log_dir,
    safe_task_id,
)

logger = logging.getLogger(__name__)


[docs] class TaskGraphExecutor: """Class that executes a task graph."""
[docs] @abstractmethod def execute(self, task_graph: TaskGraph) -> Iterator[tuple[Task, Any]]: """Start executing the graph. Return an iterator that yields the task in the graph after they are executed. This is not nessessarily in DFS or BFS order. This is not even the order in which these tasks are executed. The only garantee is that when a task is returned its executions is already finished. """
def __enter__(self) -> TaskGraphExecutor: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: self.close() return exc_type is None
[docs] @abstractmethod def close(self) -> None: """Clean-up any resources used by the executor."""
NO_TASK_CACHE = NoTaskCache()
[docs] class AbstractTaskGraphExecutor(TaskGraphExecutor): """Executor that walks the graph in order that satisfies dependancies.""" def __init__(self, task_cache: TaskCache = NO_TASK_CACHE, **kwargs: Any): super().__init__() self._task_cache = task_cache self._executing = False self._task_queue: dict[str, Task] = {} self._task_dependants: dict[str, set[str]] = defaultdict(set) log_dir = ensure_log_dir(**kwargs) self._params = copy(kwargs) self._params["task_log_dir"] = log_dir def _queue_size(self) -> int: return 1 @staticmethod def _ready_to_run(task: Task) -> bool: return not task.deps def _select_tasks_to_run( self, limit: int, ) -> list[Task]: started = time.time() selected_tasks: list[Task] = [] di_graph = AbstractTaskGraphExecutor._build_di_graph(self._task_queue) for task_id in networkx.topological_sort(di_graph): if task_id not in self._task_queue: break task = self._task_queue[task_id] elapsed = time.time() - started if elapsed > 1.0: logger.debug( "selecting tasks to run took too long (%.2f sec), " "stopping search", elapsed) break if not self._ready_to_run(task): break selected_tasks.append(task) if len(selected_tasks) >= limit: break elapsed = time.time() - started logger.debug( "selecting %s tasks to run took %.2f sec", len(selected_tasks), elapsed) return selected_tasks def _prune_dependants(self, task_id: str) -> None: started = time.time() dependents = self._task_dependants.get(task_id, set()) for dep_id in dependents: logger.debug("pruning dependent task %s", dep_id) if dep_id not in self._task_queue: continue del self._task_queue[dep_id] elapsed = time.time() - started logger.debug( "pruning dependents of task %s took %.2f sec", task_id, elapsed) @staticmethod def _exec_internal( task_func: Callable, args: list, _deps: list, params: dict[str, Any], ) -> Any: verbose = params.get("verbose") if verbose is None: # Dont use .get default in case of a Box verbose = 0 task_id = params["task_id"] log_dir = params.get("task_log_dir", ".") root_logger = logging.getLogger() handler = configure_task_logging(log_dir, task_id, verbose) root_logger.addHandler(handler) task_logger = logging.getLogger("task_executor") task_logger.warning("task logging verbosity level: %d", verbose) task_logger.info("task <%s> started", task_id) start = time.time() process = psutil.Process(os.getpid()) start_memory_mb = process.memory_info().rss / (1024 * 1024) task_logger.info( "worker process memory usage: %.2f MB", start_memory_mb) result = task_func(*args) elapsed = time.time() - start task_logger.info("task <%s> finished in %0.2fsec", task_id, elapsed) finish_memory_mb = process.memory_info().rss / (1024 * 1024) task_logger.info( "worker process memory usage: %.2f MB; change: %+0.2f MB", finish_memory_mb, finish_memory_mb - start_memory_mb) root_logger.removeHandler(handler) handler.close() return result @staticmethod def _exec_forked( task_func: Callable, args: list, deps: list, params: dict[str, Any], ) -> None: result_fn = AbstractTaskGraphExecutor._result_fn(params) try: result = AbstractTaskGraphExecutor._exec_internal( task_func, args, deps, params, ) except Exception as exp: # noqa: BLE001 # pylint: disable=broad-except result = exp try: with fsspec.open(result_fn, "wb") as out: pickle.dump(result, out) # pyright: ignore except Exception: # pylint: disable=broad-except logger.exception( "cannot write result for task %s. Ignoring and continuing.", result_fn, ) @staticmethod def _result_fn(params: dict[str, Any]) -> str: task_id = params["task_id"] status_dir = params.get("task_status_dir", ".") return os.path.join(status_dir, f"{task_id}.result") @staticmethod def _exec( task_func: Callable, args: list, _deps: list, params: dict[str, Any], ) -> Any: fork_tasks = params.get("fork_tasks", False) if not fork_tasks: return AbstractTaskGraphExecutor._exec_internal( task_func, args, _deps, params, ) mp.current_process()._config[ # type: ignore # noqa: SLF001 "daemon"] = False p = mp.Process( target=AbstractTaskGraphExecutor._exec_forked, args=(task_func, args, _deps, params), ) p.start() p.join() result_fn = AbstractTaskGraphExecutor._result_fn(params) try: with fsspec.open(result_fn, "rb") as infile: result = pickle.load(infile) # pyright: ignore except Exception: # pylint: disable=broad-except logger.exception( "cannot write result for task %s. Ignoring and continuing.", result_fn, ) result = None return result
[docs] def execute(self, task_graph: TaskGraph) -> Iterator[tuple[Task, Any]]: assert not self._executing, \ "Cannot execute a new graph while an old one is still running." self._check_for_cyclic_deps(task_graph) self._executing = True completed_tasks: dict[Task, Any] = {} for task in self._in_exec_order(task_graph): self._queue_task(task) for task, record in self._task_cache.load(task_graph): if record.type == CacheRecordType.COMPUTED: result = record.result completed_tasks[task] = result task_dependants = self._task_dependants.get(task.task_id, set()) for dep_id in task_dependants: if dep_id not in self._task_queue: continue dep_task = self._task_queue[dep_id] dep_task = self._reconfigure_task_node_deps( dep_task, task, result) self._task_queue[dep_id] = dep_task del self._task_queue[task.task_id] return self._yield_task_results(completed_tasks)
def _yield_task_results( self, already_computed_tasks: dict[Task, Any], ) -> Iterator[tuple[Task, Any]]: for task_node, result in already_computed_tasks.items(): yield task_node, result for task_node, result in self._await_tasks(): is_error = isinstance(result, BaseException) self._task_cache.cache( task_node, is_error=is_error, result=result, ) yield task_node, result self._executing = False def _collect_deps( self, task_node: Task, result: set[str] | None = None, ) -> set[str]: if result is None: result = set() for dep in task_node.deps: result.add(dep.task_id) self._collect_deps(dep, result) return result def _queue_task(self, task_node: Task) -> None: self._task_queue[task_node.task_id] = task_node for dep_id in self._collect_deps(task_node): self._task_dependants[dep_id].add(task_node.task_id) @abstractmethod def _await_tasks(self) -> Iterator[tuple[Task, Any]]: """Yield enqueued tasks as soon as they finish.""" @staticmethod def _in_exec_order( task_graph: TaskGraph, ) -> Sequence[Task]: return list(AbstractTaskGraphExecutor._toplogical_order(task_graph)) @staticmethod def _build_di_graph( tasks: dict[str, Task], ) -> networkx.DiGraph: di_graph = networkx.DiGraph() nodes: list[str] = list(tasks.keys()) di_graph.add_nodes_from(sorted(nodes)) for task in tasks.values(): for dep in task.deps: di_graph.add_edge(dep.task_id, task.task_id) return di_graph @staticmethod def _toplogical_order( task_graph: TaskGraph, ) -> Generator[Task, None, None]: queue: dict[str, Task] = {} for task in task_graph.tasks: if task.task_id not in queue: queue[task.task_id] = task graph = AbstractTaskGraphExecutor._build_di_graph(queue) for task_id in networkx.topological_sort(graph): task = queue.pop(task_id) yield task @staticmethod def _check_for_cyclic_deps(task_graph: TaskGraph) -> None: visited: set[Task] = set() stack: list[Task] = [] for node in task_graph.tasks: if node not in visited: cycle = AbstractTaskGraphExecutor._find_cycle( node, visited, stack) if cycle is not None: raise ValueError(f"Cyclic dependency {cycle}") @staticmethod def _find_cycle( node: Task, visited: set[Task], stack: list[Task], ) -> list[Task] | None: visited.add(node) stack.append(node) for dep in node.deps: if dep not in visited: return AbstractTaskGraphExecutor._find_cycle( dep, visited, stack) if dep in stack: return copy(stack) stack.pop() return None
[docs] def close(self) -> None: pass
def _reconfigure_task_node_deps( self, task: Task, completed_task: Task, completed_result: Any, ) -> Task: args = [] for arg in task.args: if isinstance(arg, Task) and arg.task_id == completed_task.task_id: arg = completed_result args.append(arg) deps = [] for dep in task.deps: if dep.task_id == completed_task.task_id: continue deps.append(dep) return Task( task.task_id, task.func, args, deps, task.input_files, ) def _process_completed_task( self, task: Task, result: Any, ) -> None: if isinstance(result, BaseException): self._prune_dependants(task.task_id) return start = time.time() task_dependants = self._task_dependants.get(task.task_id, set()) logger.debug("processing %d dependants", len(task_dependants)) for dep_id in task_dependants: if dep_id not in self._task_queue: continue dep_task = self._task_queue[dep_id] dep_task = self._reconfigure_task_node_deps(dep_task, task, result) self._task_queue[dep_id] = dep_task if task.task_id in self._task_dependants: del self._task_dependants[task.task_id] logger.debug("task_dependents size: %d", len(self._task_dependants)) elapsed = time.time() - start logger.debug("processing completed task took %.2f sec", elapsed)
[docs] class SequentialExecutor(AbstractTaskGraphExecutor): """A Task Graph Executor that executes task in sequential order.""" def _await_tasks(self) -> Generator[tuple[Task, Any], None, None]: finished_tasks = 0 initial_task_count = len(self._task_queue) while self._task_queue: selected_tasks = self._select_tasks_to_run(1) for task in selected_tasks: # handle tasks that use the output of other tasks params = copy(self._params) task_id = safe_task_id(task.task_id) params["task_id"] = task_id try: result = self._exec(task.func, task.args, [], params) except Exception as exp: # noqa: BLE001 # pylint: disable=broad-except result = exp self._process_completed_task(task, result) finished_tasks += 1 logger.debug("clean up task %s", task) logger.info( "finished %s/%s", finished_tasks, initial_task_count) del self._task_queue[task.task_id] yield task, result # all tasks have already executed. Let's clean the state. assert len(self._task_queue) == 0
[docs] class DaskExecutor(AbstractTaskGraphExecutor): """Execute tasks in parallel using Dask to do the heavy lifting.""" def __init__( self, client: Client, task_cache: TaskCache = NO_TASK_CACHE, **kwargs: Any, ): super().__init__(task_cache, **kwargs) self._client = client self._task2future: dict[Task, Future] = {} self._future2task: dict[str, Task] = {} def _submit_task(self, task: Task) -> Future: assert len(task.deps) == 0 assert not any(isinstance(arg, Task) for arg in task.args), \ "Task has no dependencies to wait for." params = copy(self._params) task_id = safe_task_id(task.task_id) params["task_id"] = task_id future = self._client.submit( self._exec, task.func, task.args, [], params, key=task_id, pure=False, ) if future is None: raise ValueError( f"unexpected dask executor return None: {task}, {task.args}, " f"{params}") assert future is not None self._task2future[task] = future self._future2task[str(future.key)] = task return future MIN_QUEUE_SIZE = 400 def _queue_size(self) -> int: n_workers = cast( int, sum(self._client.ncores().values())) return max(self.MIN_QUEUE_SIZE, 4 * n_workers) def _available_slots(self, currently_running: int) -> int: return self._queue_size() - currently_running def _schedule_tasks(self, currently_running: set[Future]) -> set[Future]: start = time.time() tasks_to_run = self._select_tasks_to_run( self._available_slots(len(currently_running))) logger.debug("scheduling %d tasks", len(tasks_to_run)) scheduled = 0 for task in tasks_to_run: del self._task_queue[task.task_id] future = self._submit_task(task) currently_running.add(future) elapsed = time.time() - start scheduled += 1 if elapsed > 2.0: logger.debug( "scheduling took too long (%.2f sec), stopping", elapsed) break elapsed = time.time() - start logger.debug( "scheduling %d tasks took %.2f sec", scheduled, elapsed) logger.debug("currently running %d tasks", len(currently_running)) return currently_running def _await_tasks(self) -> Generator[tuple[Task, Any], None, None]: # pylint: disable=import-outside-toplevel from dask.distributed import wait not_completed: set[Future] = set() completed: deque[Future] = deque() initial_task_count = len(self._task_queue) finished_tasks = 0 process = psutil.Process(os.getpid()) current_memory_mb = process.memory_info().rss / (1024 * 1024) logger.info( "executor memory usage: %.2f MB", current_memory_mb) not_completed = self._schedule_tasks(not_completed) while not_completed or self._task_queue: if len(completed) < self._queue_size() // 2: not_completed = self._schedule_tasks(not_completed) else: logger.debug( "too many completed tasks (%d), processing them first", len(completed)) if not_completed: try: new_completed, not_completed = wait( not_completed, return_when="FIRST_COMPLETED", timeout=0.1, ) except TimeoutError: new_completed = [] else: new_completed = set() for future in new_completed: completed.append(future) start = time.time() processed: list[tuple[Task, Any]] = [] logger.debug("going to process %d completed tasks", len(completed)) while completed: future = completed.popleft() task = self._future2task[str(future.key)] result_start = time.time() try: result = future.result() except Exception as ex: # noqa: BLE001 # pylint: disable=broad-except result = ex result_elapsed = time.time() - result_start logger.debug( "retrieving result for task %s took %.2f sec", task.task_id, result_elapsed) self._process_completed_task(task, result) finished_tasks += 1 processed.append((task, result)) logger.info( "finished %s/%s", finished_tasks, initial_task_count) # del ref to future in order to make dask gc its resources del self._task2future[task] del self._future2task[future.key] elapsed = time.time() - start if elapsed > 2.0: logger.debug( "processing completed tasks took too long " "(%.2f sec), stopping", elapsed) break logger.debug("processed %d completed tasks", len(processed)) start = time.time() yield from processed elapsed = time.time() - start logger.debug( "yielding %d completed tasks took %.2f sec", len(processed), elapsed) # clean up assert len(self._task2future) == 0, \ "[BUG] Dask Executor's future queue is not empty." assert len(self._task_queue) == 0, \ "[BUG] Dask Executor's task queue is not empty." self._task2future = {} self._future2task = {} assert len(self._task_queue) == 0
[docs] def close(self) -> None: self._client.close()
[docs] def task_graph_run( task_graph: TaskGraph, executor: TaskGraphExecutor | None = None, *, keep_going: bool = False, ) -> bool: """Execute (runs) the task_graph with the given executor.""" no_errors = True for result_or_error in task_graph_run_with_results( task_graph, executor, keep_going=keep_going): if isinstance(result_or_error, Exception): no_errors = False return no_errors
[docs] def task_graph_run_with_results( task_graph: TaskGraph, executor: TaskGraphExecutor | None = None, *, keep_going: bool = False, ) -> Generator[Any, None, None]: """Run a task graph, yielding the results from each task.""" if executor is None: executor = SequentialExecutor() tasks_iter = executor.execute(task_graph) for task, result_or_error in tasks_iter: if isinstance(result_or_error, Exception): if keep_going: print(f"Task {task.task_id} failed with:", file=sys.stderr) traceback.print_exception( None, value=result_or_error, tb=result_or_error.__traceback__, file=sys.stdout, ) else: raise result_or_error yield result_or_error
[docs] def task_graph_all_done(task_graph: TaskGraph, task_cache: TaskCache) -> bool: """Check if the task graph is fully executed. When all tasks are already computed, the function returns True. If there are tasks, that need to run, the function returns False. """ # pylint: disable=protected-access AbstractTaskGraphExecutor._check_for_cyclic_deps( # noqa: SLF001 task_graph) already_computed_tasks = {} for task_node, record in task_cache.load(task_graph): if record.type == CacheRecordType.COMPUTED: already_computed_tasks[task_node] = record.result for task_node in AbstractTaskGraphExecutor._in_exec_order( # noqa: SLF001 task_graph): if task_node not in already_computed_tasks: return False return True
[docs] def task_graph_status( task_graph: TaskGraph, task_cache: TaskCache, verbose: int | None) -> bool: """Show the status of each task from the task graph.""" id_col_len = max(len(t.task_id) for t in task_graph.tasks) id_col_len = min(120, max(50, id_col_len)) columns = ["TaskID", "Status"] print(f"{columns[0]:{id_col_len}s} {columns[1]}") task2record = dict(task_cache.load(task_graph)) for task in task_graph.tasks: record = task2record[task] status = record.type.name msg = f"{task.task_id:{id_col_len}s} {status}" is_error = record.type == CacheRecordType.ERROR if is_error and not verbose: msg += " (-v to see exception)" print(msg) if is_error and verbose: traceback.print_exception( None, value=record.error, tb=record.error.__traceback__, file=sys.stdout, ) return True