Source code for dae.task_graph.cli_tools

import argparse
import logging
import textwrap
from typing import Any

import yaml
from box import Box

from dae.dask.named_cluster import setup_client, setup_client_from_config
from dae.task_graph.cache import NoTaskCache, TaskCache
from dae.task_graph.executor import (
    DaskExecutor,
    SequentialExecutor,
    TaskGraphExecutor,
    task_graph_all_done,
    task_graph_run,
    task_graph_status,
)
from dae.task_graph.graph import TaskGraph

logger = logging.getLogger(__name__)


[docs] class TaskGraphCli: """Takes care of creating a task graph executor and executing a graph."""
[docs] @staticmethod def add_arguments( parser: argparse.ArgumentParser, force_mode: str = "optional", default_task_status_dir: str | None = ".", *, use_commands: bool = True, ) -> None: """Add arguments needed to execute a task graph.""" executor_group = parser.add_argument_group(title="Task Graph Executor") # cluster_name # cluster_config_file executor_group.add_argument( "-j", "--jobs", type=int, default=None, help="Number of jobs to run in parallel. Defaults to the number " "of processors on the machine") executor_group.add_argument( "-N", "--dask-cluster-name", "--dcn", dest="dask_cluster_name", type=str, default=None, help="The named of the named dask cluster", ) executor_group.add_argument( "-c", "--dccf", "--dask-cluster-config-file", dest="dask_cluster_config_file", type=str, default=None, help="dask cluster config file", ) executor_group.add_argument( "--tasks-log-dir", dest="log_dir", type=str, default="./.task-log", help="Path to directory where to store tasks' logs", ) # task_cache execution_mode_group = parser.add_argument_group( title="Execution Mode") if use_commands: execution_mode_group.add_argument( "command", choices=["run", "list", "status"], default="run", nargs="?", help=textwrap.dedent("""\ Command to execute on the import configuration. run - runs the import process list - lists the tasks to be executed but doesn't run them status - synonym for list """), ) execution_mode_group.add_argument( "-t", "--task-ids", dest="task_ids", type=str, nargs="+") execution_mode_group.add_argument( "--keep-going", default=False, action="store_true", help="Whether or not to keep executing in case of an error", ) if force_mode == "optional": execution_mode_group.add_argument( "--force", "-f", default=False, action="store_true", help="Ignore precomputed state and always rerun all tasks.", ) execution_mode_group.add_argument( "-d", "--task-status-dir", "--tsd", default=default_task_status_dir, type=str, help="Directory to store the task progress.", ) else: assert force_mode == "always"
[docs] @staticmethod def create_executor( task_cache: TaskCache | None = None, **kwargs: Any) -> TaskGraphExecutor: """Create a task graph executor according to the args specified.""" args = Box(kwargs) if task_cache is None: task_cache = NoTaskCache() if args.jobs == 1: assert args.dask_cluster_name is None assert args.dask_cluster_config_file is None return SequentialExecutor(task_cache=task_cache, **kwargs) assert args.dask_cluster_name is None or \ args.dask_cluster_config_file is None if args.dask_cluster_config_file is not None: dask_cluster_config_file = args.dask_cluster_config_file assert dask_cluster_config_file is not None with open(dask_cluster_config_file) as conf_file: dask_cluster_config = yaml.safe_load(conf_file) print("THE CLUSTER CONFIG IS:", dask_cluster_config, "loaded from", args.dask_cluster_config_file) client, _ = setup_client_from_config( dask_cluster_config, number_of_threads_param=args.jobs, ) else: client, _ = setup_client(args.dask_cluster_name, number_of_threads=args.jobs) print("Working with client:", client) return DaskExecutor(client, task_cache=task_cache, **kwargs)
[docs] @staticmethod def process_graph( task_graph: TaskGraph, force_mode: str = "optional", **kwargs: Any) -> bool: """Process task_graph in according with the arguments in args. Return true if the graph get's successfully processed. """ args = Box(kwargs) if args.task_ids: task_graph = task_graph.prune(ids_to_keep=args.task_ids) force = None if force_mode == "always" else args.get("force") task_cache = TaskCache.create( force=force, cache_dir=args.get("task_status_dir")) if args.command is None or args.command == "run": if task_graph_all_done(task_graph, task_cache): logger.warning( "All tasks are already COMPUTED; nothing to compute") return True with TaskGraphCli.create_executor(task_cache, **kwargs) as xtor: return task_graph_run( task_graph, xtor, keep_going=args.keep_going) if args.command in {"list", "status"}: return task_graph_status(task_graph, task_cache, args.verbose) raise ValueError(f"Unknown command {args.command}")