Source code for dae.dask.named_cluster

import copy
import logging
import os
from typing import Any, Dict, Tuple

import dask
from distributed.client import Client

_CLUSTER_TYPES = {}
logger = logging.getLogger(__name__)


[docs] def set_up_local_cluster(cluster_conf): """Create a local cluster using the passed cluster configuration.""" # pylint: disable=import-outside-toplevel from dask.distributed import LocalCluster kwargs = copy.copy(cluster_conf) number_of_workers = kwargs.pop("number_of_workers", None) if number_of_workers is not None: if "n_workers" not in kwargs and "cores" not in kwargs: kwargs["n_workers"] = 1 cluster = LocalCluster(**kwargs) return cluster
[docs] def set_up_sge_cluster(cluster_conf): # pylint: disable=import-outside-toplevel from dask_jobqueue import SGECluster cluster_conf.pop("number_of_workers", None) return SGECluster(**cluster_conf)
[docs] def set_up_slurm_cluster(cluster_conf): # pylint: disable=import-outside-toplevel from dask_jobqueue import SLURMCluster cluster_conf.pop("number_of_workers", None) return SLURMCluster(**cluster_conf)
[docs] def set_up_kubernetes_cluster(cluster_conf): """Create a kubernetes cluster.""" # pylint: disable=import-outside-toplevel from dask_kubernetes.operator.kubecluster import ( KubeCluster, make_cluster_spec, ) cluster_conf.pop("number_of_workers", None) env = {} if "envvars" in cluster_conf: env = {v: os.environ[v] for v in cluster_conf["envvars"]} spec = make_cluster_spec( name="gpf-dask-cluster", image=cluster_conf["container_image"], env=env, ) if cluster_conf.get("image_pull_secrets"): secrets = [ {"name": name} for name in cluster_conf.get("image_pull_secrets", []) ] spec["spec"]["worker"]["spec"]["imagePullSecrets"] = secrets spec["spec"]["scheduler"]["spec"]["imagePullSecrets"] = secrets cluster = KubeCluster(n_workers=1, custom_cluster_spec=spec) return cluster
_CLUSTER_TYPES["local"] = set_up_local_cluster _CLUSTER_TYPES["sge"] = set_up_sge_cluster _CLUSTER_TYPES["slurm"] = set_up_slurm_cluster _CLUSTER_TYPES["kubernetes"] = set_up_kubernetes_cluster
[docs] def setup_client_from_config(cluster_config, number_of_threads_param: int | None = None) \ -> Tuple[Client, Dict[str, Any]]: """Create a dask client from the provided config.""" logger.info("CLUSTER CONFIG: %s", cluster_config) cluster_type = cluster_config["type"] cluster_params = cluster_config.get("params", {}) cluster_params["number_of_workers"] = number_of_threads_param cluster = _CLUSTER_TYPES[cluster_type](cluster_params) number_of_threads = cluster_config.get("number_of_threads", None) if number_of_threads_param is not None: number_of_threads = number_of_threads_param if number_of_threads is not None: cluster.scale(n=number_of_threads) elif "adapt_params" in cluster_config: cluster.adapt(**cluster_config["adapt_params"]) client = Client(cluster) return client, cluster_config
[docs] def setup_client(cluster_name: str | None = None, number_of_threads: int | None = None) \ -> Tuple[Client, Dict[str, Any]]: """Create a dask client from the provided cluster name.""" if cluster_name is None: cluster_name = dask.config.get("dae_named_cluster.default") clusters = {conf["name"]: conf for conf in dask.config.get("dae_named_cluster.clusters")} cluster_config = clusters[cluster_name] return setup_client_from_config(cluster_config, number_of_threads)