Source code for mlrun.serving.states

# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = [
    "TaskStep",
    "RouterStep",
    "RootFlowStep",
    "ErrorStep",
    "MonitoringApplicationStep",
]

import inspect
import os
import pathlib
import traceback
from abc import ABC
from copy import copy, deepcopy
from inspect import getfullargspec, signature
from typing import Any, Optional, Union, cast

import storey.utils
from storey import ParallelExecutionMechanisms

import mlrun
import mlrun.artifacts
import mlrun.common.schemas as schemas
from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
from mlrun.artifacts.model import ModelArtifact
from mlrun.datastore.datastore_profile import (
    DatastoreProfileKafkaSource,
    DatastoreProfileKafkaTarget,
    DatastoreProfileV3io,
    datastore_profile_read,
)
from mlrun.datastore.model_provider.model_provider import ModelProvider
from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
from mlrun.utils import get_data_from_path, logger, split_path

from ..config import config
from ..datastore import get_stream_pusher
from ..datastore.utils import (
    get_kafka_brokers_from_dict,
    parse_kafka_url,
)
from ..errors import MLRunInvalidArgumentError, ModelRunnerError, err_to_str
from ..model import ModelObj, ObjectDict
from ..platforms.iguazio import parse_path
from ..utils import get_class, get_function, is_explicit_ack_supported
from .utils import StepToDict, _extract_input_data, _update_result_body

callable_prefix = "_"
path_splitter = "/"
previous_step = "$prev"
queue_class_names = [">>", "$queue"]

MAX_MODELS_PER_ROUTER = 5000


class GraphError(Exception):
    """error in graph topology or configuration"""

    pass


class StepKinds:
    router = "router"
    task = "task"
    flow = "flow"
    queue = "queue"
    choice = "choice"
    root = "root"
    error_step = "error_step"
    monitoring_application = "monitoring_application"
    model_runner = "model_runner"


_task_step_fields = [
    "kind",
    "class_name",
    "class_args",
    "handler",
    "skip_context",
    "after",
    "function",
    "comment",
    "shape",
    "full_event",
    "on_error",
    "responder",
    "input_path",
    "result_path",
    "model_endpoint_creation_strategy",
    "endpoint_type",
]

_default_fields_to_strip_from_step = [
    "model_endpoint_creation_strategy",
    "endpoint_type",
]


def new_remote_endpoint(
    url: str,
    creation_strategy: schemas.ModelEndpointCreationStrategy,
    endpoint_type: schemas.EndpointType,
    **class_args,
):
    class_args = deepcopy(class_args)
    class_args["url"] = url
    return TaskStep(
        "$remote",
        class_args=class_args,
        model_endpoint_creation_strategy=creation_strategy,
        endpoint_type=endpoint_type,
    )


