Source code for dae.task_graph.dask_executor

import logging
import threading
import time
from collections.abc import Iterator
from copy import copy
from typing import Any

from dask.distributed import Client, Future, wait

from dae.task_graph.base_executor import TaskGraphExecutorBase
from dae.task_graph.cache import NoTaskCache, TaskCache
from dae.task_graph.graph import Task, TaskDesc, TaskGraph
from dae.task_graph.logging import (
    ensure_log_dir,
    safe_task_id,
)

NO_TASK_CACHE = NoTaskCache()
logger = logging.getLogger(__name__)


[docs] class DaskExecutor(TaskGraphExecutorBase): """Dask-based task graph executor.""" def __init__( self, dask_client: Client, task_cache: TaskCache = NO_TASK_CACHE, **kwargs: Any, ) -> None: """Initialize the Dask executor. Args: dask_client: Dask client to use for task execution. """ super().__init__(task_cache=task_cache, **kwargs) self._executing = False self._dask_client = dask_client log_dir = ensure_log_dir(**kwargs) self._params = copy(kwargs) self._params["task_log_dir"] = log_dir def _submit_worker_func( self, submit_queue: list[TaskDesc | None], submit_condition: threading.Condition, running: dict[Future, Task], running_lock: threading.Lock, ) -> None: start = time.time() submit_count = 0 while True: tasks: list[TaskDesc | None] = [] with submit_condition: if not submit_queue: submit_condition.wait() tasks = copy(submit_queue) submit_queue.clear() if any(t is None for t in tasks): logger.warning( "submit worker received shutdown signal; " "skipping %s tasks...", len(tasks) - 1) return assert all(isinstance(t, TaskDesc) for t in tasks) task_ids = [ safe_task_id(task.task.task_id) for task in tasks if task is not None ] futures = self._dask_client.map( self._exec, tasks, key=task_ids, pure=False, params=self._params, ) with running_lock: for future, task in zip(futures, tasks, strict=True): assert task is not None submit_count += 1 running[future] = task.task if submit_count % 100 == 0: elapsed = time.time() - start logger.info( "submitted %s tasks in %.2f seconds; %.2f tasks/s", submit_count, elapsed, submit_count / elapsed) logger.info( "total running tasks: %s", len(running)) def _results_worker_func( self, completed_queue: list[tuple[Future, Task] | None], completed_condition: threading.Condition, results_queue: list[tuple[Task, Any]], results_lock: threading.Lock, ) -> None: result_count = 0 start = time.time() with completed_condition: while True: while completed_queue: item = completed_queue.pop() if item is None: logger.warning( "results worker received shutdown signal.") return future, task = item logger.debug("processing completed task %s", task.task_id) try: result = future.result() except Exception as ex: # noqa: BLE001 # pylint: disable=broad-except result = ex result_count += 1 elapsed = time.time() - start if result_count % 100 == 0: logger.info( "processed %s results in %.2f seconds " "(%.2f results/s)", result_count, elapsed, result_count / elapsed) with results_lock: results_queue.append((task, result)) completed_condition.wait(timeout=0.2) completed_condition.wait() def _execute( self, graph: TaskGraph, ) -> Iterator[tuple[Task, Any]]: self._executing = True submit_queue: list[TaskDesc | None] = [] submit_condition: threading.Condition = threading.Condition() running_lock: threading.Lock = threading.Lock() running: dict[Future, Task] = {} completed_queue: list[tuple[Future, Task] | None] = [] completed_condition: threading.Condition = threading.Condition() results_queue: list[tuple[Task, Any]] = [] results_lock: threading.Lock = threading.Lock() submit_worker = threading.Thread( target=self._submit_worker_func, args=( submit_queue, submit_condition, running, running_lock), daemon=True) submit_worker.start() results_worker = threading.Thread( target=self._results_worker_func, args=( completed_queue, completed_condition, results_queue, results_lock), daemon=True) results_worker.start() not_completed: set[Future] = set() is_done: bool = graph.empty() finished_tasks = 0 initial_task_count = len(graph) while not is_done: ready_tasks = graph.extract_tasks(graph.ready_tasks()) with submit_condition: if ready_tasks: submit_queue.extend(ready_tasks) submit_condition.notify_all() with running_lock: not_completed = set(running.keys()) if not not_completed: time.sleep(0.05) completed = set() else: try: completed, not_completed = wait( not_completed, return_when="FIRST_COMPLETED", timeout=0.05, ) except TimeoutError: completed = set() with running_lock, completed_condition: for future in completed: task = running[future] del running[future] completed_queue.append((future, task)) completed_condition.notify_all() with results_lock: while results_queue: item = results_queue.pop(0) task, result = item graph.process_completed_tasks([(task, result)]) finished_tasks += 1 logger.info( "finished %s/%s", finished_tasks, initial_task_count) yield task, result is_done = graph.empty() with results_lock: is_done = is_done and not results_queue with submit_condition: is_done = is_done and not submit_queue with running_lock: is_done = is_done and not running with completed_condition: is_done = is_done and not completed_queue with submit_condition: submit_queue.append(None) submit_condition.notify_all() with completed_condition: completed_queue.append(None) completed_condition.notify_all() results_worker.join() submit_worker.join() self._executing = False
[docs] def close(self) -> None: """Close the Dask executor.""" logger.info("closing Dask executor") self._dask_client.retire_workers(close_workers=True) self._dask_client.shutdown() self._dask_client.close() logger.info("Dask executor closed")