Source code for mlrun.datastore.datastore_profile

# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ast
import base64
import json
import typing
from urllib.parse import ParseResult, urlparse, urlunparse

import pydantic
from mergedeep import merge

import mlrun
import mlrun.errors

from ..secrets import get_secret_or_env


class DatastoreProfile(pydantic.BaseModel):
    type: str
    name: str
    _private_attributes: typing.List = ()

    @pydantic.validator("name")
    def lower_case(cls, v):
        return v.lower()

    @staticmethod
    def generate_secret_key(profile_name: str, project: str):
        secret_name_separator = "."
        full_key = (
            "datastore-profiles"
            + secret_name_separator
            + project
            + secret_name_separator
            + profile_name
        )
        return full_key

    def secrets(self) -> dict:
        return None

    def url(self, subpath) -> str:
        return None


class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton):
    def __init__(self):
        self._data = {}  # Initialize the dictionary

    def add(self, profile: DatastoreProfile):
        self._data[profile.name] = profile

    def get(self, key):
        return self._data.get(key, None)


class DatastoreProfileBasic(DatastoreProfile):
    type: str = pydantic.Field("basic")
    _private_attributes = "private"
    public: str
    private: typing.Optional[str] = None


class DatastoreProfileKafkaTarget(DatastoreProfile):
    type: str = pydantic.Field("kafka_target")
    _private_attributes = "kwargs_private"
    bootstrap_servers: str
    topic: str
    kwargs_public: typing.Optional[typing.Dict]
    kwargs_private: typing.Optional[typing.Dict]

    def attributes(self):
        attributes = {"bootstrap_servers": self.bootstrap_servers}
        if self.kwargs_public:
            attributes = merge(attributes, self.kwargs_public)
        if self.kwargs_private:
            attributes = merge(attributes, self.kwargs_private)
        return attributes


class DatastoreProfileKafkaSource(DatastoreProfile):
    type: str = pydantic.Field("kafka_source")
    _private_attributes = ("kwargs_private", "sasl_user", "sasl_pass")
    brokers: typing.Union[str, typing.List[str]]
    topics: typing.Union[str, typing.List[str]]
    group: typing.Optional[str] = "serving"
    initial_offset: typing.Optional[str] = "earliest"
    partitions: typing.Optional[typing.Union[str, typing.List[str]]]
    sasl_user: typing.Optional[str]
    sasl_pass: typing.Optional[str]
    kwargs_public: typing.Optional[typing.Dict]
    kwargs_private: typing.Optional[typing.Dict]

    def attributes(self):
        attributes = {}
        if self.kwargs_public:
            attributes = merge(attributes, self.kwargs_public)
        if self.kwargs_private:
            attributes = merge(attributes, self.kwargs_private)

        topics = [self.topics] if isinstance(self.topics, str) else self.topics
        brokers = [self.brokers] if isinstance(self.brokers, str) else self.brokers

        attributes["brokers"] = brokers
        attributes["topics"] = topics
        attributes["group"] = self.group
        attributes["initial_offset"] = self.initial_offset
        if self.partitions is not None:
            attributes["partitions"] = self.partitions
        sasl = attributes.pop("sasl", {})
        if self.sasl_user and self.sasl_pass:
            sasl["enabled"] = True
            sasl["user"] = self.sasl_user
            sasl["password"] = self.sasl_pass
        if sasl:
            attributes["sasl"] = sasl
        return attributes


class DatastoreProfileS3(DatastoreProfile):
    type: str = pydantic.Field("s3")
    _private_attributes = ("access_key", "secret_key")
    endpoint_url: typing.Optional[str] = None
    force_non_anonymous: typing.Optional[str] = None
    profile_name: typing.Optional[str] = None
    assume_role_arn: typing.Optional[str] = None
    access_key: typing.Optional[str] = None
    secret_key: typing.Optional[str] = None

    def secrets(self) -> dict:
        res = {}
        if self.access_key:
            res["AWS_ACCESS_KEY_ID"] = self.access_key
        if self.secret_key:
            res["AWS_SECRET_ACCESS_KEY"] = self.secret_key
        if self.endpoint_url:
            res["S3_ENDPOINT_URL"] = self.endpoint_url
        if self.force_non_anonymous:
            res["S3_NON_ANONYMOUS"] = self.force_non_anonymous
        if self.profile_name:
            res["AWS_PROFILE"] = self.profile_name
        if self.assume_role_arn:
            res["MLRUN_AWS_ROLE_ARN"] = self.assume_role_arn
        return res if res else None

    def url(self, subpath):
        return f"s3:/{subpath}"


class DatastoreProfileRedis(DatastoreProfile):
    type: str = pydantic.Field("redis")
    _private_attributes = ("username", "password")
    endpoint_url: str
    username: typing.Optional[str] = None
    password: typing.Optional[str] = None

    def url_with_credentials(self):
        parsed_url = urlparse(self.endpoint_url)
        username = self.username
        password = self.password
        netloc = parsed_url.hostname
        if username:
            if password:
                netloc = f"{username}:{password}@{parsed_url.hostname}"
            else:
                netloc = f"{username}@{parsed_url.hostname}"

        if parsed_url.port:
            netloc += f":{parsed_url.port}"

        new_parsed_url = ParseResult(
            scheme=parsed_url.scheme,
            netloc=netloc,
            path=parsed_url.path,
            params=parsed_url.params,
            query=parsed_url.query,
            fragment=parsed_url.fragment,
        )
        return urlunparse(new_parsed_url)

    def secrets(self) -> dict:
        res = {}
        if self.username:
            res["REDIS_USER"] = self.username
        if self.password:
            res["REDIS_PASSWORD"] = self.password
        return res if res else None

    def url(self, subpath):
        return self.endpoint_url + subpath