[docs] class BaseStep(ModelObj): kind = "BaseStep" default_shape = "ellipse" _dict_fields = ["kind", "comment", "after", "on_error"] _default_fields_to_strip = _default_fields_to_strip_from_step def __init__( self, name: Optional[str] = None, after: Optional[list] = None, shape: Optional[str] = None, ): self.name = name self._parent = None self.comment = None self.context = None self.after = after or [] self._next = None self.shape = shape self.on_error = None self._on_error_handler = None self.model_endpoint_creation_strategy = ( schemas.ModelEndpointCreationStrategy.SKIP ) def get_shape(self): """graphviz shape""" return self.shape or self.default_shape def set_parent(self, parent): """set/link the step parent (flow/router)""" self._parent = parent @property def next(self): return self._next @property def parent(self): """step parent (flow/router)""" return self._parent def set_next(self, key: str): """set/insert the key as next after this step, optionally remove other keys""" if not self.next: self._next = [key] elif key not in self.next: self._next.append(key) return self def after_step(self, *after, append=True): """specify the previous step names""" # add new steps to the after list if not append: self.after = [] for name in after: # if its a step/task class (vs a str) extract its name name = name if isinstance(name, str) else name.name if name not in self.after: self.after.append(name) return self
[docs] def error_handler( self, name: Optional[str] = None, class_name=None, handler=None, before=None, function=None, full_event: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, **class_args, ): """set error handler on a step or the entire graph (to be executed on failure/raise) When setting the error_handler on the graph object, the graph completes after the error handler execution. example: in the below example, an 'error_catcher' step is set as the error_handler of the 'raise' step: in case of error/raise in 'raise' step, the handle_error will be run. after that, the 'echo' step will be run. graph = function.set_topology('flow', engine='async') graph.to(name='raise', handler='raising_step')\ .error_handler(name='error_catcher', handler='handle_error', full_event=True, before='echo') graph.add_step(name="echo", handler='echo', after="raise").respond() :param name: unique name (and path) for the error handler step, default is class name :param class_name: class name or step object to build the step from the error handler step is derived from task step (ie no router/queue functionally) :param handler: class/function handler to invoke on run/event :param before: string or list of next step(s) names that will run after this step. the `before` param must not specify upstream steps as it will cause a loop. if `before` is not specified, the graph will complete after the error handler execution. :param function: function this step should run in :param full_event: this step accepts the full event (not just the body) :param input_path: selects the key/path in the event to use as input to the step this requires that the event body will behave like a dict, for example: event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will receive 7 as input :param result_path: selects the key/path in the event to write the results to this requires that the event body will behave like a dict, for example: event: {"x": 5} , result_path="y" means the output of the step will be written to event["y"] resulting in {"x": 5, "y": <result>} :param class_args: class init arguments """ if not (class_name or handler): raise MLRunInvalidArgumentError("class_name or handler must be provided") if isinstance(self, RootFlowStep) and before: raise MLRunInvalidArgumentError( "`before` arg can't be specified for graph error handler" ) name = get_name(name, class_name) step = ErrorStep( class_name, class_args, handler, name=name, function=function, full_event=full_event, input_path=input_path, result_path=result_path, ) self.on_error = name before = [before] if isinstance(before, str) else before step.before = before or [] step.base_step = self.name if hasattr(self, "_parent") and self._parent: # when self is a step step = self._parent._steps.update(name, step) step.set_parent(self._parent) else: # when self is the graph step = self._steps.update(name, step) step.set_parent(self) return self
def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): """init the step class""" self.context = context def _is_local_function(self, context): return True def get_children(self): """get child steps (for router/flow)""" return [] def __iter__(self): yield from [] @property def fullname(self): """full path/name (include parents)""" name = self.name or "" if self._parent and self._parent.fullname: name = path_splitter.join([self._parent.fullname, name]) return name.replace(":", "_") # replace for graphviz escaping def _post_init(self, mode="sync"): pass def _set_error_handler(self): """init/link the error handler for this step""" if self.on_error: error_step = self.context.root.path_to_step(self.on_error) self._on_error_handler = error_step.run def _log_error(self, event, err, **kwargs): """on failure log (for sync mode)""" error_message = err_to_str(err) self.context.logger.error( f"step {self.name} got error {error_message} when processing an event:\n {event.body}" ) error_trace = traceback.format_exc() self.context.logger.error(error_trace) self.context.push_error( event, f"{error_message}\n{error_trace}", source=self.fullname, **kwargs ) def _call_error_handler(self, event, err, **kwargs): """call the error handler if exist""" if not event.error: event.error = {} event.error[self.name] = err_to_str(err) event.origin_state = self.fullname return self._on_error_handler(event) def path_to_step(self, path: str): """return step object from step relative/fullname""" path = path or "" tree = path.split(path_splitter) next_level = self for step in tree: if step not in next_level: raise GraphError( f"step {step} doesnt exist in the graph under {next_level.fullname}" ) next_level = next_level[step] return next_level
[docs] def to( self, class_name: Union[str, StepToDict] = None, name: Optional[str] = None, handler: Optional[str] = None, graph_shape: Optional[str] = None, function: Optional[str] = None, full_event: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = None, **class_args, ): """add a step right after this step and return the new step example: a 4-step pipeline ending with a stream: graph.to('URLDownloader')\ .to('ToParagraphs')\ .to(name='to_json', handler='json.dumps')\ .to('>>', 'to_v3io', path=stream_path)\ :param class_name: class name or step object to build the step from for router steps the class name should start with '*' for queue/stream step the class should be '>>' or '$queue' :param name: unique name (and path) for the child step, default is class name :param handler: class/function handler to invoke on run/event :param graph_shape: graphviz shape name :param function: function this step should run in :param full_event: this step accepts the full event (not just body) :param input_path: selects the key/path in the event to use as input to the step this requires that the event body will behave like a dict, example: event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will receive 7 as input :param result_path: selects the key/path in the event to write the results to this require that the event body will behave like a dict, example: event: {"x": 5} , result_path="y" means the output of the step will be written to event["y"] resulting in {"x": 5, "y": <result>} :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint: * **overwrite**: 1. If model endpoints with the same name exist, delete the `latest` one. 2. Create a new model endpoint entry and set it as `latest`. * **inplace** (default): 1. If model endpoints with the same name exist, update the `latest` entry. 2. Otherwise, create a new entry. * **archive**: 1. If model endpoints with the same name exist, preserve them. 2. Create a new model endpoint with the same name and set it to `latest`. :param class_args: class init arguments """ if hasattr(self, "steps"): parent = self elif self._parent: parent = self._parent else: raise GraphError( f"step {self.name} parent is not set or it's not part of a graph" ) if not name and isinstance(class_name, BaseStep): name = class_name.name name, step = params_to_step( class_name, name, handler, graph_shape=graph_shape, function=function, full_event=full_event, input_path=input_path, result_path=result_path, class_args=class_args, model_endpoint_creation_strategy=model_endpoint_creation_strategy, ) # Make sure model endpoint was not introduce in ModelRunnerStep self.check_model_endpoint_existence(step, model_endpoint_creation_strategy) self.verify_model_runner_step(step) step = parent._steps.update(name, step) step.set_parent(parent) if not hasattr(self, "steps"): # check that its not the root, todo: in future may gave nested flows step.after_step(self.name) parent._last_added = step return step
[docs] def set_flow( self, steps: list[Union[str, StepToDict, dict[str, Any]]], force: bool = False, ): """ Set list of steps as downstream from this step, in the order specified. This will overwrite any existing downstream steps. :param steps: list of steps to follow this one :param force: whether to overwrite existing downstream steps. If False, this method will fail if any downstream steps have already been defined. Defaults to False. :return: the last step added to the flow example:: The below code sets the downstream nodes of step1 by using a list of steps (provided to `set_flow()`) and a single step (provided to `to()`), resulting in the graph (step1 -> step2 -> step3 -> step4). Notice that using `force=True` is required in case step1 already had downstream nodes (e.g. if the existing graph is step1 -> step2_old) and that following the execution of this code the existing downstream steps are removed. If the intention is to split the graph (and not to overwrite), please use `to()`. step1.set_flow( [ dict(name="step2", handler="step2_handler"), dict(name="step3", class_name="Step3Class"), ], force=True, ).to(dict(name="step4", class_name="Step4Class")) """ raise NotImplementedError("set_flow() can only be called on a FlowStep")
def supports_termination(self): return False def check_model_endpoint_existence(self, step, model_endpoint_creation_strategy): """ Verify that model endpoint name is not duplicate, in flow graph. :param step: ModelRunnerStep to verify :param model_endpoint_creation_strategy: model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint: """ if ( isinstance(step, TaskStep) and not isinstance(step, ModelRunnerStep) and model_endpoint_creation_strategy != schemas.ModelEndpointCreationStrategy.SKIP ): root = self._extract_root_step() if not isinstance(root, RootFlowStep): return models = [] if isinstance(step, RouterStep): for route in step.routes.values(): if route.name in root.model_endpoints_names: raise GraphError( f"The graph already contains the model endpoints named - {route.name}." ) models.append(route.name) else: if step.name in root.model_endpoints_names: raise GraphError( f"The graph already contains the model endpoints named - {step.name}." ) models.append(step.name) root.update_model_endpoints_routes_names(models) return def _extract_root_step(self): root = self while root.parent is not None: root = root.parent return root def verify_model_runner_step( self, step: "ModelRunnerStep", step_model_endpoints_names: Optional[list[str]] = None, verify_shared_models: bool = True, ): """ Verify ModelRunnerStep, can be part of Flow graph and models can not repeat in graph. :param step: ModelRunnerStep to verify :param step_model_endpoints_names: List of model endpoints names that are in the step. if provided will ignore step models and verify only the models on list. :param verify_shared_models: If True, verify that shared models are defined in the graph. """ if not isinstance(step, ModelRunnerStep): return root = self._extract_root_step() if not isinstance(root, RootFlowStep): raise GraphError( "ModelRunnerStep can be added to 'Flow' topology graph only" ) step_model_endpoints_names = step_model_endpoints_names or list( step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}).keys() ) # Get all model_endpoints names that are in both lists common_endpoints_names = list( set(root.model_endpoints_names) & set(step_model_endpoints_names) ) or list( set(root.model_endpoints_routes_names) & set(step_model_endpoints_names) ) if common_endpoints_names: raise GraphError( f"The graph already contains the model endpoints named - {common_endpoints_names}." ) if verify_shared_models: # Check if shared models are defined in the graph self._verify_shared_models(root, step, step_model_endpoints_names) # Update model endpoints names in the root step root.update_model_endpoints_names(step_model_endpoints_names) @staticmethod def _verify_shared_models( root: "RootFlowStep", step: "ModelRunnerStep", step_model_endpoints_names: list[str], ) -> None: proxy_endpoints = [ name for name in step_model_endpoints_names if step.class_args.get( schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM, {} ).get(name) == ParallelExecutionMechanisms.shared_executor ] shared_models = [] for name in proxy_endpoints: shared_runnable_name = ( step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}) .get(name, ["", {}])[schemas.ModelsData.MODEL_PARAMETERS.value] .get("shared_runnable_name") ) model_artifact_uri = ( step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}) .get(name, ["", {}])[schemas.ModelsData.MODEL_PARAMETERS.value] .get("artifact_uri") ) prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri) # if the model artifact is a prompt, we need to get the model URI # to ensure that the shared runnable name is correct if prefix == mlrun.utils.StorePrefix.LLMPrompt: llm_artifact, _ = mlrun.store_manager.get_store_artifact( model_artifact_uri ) model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri( llm_artifact.spec.parent_uri ) actual_shared_name = root.get_shared_model_name_by_artifact_uri( model_artifact_uri ) if not shared_runnable_name: if not actual_shared_name: raise GraphError( f"Can't find shared model for {name} model endpoint" ) else: step.class_args[schemas.ModelRunnerStepData.MODELS][name][ schemas.ModelsData.MODEL_PARAMETERS.value ]["shared_runnable_name"] = actual_shared_name shared_models.append(actual_shared_name) elif actual_shared_name != shared_runnable_name: raise GraphError( f"Model endpoint {name} shared runnable name mismatch: " f"expected {actual_shared_name}, got {shared_runnable_name}" ) else: shared_models.append(actual_shared_name) undefined_shared_models = list( set(shared_models) - set(root.shared_models.keys()) ) if undefined_shared_models: raise GraphError( f"The following shared models are not defined in the graph: {undefined_shared_models}." )
[docs] class TaskStep(BaseStep): """task execution step, runs a class or handler""" kind = "task" _dict_fields = _task_step_fields _default_class = "" def __init__( self, class_name: Optional[Union[str, type]] = None, class_args: Optional[dict] = None, handler: Optional[str] = None, name: Optional[str] = None, after: Optional[list] = None, full_event: Optional[bool] = None, function: Optional[str] = None, responder: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = schemas.ModelEndpointCreationStrategy.SKIP, endpoint_type: Optional[schemas.EndpointType] = schemas.EndpointType.NODE_EP, ): super().__init__(name, after) self.class_name = class_name self.class_args = class_args or {} self.handler = handler self.function = function self._handler = None self._object = None self._async_object = None self.skip_context = None self.context = None self._class_object = None self.responder = responder self.full_event = full_event self.input_path = input_path self.result_path = result_path self.on_error = None self._inject_context = False self._call_with_event = False self.model_endpoint_creation_strategy = model_endpoint_creation_strategy self.endpoint_type = endpoint_type
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context self._async_object = None if not self._is_local_function(context): # skip init of non local functions return if self.handler and not self.class_name: # link to function if callable(self.handler): self._handler = self.handler self.handler = self.handler.__name__ else: self._handler = get_function(self.handler, namespace) args = signature(self._handler).parameters if args and "context" in list(args.keys()): self._inject_context = True self._set_error_handler() return self._class_object, self.class_name = self.get_step_class_object( namespace=namespace ) if not self._object or reset: # init the step class + args extracted_class_args = self.get_full_class_args( namespace=namespace, class_object=self._class_object, **extra_kwargs, ) try: self._object = self._class_object(**extracted_class_args) except TypeError as exc: raise TypeError( f"failed to init step {self.name}\n args={self.class_args}" ) from exc # determine the right class handler to use handler = self.handler if handler: if not hasattr(self._object, handler): raise GraphError( f"handler ({handler}) specified but doesnt exist in class {self.class_name}" ) else: if hasattr(self._object, "do_event"): handler = "do_event" self._call_with_event = True elif hasattr(self._object, "do"): handler = "do" if handler: self._handler = getattr(self._object, handler, None) self._set_error_handler() if mode != "skip": self._post_init(mode)
[docs] def get_full_class_args(self, namespace, class_object, **extra_kwargs): class_args = {} for key, arg in self.class_args.items(): if key.startswith(callable_prefix): class_args[key[1:]] = get_function(arg, namespace) else: class_args[key] = arg class_args.update(extra_kwargs) if not isinstance(self, MonitoringApplicationStep): # add common args (name, context, ..) only if target class can accept them argspec = getfullargspec(class_object) for key in ["name", "context", "input_path", "result_path", "full_event"]: if argspec.varkw or key in argspec.args: class_args[key] = getattr(self, key) if argspec.varkw or "graph_step" in argspec.args: class_args["graph_step"] = self return class_args
[docs] def get_step_class_object(self, namespace): class_name = self.class_name class_object = self._class_object if isinstance(class_name, type): class_object = class_name class_name = class_name.__name__ elif not class_object: if class_name == "$remote": from mlrun.serving.remote import RemoteStep class_object = RemoteStep else: class_object = get_class(class_name or self._default_class, namespace) return class_object, class_name
def _is_local_function(self, context): # detect if the class is local (and should be initialized) current_function = get_current_function(context) if current_function == "*": return True if not self.function and not current_function: return True if ( self.function and self.function == "*" ) or self.function == current_function: return True return False @property def async_object(self): """return the sync or async (storey) class instance""" return self._async_object or self._object
[docs] def clear_object(self): self._object = None
def _post_init(self, mode="sync"): if self._object and hasattr(self._object, "post_init"): self._object.post_init( mode, creation_strategy=self.model_endpoint_creation_strategy, endpoint_type=self.endpoint_type, )
[docs] def respond(self): """mark this step as the responder. step output will be returned as the flow result, no other step can follow """ self.responder = True return self
[docs] def run(self, event, *args, **kwargs): """run this step, in async flows the run is done through storey""" if not self._is_local_function(self.context): # todo invoke remote via REST call return event if self.context and self.context.verbose: self.context.logger.info(f"step {self.name} got event {event.body}") # inject context parameter if it is expected by the handler if self._inject_context: kwargs["context"] = self.context elif kwargs and "context" in kwargs: del kwargs["context"] try: if self.full_event or self._call_with_event: return self._handler(event, *args, **kwargs) if self._handler is None: raise MLRunInvalidArgumentError( f"step {self.name} does not have a handler" ) result = self._handler( _extract_input_data(self.input_path, event.body), *args, **kwargs ) event.body = _update_result_body(self.result_path, event.body, result) except Exception as exc: if self._on_error_handler: self._log_error(event, exc) result = self._call_error_handler(event, exc) event.body = _update_result_body(self.result_path, event.body, result) else: raise exc return event
[docs] def to_dict( self, fields: Optional[list] = None, exclude: Optional[list] = None, strip: bool = False, ) -> dict: self.endpoint_type = ( self.endpoint_type.value if isinstance(self.endpoint_type, schemas.EndpointType) else self.endpoint_type ) self.model_endpoint_creation_strategy = ( self.model_endpoint_creation_strategy.value if isinstance( self.model_endpoint_creation_strategy, schemas.ModelEndpointCreationStrategy, ) else self.model_endpoint_creation_strategy ) return super().to_dict(fields, exclude, strip)
[docs] class MonitoringApplicationStep(TaskStep): """monitoring application execution step, runs users class code""" kind = "monitoring_application" _default_class = "" def __init__( self, class_name: Optional[Union[str, type]] = None, class_args: Optional[dict] = None, handler: Optional[str] = None, name: Optional[str] = None, after: Optional[list] = None, full_event: Optional[bool] = None, function: Optional[str] = None, responder: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, ): super().__init__( class_name=class_name, class_args=class_args, handler=handler, name=name, after=after, full_event=full_event, function=function, responder=responder, input_path=input_path, result_path=result_path, )
[docs] class ErrorStep(TaskStep): """error execution step, runs a class or handler""" kind = "error_step" _dict_fields = _task_step_fields + ["before", "base_step"] _default_class = "" def __init__( self, class_name: Optional[Union[str, type]] = None, class_args: Optional[dict] = None, handler: Optional[str] = None, name: Optional[str] = None, after: Optional[list] = None, full_event: Optional[bool] = None, function: Optional[str] = None, responder: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, ): super().__init__( class_name=class_name, class_args=class_args, handler=handler, name=name, after=after, full_event=full_event, function=function, responder=responder, input_path=input_path, result_path=result_path, ) self.before = None self.base_step = None
[docs] class RouterStep(TaskStep): """router step, implement routing logic for running child routes""" kind = "router" default_shape = "doubleoctagon" _dict_fields = _task_step_fields + ["routes", "name"] _default_class = "mlrun.serving.ModelRouter" def __init__( self, class_name: Optional[Union[str, type]] = None, class_args: Optional[dict] = None, handler: Optional[str] = None, routes: Optional[list] = None, name: Optional[str] = None, function: Optional[str] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, ): super().__init__( class_name, class_args, handler, name=get_name(name, class_name or RouterStep.kind), function=function, input_path=input_path, result_path=result_path, ) self._routes: ObjectDict = None self.routes = routes self.endpoint_type = schemas.EndpointType.ROUTER if isinstance(class_name, type): class_name = class_name.__name__ self.model_endpoint_creation_strategy = ( schemas.ModelEndpointCreationStrategy.INPLACE if class_name and "VotingEnsemble" in class_name else schemas.ModelEndpointCreationStrategy.SKIP )
[docs] def get_children(self): """get child steps (routes)""" return self._routes.values()
@property def routes(self): """child routes/steps, traffic is routed to routes based on router logic""" return self._routes @routes.setter def routes(self, routes: dict): self._routes = ObjectDict.from_dict(classes_map, routes, "task")
[docs] def add_route( self, key, route=None, class_name=None, handler=None, function=None, creation_strategy: schemas.ModelEndpointCreationStrategy = schemas.ModelEndpointCreationStrategy.INPLACE, **class_args, ): """add child route step or class to the router, if key exists it will be updated :param key: unique name (and route path) for the child step :param route: child step object (Task, ..) :param class_name: class name to build the route step from (when route is not provided) :param class_args: class init arguments :param handler: class handler to invoke on run/event :param function: function this step should run in :param creation_strategy: Strategy for creating or updating the model endpoint: * **overwrite**: 1. If model endpoints with the same name exist, delete the `latest` one. 2. Create a new model endpoint entry and set it as `latest`. * **inplace** (default): 1. If model endpoints with the same name exist, update the `latest` entry. 2. Otherwise, create a new entry. * **archive**: 1. If model endpoints with the same name exist, preserve them. 2. Create a new model endpoint with the same name and set it to `latest`. """ if len(self.routes.keys()) >= MAX_MODELS_PER_ROUTER and key not in self.routes: raise mlrun.errors.MLRunModelLimitExceededError( f"Router cannot support more than {MAX_MODELS_PER_ROUTER} model endpoints. " f"To add a new route, edit an existing one by passing the same key." ) if key in self.routes: logger.info(f"Model {key} already exists, updating it.") if not route and not class_name and not handler: raise MLRunInvalidArgumentError("route or class_name must be specified") if not route: route = TaskStep( class_name, class_args, name=key, handler=handler, model_endpoint_creation_strategy=creation_strategy, endpoint_type=schemas.EndpointType.LEAF_EP if self.class_name and "serving.VotingEnsemble" in self.class_name else schemas.EndpointType.NODE_EP, ) route.function = function or route.function self.check_model_endpoint_existence(route, creation_strategy) route = self._routes.update(key, route) route.set_parent(self) return route
[docs] def clear_children(self, routes: list): """clear child steps (routes)""" if not routes: routes = self._routes.keys() for key in routes: del self._routes[key]
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): if not self.routes: raise mlrun.errors.MLRunRuntimeError( "You have to add models to the router step before initializing it" ) if not self._is_local_function(context): return self.class_args = self.class_args or {} super().init_object( context, namespace, "skip", reset=reset, routes=self._routes, **extra_kwargs ) for route in self._routes.values(): if self.function and not route.function: # if the router runs on a child function and the # model function is not specified use the router function route.function = self.function route.set_parent(self) route.init_object(context, namespace, mode, reset=reset) self._set_error_handler() self._post_init(mode)
def __getitem__(self, name): return self._routes[name] def __setitem__(self, name, route): self.add_route(name, route) def __delitem__(self, key): del self._routes[key] def __iter__(self): yield from self._routes.keys()
[docs] def plot(self, filename=None, format=None, source=None, **kw): """plot/save graph using graphviz :param filename: target filepath for the image (None for the notebook) :param format: The output format used for rendering (``'pdf'``, ``'png'``, etc.) :param source: source step to add to the graph :param kw: kwargs passed to graphviz, e.g. rankdir="LR" (see: https://graphviz.org/doc/info/attrs.html) :return: graphviz graph object """ return _generate_graphviz( self, _add_graphviz_router, filename, format, source=source, **kw )
[docs] class Model(storey.ParallelExecutionRunnable, ModelObj): _dict_fields = [ "name", "raise_exception", "artifact_uri", "shared_runnable_name", "shared_proxy_mapping", ] kind = "model" def __init__( self, name: str, raise_exception: bool = True, artifact_uri: Optional[str] = None, shared_proxy_mapping: Optional[dict] = None, **kwargs, ): super().__init__(name=name, raise_exception=raise_exception, **kwargs) if artifact_uri is not None and not isinstance(artifact_uri, str): raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string") self.artifact_uri = artifact_uri self.shared_proxy_mapping: dict[ str : Union[str, ModelArtifact, LLMPromptArtifact] ] = shared_proxy_mapping self.invocation_artifact: Optional[LLMPromptArtifact] = None self.model_artifact: Optional[ModelArtifact] = None self.model_provider: Optional[ModelProvider] = None def __init_subclass__(cls): super().__init_subclass__() cls._dict_fields = list( set(cls._dict_fields) | set(inspect.signature(cls.__init__).parameters.keys()) ) cls._dict_fields.remove("self")
[docs] def load(self) -> None: """Override to load model if needed.""" self._load_artifacts() if self.model_artifact: self.model_provider = mlrun.get_model_provider( url=self.model_artifact.model_url, default_invoke_kwargs=self.model_artifact.default_config, raise_missing_schema_exception=False, )
def _load_artifacts(self) -> None: artifact = self._get_artifact_object() if isinstance(artifact, LLMPromptArtifact): self.invocation_artifact = artifact self.model_artifact = self.invocation_artifact.model_artifact else: self.model_artifact = artifact def _get_artifact_object( self, proxy_uri: Optional[str] = None ) -> Union[ModelArtifact, LLMPromptArtifact, None]: uri = proxy_uri or self.artifact_uri if uri: if mlrun.datastore.is_store_uri(uri): artifact, _ = mlrun.store_manager.get_store_artifact(uri) return artifact else: raise ValueError( "Could not get artifact, 'artifact_uri' must be a valid artifact store URI" ) else: return None
[docs] def init(self): self.load()
[docs] def predict(self, body: Any, **kwargs) -> Any: """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead.""" return body
[docs] async def predict_async(self, body: Any, **kwargs) -> Any: """Override to implement prediction logic if the logic requires asyncio.""" return body
[docs] def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any: return self.predict(body)
[docs] async def run_async( self, body: Any, path: str, origin_name: Optional[str] = None ) -> Any: return await self.predict_async(body)
[docs] def get_local_model_path(self, suffix="") -> (str, dict): """ Get local model file(s) and extra data items by using artifact. If the model file is stored in remote cloud storage, this method downloads it to the local file system. :param suffix: Optional; model file suffix (used when the model path is a directory). :type suffix: str :return: A tuple containing: - str: Local model file path. - dict: Dictionary of extra data items. :rtype: tuple :example: def load(self): model_file, extra_data = self.get_local_model_path(suffix=".pkl") self.model = load(open(model_file, "rb")) categories = extra_data["categories"].as_df() """ artifact = self._get_artifact_object() if artifact: model_file, _, extra_dataitems = mlrun.artifacts.get_model( suffix=suffix, model_dir=artifact ) return model_file, extra_dataitems return None, None
[docs] class LLModel(Model): def __init__( self, name: str, input_path: Optional[Union[str, list[str]]], **kwargs ): super().__init__(name, **kwargs) self._input_path = split_path(input_path)
[docs] def predict( self, body: Any, messages: Optional[list[dict]] = None, model_configuration: Optional[dict] = None, **kwargs, ) -> Any: if isinstance( self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact ) and isinstance(self.model_provider, ModelProvider): body["result"] = self.model_provider.invoke( messages=messages, as_str=True, **(model_configuration or {}), ) return body
[docs] async def predict_async( self, body: Any, messages: Optional[list[dict]] = None, model_configuration: Optional[dict] = None, **kwargs, ) -> Any: if isinstance( self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact ) and isinstance(self.model_provider, ModelProvider): body["result"] = await self.model_provider.async_invoke( messages=messages, as_str=True, **(model_configuration or {}), ) return body
[docs] def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any: messages, model_configuration = self.enrich_prompt(body, origin_name) return self.predict( body, messages=messages, model_configuration=model_configuration )
[docs] async def run_async( self, body: Any, path: str, origin_name: Optional[str] = None ) -> Any: messages, model_configuration = self.enrich_prompt(body, origin_name) return await self.predict_async( body, messages=messages, model_configuration=model_configuration )
[docs] def enrich_prompt( self, body: dict, origin_name: str ) -> Union[tuple[list[dict], dict], tuple[None, None]]: if origin_name and self.shared_proxy_mapping: llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name) if isinstance(llm_prompt_artifact, str): llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact) self.shared_proxy_mapping[origin_name] = llm_prompt_artifact else: llm_prompt_artifact = ( self.invocation_artifact or self._get_artifact_object() ) if not ( llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact) ): logger.warning( "LLMModel must be provided with LLMPromptArtifact", llm_prompt_artifact=llm_prompt_artifact, ) return None, None prompt_legend = llm_prompt_artifact.spec.prompt_legend prompt_template = deepcopy(llm_prompt_artifact.read_prompt()) input_data = copy(get_data_from_path(self._input_path, body)) if isinstance(input_data, dict): kwargs = ( { place_holder: input_data.get(body_map["field"]) for place_holder, body_map in prompt_legend.items() } if prompt_legend else {} ) input_data.update(kwargs) default_place_holders = PlaceholderDefaultDict(lambda: None, input_data) for message in prompt_template: try: message["content"] = message["content"].format(**input_data) except KeyError as e: logger.warning( "Input data was missing a placeholder, placeholder stay unformatted", key_error=e, ) message["content"] = message["content"].format_map( default_place_holders ) else: logger.warning( f"Expected input data to be a dict, but received input data from type {type(input_data)} prompt " f"template stay unformatted", ) return prompt_template, llm_prompt_artifact.spec.model_configuration
[docs] class ModelSelector: """Used to select which models to run on each event."""
[docs] def select( self, event, available_models: list[Model] ) -> Union[list[str], list[Model]]: """ Given an event, returns a list of model names or a list of model objects to run on the event. If None is returned, all models will be run. :param event: The full event :param available_models: List of available models """ pass
[docs] class ModelRunner(storey.ParallelExecution): """ Runs multiple Models on each event. See ModelRunnerStep. :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each event. Optional. If not passed, all models will be run. """ def __init__( self, *args, context, model_selector: Optional[ModelSelector] = None, **kwargs ): super().__init__(*args, **kwargs) self.model_selector = model_selector or ModelSelector() self.context = context
[docs] def preprocess_event(self, event): if not hasattr(event, "_metadata"): event._metadata = {} event._metadata["model_runner_name"] = self.name event._metadata["inputs"] = deepcopy(event.body) return event
[docs] def select_runnables(self, event): models = cast(list[Model], self.runnables) return self.model_selector.select(event, models)
[docs] class MonitoredStep(ABC, TaskStep, StepToDict): kind = "monitored" _dict_fields = TaskStep._dict_fields + ["raise_exception"] def __init__(self, *args, name: str, raise_exception=True, **kwargs): super().__init__(*args, name=name, **kwargs) self.raise_exception = raise_exception self._monitoring_data = None
[docs] def _calculate_monitoring_data(self) -> dict[str, Any]: """ Child class must override `_calculate_monitoring_data()` method and provide meaningful data-structure to the pre-process step in the monitoring flow. Monitoring data structure should support the following schema: :: { "inputs": inputs features, "outputs": output schema expected, "input_path": the path where inputs are, "result_path": the path where results are, "creation_strategy": model endpoint creation strategy, "labels": model endpoint labels, "model_endpoint_uid": model endpoint uid (added in deployment), "model_class": the model class } """ raise NotImplementedError
@property def monitoring_data(self) -> dict[str, Any]: self._monitoring_data = self._calculate_monitoring_data() return self._monitoring_data
[docs] class ModelRunnerStep(MonitoredStep): """ Runs multiple Models on each event. example:: model_runner_step = ModelRunnerStep(name="my_model_runner") model_runner_step.add_model(..., model_class=MyModel(name="my_model")) graph.to(model_runner_step) :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each event. Optional. If not passed, all models will be run. :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised an error. If False, the error will appear in the output event. :raise ModelRunnerError - when a model raise an error the ModelRunnerStep will handle it, collect errors and outputs from added models, If raise_exception is True will raise ModelRunnerError Else will add the error msg as part of the event body mapped by model name if more than one model was added to the ModelRunnerStep """ kind = "model_runner" _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"] def __init__( self, *args, name: Optional[str] = None, model_selector: Optional[Union[str, ModelSelector]] = None, raise_exception: bool = True, **kwargs, ): super().__init__( *args, name=name, raise_exception=raise_exception, class_name="mlrun.serving.ModelRunner", class_args=dict(model_selector=model_selector), **kwargs, ) self.raise_exception = raise_exception self.shape = "folder" self._shared_proxy_mapping = {}
[docs] def add_shared_model_proxy( self, endpoint_name: str, model_artifact: Union[str, ModelArtifact, LLMPromptArtifact], shared_model_name: Optional[str] = None, labels: Optional[Union[list[str], dict[str, str]]] = None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = schemas.ModelEndpointCreationStrategy.INPLACE, inputs: Optional[list[str]] = None, outputs: Optional[list[str]] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, override: bool = False, ) -> None: """ Add a proxy model to the ModelRunnerStep, which is a proxy for a model that is already defined as shared model within the graph :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name :param model_artifact: model artifact or mlrun model artifact uri, according to the model artifact we will match the model endpoint to the correct shared model. :param shared_model_name: str, the name of the shared model that is already defined within the graph :param labels: model endpoint labels, should be list of str or mapping of str:str :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint: * **overwrite**: 1. If model endpoints with the same name exist, delete the `latest` one. 2. Create a new model endpoint entry and set it as `latest`. * **inplace** (default): 1. If model endpoints with the same name exist, update the `latest` entry. 2. Otherwise, create a new entry. * **archive**: 1. If model endpoints with the same name exist, preserve them. 2. Create a new model endpoint with the same name and set it to `latest`. :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs that been configured in the model artifact, please note that those inputs need to be equal in length and order to the inputs that model_class predict method expects :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs that been configured in the model artifact, please note that those outputs need to be equal to the model_class predict method outputs (length, and order) :param input_path: input path inside the user event, expect scopes to be defined by dot notation (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path. :param result_path: result path inside the user output event, expect scopes to be defined by dot notation (e.g "outputs.my_model_outputs") expects list or dictionary type object in path. :param override: bool allow override existing model on the current ModelRunnerStep. """ model_class, model_params = ( "mlrun.serving.Model", {"name": endpoint_name, "shared_runnable_name": shared_model_name}, ) if isinstance(model_artifact, str): model_artifact_uri = model_artifact elif isinstance(model_artifact, ModelArtifact): model_artifact_uri = model_artifact.uri elif isinstance(model_artifact, LLMPromptArtifact): model_artifact_uri = model_artifact.model_artifact.uri else: raise MLRunInvalidArgumentError( "model_artifact must be a string, ModelArtifact or LLMPromptArtifact" ) root = self._extract_root_step() if isinstance(root, RootFlowStep): shared_model_name = ( shared_model_name or root.get_shared_model_name_by_artifact_uri(model_artifact_uri) ) if not root.shared_models or ( root.shared_models and shared_model_name and shared_model_name not in root.shared_models.keys() ): raise GraphError( f"ModelRunnerStep can only add proxy models that were added to the root flow step, " f"model {shared_model_name} is not in the shared models." ) if shared_model_name not in self._shared_proxy_mapping: self._shared_proxy_mapping[shared_model_name] = { endpoint_name: model_artifact.uri if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact)) else model_artifact } else: self._shared_proxy_mapping[shared_model_name].update( { endpoint_name: model_artifact.uri if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact)) else model_artifact } ) self.add_model( endpoint_name=endpoint_name, model_class=model_class, execution_mechanism=ParallelExecutionMechanisms.shared_executor, model_artifact=model_artifact, labels=labels, model_endpoint_creation_strategy=model_endpoint_creation_strategy, override=override, inputs=inputs, outputs=outputs, input_path=input_path, result_path=result_path, **model_params, )
[docs] def add_model( self, endpoint_name: str, model_class: Union[str, Model], execution_mechanism: Union[str, ParallelExecutionMechanisms], model_artifact: Optional[Union[str, ModelArtifact, LLMPromptArtifact]] = None, labels: Optional[Union[list[str], dict[str, str]]] = None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = schemas.ModelEndpointCreationStrategy.INPLACE, inputs: Optional[list[str]] = None, outputs: Optional[list[str]] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, override: bool = False, **model_parameters, ) -> None: """ Add a Model to this ModelRunner. :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name :param model_class: Model class name :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of: * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter Lock (GIL). * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive tasks that also require significant Runnable-specific initialization (e.g. a large model). * "thread_pool" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise block the main event loop thread. * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event loop to continue running while waiting for a response. * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially useful when: - You want to share a heavy resource like a large model loaded onto a GPU. - You want to centralize task scheduling or coordination for multiple lightweight tasks. - You aim to minimize overhead from creating new executors or processes/threads per runnable. The runnable is expected to be pre-initialized and reused across events, enabling efficient use of memory and hardware accelerators. * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It means that the runnable will not actually be run in parallel to anything else. :param model_artifact: model artifact or mlrun model artifact uri :param labels: model endpoint labels, should be list of str or mapping of str:str :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint: * **overwrite**: 1. If model endpoints with the same name exist, delete the `latest` one. 2. Create a new model endpoint entry and set it as `latest`. * **inplace** (default): 1. If model endpoints with the same name exist, update the `latest` entry. 2. Otherwise, create a new entry. * **archive**: 1. If model endpoints with the same name exist, preserve them. 2. Create a new model endpoint with the same name and set it to `latest`. :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs that been configured in the model artifact, please note that those inputs need to be equal in length and order to the inputs that model_class predict method expects :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs that been configured in the model artifact, please note that those outputs need to be equal to the model_class predict method outputs (length, and order) :param input_path: when specified selects the key/path in the event to use as model monitoring inputs this require that the event body will behave like a dict, expects scopes to be defined by dot notation (e.g "data.d"). examples: input_path="data.b" event: {"data":{"a": 5, "b": 7}}, means monitored body will be 7. event: {"data":{"a": [5, 9], "b": [7, 8]}} means monitored body will be [7,8]. event: {"data":{"a": "extra_data", "b": {"f0": [1, 2]}}} means monitored body will be {"f0": [1, 2]}. if a ``list`` or ``list of lists`` is provided, it must follow the order and size defined by the input schema. :param result_path: when specified selects the key/path in the output event to use as model monitoring outputs this require that the output event body will behave like a dict, expects scopes to be defined by dot notation (e.g "data.d"). examples: result_path="out.b" event: {"out":{"a": 5, "b": 7}}, means monitored body will be 7. event: {"out":{"a": [5, 9], "b": [7, 8]}} means monitored body will be [7,8] event: {"out":{"a": "extra_data", "b": {"f0": [1, 2]}}} means monitored body will be {"f0": [1, 2]} if a ``list`` or ``list of lists`` is provided, it must follow the order and size defined by the output schema. :param override: bool allow override existing model on the current ModelRunnerStep. :param model_parameters: Parameters for model instantiation """ if isinstance(model_class, Model) and model_parameters: raise mlrun.errors.MLRunInvalidArgumentError( "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`." ) model_parameters = model_parameters or ( model_class.to_dict() if isinstance(model_class, Model) else {} ) if isinstance( model_artifact, str, ): try: model_artifact, _ = mlrun.store_manager.get_store_artifact( mlrun.utils.remove_tag_from_artifact_uri(model_artifact) ) except mlrun.errors.MLRunNotFoundError: raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.") outputs = outputs or self._get_model_output_schema(model_artifact) model_artifact = ( model_artifact.uri if isinstance(model_artifact, mlrun.artifacts.Artifact) else model_artifact ) model_artifact = ( mlrun.utils.remove_tag_from_artifact_uri(model_artifact) if model_artifact else None ) model_parameters["artifact_uri"] = model_parameters.get( "artifact_uri", model_artifact ) if model_parameters.get("name", endpoint_name) != endpoint_name or ( isinstance(model_class, Model) and model_class.name != endpoint_name ): raise mlrun.errors.MLRunInvalidArgumentError( "Inconsistent name for model added to ModelRunnerStep." ) models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {}) if endpoint_name in models and not override: raise mlrun.errors.MLRunInvalidArgumentError( f"Model with name {endpoint_name} already exists in this ModelRunnerStep." ) root = self._extract_root_step() if isinstance(root, RootFlowStep): self.verify_model_runner_step( self, [endpoint_name], verify_shared_models=False ) ParallelExecutionMechanisms.validate(execution_mechanism) self.class_args[schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM] = ( self.class_args.get( schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM, {}, ) ) self.class_args[schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM][ endpoint_name ] = execution_mechanism model_parameters["name"] = endpoint_name monitoring_data = self.class_args.get( schemas.ModelRunnerStepData.MONITORING_DATA, {} ) model_class = ( model_class if isinstance(model_class, str) else model_class.__class__.__name__ ) models[endpoint_name] = (model_class, model_parameters) monitoring_data[endpoint_name] = { schemas.MonitoringData.INPUTS: inputs, schemas.MonitoringData.OUTPUTS: outputs, schemas.MonitoringData.INPUT_PATH: input_path, schemas.MonitoringData.RESULT_PATH: result_path, schemas.MonitoringData.CREATION_STRATEGY: model_endpoint_creation_strategy, schemas.MonitoringData.LABELS: labels, schemas.MonitoringData.MODEL_PATH: model_artifact, schemas.MonitoringData.MODEL_CLASS: model_class, } self.class_args[schemas.ModelRunnerStepData.MODELS] = models self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
@staticmethod def _get_model_output_schema( model_artifact: Union[ModelArtifact, LLMPromptArtifact], ) -> Optional[list[str]]: if isinstance( model_artifact, ModelArtifact, ): return [feature.name for feature in model_artifact.spec.outputs] elif isinstance( model_artifact, LLMPromptArtifact, ): _model_artifact = model_artifact.model_artifact return [feature.name for feature in _model_artifact.spec.outputs] @staticmethod def _get_model_endpoint_output_schema( name: str, project: str, uid: str, ) -> list[str]: output_schema = None try: model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = ( mlrun.db.get_run_db().get_model_endpoint( name=name, project=project, endpoint_id=uid, tsdb_metrics=False, ) ) output_schema = model_endpoint.spec.label_names except ( mlrun.errors.MLRunNotFoundError, mlrun.errors.MLRunInvalidArgumentError, ) as ex: logger.warning( f"Model endpoint not found, using default output schema for model {name}", error=f"{type(ex).__name__}: {ex}", ) return output_schema def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]: monitoring_data = deepcopy( self.class_args.get( mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA ) ) if isinstance(monitoring_data, dict): for model in monitoring_data: monitoring_data[model][schemas.MonitoringData.INPUT_PATH] = split_path( monitoring_data[model][schemas.MonitoringData.INPUT_PATH] ) monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path( monitoring_data[model][schemas.MonitoringData.RESULT_PATH] ) return monitoring_data else: raise mlrun.errors.MLRunInvalidArgumentError( "Monitoring data must be a dictionary." )
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context if not self._is_local_function(context): # skip init of non local functions return model_selector = self.class_args.get("model_selector") execution_mechanism_by_model_name = self.class_args.get( schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM ) models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {}) if isinstance(model_selector, str): model_selector = get_class(model_selector, namespace)() model_objects = [] for model, model_params in models.values(): model_params[schemas.MonitoringData.INPUT_PATH] = ( self.class_args.get( mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {} ) .get(model_params.get("name"), {}) .get(schemas.MonitoringData.INPUT_PATH) ) model = get_class(model, namespace).from_dict( model_params, init_with_params=True ) model._raise_exception = False model_objects.append(model) self._async_object = ModelRunner( model_selector=model_selector, runnables=model_objects, execution_mechanism_by_runnable_name=execution_mechanism_by_model_name, shared_proxy_mapping=self._shared_proxy_mapping or None, name=self.name, context=context, )
class ModelRunnerErrorRaiser(storey.MapClass): def __init__(self, raise_exception: bool, models_names: list[str], **kwargs): super().__init__(**kwargs) self._raise_exception = raise_exception self._models_names = models_names def do(self, event): if self._raise_exception: errors = {} should_raise = False if len(self._models_names) == 1: should_raise = event.body.get("error") is not None errors[self._models_names[0]] = event.body.get("error") else: for model in event.body: errors[model] = event.body.get(model).get("error") if errors[model] is not None: should_raise = True if should_raise: raise ModelRunnerError(models_errors=errors) return event
[docs] class QueueStep(BaseStep, StepToDict): """queue step, implement an async queue or represent a stream""" kind = "queue" default_shape = "cds" _dict_fields = BaseStep._dict_fields + [ "path", "shards", "retention_in_hours", "trigger_args", "options", ] def __init__( self, name: Optional[str] = None, path: Optional[str] = None, after: Optional[list] = None, shards: Optional[int] = None, retention_in_hours: Optional[int] = None, trigger_args: Optional[dict] = None, **options, ): super().__init__(name, after) self.path = path self.shards = shards self.retention_in_hours = retention_in_hours self.options = options self.trigger_args = trigger_args self._stream = None self._async_object = None
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context if self.path: self._stream = get_stream_pusher( self.path, shards=self.shards, retention_in_hours=self.retention_in_hours, **self.options, ) if hasattr(self._stream, "create_stream"): self._stream.create_stream() self._set_error_handler()
@property def async_object(self): return self._async_object
[docs] def to( self, class_name: Union[str, StepToDict] = None, name: Optional[str] = None, handler: Optional[str] = None, graph_shape: Optional[str] = None, function: Optional[str] = None, full_event: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = None, **class_args, ): if not function: name = get_name(name, class_name) raise mlrun.errors.MLRunInvalidArgumentError( f"step '{name}' must specify a function, because it follows a queue step" ) return super().to( class_name, name, handler, graph_shape, function, full_event, input_path, result_path, model_endpoint_creation_strategy, **class_args, )
[docs] def run(self, event, *args, **kwargs): data = event.body if not data: return event if self._stream: full_event = self.options.get("full_event") if full_event or full_event is None and self.next: data = storey.utils.wrap_event_for_serialization(event, data) self._stream.push(data) event.terminated = True event.body = None return event
class FlowStep(BaseStep): """flow step, represent a workflow or DAG""" kind = "flow" _dict_fields = BaseStep._dict_fields + [ "steps", "engine", "default_final_step", ] def __init__( self, name=None, steps=None, after: Optional[list] = None, engine=None, final_step=None, ): super().__init__(name, after) self._steps = None self.steps = steps self.engine = engine self.from_step = os.environ.get("START_FROM_STEP", None) self.final_step = final_step self._last_added = None self._controller = None self._wait_for_result = False self._source = None self._start_steps = [] def get_children(self): return self._steps.values() @property def steps(self): """child (workflow) steps""" return self._steps @property def controller(self): """async (storey) flow controller""" return self._controller @steps.setter def steps(self, steps): self._steps = ObjectDict.from_dict(classes_map, steps, "task") def add_step( self, class_name=None, name=None, handler=None, after=None, before=None, graph_shape=None, function=None, full_event: Optional[bool] = None, input_path: Optional[str] = None, result_path: Optional[str] = None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = None, **class_args, ): """add task, queue or router step/class to the flow use after/before to insert into a specific location example: graph = fn.set_topology("flow", exist_ok=True) graph.add_step(class_name="Chain", name="s1") graph.add_step(class_name="Chain", name="s3", after="$prev") graph.add_step(class_name="Chain", name="s2", after="s1", before="s3") :param class_name: class name or step object to build the step from for router steps the class name should start with '*' for queue/stream step the class should be '>>' or '$queue' :param name: unique name (and path) for the child step, default is class name :param handler: class/function handler to invoke on run/event :param after: the step name this step comes after can use $prev to indicate the last added step :param before: string or list of next step names that will run after this step :param graph_shape: graphviz shape name :param function: function this step should run in :param full_event: this step accepts the full event (not just body) :param input_path: selects the key/path in the event to use as input to the step this require that the event body will behave like a dict, example: event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will receive 7 as input :param result_path: selects the key/path in the event to write the results to this require that the event body will behave like a dict, example: event: {"x": 5} , result_path="y" means the output of the step will be written to event["y"] resulting in {"x": 5, "y": <result>} :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint: * **overwrite**: 1. If model endpoints with the same name exist, delete the `latest` one. 2. Create a new model endpoint entry and set it as `latest`. * **inplace** (default): 1. If model endpoints with the same name exist, update the `latest` entry. 2. Otherwise, create a new entry. * **archive**: 1. If model endpoints with the same name exist, preserve them. 2. Create a new model endpoint with the same name and set it to `latest`. :param class_args: class init arguments """ if not name and isinstance(class_name, BaseStep): name = class_name.name name, step = params_to_step( class_name, name, handler, graph_shape=graph_shape, function=function, full_event=full_event, input_path=input_path, result_path=result_path, model_endpoint_creation_strategy=model_endpoint_creation_strategy, class_args=class_args, ) # Make sure model endpoint was not introduce in ModelRunnerStep self.check_model_endpoint_existence(step, model_endpoint_creation_strategy) self.verify_model_runner_step(step) after_list = after if isinstance(after, list) else [after] for after in after_list: self.insert_step(name, step, after, before) return step def insert_step(self, key, step, after, before=None): """insert step object into the flow, specify before and after""" step = self._steps.update(key, step) step.set_parent(self) if after == "$prev" and len(self._steps) == 1: after = None previous = "" if after: if after == "$prev" and self._last_added: previous = self._last_added.name else: if after not in self._steps.keys(): raise MLRunInvalidArgumentError( f"cant set after, there is no step named {after}" ) previous = after step.after_step(previous) if before: if before not in self._steps.keys(): raise MLRunInvalidArgumentError( f"cant set before, there is no step named {before}" ) if before == step.name or before == previous: raise GraphError( f"graph loop, step {before} is specified in before and/or after {key}" ) self[step.name].after_step(*self[before].after, append=False) self[before].after_step(step.name, append=False) self._last_added = step return step def clear_children(self, steps: Optional[list] = None): """remove some or all of the states, empty/None for all""" if not steps: steps = self._steps.keys() for key in steps: del self._steps[key] def __getitem__(self, name): return self._steps[name] def __setitem__(self, name, step): self.add_step(name, step) def __delitem__(self, key): del self._steps[key] def __iter__(self): yield from self._steps.keys() def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): """initialize graph objects and classes""" self.context = context self._insert_all_error_handlers() self.check_and_process_graph() for step in self.steps.values(): step.set_parent(self) step.init_object(context, namespace, mode, reset=reset) self._set_error_handler() self._post_init(mode) if self.engine != "sync": self._build_async_flow() self._run_async_flow() def check_and_process_graph(self, allow_empty=False): """validate correct graph layout and initialize the .next links""" if self.is_empty() and allow_empty: self._start_steps = [] return [], None, [] def has_loop(step, previous): for next_step in step.after or []: if next_step in previous: return step.name downstream = has_loop(self[next_step], previous + [next_step]) if downstream: return downstream return None start_steps = [] for step in self._steps.values(): step._next = None step._visited = False if step.after: loop_step = has_loop(step, []) if loop_step: raise GraphError( f"Error, loop detected in step {loop_step}, graph must be acyclic (DAG)" ) else: start_steps.append(step.name) responders = [] for step in self._steps.values(): if ( hasattr(step, "responder") and step.responder and step.kind != "error_step" ): responders.append(step.name) if step.on_error and step.on_error in start_steps: start_steps.remove(step.on_error) if step.after: for prev_step in step.after: self[prev_step].set_next(step.name) if self.on_error and self.on_error in start_steps: start_steps.remove(self.on_error) if ( len(responders) > 1 ): # should not have multiple steps which respond to request raise GraphError( f'there are more than one responder steps in the graph ({",".join(responders)})' ) if self.from_step: if self.from_step not in self.steps: raise GraphError( f"from_step ({self.from_step}) specified and not found in graph steps" ) start_steps = [self.from_step] self._start_steps = [self[name] for name in start_steps] def get_first_function_step(step, current_function): # find the first step which belongs to the function if ( hasattr(step, "function") and step.function and step.function == current_function ): return step for item in step.next or []: next_step = self[item] returned_step = get_first_function_step(next_step, current_function) if returned_step: return returned_step current_function = get_current_function(self.context) if current_function and current_function != "*": new_start_steps = [] for from_step in self._start_steps: step = get_first_function_step(from_step, current_function) if step: new_start_steps.append(step) if not new_start_steps: raise GraphError( f"did not find steps pointing to current function ({current_function})" ) self._start_steps = new_start_steps if self.engine == "sync" and len(self._start_steps) > 1: raise GraphError( "sync engine can only have one starting step (without .after)" ) default_final_step = None if self.final_step: if self.final_step not in self.steps: raise GraphError( f"final_step ({self.final_step}) specified and not found in graph steps" ) default_final_step = self.final_step elif len(self._start_steps) == 1: # find the final step in case if a simple sequence of steps next_obj = self._start_steps[0] while next_obj: next = next_obj.next if not next: default_final_step = next_obj.name break next_obj = self[next[0]] if len(next) == 1 else None return self._start_steps, default_final_step, responders def set_flow_source(self, source): """set the async flow (storey) source""" self._source = source def _build_async_flow(self): """initialize and build the async/storey DAG""" def process_step(state, step, root): if not state._is_local_function(self.context) or state._visited: return for item in state.next or []: next_state = root[item] if next_state.async_object: next_step = step.to(next_state.async_object) process_step(next_state, next_step, root) state._visited = ( True # mark visited to avoid re-visit in case of multiple uplinks ) default_source, self._wait_for_result = _init_async_objects( self.context, self._steps.values() ) source = self._source or default_source for next_state in self._start_steps: next_step = source.to(next_state.async_object) process_step(next_state, next_step, self) for step in self._steps.values(): # add error handler hooks if (step.on_error or self.on_error) and step.async_object: error_step = self._steps[step.on_error or self.on_error] # never set a step as its own error handler if step != error_step: step.async_object.set_recovery_step(error_step.async_object) for next_step in error_step.next or []: next_state = self[next_step] if next_state.async_object and error_step.async_object: error_step.async_object.to(next_state.async_object) self._async_flow = source def _run_async_flow(self): self._controller = self._async_flow.run() def get_queue_links(self): """return dict of function and queue its listening on, for building stream triggers""" links = {} for step in self.get_children(): if step.kind == StepKinds.queue: for item in step.next or []: next_step = self[item] if not next_step.function: raise GraphError( f"child function name must be specified in steps ({next_step.name}) which follow a queue" ) if next_step.function in links: raise GraphError( f"function ({next_step.function}) cannot read from multiple queues" ) links[next_step.function] = step return links def create_queue_streams(self): """create the streams used in this flow""" for step in self.get_children(): if step.kind == StepKinds.queue: step.init_object(self.context, None) def list_child_functions(self): """return a list of child function names referred to in the steps""" functions = [] for step in self.get_children(): if ( hasattr(step, "function") and step.function and step.function not in functions ): functions.append(step.function) return functions def is_empty(self): """is the graph empty (no child steps)""" return len(self.steps) == 0 @staticmethod async def _await_and_return_id(awaitable, event): await awaitable event = copy(event) event.body = {"id": event.id} return event def run(self, event, *args, **kwargs): if self._controller: # async flow (using storey) event._awaitable_result = None if self.context.is_mock: resp = self._controller.emit( event, return_awaitable_result=self._wait_for_result ) if self._wait_for_result and resp: return resp.await_result() else: resp_awaitable = self._controller.emit( event, await_result=self._wait_for_result ) if self._wait_for_result: return resp_awaitable return self._await_and_return_id(resp_awaitable, event) event = copy(event) event.body = {"id": event.id} return event event = storey.utils.unpack_event_if_wrapped(event) if len(self._start_steps) == 0: return event next_obj = self._start_steps[0] while next_obj: try: event = next_obj.run(event, *args, **kwargs) except Exception as exc: if self._on_error_handler: self._log_error(event, exc, failed_step=next_obj.name) event.body = self._call_error_handler(event, exc) event.terminated = True return event else: raise exc if hasattr(event, "terminated") and event.terminated: return event if ( hasattr(event, "error") and isinstance(event.error, dict) and next_obj.name in event.error ): next_obj = self._steps[next_obj.on_error] next = next_obj.next if next and len(next) > 1: raise GraphError( f"synchronous flow engine doesnt support branches use async, step={next_obj.name}" ) next_obj = self[next[0]] if next else None return event def wait_for_completion(self): """wait for completion of run in async flows""" if self._controller: if hasattr(self._controller, "terminate"): return self._controller.terminate(wait=True) else: return self._controller.await_termination() def plot(self, filename=None, format=None, source=None, targets=None, **kw): """plot/save graph using graphviz :param filename: target filepath for the graph image (None for the notebook) :param format: the output format used for rendering (``'pdf'``, ``'png'``, etc.) :param source: source step to add to the graph image :param targets: list of target steps to add to the graph image :param kw: kwargs passed to graphviz, e.g. rankdir="LR" (see https://graphviz.org/doc/info/attrs.html) :return: graphviz graph object """ return _generate_graphviz( self, _add_graphviz_flow, filename, format, source=source, targets=targets, **kw, ) def _insert_all_error_handlers(self): """ insert all error steps to the graph run after deployment """ for name, step in self._steps.items(): if step.kind == "error_step": self._insert_error_step(name, step) def _insert_error_step(self, name, step): """ insert error step to the graph run after deployment """ if not step.before and not any( [step.name in other_step.after for other_step in self._steps.values()] ): if any( [ getattr(step_in_graph, "responder", False) for step_in_graph in self._steps.values() ] ): step.responder = True return for step_name in step.before: if step_name not in self._steps.keys(): raise MLRunInvalidArgumentError( f"cant set before, there is no step named {step_name}" ) self[step_name].after_step(name) def set_flow( self, steps: list[Union[str, StepToDict, dict[str, Any]]], force: bool = False, ): if not force and self.steps: raise mlrun.errors.MLRunInvalidArgumentError( "set_flow() called on a step that already has downstream steps. " "If you want to overwrite existing steps, set force=True." ) self.steps = None step = self for next_step in steps: if isinstance(next_step, dict): step = step.to(**next_step) else: step = step.to(next_step) return step def supports_termination(self): return self.engine != "sync" class RootFlowStep(FlowStep): """root flow step""" kind = "root" _dict_fields = [ "steps", "engine", "final_step", "on_error", "model_endpoints_names", "model_endpoints_routes_names", "track_models", "shared_max_processes", "shared_max_threads", "shared_models", "shared_models_mechanism", "pool_factor", ] def __init__( self, name=None, steps=None, after: Optional[list] = None, engine=None, final_step=None, ): super().__init__( name, steps, after, engine, final_step, ) self._models = set() self._route_models = set() self._track_models = False self._shared_models: dict[str, tuple[str, dict]] = {} self._shared_models_mechanism: dict[str, ParallelExecutionMechanisms] = {} self._shared_max_processes = None self._shared_max_threads = None self._pool_factor = None def add_shared_model( self, name: str, model_class: Union[str, Model], execution_mechanism: Union[str, ParallelExecutionMechanisms], model_artifact: Union[str, ModelArtifact], override: bool = False, **model_parameters, ) -> None: """ Add a shared model to the graph, this model will be available to all the ModelRunners in the graph :param name: Name of the shared model (should be unique in the graph) :param model_class: Model class name :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of: * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter Lock (GIL). * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive tasks that also require significant Runnable-specific initialization (e.g. a large model). * "thread_pool" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise block the main event loop thread. * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event loop to continue running while waiting for a response. * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially useful when: - You want to share a heavy resource like a large model loaded onto a GPU. - You want to centralize task scheduling or coordination for multiple lightweight tasks. - You aim to minimize overhead from creating new executors or processes/threads per runnable. The runnable is expected to be pre-initialized and reused across events, enabling efficient use of memory and hardware accelerators. * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It means that the runnable will not actually be run in parallel to anything else. :param model_artifact: model artifact or mlrun model artifact uri :param override: bool allow override existing model on the current ModelRunnerStep. :param model_parameters: Parameters for model instantiation """ if isinstance(model_class, Model) and model_parameters: raise mlrun.errors.MLRunInvalidArgumentError( "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`." ) if execution_mechanism == ParallelExecutionMechanisms.shared_executor: raise mlrun.errors.MLRunInvalidArgumentError( "Cannot add a shared model with execution mechanism 'shared_executor'" ) ParallelExecutionMechanisms.validate(execution_mechanism) model_parameters = model_parameters or ( model_class.to_dict() if isinstance(model_class, Model) else {} ) model_artifact = ( model_artifact.uri if isinstance(model_artifact, mlrun.artifacts.Artifact) else model_artifact ) model_artifact = mlrun.utils.remove_tag_from_artifact_uri(model_artifact) model_parameters["artifact_uri"] = model_parameters.get( "artifact_uri", model_artifact ) if model_parameters.get("name", name) != name or ( isinstance(model_class, Model) and model_class.name != name ): raise mlrun.errors.MLRunInvalidArgumentError( "Inconsistent name for the added model." ) model_parameters["name"] = name if name in self.shared_models and not override: raise mlrun.errors.MLRunInvalidArgumentError( f"Model with name {name} already exists in this graph." ) model_class = ( model_class if isinstance(model_class, str) else model_class.__class__.__name__ ) self.shared_models[name] = (model_class, model_parameters) self.shared_models_mechanism[name] = execution_mechanism def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]: """ Get a shared model by its artifact URI. :param artifact_uri: The artifact URI of the model. :return: A tuple of (model_class, model_parameters) if found, otherwise None. """ for model_name, (model_class, model_params) in self.shared_models.items(): if model_params.get("artifact_uri") == artifact_uri: return model_name return None def config_pool_resource( self, max_processes: Optional[int] = None, max_threads: Optional[int] = None, pool_factor: Optional[int] = None, ) -> None: """ Configure the resource limits for the shared models in the graph. :param max_processes: Maximum number of processes to spawn (excluding dedicated processes). Defaults to the number of CPUs or 16 if undetectable. :param max_threads: Maximum number of threads to spawn. Defaults to 32. :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1. """ self.shared_max_processes = max_processes self.shared_max_threads = max_threads self.pool_factor = pool_factor def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context if self.shared_models: self.context.executor = storey.flow.RunnableExecutor( max_processes=self.shared_max_processes, max_threads=self.shared_max_threads, pool_factor=self.pool_factor, ) monitored_steps = self.get_monitored_steps().values() for monitored_step in monitored_steps: if isinstance(monitored_step, ModelRunnerStep): for model, model_params in self.shared_models.values(): if "shared_proxy_mapping" in model_params: model_params["shared_proxy_mapping"].update( deepcopy( monitored_step._shared_proxy_mapping.get( model_params.get("name"), {} ) ) ) else: model_params["shared_proxy_mapping"] = deepcopy( monitored_step._shared_proxy_mapping.get( model_params.get("name"), {} ) ) for model, model_params in self.shared_models.values(): model = get_class(model, namespace).from_dict( model_params, init_with_params=True ) model._raise_exception = False self.context.executor.add_runnable( model, self._shared_models_mechanism[model.name] ) super().init_object(context, namespace, mode, reset=reset, **extra_kwargs) @property def model_endpoints_names(self) -> list[str]: return list(self._models) @model_endpoints_names.setter def model_endpoints_names(self, models: list[str]): self._models = set(models) def update_model_endpoints_names(self, model_endpoints_names: list): self._models.update(model_endpoints_names) @property def model_endpoints_routes_names(self) -> list[str]: return list(self._route_models) @model_endpoints_routes_names.setter def model_endpoints_routes_names(self, models: list[str]): self._route_models = set(models) @property def track_models(self): return self._track_models @track_models.setter def track_models(self, track_models: bool): self._track_models = track_models @property def shared_models(self) -> dict[str, tuple[str, dict]]: return self._shared_models @shared_models.setter def shared_models(self, shared_models: dict[str, tuple[str, dict]]): self._shared_models = shared_models @property def shared_models_mechanism(self) -> dict[str, ParallelExecutionMechanisms]: return self._shared_models_mechanism @shared_models_mechanism.setter def shared_models_mechanism( self, shared_models_mechanism: dict[str, ParallelExecutionMechanisms] ): self._shared_models_mechanism = shared_models_mechanism @property def shared_max_processes(self) -> Optional[int]: return self._shared_max_processes @shared_max_processes.setter def shared_max_processes(self, max_processes: Optional[int]): self._shared_max_processes = max_processes @property def shared_max_threads(self) -> Optional[int]: return self._shared_max_threads @shared_max_threads.setter def shared_max_threads(self, max_threads: Optional[int]): self._shared_max_threads = max_threads @property def pool_factor(self) -> Optional[int]: return self._pool_factor @pool_factor.setter def pool_factor(self, pool_factor: Optional[int]): self._pool_factor = pool_factor def update_model_endpoints_routes_names(self, model_endpoints_names: list): self._route_models.update(model_endpoints_names) def include_monitored_step(self) -> bool: for step in self.steps.values(): if isinstance(step, mlrun.serving.MonitoredStep): return True return False def get_monitored_steps(self) -> dict[str, "MonitoredStep"]: return { step.name: step for step in self.steps.values() if isinstance(step, mlrun.serving.MonitoredStep) } classes_map = { "task": TaskStep, "router": RouterStep, "flow": FlowStep, "queue": QueueStep, "error_step": ErrorStep, "monitoring_application": MonitoringApplicationStep, "model_runner": ModelRunnerStep, } def get_current_function(context): if context and hasattr(context, "current_function"): return context.current_function or "" return "" def _add_graphviz_router(graph, step, source=None, **kwargs): if source: graph.node("_start", source.name, shape=source.shape, style="filled") graph.edge("_start", step.fullname) graph.node(step.fullname, label=step.name, shape=step.get_shape()) for route in step.get_children(): graph.node(route.fullname, label=route.name, shape=route.get_shape()) graph.edge(step.fullname, route.fullname) def _add_graphviz_model_runner(graph, step, source=None): if source: graph.node("_start", source.name, shape=source.shape, style="filled") graph.edge("_start", step.fullname) is_monitored = step._extract_root_step().track_models m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else "" number_of_models = len( list(step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}).keys()) ) number_badge = f""" <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" BGCOLOR="black" CELLPADDING="2"> <TR> <TD><FONT COLOR="white" POINT-SIZE="9"><B>{number_of_models}</B></FONT></TD> </TR> </TABLE> """ html_label = f"""< <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="4"> <TR> <TD ALIGN="LEFT">{m_cell}</TD> <TD ALIGN="RIGHT">{number_badge}</TD> </TR> <TR> <TD COLSPAN="2" ALIGN="CENTER"><FONT POINT-SIZE="14">{step.name}</FONT></TD> </TR> </TABLE> >""" graph.node(step.fullname, label=html_label, shape=step.get_shape()) def _add_graphviz_flow( graph, step, source=None, targets=None, ): start_steps, default_final_step, responders = step.check_and_process_graph( allow_empty=True ) graph.node("_start", source.name, shape=source.shape, style="filled") for start_step in start_steps: graph.edge("_start", start_step.fullname) for child in step.get_children(): kind = child.kind if kind == StepKinds.router: with graph.subgraph(name="cluster_" + child.fullname) as sg: _add_graphviz_router(sg, child) elif kind == StepKinds.model_runner: _add_graphviz_model_runner(graph, child) else: graph.node(child.fullname, label=child.name, shape=child.get_shape()) _add_edges(child.after or [], step, graph, child) _add_edges(getattr(child, "before", []), step, graph, child, after=False) if child.on_error: graph.edge(child.fullname, child.on_error, style="dashed") # draw targets after the last step (if specified) if targets: for target in targets or []: target_kind, target_name = target.name.split("/", 1) if target_kind != target_name: label = ( f"<{target_name}<br/><font point-size='8'>({target_kind})</font>>" ) else: label = target_name graph.node(target.fullname, label=label, shape=target.get_shape()) last_step = target.after or default_final_step if last_step: graph.edge(last_step, target.fullname) def _add_edges(items, step, graph, child, after=True): for item in items: next_or_prev_object = step[item] kw = {} if next_or_prev_object.kind == StepKinds.router: kw["ltail"] = f"cluster_{next_or_prev_object.fullname}" if after: graph.edge(next_or_prev_object.fullname, child.fullname, **kw) else: graph.edge(child.fullname, next_or_prev_object.fullname, **kw) def _generate_graphviz( step, renderer, filename=None, format=None, source=None, targets=None, **kw, ): try: from graphviz import Digraph except ImportError: raise ImportError( 'graphviz is not installed, run "pip install graphviz" first!' ) graph = Digraph("mlrun-flow", format="jpg") graph.attr(compound="true", **kw) source = source or BaseStep("start", shape="egg") renderer(graph, step, source=source, targets=targets) if filename: suffix = pathlib.Path(filename).suffix if suffix: filename = filename[: -len(suffix)] format = format or suffix[1:] format = format or "png" graph.render(filename, format=format) return graph def graph_root_setter(server, graph): """set graph root object from class or dict""" if graph: if isinstance(graph, dict): kind = graph.get("kind") elif hasattr(graph, "kind"): kind = graph.kind else: raise MLRunInvalidArgumentError("graph must be a dict or a valid object") if kind == StepKinds.router: server._graph = server._verify_dict(graph, "graph", RouterStep) elif not kind or kind == StepKinds.root: server._graph = server._verify_dict(graph, "graph", RootFlowStep) else: raise GraphError(f"illegal root step {kind}") def get_name(name, class_name): """get task name from provided name or class""" if name: return name if not class_name: raise MLRunInvalidArgumentError("name or class_name must be provided") if isinstance(class_name, type): return class_name.__name__ return class_name.split(".")[-1] def params_to_step( class_name, name, handler=None, graph_shape=None, function=None, full_event=None, input_path: Optional[str] = None, result_path: Optional[str] = None, class_args=None, model_endpoint_creation_strategy: Optional[ schemas.ModelEndpointCreationStrategy ] = None, endpoint_type: Optional[schemas.EndpointType] = None, ): """return step object from provided params or classes/objects""" class_args = class_args or {} if isinstance(class_name, QueueStep): if not (name or class_name.name): raise MLRunInvalidArgumentError("queue name must be specified") step = class_name elif class_name in queue_class_names: if "path" not in class_args: raise MLRunInvalidArgumentError( "path=<stream path or None> must be specified for queues" ) if not name: raise MLRunInvalidArgumentError("queue name must be specified") # Pass full_event on only if it's explicitly defined if full_event is not None: class_args = class_args.copy() class_args["full_event"] = full_event step = QueueStep(name, **class_args) elif class_name and hasattr(class_name, "to_dict"): struct = deepcopy(class_name.to_dict()) kind = struct.get("kind", StepKinds.task) name = ( name or struct.get("name", struct.get("class_name")) or class_name.to_dict(["name"]).get("name") ) cls = classes_map.get(kind, RootFlowStep) step = cls.from_dict(struct) step.function = function step.full_event = full_event or step.full_event step.input_path = input_path or step.input_path step.result_path = result_path or step.result_path if kind == StepKinds.task: step.model_endpoint_creation_strategy = model_endpoint_creation_strategy step.endpoint_type = endpoint_type elif class_name and class_name.startswith("*"): routes = class_args.get("routes", None) class_name = class_name[1:] name = get_name(name, class_name or "router") step = RouterStep( class_name, class_args, handler, name=name, function=function, routes=routes, input_path=input_path, result_path=result_path, ) elif class_name or handler: name = get_name(name, class_name) step = TaskStep( class_name, class_args, handler, name=name, function=function, full_event=full_event, input_path=input_path, result_path=result_path, model_endpoint_creation_strategy=model_endpoint_creation_strategy, endpoint_type=endpoint_type, ) else: raise MLRunInvalidArgumentError("class_name or handler must be provided") if graph_shape: step.shape = graph_shape return name, step def _init_async_objects(context, steps): try: import storey except ImportError: raise GraphError("storey package is not installed, use pip install storey") wait_for_result = False trigger = getattr(context, "trigger", None) context.logger.debug(f"trigger is {trigger or 'unknown'}") # respond is only supported for HTTP trigger respond_supported = trigger is None or trigger == "http" for step in steps: if hasattr(step, "async_object") and step._is_local_function(context): if step.kind == StepKinds.queue: skip_stream = context.is_mock and step.next if step.path and not skip_stream: stream_path = step.path endpoint = None # in case of a queue, we default to a full_event=True full_event = step.options.get("full_event") options = { "full_event": full_event or full_event is None and step.next } options.update(step.options) kafka_brokers = get_kafka_brokers_from_dict(options, pop=True) if stream_path and stream_path.startswith("ds://"): datastore_profile = datastore_profile_read(stream_path) if isinstance( datastore_profile, (DatastoreProfileKafkaTarget, DatastoreProfileKafkaSource), ): step._async_object = KafkaStoreyTarget( path=stream_path, context=context, **options, ) elif isinstance(datastore_profile, DatastoreProfileV3io): step._async_object = StreamStoreyTarget( stream_path=stream_path, context=context, **options, ) else: raise mlrun.errors.MLRunValueError( f"Received an unexpected stream profile type: {type(datastore_profile)}\n" "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`." ) elif stream_path.startswith("kafka://") or kafka_brokers: topic, brokers = parse_kafka_url(stream_path, kafka_brokers) kafka_producer_options = options.pop( "kafka_producer_options", None ) step._async_object = storey.KafkaTarget( topic=topic, brokers=brokers, producer_options=kafka_producer_options, context=context, **options, ) else: if stream_path.startswith("v3io://"): endpoint, stream_path = parse_path(step.path) stream_path = stream_path.strip("/") step._async_object = storey.StreamTarget( storey.V3ioDriver(endpoint or config.v3io_api), stream_path, context=context, **options, ) else: step._async_object = storey.Map(lambda x: x) elif not step.async_object or not hasattr(step.async_object, "_outlets"): # if regular class, wrap with storey Map step._async_object = storey.Map( step._handler, full_event=step.full_event or step._call_with_event, input_path=step.input_path, result_path=step.result_path, name=step.name, context=context, pass_context=step._inject_context, ) if ( respond_supported and not step.next and hasattr(step, "responder") and step.responder ): # if responder step (return result), add Complete() step.async_object.to(storey.Complete(full_event=True)) wait_for_result = True source_args = context.get_param("source_args", {}) explicit_ack = ( is_explicit_ack_supported(context) and mlrun.mlconf.is_explicit_ack_enabled() ) if context.is_mock: source_class = storey.SyncEmitSource else: source_class = storey.AsyncEmitSource default_source = source_class( context=context, explicit_ack=explicit_ack, **source_args, ) return default_source, wait_for_result