from __future__ import annotations
import json
import logging
import warnings
from typing import (
    Any,
    AsyncIterator,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Union,
)
from astrapy.authentication import TokenProvider
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from langchain_astradb.utils.astradb import (
    SetupMode,
    _AstraDBCollectionEnvironment,
)
logger = logging.getLogger(__name__)
_NOT_SET = object()
[docs]
class AstraDBLoader(BaseLoader):
[docs]
    def __init__(
        self,
        collection_name: str,
        *,
        token: Optional[Union[str, TokenProvider]] = None,
        api_endpoint: Optional[str] = None,
        environment: Optional[str] = None,
        astra_db_client: Optional[AstraDB] = None,
        async_astra_db_client: Optional[AsyncAstraDB] = None,
        namespace: Optional[str] = None,
        filter_criteria: Optional[Dict[str, Any]] = None,
        projection: Optional[Dict[str, Any]] = _NOT_SET,  # type: ignore[assignment]
        find_options: Optional[Dict[str, Any]] = None,
        limit: Optional[int] = None,
        nb_prefetched: int = _NOT_SET,  # type: ignore[assignment]
        page_content_mapper: Callable[[Dict], str] = json.dumps,
        metadata_mapper: Optional[Callable[[Dict], Dict[str, Any]]] = None,
    ) -> None:
        """Load DataStax Astra DB documents.
        Args:
            collection_name: name of the Astra DB collection to use.
            token: API token for Astra DB usage, either in the form of a string
                or a subclass of `astrapy.authentication.TokenProvider`.
                If not provided, the environment variable
                ASTRA_DB_APPLICATION_TOKEN is inspected.
            api_endpoint: full URL to the API endpoint, such as
                `https://<DB-ID>-us-east1.apps.astra.datastax.com`. If not provided,
                the environment variable ASTRA_DB_API_ENDPOINT is inspected.
            environment: a string specifying the environment of the target Data API.
                If omitted, defaults to "prod" (Astra DB production).
                Other values are in `astrapy.constants.Environment` enum class.
            astra_db_client:
                *DEPRECATED starting from version 0.3.5.*
                *Please use 'token', 'api_endpoint' and optionally 'environment'.*
                you can pass an already-created 'astrapy.db.AstraDB' instance
                (alternatively to 'token', 'api_endpoint' and 'environment').
            async_astra_db_client:
                *DEPRECATED starting from version 0.3.5.*
                *Please use 'token', 'api_endpoint' and optionally 'environment'.*
                you can pass an already-created 'astrapy.db.AsyncAstraDB' instance
                (alternatively to 'token', 'api_endpoint' and 'environment').
            namespace: namespace (aka keyspace) where the collection resides.
                If not provided, the environment variable ASTRA_DB_KEYSPACE is
                inspected. Defaults to the database's "default namespace".
            filter_criteria: Criteria to filter documents.
            projection: Specifies the fields to return. If not provided, reads
                fall back to the Data API default projection.
            find_options: Additional options for the query.
                *DEPRECATED starting from version 0.3.5.*
                *For limiting, please use `limit`. Other options are ignored.*
            limit: a maximum number of documents to return in the read query.
            nb_prefetched: Max number of documents to pre-fetch.
                *IGNORED starting from v. 0.3.5: astrapy v1.0+ does not support it.*
            page_content_mapper: Function applied to collection documents to create
                the `page_content` of the LangChain Document. Defaults to `json.dumps`.
        """
        astra_db_env = _AstraDBCollectionEnvironment(
            collection_name=collection_name,
            token=token,
            api_endpoint=api_endpoint,
            environment=environment,
            astra_db_client=astra_db_client,
            async_astra_db_client=async_astra_db_client,
            namespace=namespace,
            setup_mode=SetupMode.OFF,
        )
        self.astra_db_env = astra_db_env
        self.filter = filter_criteria
        self._projection: Optional[Dict[str, Any]] = (
            projection if projection is not _NOT_SET else {"*": True}
        )
        # warning if 'prefetched' passed
        if nb_prefetched is not _NOT_SET:
            warnings.warn(
                (
                    "Parameter 'nb_prefetched' is not supported by the Data API "
                    "client and will be ignored in reading document."
                ),
                UserWarning,
            )
        # normalizing limit and options and deprecations
        _find_options = find_options.copy() if find_options else {}
        if "limit" in _find_options:
            if limit is not None:
                raise ValueError(
                    "Duplicate 'limit' directive supplied. Please remove it "
                    "from the 'find_options' map parameter."
                )
            else:
                warnings.warn(
                    (
                        "Passing 'limit' as part of the 'find_options' "
                        "dictionary is deprecated starting from version 0.3.5. "
                        "Please switch to passing 'limit=<number>' "
                        "directly in the constructor."
                    ),
                    DeprecationWarning,
                )
        self.limit = _find_options.pop("limit", limit)
        if _find_options:
            warnings.warn(
                (
                    "Unknown keys passed in the 'find_options' dictionary. "
                    "This parameter is deprecated starting from version 0.3.5."
                ),
                DeprecationWarning,
            )
        #
        self.nb_prefetched = nb_prefetched
        self.page_content_mapper = page_content_mapper
        self.metadata_mapper = metadata_mapper or (
            lambda _: {
                "namespace": self.astra_db_env.database.namespace,
                "api_endpoint": self.astra_db_env.database.api_endpoint,
                "collection": collection_name,
            }
        ) 
    def _to_langchain_doc(self, doc: Dict[str, Any]) -> Document:
        return Document(
            page_content=self.page_content_mapper(doc),
            metadata=self.metadata_mapper(doc),
        )
[docs]
    def lazy_load(self) -> Iterator[Document]:
        for doc in self.astra_db_env.collection.find(
            filter=self.filter,
            projection=self._projection,
            limit=self.limit,
            # prefetch: not available at the moment (silently ignored)
            # prefetched=self.nb_prefetched,
        ):
            yield self._to_langchain_doc(doc) 
[docs]
    async def aload(self) -> List[Document]:
        """Load data into Document objects."""
        return [doc async for doc in self.alazy_load()] 
[docs]
    async def alazy_load(self) -> AsyncIterator[Document]:
        async for doc in self.astra_db_env.async_collection.find(
            filter=self.filter,
            projection=self._projection,
            limit=self.limit,
            # prefetch: not available at the moment (silently ignored):
            # prefetched=self.nb_prefetched,
        ):
            yield self._to_langchain_doc(doc)