import logging
import threading
import time
from collections.abc import Iterator
from copy import copy
from typing import Any
from dask.distributed import Client, Future, wait
from dae.task_graph.base_executor import TaskGraphExecutorBase
from dae.task_graph.cache import NoTaskCache, TaskCache
from dae.task_graph.graph import Task, TaskDesc, TaskGraph
from dae.task_graph.logging import (
ensure_log_dir,
safe_task_id,
)
NO_TASK_CACHE = NoTaskCache()
logger = logging.getLogger(__name__)
[docs]
class DaskExecutor(TaskGraphExecutorBase):
"""Dask-based task graph executor."""
def __init__(
self, dask_client: Client,
task_cache: TaskCache = NO_TASK_CACHE, **kwargs: Any,
) -> None:
"""Initialize the Dask executor.
Args:
dask_client: Dask client to use for task execution.
"""
super().__init__(task_cache=task_cache, **kwargs)
self._executing = False
self._dask_client = dask_client
log_dir = ensure_log_dir(**kwargs)
self._params = copy(kwargs)
self._params["task_log_dir"] = log_dir
def _submit_worker_func(
self,
submit_queue: list[TaskDesc | None],
submit_condition: threading.Condition,
running: dict[Future, Task],
running_lock: threading.Lock,
) -> None:
start = time.time()
submit_count = 0
while True:
tasks: list[TaskDesc | None] = []
with submit_condition:
if not submit_queue:
submit_condition.wait()
tasks = copy(submit_queue)
submit_queue.clear()
if any(t is None for t in tasks):
logger.warning(
"submit worker received shutdown signal; "
"skipping %s tasks...",
len(tasks) - 1)
return
assert all(isinstance(t, TaskDesc) for t in tasks)
task_ids = [
safe_task_id(task.task.task_id)
for task in tasks
if task is not None
]
futures = self._dask_client.map(
self._exec, tasks,
key=task_ids,
pure=False,
params=self._params,
)
with running_lock:
for future, task in zip(futures, tasks, strict=True):
assert task is not None
submit_count += 1
running[future] = task.task
if submit_count % 100 == 0:
elapsed = time.time() - start
logger.info(
"submitted %s tasks in %.2f seconds; %.2f tasks/s",
submit_count, elapsed, submit_count / elapsed)
logger.info(
"total running tasks: %s", len(running))
def _results_worker_func(
self,
completed_queue: list[tuple[Future, Task] | None],
completed_condition: threading.Condition,
results_queue: list[tuple[Task, Any]],
results_lock: threading.Lock,
) -> None:
result_count = 0
start = time.time()
with completed_condition:
while True:
while completed_queue:
item = completed_queue.pop()
if item is None:
logger.warning(
"results worker received shutdown signal.")
return
future, task = item
logger.debug("processing completed task %s", task.task_id)
try:
result = future.result()
except Exception as ex: # noqa: BLE001
# pylint: disable=broad-except
result = ex
result_count += 1
elapsed = time.time() - start
if result_count % 100 == 0:
logger.info(
"processed %s results in %.2f seconds "
"(%.2f results/s)",
result_count, elapsed,
result_count / elapsed)
with results_lock:
results_queue.append((task, result))
completed_condition.wait(timeout=0.2)
completed_condition.wait()
def _execute(
self, graph: TaskGraph,
) -> Iterator[tuple[Task, Any]]:
self._executing = True
submit_queue: list[TaskDesc | None] = []
submit_condition: threading.Condition = threading.Condition()
running_lock: threading.Lock = threading.Lock()
running: dict[Future, Task] = {}
completed_queue: list[tuple[Future, Task] | None] = []
completed_condition: threading.Condition = threading.Condition()
results_queue: list[tuple[Task, Any]] = []
results_lock: threading.Lock = threading.Lock()
submit_worker = threading.Thread(
target=self._submit_worker_func,
args=(
submit_queue, submit_condition,
running, running_lock),
daemon=True)
submit_worker.start()
results_worker = threading.Thread(
target=self._results_worker_func,
args=(
completed_queue, completed_condition,
results_queue, results_lock),
daemon=True)
results_worker.start()
not_completed: set[Future] = set()
is_done: bool = graph.empty()
finished_tasks = 0
initial_task_count = len(graph)
while not is_done:
ready_tasks = graph.extract_tasks(graph.ready_tasks())
with submit_condition:
if ready_tasks:
submit_queue.extend(ready_tasks)
submit_condition.notify_all()
with running_lock:
not_completed = set(running.keys())
if not not_completed:
time.sleep(0.05)
completed = set()
else:
try:
completed, not_completed = wait(
not_completed,
return_when="FIRST_COMPLETED",
timeout=0.05,
)
except TimeoutError:
completed = set()
with running_lock, completed_condition:
for future in completed:
task = running[future]
del running[future]
completed_queue.append((future, task))
completed_condition.notify_all()
with results_lock:
while results_queue:
item = results_queue.pop(0)
task, result = item
graph.process_completed_tasks([(task, result)])
finished_tasks += 1
logger.info(
"finished %s/%s", finished_tasks,
initial_task_count)
yield task, result
is_done = graph.empty()
with results_lock:
is_done = is_done and not results_queue
with submit_condition:
is_done = is_done and not submit_queue
with running_lock:
is_done = is_done and not running
with completed_condition:
is_done = is_done and not completed_queue
with submit_condition:
submit_queue.append(None)
submit_condition.notify_all()
with completed_condition:
completed_queue.append(None)
completed_condition.notify_all()
results_worker.join()
submit_worker.join()
self._executing = False
[docs]
def close(self) -> None:
"""Close the Dask executor."""
logger.info("closing Dask executor")
self._dask_client.retire_workers(close_workers=True)
self._dask_client.shutdown()
self._dask_client.close()
logger.info("Dask executor closed")