Source code for mlrun.runtimes.daskjob

# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import inspect
import socket
import time
from os import environ

from deprecated import deprecated

import mlrun.common.schemas
import mlrun.errors
import mlrun.k8s_utils
import mlrun.utils
import mlrun.utils.regex
from mlrun.errors import err_to_str

from ..config import config
from ..execution import MLClientCtx
from ..model import RunObject
from ..render import ipython_display
from ..utils import logger, normalize_name, update_in
from .base import FunctionStatus
from .kubejob import KubejobRuntime
from .local import exec_from_params, load_module
from .pod import KubeResourceSpec, kube_resource_spec_to_pod_spec
from .utils import RunError, get_func_selector, get_resource_labels, log_std


def get_dask_resource():
    return {
        "scope": "function",
        "start": deploy_function,
        "status": get_obj_status,
    }


class DaskSpec(KubeResourceSpec):
    _dict_fields = KubeResourceSpec._dict_fields + [
        "extra_pip",
        "remote",
        "service_type",
        "nthreads",
        "kfp_image",
        "node_port",
        "min_replicas",
        "max_replicas",
        "scheduler_timeout",
        "scheduler_resources",
        "worker_resources",
    ]

    def __init__(
        self,
        command=None,
        args=None,
        image=None,
        mode=None,
        volumes=None,
        volume_mounts=None,
        env=None,
        resources=None,
        build=None,
        default_handler=None,
        entry_points=None,
        description=None,
        replicas=None,
        image_pull_policy=None,
        service_account=None,
        image_pull_secret=None,
        extra_pip=None,
        remote=None,
        service_type=None,
        nthreads=None,
        kfp_image=None,
        node_port=None,
        min_replicas=None,
        max_replicas=None,
        scheduler_timeout=None,
        node_name=None,
        node_selector=None,
        affinity=None,
        scheduler_resources=None,
        worker_resources=None,
        priority_class_name=None,
        disable_auto_mount=False,
        pythonpath=None,
        workdir=None,
        tolerations=None,
        preemption_mode=None,
        security_context=None,
        clone_target_dir=None,
    ):

        super().__init__(
            command=command,
            args=args,
            image=image,
            mode=mode,
            volumes=volumes,
            volume_mounts=volume_mounts,
            env=env,
            resources=resources,
            replicas=replicas,
            image_pull_policy=image_pull_policy,
            service_account=service_account,
            build=build,
            default_handler=default_handler,
            entry_points=entry_points,
            description=description,
            image_pull_secret=image_pull_secret,
            node_name=node_name,
            node_selector=node_selector,
            affinity=affinity,
            priority_class_name=priority_class_name,
            disable_auto_mount=disable_auto_mount,
            pythonpath=pythonpath,
            workdir=workdir,
            tolerations=tolerations,
            preemption_mode=preemption_mode,
            security_context=security_context,
            clone_target_dir=clone_target_dir,
        )
        self.args = args

        self.extra_pip = extra_pip
        self.remote = True if remote is None else remote  # make remote the default

        self.service_type = service_type
        self.kfp_image = kfp_image
        self.node_port = node_port
        self.min_replicas = min_replicas or 0
        self.max_replicas = max_replicas or 16
        # supported format according to https://github.com/dask/dask/blob/master/dask/utils.py#L1402
        self.scheduler_timeout = scheduler_timeout or "60 minutes"
        self.nthreads = nthreads or 1
        self._scheduler_resources = self.enrich_resources_with_default_pod_resources(
            "scheduler_resources", scheduler_resources
        )
        self._worker_resources = self.enrich_resources_with_default_pod_resources(
            "worker_resources", worker_resources
        )

    @property
    def scheduler_resources(self) -> dict:
        return self._scheduler_resources

    @scheduler_resources.setter
    def scheduler_resources(self, resources):
        self._scheduler_resources = self.enrich_resources_with_default_pod_resources(
            "scheduler_resources", resources
        )

    @property
    def worker_resources(self) -> dict:
        return self._worker_resources

    @worker_resources.setter
    def worker_resources(self, resources):
        self._worker_resources = self.enrich_resources_with_default_pod_resources(
            "worker_resources", resources
        )


class DaskStatus(FunctionStatus):
    def __init__(
        self,
        state=None,
        build_pod=None,
        scheduler_address=None,
        cluster_name=None,
        node_ports=None,
    ):
        super().__init__(state, build_pod)

        self.scheduler_address = scheduler_address
        self.cluster_name = cluster_name
        self.node_ports = node_ports