class DatastoreProfileDBFS(DatastoreProfile):
    type: str = pydantic.Field("dbfs")
    _private_attributes = ("token",)
    endpoint_url: typing.Optional[str] = None  # host
    token: typing.Optional[str] = None

    def url(self, subpath) -> str:
        return f"dbfs://{subpath}"

    def secrets(self) -> dict:
        res = {}
        if self.token:
            res["DATABRICKS_TOKEN"] = self.token
        if self.endpoint_url:
            res["DATABRICKS_HOST"] = self.endpoint_url
        return res if res else None


class DatastoreProfile2Json(pydantic.BaseModel):
    @staticmethod
    def _to_json(attributes):
        # First, base64 encode the values
        encoded_dict = {
            k: base64.b64encode(str(v).encode()).decode() for k, v in attributes.items()
        }
        # Then, return the dictionary as a JSON string with no spaces
        return json.dumps(encoded_dict).replace(" ", "")

    @staticmethod
    def get_json_public(profile: DatastoreProfile) -> str:
        return DatastoreProfile2Json._to_json(
            {
                k: v
                for k, v in profile.dict().items()
                if not str(k) in profile._private_attributes
            }
        )

    @staticmethod
    def get_json_private(profile: DatastoreProfile) -> str:
        return DatastoreProfile2Json._to_json(
            {
                k: v
                for k, v in profile.dict().items()
                if str(k) in profile._private_attributes
            }
        )

    @staticmethod
    def create_from_json(public_json: str, private_json: str = "{}"):
        attributes = json.loads(public_json)
        attributes_public = {
            k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
        }
        attributes = json.loads(private_json)
        attributes_private = {
            k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
        }
        decoded_dict = merge(attributes_public, attributes_private)

        def safe_literal_eval(value):
            try:
                return ast.literal_eval(value)
            except (ValueError, SyntaxError):
                return value

        decoded_dict = {k: safe_literal_eval(v) for k, v in decoded_dict.items()}
        datastore_type = decoded_dict.get("type")
        ds_profile_factory = {
            "s3": DatastoreProfileS3,
            "redis": DatastoreProfileRedis,
            "basic": DatastoreProfileBasic,
            "kafka_target": DatastoreProfileKafkaTarget,
            "kafka_source": DatastoreProfileKafkaSource,
            "dbfs": DatastoreProfileDBFS,
        }
        if datastore_type in ds_profile_factory:
            return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
        else:
            if datastore_type:
                reason = f"unexpected type '{decoded_dict['type']}'"
            else:
                reason = "missing type"
            raise mlrun.errors.MLRunInvalidArgumentError(
                f"Datastore profile failed to create from json due to {reason}"
            )


def datastore_profile_read(url):
    parsed_url = urlparse(url)
    if parsed_url.scheme.lower() != "ds":
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"resource URL '{url}' cannot be read as a datastore profile because its scheme is not 'ds'"
        )

    profile_name = parsed_url.hostname
    project_name = parsed_url.username or mlrun.mlconf.default_project
    datastore = TemporaryClientDatastoreProfiles().get(profile_name)
    if datastore:
        return datastore
    public_profile = mlrun.db.get_run_db().get_datastore_profile(
        profile_name, project_name
    )
    project_ds_name_private = DatastoreProfile.generate_secret_key(
        profile_name, project_name
    )
    private_body = get_secret_or_env(project_ds_name_private)
    if not public_profile or not private_body:
        raise mlrun.errors.MLRunInvalidArgumentError(
            f"Unable to retrieve the datastore profile '{url}' from either the server or local environment."
            "Make sure the profile is registered correctly, or if running in a local environment,"
            "use register_temporary_client_datastore_profile() to provide credentials locally."
        )

    datastore = DatastoreProfile2Json.create_from_json(
        public_json=DatastoreProfile2Json.get_json_public(public_profile),
        private_json=private_body,
    )
    return datastore


[docs]def register_temporary_client_datastore_profile(profile: DatastoreProfile): """Register the datastore profile. This profile is temporary and remains valid only for the duration of the caller's session. It's beneficial for testing purposes. """ TemporaryClientDatastoreProfiles().add(profile)
def datastore_profile_embed_url_scheme(url): profile = datastore_profile_read(url) parsed_url = urlparse(url) scheme = profile.type # Add scheme as a password to the network location part netloc = f"{parsed_url.username or ''}:{scheme}@{parsed_url.netloc}" # Construct the new URL new_url = urlunparse( [ parsed_url.scheme, netloc, parsed_url.path, parsed_url.params, parsed_url.query, parsed_url.fragment, ] ) return new_url