import typing
from typing import Any, Dict, List, Optional, Sequence
import cohere
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from .utils import _create_retry_decorator
[docs]
class CohereEmbeddings(BaseModel, Embeddings):
    """
    Implements the Embeddings interface with Cohere's text representation language
    models.
    Find out more about us at https://cohere.com and https://huggingface.co/CohereForAI
    This implementation uses the Embed API - see https://docs.cohere.com/reference/embed
    To use this you'll need to a Cohere API key - either pass it to cohere_api_key
    parameter or set the COHERE_API_KEY environment variable.
    API keys are available on https://cohere.com - it's free to sign up and trial API
    keys work with this implementation.
    Basic Example:
        .. code-block:: python
            cohere_embeddings = CohereEmbeddings(model="embed-english-light-v3.0")
            text = "This is a test document."
            query_result = cohere_embeddings.embed_query(text)
            print(query_result)
            doc_result = cohere_embeddings.embed_documents([text])
            print(doc_result)
    """
    client: Any  #: :meta private:
    """Cohere client."""
    async_client: Any  #: :meta private:
    """Cohere async client."""
    model: Optional[str] = None
    """Model name to use. It is mandatory to specify the model name."""
    truncate: Optional[str] = None
    """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
    cohere_api_key: Optional[str] = None
    embedding_types: Optional[Sequence[str]] = ["float"]
    "Specifies the types of embeddings you want to get back"
    max_retries: int = 3
    """Maximum number of retries to make when generating."""
    request_timeout: Optional[float] = None
    """Timeout in seconds for the Cohere API request."""
    user_agent: str = "langchain:partner"
    """Identifier for the application making the request."""
    base_url: Optional[str] = None
    """Override the default Cohere API URL."""
    class Config:
        """Configuration for this pydantic object."""
        arbitrary_types_allowed = True
        extra = Extra.forbid
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        cohere_api_key = get_from_dict_or_env(
            values, "cohere_api_key", "COHERE_API_KEY"
        )
        request_timeout = values.get("request_timeout")
        client_name = values["user_agent"]
        values["client"] = cohere.Client(
            cohere_api_key,
            timeout=request_timeout,
            client_name=client_name,
            base_url=values["base_url"],
        )
        values["async_client"] = cohere.AsyncClient(
            cohere_api_key,
            timeout=request_timeout,
            client_name=client_name,
            base_url=values["base_url"],
        )
        return values
    @root_validator()
    def validate_model_specified(cls, values: Dict) -> Dict:
        """Validate that model is specified."""
        model = values.get("model")
        if not model:
            raise ValueError(
                "Did not find `model`! Please "
                " pass `model` as a named parameter."
                " Please check out"
                " https://docs.cohere.com/reference/embed"
                " for available models."
            )
        return values
[docs]
    def embed_with_retry(self, **kwargs: Any) -> Any:
        """Use tenacity to retry the embed call."""
        retry_decorator = _create_retry_decorator(self.max_retries)
        @retry_decorator
        def _embed_with_retry(**kwargs: Any) -> Any:
            return self.client.embed(**kwargs)
        return _embed_with_retry(**kwargs) 
[docs]
    def aembed_with_retry(self, **kwargs: Any) -> Any:
        """Use tenacity to retry the embed call."""
        retry_decorator = _create_retry_decorator(self.max_retries)
        @retry_decorator
        async def _embed_with_retry(**kwargs: Any) -> Any:
            return await self.async_client.embed(**kwargs)
        return _embed_with_retry(**kwargs) 
[docs]
    def embed(
        self,
        texts: List[str],
        *,
        input_type: typing.Optional[cohere.EmbedInputType] = None,
    ) -> List[List[float]]:
        embeddings = self.embed_with_retry(
            model=self.model,
            texts=texts,
            input_type=input_type,
            truncate=self.truncate,
            embedding_types=self.embedding_types,
        ).embeddings
        return [list(map(float, e)) for e in embeddings] 
[docs]
    async def aembed(
        self,
        texts: List[str],
        *,
        input_type: typing.Optional[cohere.EmbedInputType] = None,
    ) -> List[List[float]]:
        embeddings = (
            await self.aembed_with_retry(
                model=self.model,
                texts=texts,
                input_type=input_type,
                truncate=self.truncate,
                embedding_types=self.embedding_types,
            )
        ).embeddings
        return [list(map(float, e)) for e in embeddings] 
[docs]
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of document texts.
        Args:
            texts: The list of texts to embed.
        Returns:
            List of embeddings, one for each text.
        """
        return self.embed(texts, input_type="search_document") 
[docs]
    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Async call out to Cohere's embedding endpoint.
        Args:
            texts: The list of texts to embed.
        Returns:
            List of embeddings, one for each text.
        """
        return await self.aembed(texts, input_type="search_document") 
[docs]
    def embed_query(self, text: str) -> List[float]:
        """Call out to Cohere's embedding endpoint.
        Args:
            text: The text to embed.
        Returns:
            Embeddings for the text.
        """
        return self.embed([text], input_type="search_query")[0] 
[docs]
    async def aembed_query(self, text: str) -> List[float]:
        """Async call out to Cohere's embedding endpoint.
        Args:
            text: The text to embed.
        Returns:
            Embeddings for the text.
        """
        return (await self.aembed([text], input_type="search_query"))[0]