[docs]class DaskCluster(KubejobRuntime): kind = "dask" _is_nested = False _is_remote = False def __init__(self, spec=None, metadata=None): super().__init__(spec, metadata) self._cluster = None self.use_remote = not mlrun.k8s_utils.is_running_inside_kubernetes_cluster() self.spec.build.base_image = self.spec.build.base_image or "daskdev/dask:latest" @property def spec(self) -> DaskSpec: return self._spec @spec.setter def spec(self, spec): self._spec = self._verify_dict(spec, "spec", DaskSpec) @property def status(self) -> DaskStatus: return self._status @status.setter def status(self, status): self._status = self._verify_dict(status, "status", DaskStatus)
[docs] def is_deployed(self): if not self.spec.remote: return True return super().is_deployed()
@property def initialized(self): return bool(self._cluster) def _load_db_status(self): meta = self.metadata if self._is_remote_api(): db = self._get_db() db_func = None try: db_func = db.get_function(meta.name, meta.project, meta.tag) except Exception: pass if db_func and "status" in db_func: self.status = db_func["status"] if self.kfp: logger.info(f"dask status: {db_func['status']}") return "scheduler_address" in db_func["status"] return False def _start(self, watch=True): if self._is_remote_api(): self.try_auto_mount_based_on_config() self._fill_credentials() db = self._get_db() if not self.is_deployed(): raise RunError( "function image is not built/ready, use .deploy()" " method first, or set base dask image (daskdev/dask:latest)" ) self.save(versioned=False) background_task = db.remote_start(self._function_uri()) if watch: now = datetime.datetime.utcnow() timeout = now + datetime.timedelta(minutes=10) while now < timeout: background_task = db.get_project_background_task( background_task.metadata.project, background_task.metadata.name ) if ( background_task.status.state in mlrun.common.schemas.BackgroundTaskState.terminal_states() ): if ( background_task.status.state == mlrun.common.schemas.BackgroundTaskState.failed ): raise mlrun.errors.MLRunRuntimeError( "Failed bringing up dask cluster" ) else: function = db.get_function( self.metadata.name, self.metadata.project, self.metadata.tag, ) if function and function.get("status"): self.status = function.get("status") return time.sleep(5) now = datetime.datetime.utcnow() else: self._cluster = deploy_function(self) self.save(versioned=False)
[docs] def close(self, running=True): from dask.distributed import default_client try: client = default_client() # shutdown the cluster first, then close the client client.shutdown() client.close() except ValueError: pass
[docs] def get_status(self): meta = self.metadata selector = get_func_selector(meta.project, meta.name, meta.tag) if self._is_remote_api(): db = self._get_db() return db.remote_status(meta.project, meta.name, self.kind, selector) status = get_obj_status(selector) print(status) return status
[docs] def cluster(self): return self._cluster
def _remote_addresses(self): addr = self.status.scheduler_address dash = "" if config.remote_host: if self.spec.service_type == "NodePort" and self.use_remote: addr = f"{config.remote_host}:{self.status.node_ports.get('scheduler')}" if self.spec.service_type == "NodePort": dash = f"{config.remote_host}:{self.status.node_ports.get('dashboard')}" else: logger.info("to get a dashboard link, use NodePort service_type") return addr, dash @property def client(self): from dask.distributed import Client, default_client if self.spec.remote and not self.status.scheduler_address: if not self._load_db_status(): self._start() if self.status.scheduler_address: addr, dash = self._remote_addresses() logger.info(f"trying dask client at: {addr}") try: client = Client(addr) except OSError as exc: logger.warning( f"remote scheduler at {addr} not ready, will try to restart {err_to_str(exc)}" ) status = self.get_status() if status != "running": self._start() addr, dash = self._remote_addresses() client = Client(addr) logger.info( f"using remote dask scheduler ({self.status.cluster_name}) at: {addr}" ) if dash: ipython_display( f'<a href="http://{dash}/status" target="_blank" >dashboard link: {dash}</a>', alt_text=f"remote dashboard: {dash}", ) return client try: return default_client() except ValueError: return Client()
[docs] def deploy( self, watch=True, with_mlrun=None, skip_deployed=False, is_kfp=False, mlrun_version_specifier=None, builder_env: dict = None, show_on_failure: bool = False, ): """deploy function, build container with dependencies :param watch: wait for the deploy to complete (and print build logs) :param with_mlrun: add the current mlrun package to the container build :param skip_deployed: skip the build if we already have an image for the function :param is_kfp: deploy as part of a kfp pipeline :param mlrun_version_specifier: which mlrun package version to include (if not current) :param builder_env: Kaniko builder pod env vars dict (for config/credentials) e.g. builder_env={"GIT_TOKEN": token} :param show_on_failure: show logs only in case of build failure :return True if the function is ready (deployed) """ return super().deploy( watch, with_mlrun, skip_deployed, is_kfp=is_kfp, mlrun_version_specifier=mlrun_version_specifier, builder_env=builder_env, show_on_failure=show_on_failure, )
# TODO: Remove in 1.5.0
[docs] @deprecated( version="1.3.0", reason="'Dask gpus' will be removed in 1.5.0, use 'with_scheduler_limits' / 'with_worker_limits' instead", category=FutureWarning, ) def gpus(self, gpus, gpu_type="nvidia.com/gpu"): update_in(self.spec.scheduler_resources, ["limits", gpu_type], gpus) update_in(self.spec.worker_resources, ["limits", gpu_type], gpus)
[docs] def with_limits( self, mem=None, cpu=None, gpus=None, gpu_type="nvidia.com/gpu", patch: bool = False, ): raise NotImplementedError( "Use with_scheduler_limits/with_worker_limits to set resource limits", )
[docs] def with_scheduler_limits( self, mem: str = None, cpu: str = None, gpus: int = None, gpu_type: str = "nvidia.com/gpu", patch: bool = False, ): """ set scheduler pod resources limits by default it overrides the whole limits section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_limits( "scheduler_resources", mem, cpu, gpus, gpu_type, patch=patch )
[docs] def with_worker_limits( self, mem: str = None, cpu: str = None, gpus: int = None, gpu_type: str = "nvidia.com/gpu", patch: bool = False, ): """ set worker pod resources limits by default it overrides the whole limits section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_limits( "worker_resources", mem, cpu, gpus, gpu_type, patch=patch )
[docs] def with_requests(self, mem=None, cpu=None, patch: bool = False): raise NotImplementedError( "Use with_scheduler_requests/with_worker_requests to set resource requests", )
[docs] def with_scheduler_requests( self, mem: str = None, cpu: str = None, patch: bool = False ): """ set scheduler pod resources requests by default it overrides the whole requests section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_requests("scheduler_resources", mem, cpu, patch=patch)
[docs] def with_worker_requests( self, mem: str = None, cpu: str = None, patch: bool = False ): """ set worker pod resources requests by default it overrides the whole requests section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_requests("worker_resources", mem, cpu, patch=patch)
def _run(self, runobj: RunObject, execution): handler = runobj.spec.handler self._force_handler(handler) extra_env = self._generate_runtime_env(runobj) environ.update(extra_env) context = MLClientCtx.from_dict( runobj.to_dict(), rundb=self.spec.rundb, autocommit=False, host=socket.gethostname(), ) if not inspect.isfunction(handler): if not self.spec.command: raise ValueError( "specified handler (string) without command " "(py file path), specify command or use handler pointer" ) handler = load_module(self.spec.command, handler, context=context) client = self.client setattr(context, "dask_client", client) sout, serr = exec_from_params(handler, runobj, context) log_std(self._db_conn, runobj, sout, serr, skip=self.is_child, show=False) return context.to_dict()
def deploy_function( function: DaskCluster, secrets=None, client_version: str = None, client_python_version: str = None, ): _validate_dask_related_libraries_installed() scheduler_pod, worker_pod, function, namespace = enrich_dask_cluster( function, secrets, client_version, client_python_version ) return initialize_dask_cluster(scheduler_pod, worker_pod, function, namespace) def initialize_dask_cluster(scheduler_pod, worker_pod, function, namespace): import dask import dask_kubernetes spec, meta = function.spec, function.metadata svc_temp = dask.config.get("kubernetes.scheduler-service-template") if spec.service_type or spec.node_port: if spec.node_port: spec.service_type = "NodePort" svc_temp["spec"]["ports"][1]["nodePort"] = spec.node_port update_in(svc_temp, "spec.type", spec.service_type) norm_name = normalize_name(meta.name) dask.config.set( { "kubernetes.scheduler-service-template": svc_temp, "kubernetes.name": "mlrun-" + norm_name + "-{uuid}", } ) cluster = dask_kubernetes.KubeCluster( worker_pod, scheduler_pod_template=scheduler_pod, deploy_mode="remote", namespace=namespace, idle_timeout=spec.scheduler_timeout, ) logger.info(f"cluster {cluster.name} started at {cluster.scheduler_address}") function.status.scheduler_address = cluster.scheduler_address function.status.cluster_name = cluster.name if spec.service_type == "NodePort": ports = cluster.scheduler.service.spec.ports function.status.node_ports = { "scheduler": ports[0].node_port, "dashboard": ports[1].node_port, } if spec.replicas: cluster.scale(spec.replicas) else: cluster.adapt(minimum=spec.min_replicas, maximum=spec.max_replicas) return cluster def enrich_dask_cluster( function, secrets, client_version: str = None, client_python_version: str = None ): from dask.distributed import Client, default_client # noqa: F401 from dask_kubernetes import KubeCluster, make_pod_spec # noqa: F401 from kubernetes import client # Is it possible that the function will not have a project at this point? if function.metadata.project: function._add_secrets_to_spec_before_running(project=function.metadata.project) spec = function.spec meta = function.metadata spec.remote = True image = ( function.full_image_path( client_version=client_version, client_python_version=client_python_version ) # TODO: we might never enter here, since running a function requires defining an image or "daskdev/dask:latest" ) env = spec.env namespace = meta.namespace or config.namespace if spec.extra_pip: env.append(spec.extra_pip) pod_labels = get_resource_labels(function, scrape_metrics=config.scrape_metrics) # TODO: 'dask-worker' is deprecated, new dask CLI was introduced in 2022.10.0. # Upgrade when we drop python 3.7 support and use 'dask worker' instead worker_args = ["dask-worker", "--nthreads", str(spec.nthreads)] memory_limit = spec.worker_resources.get("limits", {}).get("memory") if memory_limit: worker_args.extend(["--memory-limit", str(memory_limit)]) if spec.args: worker_args.extend(spec.args) # TODO: 'dask-scheduler' is deprecated, new dask CLI was introduced in 2022.10.0. # Upgrade when we drop python 3.7 support and use 'dask scheduler' instead scheduler_args = ["dask-scheduler"] container_kwargs = { "name": "base", "image": image, "env": env, "image_pull_policy": spec.image_pull_policy, "volume_mounts": spec.volume_mounts, } scheduler_container = client.V1Container( resources=spec.scheduler_resources, args=scheduler_args, **container_kwargs ) worker_container = client.V1Container( resources=spec.worker_resources, args=worker_args, **container_kwargs ) scheduler_pod_spec = kube_resource_spec_to_pod_spec(spec, scheduler_container) worker_pod_spec = kube_resource_spec_to_pod_spec(spec, worker_container) for pod_spec in [scheduler_pod_spec, worker_pod_spec]: if spec.image_pull_secret: pod_spec.image_pull_secrets = [ client.V1LocalObjectReference(name=spec.image_pull_secret) ] scheduler_pod = client.V1Pod( metadata=client.V1ObjectMeta(namespace=namespace, labels=pod_labels), # annotations=meta.annotation), spec=scheduler_pod_spec, ) worker_pod = client.V1Pod( metadata=client.V1ObjectMeta(namespace=namespace, labels=pod_labels), # annotations=meta.annotation), spec=worker_pod_spec, ) return scheduler_pod, worker_pod, function, namespace def _validate_dask_related_libraries_installed(): try: import dask # noqa: F401 from dask.distributed import Client, default_client # noqa: F401 from dask_kubernetes import KubeCluster, make_pod_spec # noqa: F401 from kubernetes import client # noqa: F401 except ImportError as exc: print( "missing dask or dask_kubernetes, please run " '"pip install dask distributed dask_kubernetes", %s', exc, ) raise exc def get_obj_status(selector=None, namespace=None): if selector is None: selector = [] import mlrun.api.utils.singletons.k8s k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper() namespace = namespace or config.namespace selector = ",".join(["dask.org/component=scheduler"] + selector) pods = k8s.list_pods(namespace, selector=selector) status = "" for pod in pods: status = pod.status.phase.lower() if status == "running": cluster = pod.metadata.labels.get("dask.org/cluster-name") logger.info( f"found running dask function {pod.metadata.name}, cluster={cluster}" ) return status logger.info( f"found dask function {pod.metadata.name} in non ready state ({status})" ) return status