Source code for mlrun.artifacts.document

# Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import tempfile
from collections.abc import Iterator
from copy import deepcopy
from importlib import import_module
from typing import Optional, Union

import mlrun
import mlrun.artifacts
from mlrun.artifacts import Artifact, ArtifactSpec
from mlrun.model import ModelObj

from ..utils import generate_artifact_uri
from .base import ArtifactStatus


[docs] class DocumentLoaderSpec(ModelObj): """ A class to load a document from a file path using a specified loader class. This class is responsible for loading documents from a given source path using a specified loader class. The loader class is dynamically imported and instantiated with the provided arguments. The loaded documents can be optionally uploaded as artifacts. Note that only loader classes that return single results (e.g., TextLoader, UnstructuredHTMLLoader, WebBaseLoader(scalar)) are supported - loaders returning multiple results like DirectoryLoader or WebBaseLoader(list) are not compatible. Attributes: loader_class_name (str): The name of the loader class to use for loading documents. src_name (str): The name of the source attribute to pass to the loader class. kwargs (Optional[dict]): Additional keyword arguments to pass to the loader class. """ _dict_fields = ["loader_class_name", "src_name", "download_object", "kwargs"] def __init__( self, loader_class_name: str = "langchain_community.document_loaders.TextLoader", src_name: str = "file_path", download_object: bool = True, kwargs: Optional[dict] = None, ): """ Initialize the document loader. Args: loader_class_name (str): The name of the loader class to use. src_name (str): The source name for the document. kwargs (Optional[dict]): Additional keyword arguments to pass to the loader class. download_object (bool, optional): If True, the file will be downloaded before launching the loader. If False, the loader accepts a link that should not be downloaded. Defaults to True. Example: >>> # Create a loader specification for PDF documents >>> loader_spec = DocumentLoaderSpec( ... loader_class_name="langchain_community.document_loaders.PDFLoader", ... src_name="file_path", ... kwargs={"extract_images": True}, ... ) >>> # Create a loader instance for a specific PDF file >>> pdf_loader = loader_spec.make_loader("/path/to/document.pdf") >>> # Load the documents >>> documents = pdf_loader.load() """ self.loader_class_name = loader_class_name self.src_name = src_name self.download_object = download_object self.kwargs = kwargs
[docs] def make_loader(self, src_path): module_name, class_name = self.loader_class_name.rsplit(".", 1) module = import_module(module_name) loader_class = getattr(module, class_name) kwargs = deepcopy(self.kwargs or {}) kwargs[self.src_name] = src_path loader = loader_class(**kwargs) return loader
[docs] class MLRunLoader: """ A factory class for creating instances of a dynamically defined document loader. Args: artifact_key (str, optional): The key for the artifact to be logged. The '%%' pattern in the key will be replaced by the source path with any unsupported characters converted to '_'. Defaults to "%%". local_path (str): The source path of the document to be loaded. loader_spec (DocumentLoaderSpec): Specification for the document loader. producer (Optional[Union[MlrunProject, str, MLClientCtx]], optional): The producer of the document. If not specified, will try to get the current MLRun context or project. Defaults to None. upload (bool, optional): Flag indicating whether to upload the document. labels (Optional[Dict[str, str]], optional): Key-value labels to attach to the artifact. Defaults to None. tag (str, optional): Version tag for the artifact. Defaults to "". Returns: DynamicDocumentLoader: An instance of a dynamically defined subclass of BaseLoader. Example: >>> # Create a document loader specification >>> loader_spec = DocumentLoaderSpec( ... loader_class_name="langchain_community.document_loaders.TextLoader", ... src_name="file_path", ... ) >>> # Create a basic loader for a single file >>> loader = MLRunLoader( ... source_path="/path/to/document.txt", ... loader_spec=loader_spec, ... artifact_key="my_doc", ... producer=project, ... upload=True, ... ) >>> documents = loader.load() >>> # Create a loader with auto-generated keys >>> loader = MLRunLoader( ... source_path="/path/to/document.txt", ... loader_spec=loader_spec, ... artifact_key="%%", # %% will be replaced with encoded path ... producer=project, ... ) >>> documents = loader.load() >>> # Use with DirectoryLoader >>> from langchain_community.document_loaders import DirectoryLoader >>> dir_loader = DirectoryLoader( ... "/path/to/directory", ... glob="**/*.txt", ... loader_cls=MLRunLoader, ... loader_kwargs={ ... "loader_spec": loader_spec, ... "artifact_key": "%%", ... "producer": project, ... "upload": True, ... }, ... ) >>> documents = dir_loader.load() """ def __new__( cls, source_path: str, loader_spec: "DocumentLoaderSpec", artifact_key="%%", producer: Optional[Union["MlrunProject", str, "MLClientCtx"]] = None, # noqa: F821 upload: bool = False, tag: str = "", labels: Optional[dict[str, str]] = None, ): # Dynamically import BaseLoader from langchain_community.document_loaders.base import BaseLoader class DynamicDocumentLoader(BaseLoader): def __init__( self, local_path, loader_spec, artifact_key, producer, upload, tag, labels, ): self.producer = producer self.artifact_key = ( MLRunLoader.artifact_key_instance(artifact_key, local_path) if "%%" in artifact_key else artifact_key ) self.loader_spec = loader_spec self.local_path = local_path self.upload = upload self.tag = tag self.labels = labels # Resolve the producer if not self.producer: self.producer = mlrun.mlconf.active_project if isinstance(self.producer, str): self.producer = mlrun.get_or_create_project(self.producer) def lazy_load(self) -> Iterator["Document"]: # noqa: F821 collections = None try: artifact = self.producer.get_artifact(self.artifact_key, self.tag) collections = ( artifact.status.collections if artifact else collections ) except mlrun.MLRunNotFoundError: pass artifact = self.producer.log_document( key=self.artifact_key, document_loader_spec=self.loader_spec, local_path=self.local_path, upload=self.upload, labels=self.labels, tag=self.tag, collections=collections, ) res = artifact.to_langchain_documents() return res # Return an instance of the dynamically defined subclass instance = DynamicDocumentLoader( artifact_key=artifact_key, local_path=source_path, loader_spec=loader_spec, producer=producer, upload=upload, tag=tag, labels=labels, ) return instance
[docs] @staticmethod def artifact_key_instance(artifact_key: str, src_path: str) -> str: if "%%" in artifact_key: resolved_path = DocumentArtifact.key_from_source(src_path) artifact_key = artifact_key.replace("%%", resolved_path) return artifact_key
[docs] class DocumentArtifact(Artifact): """ A specific artifact class inheriting from generic artifact, used to maintain Document meta-data. """
[docs] @staticmethod def key_from_source(src_path: str) -> str: """Convert a source path into a valid artifact key by replacing invalid characters with underscores. Args: src_path (str): The source path to be converted into a valid artifact key Returns: str: A modified version of the source path where all invalid characters are replaced with underscores while preserving valid sequences in their original positions Examples: >>> DocumentArtifact.key_from_source("data/file-name(v1).txt") "data_file-name_v1__txt" """ pattern = mlrun.utils.regex.artifact_key[0] # Convert anchored pattern (^...$) to non-anchored version for finditer search_pattern = pattern.strip("^$") result = [] current_pos = 0 # Find all valid sequences for match in re.finditer(search_pattern, src_path): # Add '_' values for characters between matches for char in src_path[current_pos : match.start()]: result.append("_") # Add the valid sequence result.append(match.group()) current_pos = match.end() # Handle any remaining characters after the last match for char in src_path[current_pos:]: result.append("_") resolved_path = "".join(result) resolved_path = resolved_path.lstrip("_") return resolved_path
[docs] class DocumentArtifactSpec(ArtifactSpec): _dict_fields = ArtifactSpec._dict_fields + [ "document_loader", "original_source", ] def __init__( self, *args, document_loader: Optional[DocumentLoaderSpec] = None, original_source: Optional[str] = None, **kwargs, ): super().__init__(*args, **kwargs) self.document_loader = document_loader self.original_source = original_source
[docs] class DocumentArtifactStatus(ArtifactStatus): _dict_fields = ArtifactStatus._dict_fields + ["collections"] def __init__( self, *args, collections: Optional[dict] = None, **kwargs, ): super().__init__(*args, **kwargs) self.collections = collections if collections is not None else {}
kind = "document" METADATA_SOURCE_KEY = "source" METADATA_ORIGINAL_SOURCE_KEY = "original_source" METADATA_CHUNK_KEY = "mlrun_chunk" METADATA_ARTIFACT_TARGET_PATH_KEY = "mlrun_target_path" METADATA_ARTIFACT_TAG = "mlrun_tag" METADATA_ARTIFACT_KEY = "mlrun_key" METADATA_ARTIFACT_PROJECT = "mlrun_project" def __init__( self, original_source: Optional[str] = None, document_loader_spec: Optional[DocumentLoaderSpec] = None, collections: Optional[dict] = None, **kwargs, ): super().__init__(**kwargs) self.spec.document_loader = ( document_loader_spec.to_dict() if document_loader_spec else self.spec.document_loader ) self.spec.original_source = original_source or self.spec.original_source self.status = DocumentArtifact.DocumentArtifactStatus(collections=collections) @property def status(self) -> DocumentArtifactStatus: return self._status @status.setter def status(self, status): self._status = self._verify_dict( status, "status", DocumentArtifact.DocumentArtifactStatus ) @property def spec(self) -> DocumentArtifactSpec: return self._spec @spec.setter def spec(self, spec): self._spec = self._verify_dict( spec, "spec", DocumentArtifact.DocumentArtifactSpec )
[docs] def get_source(self): """Get the source URI for this artifact.""" return generate_artifact_uri(self.metadata.project, self.spec.db_key)
[docs] def to_langchain_documents( self, splitter: Optional["TextSplitter"] = None, # noqa: F821 ) -> list["Document"]: # noqa: F821 from langchain.schema import Document """ Create LC documents from the artifact Args: splitter (Optional[TextSplitter]): A LangChain TextSplitter to split the document into chunks. Returns: list[Document]: A list of LangChain Document objects. """ loader_spec = DocumentLoaderSpec.from_dict(self.spec.document_loader) if loader_spec.download_object and self.get_target_path(): with tempfile.NamedTemporaryFile() as tmp_file: mlrun.datastore.store_manager.object( url=self.get_target_path() ).download(tmp_file.name) loader = loader_spec.make_loader(tmp_file.name) documents = loader.load() elif self.spec.original_source: loader = loader_spec.make_loader(self.spec.original_source) documents = loader.load() else: raise ValueError( "No src_path or target_path provided. Cannot load document." ) results = [] idx = 0 for document in documents: if splitter: texts = splitter.split_text(document.page_content) else: texts = [document.page_content] metadata = document.metadata metadata[self.METADATA_ORIGINAL_SOURCE_KEY] = self.spec.original_source metadata[self.METADATA_SOURCE_KEY] = self.get_source() metadata[self.METADATA_ARTIFACT_TAG] = self.tag or "latest" metadata[self.METADATA_ARTIFACT_KEY] = self.db_key metadata[self.METADATA_ARTIFACT_PROJECT] = self.metadata.project if self.get_target_path(): metadata[self.METADATA_ARTIFACT_TARGET_PATH_KEY] = ( self.get_target_path() ) for text in texts: metadata[self.METADATA_CHUNK_KEY] = str(idx) doc = Document( page_content=text, metadata=metadata.copy(), ) results.append(doc) idx = idx + 1 return results
[docs] def collection_add(self, collection_id: str) -> bool: """ Add a collection ID to the artifact's collection list. Adds the specified collection ID to the artifact's collection mapping if it doesn't already exist. This method only modifies the client-side artifact object and does not persist the changes to the MLRun DB. To save the changes permanently, you must call project.update_artifact() after this method. Args: collection_id (str): The ID of the collection to add """ if collection_id not in self.status.collections: self.status.collections[collection_id] = "1" return True return False
[docs] def collection_remove(self, collection_id: str) -> bool: """ Remove a collection ID from the artifact's collection list. Removes the specified collection ID from the artifact's local collection mapping. This method only modifies the client-side artifact object and does not persist the changes to the MLRun DB. To save the changes permanently, you must call project.update_artifact() or context.update_artifact() after this method. Args: collection_id (str): The ID of the collection to remove """ if collection_id in self.status.collections: self.status.collections.pop(collection_id) return True return False