Source code for dae.task_graph.process_pool_executor

from __future__ import annotations

import logging
import os
import time
from collections import deque
from collections.abc import Iterator
from concurrent.futures import (
    Future,
    ProcessPoolExecutor,
    as_completed,
)
from typing import Any

import psutil

from dae.task_graph.base_executor import TaskGraphExecutorBase
from dae.task_graph.graph import Task, TaskDesc, TaskGraph
from dae.task_graph.logging import (
    safe_task_id,
)

logger = logging.getLogger(__name__)


[docs] class ProcessPoolTaskExecutor(TaskGraphExecutorBase): """Execute tasks in parallel using Dask to do the heavy lifting.""" def __init__( self, **kwargs: Any, ): super().__init__(**kwargs) max_workers = kwargs.get("n_threads", os.cpu_count() or 1) self._executor = ProcessPoolExecutor(max_workers=max_workers) def _submit_task(self, task: TaskDesc) -> Future: assert len(task.deps) == 0 assert not any(isinstance(arg, Task) for arg in task.args), \ "Task has dependencies to wait for." future = self._executor.submit( self._exec_internal, task, self._params, ) if future is None: raise ValueError( f"unexpected dask executor return None: {task}, {task.args}") assert future is not None return future @staticmethod def _exec_internal( task: TaskDesc, params: dict[str, Any], # noqa: ARG004 ) -> Any: start = time.time() process = psutil.Process(os.getpid()) start_memory_mb = process.memory_info().rss / (1024 * 1024) task_id = safe_task_id(task.task.task_id) logger.info( "worker process memory usage: %.2f MB", start_memory_mb) task_func = task.func args = task.args result = task_func(*args) elapsed = time.time() - start logger.info("task <%s> finished in %0.2fsec", task_id, elapsed) finish_memory_mb = process.memory_info().rss / (1024 * 1024) logger.info( "worker process memory usage: %.2f MB; change: %+0.2f MB", finish_memory_mb, finish_memory_mb - start_memory_mb) return result def _schedule_tasks( self, graph: TaskGraph, ) -> dict[Future, Task]: ready_tasks = graph.extract_tasks(graph.ready_tasks()) submitted_tasks: dict[Future, Task] = {} if ready_tasks: logger.debug("scheduling %d tasks", len(ready_tasks)) for task in ready_tasks: future = self._submit_task(task) submitted_tasks[future] = task.task return submitted_tasks def _execute(self, graph: TaskGraph) -> Iterator[tuple[Task, Any]]: not_completed: set[Future] = set() completed: deque[Future] = deque() initial_task_count = len(graph) finished_tasks = 0 process = psutil.Process(os.getpid()) current_memory_mb = process.memory_info().rss / (1024 * 1024) logger.info( "executor memory usage: %.2f MB", current_memory_mb) submitted_tasks: dict[Future, Task] = {} while not_completed or not graph.empty(): submitted_tasks.update(self._schedule_tasks(graph)) not_completed = set(submitted_tasks.keys()) try: for future in as_completed(not_completed, timeout=0.25): not_completed.remove(future) completed.append(future) except TimeoutError: pass processed: list[tuple[Task, Any]] = [] logger.debug("going to process %d completed tasks", len(completed)) while completed: future = completed.popleft() task = submitted_tasks[future] try: result = future.result() except Exception as ex: # noqa: BLE001 # pylint: disable=broad-except result = ex graph.process_completed_tasks([(task, result)]) finished_tasks += 1 processed.append((task, result)) logger.info( "finished %s/%s", finished_tasks, initial_task_count) # del ref to future in order to make dask gc its resources del submitted_tasks[future] logger.debug("processed %d completed tasks", len(processed)) yield from processed # clean up assert len(submitted_tasks) == 0, \ "[BUG] Dask Executor's future queue is not empty." assert len(graph) == 0
[docs] def close(self) -> None: self._executor.shutdown()