import argparse
import logging
import sys
import textwrap
import traceback
from collections.abc import Generator
from typing import Any
import yaml
from box import Box
from dae.task_graph.cache import CacheRecordType, NoTaskCache, TaskCache
from dae.task_graph.executor import (
TaskGraphExecutor,
)
from dae.task_graph.graph import TaskGraph
from dae.task_graph.process_pool_executor import ProcessPoolTaskExecutor
from dae.task_graph.sequential_executor import SequentialExecutor
logger = logging.getLogger(__name__)
[docs]
class TaskGraphCli:
"""Takes care of creating a task graph executor and executing a graph."""
[docs]
@staticmethod
def add_arguments(
parser: argparse.ArgumentParser, *,
task_progress_mode: bool = True,
default_task_status_dir: str | None = "./.task-progress",
use_commands: bool = True,
) -> None:
"""Add arguments needed to execute a task graph."""
executor_group = parser.add_argument_group(title="Task Graph Executor")
# cluster_name
# cluster_config_file
executor_group.add_argument(
"-j", "--jobs", type=int, default=None,
help="Number of jobs to run in parallel. Defaults to the number "
"of processors on the machine")
executor_group.add_argument(
"--process-pool", "--pp",
dest="use_process_pool", action="store_true",
help="Use a process pool executor with the specified number of "
"processes instead of a dask distributed executor.",
)
executor_group.add_argument(
"-N", "--dask-cluster-name", "--dcn",
dest="dask_cluster_name",
type=str, default=None,
help="The named of the named dask cluster",
)
executor_group.add_argument(
"-c", "--dccf", "--dask-cluster-config-file",
dest="dask_cluster_config_file",
type=str, default=None,
help="dask cluster config file",
)
executor_group.add_argument(
"--task-log-dir", dest="task_log_dir", type=str,
default=None,
help="Path to directory where to store tasks' logs",
)
# task_cache
execution_mode_group = parser.add_argument_group(
title="Execution Mode")
if use_commands:
execution_mode_group.add_argument(
"command",
choices=["run", "list", "status"],
default="run", nargs="?",
help=textwrap.dedent("""\
Command to execute on the import configuration.
run - runs the import process
list - lists the tasks to be executed but doesn't run them
status - synonym for list
"""),
)
execution_mode_group.add_argument(
"-t", "--task-ids", dest="task_ids", type=str, nargs="+")
execution_mode_group.add_argument(
"--keep-going", default=False, action="store_true",
help="Whether or not to keep executing in case of an error",
)
if task_progress_mode:
execution_mode_group.add_argument(
"--force", "-f", default=False, action="store_true",
help="Ignore precomputed state and always rerun all tasks.",
)
execution_mode_group.add_argument(
"-d", "--task-status-dir", "--tsd",
default=default_task_status_dir,
type=str, help="Directory to store the task progress.",
)
execution_mode_group.add_argument(
"--fork-tasks", "--fork-task", "--fork",
dest="fork_tasks", action="store_true",
help="Whether to fork a new worker process for each task",
)
else:
assert not task_progress_mode, \
"task_progress_mode must be False if no cache is used"
@staticmethod
def _create_dask_executor(
task_cache: TaskCache,
**kwargs: Any,
) -> TaskGraphExecutor:
"""Create a task graph executor according to the args specified."""
# pylint: disable=import-outside-toplevel
from dae.dask.named_cluster import (
setup_client,
setup_client_from_config,
)
from dae.task_graph.dask_executor import DaskExecutor
args = Box(kwargs)
assert args.dask_cluster_name is None or \
args.dask_cluster_config_file is None
if args.dask_cluster_config_file is not None:
dask_cluster_config_file = args.dask_cluster_config_file
assert dask_cluster_config_file is not None
with open(dask_cluster_config_file) as conf_file:
dask_cluster_config = yaml.safe_load(conf_file)
logger.info(
"THE CLUSTER CONFIG IS: %s; loaded from: %s",
dask_cluster_config,
args.dask_cluster_config_file)
client, _ = setup_client_from_config(
dask_cluster_config,
number_of_workers=args.jobs,
)
else:
client, _ = setup_client(
args.dask_cluster_name,
number_of_workers=args.jobs)
logger.info("Working with client: %s", client)
return DaskExecutor(client, task_cache=task_cache, **kwargs)
[docs]
@staticmethod
def create_executor(
task_cache: TaskCache | None = None,
**kwargs: Any) -> TaskGraphExecutor:
"""Create a task graph executor according to the args specified."""
args = Box(kwargs)
if task_cache is None:
task_cache = NoTaskCache()
if args.jobs == 1:
assert args.dask_cluster_name is None
assert args.dask_cluster_config_file is None
return SequentialExecutor(task_cache=task_cache, **kwargs)
if args.use_process_pool:
assert args.dask_cluster_name is None
assert args.dask_cluster_config_file is None
return ProcessPoolTaskExecutor(
max_workers=args.jobs,
task_cache=task_cache,
**kwargs)
return TaskGraphCli._create_dask_executor(
task_cache=task_cache, **kwargs)
[docs]
@staticmethod
def process_graph(
task_graph: TaskGraph, *,
task_progress_mode: bool = True,
**kwargs: Any,
) -> bool:
"""Process task_graph in according with the arguments in args.
Return true if the graph get's successfully processed.
"""
args = Box(kwargs)
if args.task_ids:
task_graph.prune(tasks_to_keep=args.task_ids)
force = args.get("force", False)
task_cache = TaskCache.create(
task_progress_mode=task_progress_mode,
force=force,
cache_dir=args.get("task_status_dir"),
)
if args.command is None or args.command == "run":
if task_graph_all_done(task_graph, task_cache):
logger.warning(
"All tasks are already COMPUTED; nothing to compute")
return True
with TaskGraphCli.create_executor(task_cache, **kwargs) as xtor:
return task_graph_run(
task_graph, xtor, keep_going=args.keep_going)
if args.command in {"list", "status"}:
return task_graph_status(task_graph, task_cache, args.verbose)
raise ValueError(f"Unknown command {args.command}")
[docs]
def task_graph_run(
task_graph: TaskGraph,
executor: TaskGraphExecutor | None = None,
*,
keep_going: bool = False,
) -> bool:
"""Execute (runs) the task_graph with the given executor."""
no_errors = True
for result_or_error in task_graph_run_with_results(
task_graph, executor, keep_going=keep_going):
if isinstance(result_or_error, Exception):
no_errors = False
return no_errors
[docs]
def task_graph_run_with_results(
task_graph: TaskGraph, executor: TaskGraphExecutor | None = None,
*,
keep_going: bool = False,
) -> Generator[Any, None, None]:
"""Run a task graph, yielding the results from each task."""
if executor is None:
executor = SequentialExecutor()
tasks_iter = executor.execute(task_graph)
for task, result_or_error in tasks_iter:
if isinstance(result_or_error, Exception):
if keep_going:
print(f"Task {task.task_id} failed with:",
file=sys.stderr)
traceback.print_exception(
None, value=result_or_error,
tb=result_or_error.__traceback__,
file=sys.stdout,
)
else:
raise result_or_error
yield result_or_error
[docs]
def task_graph_all_done(task_graph: TaskGraph, task_cache: TaskCache) -> bool:
"""Check if the task graph is fully executed.
When all tasks are already computed, the function returns True.
If there are tasks, that need to run, the function returns False.
"""
# pylint: disable=protected-access
already_computed_tasks = {}
for task_node, record in task_cache.load(task_graph):
if record.type == CacheRecordType.COMPUTED:
already_computed_tasks[task_node] = record.result
for task_node in task_graph.topological_order():
if task_node not in already_computed_tasks:
return False
return True
[docs]
def task_graph_status(
task_graph: TaskGraph, task_cache: TaskCache,
verbose: int | None) -> bool:
"""Show the status of each task from the task graph."""
id_col_len = max(len(t.task_id) for t in task_graph.tasks)
id_col_len = min(120, max(50, id_col_len))
columns = ["TaskID", "Status"]
print(f"{columns[0]:{id_col_len}s} {columns[1]}")
task2record = dict(task_cache.load(task_graph))
for task in task_graph.tasks:
record = task2record[task]
status = record.type.name
msg = f"{task.task_id:{id_col_len}s} {status}"
is_error = record.type == CacheRecordType.ERROR
if is_error and not verbose:
msg += " (-v to see exception)"
print(msg)
if is_error and verbose:
traceback.print_exception(
None, value=record.error,
tb=record.error.__traceback__,
file=sys.stdout,
)
return True