from __future__ import annotations
import asyncio
import inspect
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
from astrapy.authentication import EmbeddingHeadersProvider, TokenProvider
from astrapy.db import (
AstraDB as AstraDBClient,
)
from astrapy.db import (
AsyncAstraDB as AsyncAstraDBClient,
)
from astrapy.exceptions import InsertManyException
from astrapy.info import CollectionVectorServiceOptions
from astrapy.results import UpdateResult
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.vectorstores import VectorStore
from langchain_astradb.utils.astradb import (
DEFAULT_DOCUMENT_CHUNK_SIZE,
MAX_CONCURRENT_DOCUMENT_DELETIONS,
MAX_CONCURRENT_DOCUMENT_INSERTIONS,
MAX_CONCURRENT_DOCUMENT_REPLACEMENTS,
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_astradb.utils.mmr import maximal_marginal_relevance
T = TypeVar("T")
U = TypeVar("U")
DocDict = Dict[str, Any] # dicts expressing entries to insert
# indexing options when creating a collection
DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]}
def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]:
visited_keys: Set[U] = set()
new_lst = []
for item in lst:
item_key = key(item)
if item_key not in visited_keys:
visited_keys.add(item_key)
new_lst.append(item)
return new_lst
[docs]
class AstraDBVectorStore(VectorStore):
"""AstraDB vector store integration.
Setup:
Install ``langchain-astradb`` and head to the [AstraDB website](https://astra.datastax.com), create an account, create a new database and [create an application token](https://docs.datastax.com/en/astra-db-serverless/administration/manage-application-tokens.html#generate-application-token).
.. code-block:: bash
pip install -qU langchain-astradb
Key init args — indexing params:
collection_name: str
Name of the collection.
embedding: Embeddings
Embedding function to use.
Key init args — client params:
api_endpoint: str
AstraDB API endpoint.
token: str
API token for Astra DB usage.
namespace: Optional[str]
Namespace (aka keyspace) where the collection is created
# TODO: Replace with relevant init params.
Instantiate:
Get your API endpoint and application token from the dashboard of your database.
.. code-block:: python
import getpass
from langchain_astradb import AstraDBVectorStore
from langchain_openai import OpenAIEmbeddings
ASTRA_DB_API_ENDPOINT = getpass.getpass("ASTRA_DB_API_ENDPOINT = ")
ASTRA_DB_APPLICATION_TOKEN = getpass.getpass("ASTRA_DB_APPLICATION_TOKEN = ")
vector_store = AstraDBVectorStore(
collection_name="astra_vector_langchain",
embedding=OpenAIEmbeddings(),
api_endpoint=ASTRA_DB_API_ENDPOINT,
token=ASTRA_DB_APPLICATION_TOKEN,
)
Add Documents:
.. code-block:: python
from langchain_core.documents import Document
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
document_3 = Document(page_content="i will be deleted :(")
documents = [document_1, document_2, document_3]
ids = ["1", "2", "3"]
vector_store.add_documents(documents=documents, ids=ids)
Delete Documents:
.. code-block:: python
vector_store.delete(ids=["3"])
Search:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1)
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'bar': 'baz'}]
Search with filter:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'bar': 'baz'}]
Search with score:
.. code-block:: python
results = vector_store.similarity_search_with_score(query="qux",k=1)
for doc, score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.916135] foo [{'baz': 'bar'}]
Async:
.. code-block:: python
# add documents
# await vector_store.aadd_documents(documents=documents, ids=ids)
# delete documents
# await vector_store.adelete(ids=["3"])
# search
# results = vector_store.asimilarity_search(query="thud",k=1)
# search with score
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
for doc,score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.916135] foo [{'baz': 'bar'}]
Use as Retriever:
.. code-block:: python
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 1, "score_threshold": 0.5},
)
retriever.invoke("thud")
.. code-block:: python
[Document(metadata={'bar': 'baz'}, page_content='thud')]
""" # noqa: E501
@staticmethod
def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if filter_dict is None:
return {}
else:
metadata_filter = {}
for k, v in filter_dict.items():
if k and k[0] == "$":
if isinstance(v, list):
metadata_filter[k] = [
AstraDBVectorStore._filter_to_metadata(f) for f in v
]
else:
# assume each list item can be fed back to this function
metadata_filter[k] = AstraDBVectorStore._filter_to_metadata(v) # type: ignore[assignment]
else:
metadata_filter[f"metadata.{k}"] = v
return metadata_filter
@staticmethod
def _normalize_metadata_indexing_policy(
metadata_indexing_include: Optional[Iterable[str]],
metadata_indexing_exclude: Optional[Iterable[str]],
collection_indexing_policy: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Validate the constructor indexing parameters and normalize them
into a ready-to-use dict for the 'options' when creating a collection.
"""
none_count = sum(
[
1 if var is None else 0
for var in [
metadata_indexing_include,
metadata_indexing_exclude,
collection_indexing_policy,
]
]
)
if none_count >= 2:
if metadata_indexing_include is not None:
return {
"allow": [
f"metadata.{md_field}" for md_field in metadata_indexing_include
]
}
elif metadata_indexing_exclude is not None:
return {
"deny": [
f"metadata.{md_field}" for md_field in metadata_indexing_exclude
]
}
elif collection_indexing_policy is not None:
return collection_indexing_policy
else:
return DEFAULT_INDEXING_OPTIONS
else:
raise ValueError(
"At most one of the parameters `metadata_indexing_include`,"
" `metadata_indexing_exclude` and `collection_indexing_policy`"
" can be specified as non null."
)
[docs]
def __init__(
self,
*,
collection_name: str,
embedding: Optional[Embeddings] = None,
token: Optional[Union[str, TokenProvider]] = None,
api_endpoint: Optional[str] = None,
environment: Optional[str] = None,
astra_db_client: Optional[AstraDBClient] = None,
async_astra_db_client: Optional[AsyncAstraDBClient] = None,
namespace: Optional[str] = None,
metric: Optional[str] = None,
batch_size: Optional[int] = None,
bulk_insert_batch_concurrency: Optional[int] = None,
bulk_insert_overwrite_concurrency: Optional[int] = None,
bulk_delete_concurrency: Optional[int] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
metadata_indexing_include: Optional[Iterable[str]] = None,
metadata_indexing_exclude: Optional[Iterable[str]] = None,
collection_indexing_policy: Optional[Dict[str, Any]] = None,
collection_vector_service_options: Optional[
CollectionVectorServiceOptions
] = None,
collection_embedding_api_key: Optional[
Union[str, EmbeddingHeadersProvider]
] = None,
) -> None:
"""Wrapper around DataStax Astra DB for vector-store workloads.
For quickstart and details, visit
https://docs.datastax.com/en/astra/astra-db-vector/
Args:
embedding: the embeddings function or service to use.
This enables client-side embedding functions or calls to external
embedding providers. If `embedding` is provided, arguments
`collection_vector_service_options` and
`collection_embedding_api_key` cannot be provided.
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage, either in the form of a string
or a subclass of `astrapy.authentication.TokenProvider`.
If not provided, the environment variable
ASTRA_DB_APPLICATION_TOKEN is inspected.
api_endpoint: full URL to the API endpoint, such as
`https://<DB-ID>-us-east1.apps.astra.datastax.com`. If not provided,
the environment variable ASTRA_DB_API_ENDPOINT is inspected.
environment: a string specifying the environment of the target Data API.
If omitted, defaults to "prod" (Astra DB production).
Other values are in `astrapy.constants.Environment` enum class.
astra_db_client:
*DEPRECATED starting from version 0.3.5.*
*Please use 'token', 'api_endpoint' and optionally 'environment'.*
you can pass an already-created 'astrapy.db.AstraDB' instance
(alternatively to 'token', 'api_endpoint' and 'environment').
async_astra_db_client:
*DEPRECATED starting from version 0.3.5.*
*Please use 'token', 'api_endpoint' and optionally 'environment'.*
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance
(alternatively to 'token', 'api_endpoint' and 'environment').
namespace: namespace (aka keyspace) where the collection is created.
If not provided, the environment variable ASTRA_DB_KEYSPACE is
inspected. Defaults to the database's "default namespace".
metric: similarity function to use out of those available in Astra DB.
If left out, it will use Astra DB API's defaults (i.e. "cosine" - but,
for performance reasons, "dot_product" is suggested if embeddings are
normalized to one).
batch_size: Size of document chunks for each individual insertion
API request. If not provided, astrapy defaults are applied.
bulk_insert_batch_concurrency: Number of threads or coroutines to insert
batches concurrently.
bulk_insert_overwrite_concurrency: Number of threads or coroutines in a
batch to insert pre-existing entries.
bulk_delete_concurrency: Number of threads or coroutines for
multiple-entry deletes.
pre_delete_collection: whether to delete the collection before creating it.
If False and the collection already exists, the collection will be used
as is.
metadata_indexing_include: an allowlist of the specific metadata subfields
that should be indexed for later filtering in searches.
metadata_indexing_exclude: a denylist of the specific metadata subfields
that should not be indexed for later filtering in searches.
collection_indexing_policy: a full "indexing" specification for
what fields should be indexed for later filtering in searches.
This dict must conform to to the API specifications
(see docs.datastax.com/en/astra/astra-db-vector/api-reference/
data-api-commands.html#advanced-feature-indexing-clause-on-createcollection)
collection_vector_service_options: specifies the use of server-side
embeddings within Astra DB. If passing this parameter, `embedding`
cannot be provided.
collection_embedding_api_key: for usage of server-side embeddings
within Astra DB. With this parameter one can supply an API Key
that will be passed to Astra DB with each data request.
This parameter can be either a string or a subclass of
`astrapy.authentication.EmbeddingHeadersProvider`.
This is useful when the service is configured for the collection,
but no corresponding secret is stored within
Astra's key management system.
This parameter cannot be provided without
specifying `collection_vector_service_options`.
Note:
For concurrency in synchronous :meth:`~add_texts`:, as a rule of thumb, on a
typical client machine it is suggested to keep the quantity
bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency
much below 1000 to avoid exhausting the client multithreading/networking
resources. The hardcoded defaults are somewhat conservative to meet
most machines' specs, but a sensible choice to test may be:
- bulk_insert_batch_concurrency = 80
- bulk_insert_overwrite_concurrency = 10
A bit of experimentation is required to nail the best results here,
depending on both the machine/network specs and the expected workload
(specifically, how often a write is an update of an existing id).
Remember you can pass concurrency settings to individual calls to
:meth:`~add_texts` and :meth:`~add_documents` as well.
"""
# Embedding and the server-side embeddings are mutually exclusive,
# as both specify how to produce embeddings
if embedding is None and collection_vector_service_options is None:
raise ValueError(
"Either an `embedding` or a `collection_vector_service_options`\
must be provided."
)
if embedding is not None and collection_vector_service_options is not None:
raise ValueError(
"Only one of `embedding` or `collection_vector_service_options`\
can be provided."
)
if (
collection_vector_service_options is None
and collection_embedding_api_key is not None
):
raise ValueError(
"`collection_embedding_api_key` cannot be provided unless"
" `collection_vector_service_options` is also passed."
)
self.embedding_dimension: Optional[int] = None
self.embedding = embedding
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.environment = environment
self.namespace = namespace
self.collection_vector_service_options = collection_vector_service_options
self.collection_embedding_api_key = collection_embedding_api_key
# Concurrency settings
self.batch_size: Optional[int] = batch_size or DEFAULT_DOCUMENT_CHUNK_SIZE
self.bulk_insert_batch_concurrency: int = (
bulk_insert_batch_concurrency or MAX_CONCURRENT_DOCUMENT_INSERTIONS
)
self.bulk_insert_overwrite_concurrency: int = (
bulk_insert_overwrite_concurrency or MAX_CONCURRENT_DOCUMENT_REPLACEMENTS
)
self.bulk_delete_concurrency: int = (
bulk_delete_concurrency or MAX_CONCURRENT_DOCUMENT_DELETIONS
)
# "vector-related" settings
self.metric = metric
embedding_dimension_m: Union[int, Awaitable[int], None] = None
if self.embedding is not None:
if setup_mode == SetupMode.ASYNC:
embedding_dimension_m = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC or setup_mode == SetupMode.OFF:
embedding_dimension_m = self._get_embedding_dimension()
# indexing policy setting
self.indexing_policy: Dict[str, Any] = self._normalize_metadata_indexing_policy(
metadata_indexing_include=metadata_indexing_include,
metadata_indexing_exclude=metadata_indexing_exclude,
collection_indexing_policy=collection_indexing_policy,
)
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=self.token,
api_endpoint=self.api_endpoint,
environment=self.environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=self.namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension_m,
metric=self.metric,
requested_indexing_policy=self.indexing_policy,
default_indexing_policy=DEFAULT_INDEXING_OPTIONS,
collection_vector_service_options=self.collection_vector_service_options,
collection_embedding_api_key=self.collection_embedding_api_key,
)
def _get_embedding_dimension(self) -> int:
assert self.embedding is not None
if self.embedding_dimension is None:
self.embedding_dimension = len(
self.embedding.embed_query(text="This is a sample sentence.")
)
return self.embedding_dimension
async def _aget_embedding_dimension(self) -> int:
assert self.embedding is not None
if self.embedding_dimension is None:
self.embedding_dimension = len(
await self.embedding.aembed_query(text="This is a sample sentence.")
)
return self.embedding_dimension
@property
def embeddings(self) -> Optional[Embeddings]:
"""
Accesses the supplied embeddings object. If using server-side embeddings,
this will return None.
"""
return self.embedding
def _using_vectorize(self) -> bool:
"""Indicates whether server-side embeddings are being used."""
return self.collection_vector_service_options is not None
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The underlying API calls already returns a "score proper",
i.e. one in [0, 1] where higher means more *similar*,
so here the final score transformation is not reversing the interval:
"""
return lambda score: score
[docs]
def clear(self) -> None:
"""Empty the collection of all its stored entries."""
self.astra_env.ensure_db_setup()
self.astra_env.collection.delete_many({})
[docs]
async def aclear(self) -> None:
"""Empty the collection of all its stored entries."""
await self.astra_env.aensure_db_setup()
await self.astra_env.async_collection.delete_many({})
[docs]
def delete_by_document_id(self, document_id: str) -> bool:
"""
Remove a single document from the store, given its document ID.
Args:
document_id: The document ID
Returns
True if a document has indeed been deleted, False if ID not found.
"""
self.astra_env.ensure_db_setup()
# self.collection is not None (by _ensure_astra_db_client)
deletion_response = self.astra_env.collection.delete_one({"_id": document_id})
return deletion_response.deleted_count == 1
[docs]
async def adelete_by_document_id(self, document_id: str) -> bool:
"""
Remove a single document from the store, given its document ID.
Args:
document_id: The document ID
Returns
True if a document has indeed been deleted, False if ID not found.
"""
await self.astra_env.aensure_db_setup()
deletion_response = await self.astra_env.async_collection.delete_one(
{"_id": document_id},
)
return deletion_response.deleted_count == 1
[docs]
def delete(
self,
ids: Optional[List[str]] = None,
concurrency: Optional[int] = None,
**kwargs: Any,
) -> Optional[bool]:
"""Delete by vector ids.
Args:
ids: List of ids to delete.
concurrency: max number of threads issuing single-doc delete requests.
Defaults to vector-store overall setting.
Returns:
True if deletion is (entirely) successful, False otherwise.
"""
if kwargs:
warnings.warn(
"Method 'delete' of AstraDBVectorStore vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
if ids is None:
raise ValueError("No ids provided to delete.")
_max_workers = concurrency or self.bulk_delete_concurrency
with ThreadPoolExecutor(max_workers=_max_workers) as tpe:
_ = list(
tpe.map(
self.delete_by_document_id,
ids,
)
)
return True
[docs]
async def adelete(
self,
ids: Optional[List[str]] = None,
concurrency: Optional[int] = None,
**kwargs: Any,
) -> Optional[bool]:
"""Delete by vector ids.
Args:
ids: List of ids to delete.
concurrency: max number of simultaneous coroutines for single-doc
delete requests. Defaults to vector-store overall setting.
Returns:
True if deletion is (entirely) successful, False otherwise.
"""
if kwargs:
warnings.warn(
"Method 'adelete' of AstraDBVectorStore invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
if ids is None:
raise ValueError("No ids provided to delete.")
_max_workers = concurrency or self.bulk_delete_concurrency
return all(
await gather_with_concurrency(
_max_workers, *[self.adelete_by_document_id(doc_id) for doc_id in ids]
)
)
[docs]
def delete_collection(self) -> None:
"""
Completely delete the collection from the database (as opposed
to :meth:`~clear`, which empties it only).
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
self.astra_env.ensure_db_setup()
self.astra_env.collection.drop()
[docs]
async def adelete_collection(self) -> None:
"""
Completely delete the collection from the database (as opposed
to :meth:`~aclear`, which empties it only).
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
await self.astra_env.aensure_db_setup()
await self.astra_env.async_collection.drop()
@staticmethod
def _get_documents_to_insert(
texts: Iterable[str],
embedding_vectors: List[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[DocDict]:
if ids is None:
ids = [uuid.uuid4().hex for _ in texts]
if metadatas is None:
metadatas = [{} for _ in texts]
#
documents_to_insert = [
{
"content": b_txt,
"_id": b_id,
"$vector": b_emb,
"metadata": b_md,
}
for b_txt, b_emb, b_id, b_md in zip(
texts,
embedding_vectors,
ids,
metadatas,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list(
documents_to_insert[::-1],
lambda document: document["_id"],
)[::-1]
return uniqued_documents_to_insert
@staticmethod
def _get_vectorize_documents_to_insert(
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[DocDict]:
if ids is None:
ids = [uuid.uuid4().hex for _ in texts]
if metadatas is None:
metadatas = [{} for _ in texts]
#
documents_to_insert = [
{
"_id": b_id,
"$vectorize": b_txt,
"metadata": b_md,
}
for b_txt, b_id, b_md in zip(
texts,
ids,
metadatas,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list(
documents_to_insert[::-1],
lambda document: document["_id"],
)[::-1]
return uniqued_documents_to_insert
@staticmethod
def _get_missing_from_batch(
document_batch: List[DocDict], insert_result: Dict[str, Any]
) -> Tuple[List[str], List[DocDict]]:
if "status" not in insert_result:
raise ValueError(
f"API Exception while running bulk insertion: {str(insert_result)}"
)
batch_inserted = insert_result["status"]["insertedIds"]
# estimation of the preexisting documents that failed
missed_inserted_ids = {document["_id"] for document in document_batch} - set(
batch_inserted
)
errors = insert_result.get("errors", [])
# careful for other sources of error other than "doc already exists"
num_errors = len(errors)
unexpected_errors = any(
error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors
)
if num_errors != len(missed_inserted_ids) or unexpected_errors:
raise ValueError(
f"API Exception while running bulk insertion: {str(errors)}"
)
# deal with the missing insertions as upserts
missing_from_batch = [
document
for document in document_batch
if document["_id"] in missed_inserted_ids
]
return batch_inserted, missing_from_batch
[docs]
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
*,
batch_size: Optional[int] = None,
batch_concurrency: Optional[int] = None,
overwrite_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run texts through the embeddings and add them to the vectorstore.
If passing explicit ids, those entries whose id is in the store already
will be replaced.
Args:
texts: Texts to add to the vectorstore.
metadatas: Optional list of metadatas.
ids: Optional list of ids.
batch_size: Size of document chunks for each individual insertion
API request. If not provided, defaults to the vector-store
overall defaults (which in turn falls to astrapy defaults).
batch_concurrency: number of threads to process
insertion batches concurrently. Defaults to the vector-store
overall setting if not provided.
overwrite_concurrency: number of threads to process
pre-existing documents in each batch. Defaults to the vector-store
overall setting if not provided.
Note:
There are constraints on the allowed field names
in the metadata dictionaries, coming from the underlying Astra DB API.
For instance, the `$` (dollar sign) cannot be used in the dict keys.
See this document for details:
https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html
Returns:
The list of ids of the added texts.
"""
if kwargs:
warnings.warn(
"Method 'add_texts' of AstraDBVectorStore vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
self.astra_env.ensure_db_setup()
if self._using_vectorize():
documents_to_insert = self._get_vectorize_documents_to_insert(
texts, metadatas, ids
)
else:
assert self.embedding is not None
embedding_vectors = self.embedding.embed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
# perform an AstraPy insert_many, catching exceptions for overwriting docs
ids_to_replace: List[int]
inserted_ids: List[str] = []
try:
insert_many_result = self.astra_env.collection.insert_many(
documents_to_insert,
ordered=False,
concurrency=batch_concurrency or self.bulk_insert_batch_concurrency,
chunk_size=batch_size or self.batch_size,
)
ids_to_replace = []
inserted_ids = insert_many_result.inserted_ids
except InsertManyException as err:
inserted_ids = err.partial_result.inserted_ids
inserted_ids_set = set(inserted_ids)
ids_to_replace = [
document["_id"]
for document in documents_to_insert
if document["_id"] not in inserted_ids_set
]
# if necessary, replace docs for the non-inserted ids
if ids_to_replace:
documents_to_replace = [
document
for document in documents_to_insert
if document["_id"] in ids_to_replace
]
_max_workers = (
overwrite_concurrency or self.bulk_insert_overwrite_concurrency
)
with ThreadPoolExecutor(
max_workers=_max_workers,
) as executor:
def _replace_document(
document: Dict[str, Any],
) -> Tuple[UpdateResult, str]:
return self.astra_env.collection.replace_one(
{"_id": document["_id"]},
document,
), document["_id"]
replace_results = list(
executor.map(
_replace_document,
documents_to_replace,
)
)
replaced_count = sum(r_res.update_info["n"] for r_res, _ in replace_results)
inserted_ids += [replaced_id for _, replaced_id in replace_results]
if replaced_count != len(ids_to_replace):
missing = len(ids_to_replace) - replaced_count
raise ValueError(
"AstraDBVectorStore.add_texts could not insert all requested "
f"documents ({missing} failed replace_one calls)"
)
return inserted_ids
[docs]
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
*,
batch_size: Optional[int] = None,
batch_concurrency: Optional[int] = None,
overwrite_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run texts through the embeddings and add them to the vectorstore.
If passing explicit ids, those entries whose id is in the store already
will be replaced.
Args:
texts: Texts to add to the vectorstore.
metadatas: Optional list of metadatas.
ids: Optional list of ids.
batch_size: Size of document chunks for each individual insertion
API request. If not provided, defaults to the vector-store
overall defaults (which in turn falls to astrapy defaults).
batch_concurrency: number of simultaneous coroutines to process
insertion batches concurrently. Defaults to the vector-store
overall setting if not provided.
overwrite_concurrency: number of simultaneous coroutines to process
pre-existing documents in each batch. Defaults to the vector-store
overall setting if not provided.
Note:
There are constraints on the allowed field names
in the metadata dictionaries, coming from the underlying Astra DB API.
For instance, the `$` (dollar sign) cannot be used in the dict keys.
See this document for details:
https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html
Returns:
The list of ids of the added texts.
"""
if kwargs:
warnings.warn(
"Method 'aadd_texts' of AstraDBVectorStore invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
await self.astra_env.aensure_db_setup()
if self._using_vectorize():
# using server-side embeddings
documents_to_insert = self._get_vectorize_documents_to_insert(
texts, metadatas, ids
)
else:
assert self.embedding is not None
embedding_vectors = await self.embedding.aembed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
# perform an AstraPy insert_many, catching exceptions for overwriting docs
ids_to_replace: List[int]
inserted_ids: List[str] = []
try:
insert_many_result = await self.astra_env.async_collection.insert_many(
documents_to_insert,
ordered=False,
concurrency=batch_concurrency or self.bulk_insert_batch_concurrency,
chunk_size=batch_size or self.batch_size,
)
ids_to_replace = []
inserted_ids = insert_many_result.inserted_ids
except InsertManyException as err:
inserted_ids = err.partial_result.inserted_ids
inserted_ids_set = set(inserted_ids)
ids_to_replace = [
document["_id"]
for document in documents_to_insert
if document["_id"] not in inserted_ids_set
]
# if necessary, replace docs for the non-inserted ids
if ids_to_replace:
documents_to_replace = [
document
for document in documents_to_insert
if document["_id"] in ids_to_replace
]
sem = asyncio.Semaphore(
overwrite_concurrency or self.bulk_insert_overwrite_concurrency,
)
_async_collection = self.astra_env.async_collection
async def _replace_document(
document: Dict[str, Any],
) -> Tuple[UpdateResult, str]:
async with sem:
return await _async_collection.replace_one(
{"_id": document["_id"]},
document,
), document["_id"]
tasks = [
asyncio.create_task(_replace_document(document))
for document in documents_to_replace
]
replace_results = await asyncio.gather(*tasks, return_exceptions=False)
replaced_count = sum(r_res.update_info["n"] for r_res, _ in replace_results)
inserted_ids += [replaced_id for _, replaced_id in replace_results]
if replaced_count != len(ids_to_replace):
missing = len(ids_to_replace) - replaced_count
raise ValueError(
"AstraDBVectorStore.add_texts could not insert all requested "
f"documents ({missing} failed replace_one calls)"
)
return inserted_ids
[docs]
def similarity_search_with_score_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector with score and id.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score, id), the most similar to the query vector.
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
self.astra_env.collection.find(
filter=metadata_parameter,
projection={
"_id": True,
"content": True,
"metadata": True,
},
limit=k,
include_similarity=True,
sort={"$vector": embedding},
)
)
#
return [
(
Document(
page_content=hit["content"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
for hit in hits
]
[docs]
async def asimilarity_search_with_score_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector with score and id.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score, id), the most similar to the query vector.
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
return [
(
Document(
page_content=hit["content"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
async for hit in self.astra_env.async_collection.find(
filter=metadata_parameter,
projection={
"_id": True,
"content": True,
"metadata": True,
},
limit=k,
include_similarity=True,
sort={"$vector": embedding},
)
]
def _similarity_search_with_score_id_with_vectorize(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id using $vectorize.
This is only available when using server-side embeddings.
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
self.astra_env.collection.find(
filter=metadata_parameter,
projection={
"_id": True,
"$vectorize": True,
"metadata": True,
},
limit=k,
include_similarity=True,
sort={"$vectorize": query},
)
)
#
return [
(
Document(
# text content is stored in $vectorize instead of content
page_content=hit["$vectorize"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
for hit in hits
]
async def _asimilarity_search_with_score_id_with_vectorize(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id using $vectorize.
This is only available when using server-side embeddings.
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
return [
(
Document(
# text content is stored in $vectorize instead of content
page_content=hit["$vectorize"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
async for hit in self.astra_env.async_collection.find(
filter=metadata_parameter,
projection={
"_id": True,
"$vectorize": True,
"metadata": True,
},
limit=k,
include_similarity=True,
sort={"$vectorize": query},
)
]
[docs]
def similarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score, id), the most similar to the query.
"""
if self._using_vectorize():
return self._similarity_search_with_score_id_with_vectorize(
query=query,
k=k,
filter=filter,
)
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
[docs]
async def asimilarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score, id), the most similar to the query.
"""
if self._using_vectorize():
return await self._asimilarity_search_with_score_id_with_vectorize(
query=query,
k=k,
filter=filter,
)
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
[docs]
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector with score.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score), the most similar to the query vector.
"""
return [
(doc, score)
for (doc, score, doc_id) in self.similarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]
[docs]
async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector with score.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score), the most similar to the query vector.
"""
return [
(doc, score)
for (
doc,
score,
doc_id,
) in await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]
[docs]
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query.
"""
if self._using_vectorize():
return [
doc
for (doc, _, _) in self._similarity_search_with_score_id_with_vectorize(
query,
k,
filter=filter,
)
]
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
[docs]
async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query.
"""
if self._using_vectorize():
return [
doc
for (
doc,
_,
_,
) in await self._asimilarity_search_with_score_id_with_vectorize(
query,
k,
filter=filter,
)
]
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
[docs]
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query vector.
"""
return [
doc
for doc, _ in self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]
[docs]
async def asimilarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of Documents most similar to the query vector.
"""
return [
doc
for doc, _ in await self.asimilarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]
[docs]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query with score.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score), the most similar to the query vector.
"""
if self._using_vectorize():
return [
(doc, score)
for (
doc,
score,
doc_id,
) in self._similarity_search_with_score_id_with_vectorize(
query=query,
k=k,
filter=filter,
)
]
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
[docs]
async def asimilarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query with score.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
Returns:
The list of (Document, score), the most similar to the query vector.
"""
if self._using_vectorize():
return [
(doc, score)
for (
doc,
score,
doc_id,
) in await self._asimilarity_search_with_score_id_with_vectorize(
query=query,
k=k,
filter=filter,
)
]
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
def _run_mmr_query_by_sort(
self,
sort: Dict[str, Any],
k: int,
fetch_k: int,
lambda_mult: float,
metadata_parameter: Dict[str, Any],
**kwargs: Any,
) -> List[Document]:
prefetch_cursor = self.astra_env.collection.find(
filter=metadata_parameter,
projection={
"_id": True,
"content": True,
"metadata": True,
"$vector": True,
"$vectorize": True,
},
limit=fetch_k,
include_similarity=True,
include_sort_vector=True,
sort=sort,
)
prefetch_hits = list(prefetch_cursor)
query_vector = prefetch_cursor.get_sort_vector()
return self._get_mmr_hits(
embedding=query_vector, # type: ignore[arg-type]
k=k,
lambda_mult=lambda_mult,
prefetch_hits=prefetch_hits,
content_field="$vectorize" if self._using_vectorize() else "content",
)
async def _arun_mmr_query_by_sort(
self,
sort: Dict[str, Any],
k: int,
fetch_k: int,
lambda_mult: float,
metadata_parameter: Dict[str, Any],
**kwargs: Any,
) -> List[Document]:
prefetch_cursor = self.astra_env.async_collection.find(
filter=metadata_parameter,
projection={
"_id": True,
"content": True,
"metadata": True,
"$vector": True,
"$vectorize": True,
},
limit=fetch_k,
include_similarity=True,
include_sort_vector=True,
sort=sort,
)
prefetch_hits = [hit async for hit in prefetch_cursor]
query_vector = await prefetch_cursor.get_sort_vector()
return self._get_mmr_hits(
embedding=query_vector, # type: ignore[arg-type]
k=k,
lambda_mult=lambda_mult,
prefetch_hits=prefetch_hits,
content_field="$vectorize" if self._using_vectorize() else "content",
)
@staticmethod
def _get_mmr_hits(
embedding: List[float],
k: int,
lambda_mult: float,
prefetch_hits: List[DocDict],
content_field: str,
) -> List[Document]:
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
mmr_hits = [
prefetch_hit
for prefetch_index, prefetch_hit in enumerate(prefetch_hits)
if prefetch_index in mmr_chosen_indices
]
return [
Document(
page_content=hit[content_field],
metadata=hit["metadata"],
)
for hit in mmr_hits
]
[docs]
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
filter: Filter on the metadata to apply.
Returns:
The list of Documents selected by maximal marginal relevance.
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
return self._run_mmr_query_by_sort(
sort={"$vector": embedding},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)
[docs]
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
filter: Filter on the metadata to apply.
Returns:
The list of Documents selected by maximal marginal relevance.
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
return await self._arun_mmr_query_by_sort(
sort={"$vector": embedding},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)
[docs]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
filter: Filter on the metadata to apply.
Returns:
The list of Documents selected by maximal marginal relevance.
"""
if self._using_vectorize():
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
return self._run_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
[docs]
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Query to look up documents similar to.
k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
filter: Filter on the metadata to apply.
Returns:
The list of Documents selected by maximal marginal relevance.
"""
if self._using_vectorize():
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
return await self._arun_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
@classmethod
def _from_kwargs(
cls: Type[AstraDBVectorStore],
**kwargs: Any,
) -> AstraDBVectorStore:
_args = inspect.getfullargspec(AstraDBVectorStore.__init__).args
_kwargs = inspect.getfullargspec(AstraDBVectorStore.__init__).kwonlyargs
known_kwarg_keys = (set(_args) | set(_kwargs)) - {"self"}
if kwargs:
unknown_kwarg_keys = set(kwargs.keys()) - known_kwarg_keys
if unknown_kwarg_keys:
warnings.warn(
(
"Method 'from_texts/afrom_texts' of AstraDBVectorStore "
"vector store invoked with unsupported arguments "
f"({', '.join(sorted(unknown_kwarg_keys))}), "
"which will be ignored."
),
UserWarning,
)
known_kwargs = {k: v for k, v in kwargs.items() if k in known_kwarg_keys}
return cls(**known_kwargs)
[docs]
@classmethod
def from_texts(
cls: Type[AstraDBVectorStore],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> AstraDBVectorStore:
"""Create an Astra DB vectorstore from raw texts.
Args:
texts: the texts to insert.
embedding: the embedding function to use in the store.
metadatas: metadata dicts for the texts.
ids: ids to associate to the texts.
**kwargs: you can pass any argument that you would
to :meth:`~add_texts` and/or to the 'AstraDBVectorStore' constructor
(see these methods for details). These arguments will be
routed to the respective methods as they are.
Returns:
an `AstraDBVectorStore` vectorstore.
"""
_add_texts_inspection = inspect.getfullargspec(AstraDBVectorStore.add_texts)
_method_args = (
set(_add_texts_inspection.kwonlyargs)
| set(_add_texts_inspection.kwonlyargs)
) - {"self", "texts", "metadatas", "ids"}
_init_kwargs = {k: v for k, v in kwargs.items() if k not in _method_args}
_method_kwargs = {k: v for k, v in kwargs.items() if k in _method_args}
astra_db_store = AstraDBVectorStore._from_kwargs(
embedding=embedding,
**_init_kwargs,
)
astra_db_store.add_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
**_method_kwargs,
)
return astra_db_store
[docs]
@classmethod
async def afrom_texts(
cls: Type[AstraDBVectorStore],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> AstraDBVectorStore:
"""Create an Astra DB vectorstore from raw texts.
Args:
texts: the texts to insert.
metadatas: metadata dicts for the texts.
ids: ids to associate to the texts.
**kwargs: you can pass any argument that you would
to :meth:`~aadd_texts` and/or to the 'AstraDBVectorStore' constructor
(see these methods for details). These arguments will be
routed to the respective methods as they are.
Returns:
an `AstraDBVectorStore` vectorstore.
"""
_aadd_texts_inspection = inspect.getfullargspec(AstraDBVectorStore.aadd_texts)
_method_args = (
set(_aadd_texts_inspection.kwonlyargs)
| set(_aadd_texts_inspection.kwonlyargs)
) - {"self", "texts", "metadatas", "ids"}
_init_kwargs = {k: v for k, v in kwargs.items() if k not in _method_args}
_method_kwargs = {k: v for k, v in kwargs.items() if k in _method_args}
astra_db_store = AstraDBVectorStore._from_kwargs(
embedding=embedding,
**_init_kwargs,
)
await astra_db_store.aadd_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
**_method_kwargs,
)
return astra_db_store
[docs]
@classmethod
def from_documents(
cls: Type[AstraDBVectorStore],
documents: List[Document],
embedding: Optional[Embeddings] = None,
**kwargs: Any,
) -> AstraDBVectorStore:
"""Create an Astra DB vectorstore from a document list.
Utility method that defers to 'from_texts' (see that one).
Args: see 'from_texts', except here you have to supply 'documents'
in place of 'texts' and 'metadatas'.
Returns:
an `AstraDBVectorStore` vectorstore.
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(
texts,
embedding=embedding,
metadatas=metadatas,
**kwargs,
)
[docs]
@classmethod
async def afrom_documents(
cls: Type[AstraDBVectorStore],
documents: List[Document],
embedding: Optional[Embeddings] = None,
**kwargs: Any,
) -> AstraDBVectorStore:
"""Create an Astra DB vectorstore from a document list.
Utility method that defers to 'afrom_texts' (see that one).
Args: see 'afrom_texts', except here you have to supply 'documents'
in place of 'texts' and 'metadatas'.
Returns:
an `AstraDBVectorStore` vectorstore.
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return await cls.afrom_texts(
texts,
embedding=embedding,
metadatas=metadatas,
**kwargs,
)