Source code for langchain_core.runnables.utils

"""Utility code for runnables."""

from __future__ import annotations

import ast
import asyncio
import inspect
import textwrap
from functools import lru_cache
from inspect import signature
from itertools import groupby
from typing import (
    Any,
    AsyncIterable,
    AsyncIterator,
    Awaitable,
    Callable,
    Coroutine,
    Dict,
    Iterable,
    List,
    Mapping,
    NamedTuple,
    Optional,
    Protocol,
    Sequence,
    Set,
    Type,
    TypeVar,
    Union,
)

from typing_extensions import TypeGuard

from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent

Input = TypeVar("Input", contravariant=True)
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output", covariant=True)


[docs] async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: """Run a coroutine with a semaphore. Args: semaphore: The semaphore to use. coro: The coroutine to run. Returns: The result of the coroutine. """ async with semaphore: return await coro
[docs] async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: """Gather coroutines with a limit on the number of concurrent coroutines. Args: n: The number of coroutines to run concurrently. *coros: The coroutines to run. Returns: The results of the coroutines. """ if n is None: return await asyncio.gather(*coros) semaphore = asyncio.Semaphore(n) return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
[docs] def accepts_run_manager(callable: Callable[..., Any]) -> bool: """Check if a callable accepts a run_manager argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a run_manager argument, False otherwise. """ try: return signature(callable).parameters.get("run_manager") is not None except ValueError: return False
[docs] def accepts_config(callable: Callable[..., Any]) -> bool: """Check if a callable accepts a config argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a config argument, False otherwise. """ try: return signature(callable).parameters.get("config") is not None except ValueError: return False
[docs] def accepts_context(callable: Callable[..., Any]) -> bool: """Check if a callable accepts a context argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a context argument, False otherwise. """ try: return signature(callable).parameters.get("context") is not None except ValueError: return False
@lru_cache(maxsize=1) def asyncio_accepts_context() -> bool: return accepts_context(asyncio.create_task)
[docs] class IsLocalDict(ast.NodeVisitor): """Check if a name is a local dict."""
[docs] def __init__(self, name: str, keys: Set[str]) -> None: """Initialize the visitor. Args: name: The name to check. keys: The keys to populate. """ self.name = name self.keys = keys
[docs] def visit_Subscript(self, node: ast.Subscript) -> Any: """Visit a subscript node. Args: node: The node to visit. Returns: Any: The result of the visit. """ if ( isinstance(node.ctx, ast.Load) and isinstance(node.value, ast.Name) and node.value.id == self.name and isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str) ): # we've found a subscript access on the name we're looking for self.keys.add(node.slice.value)
[docs] def visit_Call(self, node: ast.Call) -> Any: """Visit a call node. Args: node: The node to visit. Returns: Any: The result of the visit. """ if ( isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id == self.name and node.func.attr == "get" and len(node.args) in (1, 2) and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str) ): # we've found a .get() call on the name we're looking for self.keys.add(node.args[0].value)
[docs] class IsFunctionArgDict(ast.NodeVisitor): """Check if the first argument of a function is a dict."""
[docs] def __init__(self) -> None: self.keys: Set[str] = set()
[docs] def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """ if not node.args.args: return input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node.body)
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ if not node.args.args: return input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node)
[docs] def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ if not node.args.args: return input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node)
[docs] class NonLocals(ast.NodeVisitor): """Get nonlocal variables accessed."""
[docs] def __init__(self) -> None: self.loads: Set[str] = set() self.stores: Set[str] = set()
[docs] def visit_Name(self, node: ast.Name) -> Any: """Visit a name node. Args: node: The node to visit. Returns: Any: The result of the visit. """ if isinstance(node.ctx, ast.Load): self.loads.add(node.id) elif isinstance(node.ctx, ast.Store): self.stores.add(node.id)
[docs] def visit_Attribute(self, node: ast.Attribute) -> Any: """Visit an attribute node. Args: node: The node to visit. Returns: Any: The result of the visit. """ if isinstance(node.ctx, ast.Load): parent = node.value attr_expr = node.attr while isinstance(parent, ast.Attribute): attr_expr = parent.attr + "." + attr_expr parent = parent.value if isinstance(parent, ast.Name): self.loads.add(parent.id + "." + attr_expr) self.loads.discard(parent.id)
[docs] class FunctionNonLocals(ast.NodeVisitor): """Get the nonlocal variables accessed of a function."""
[docs] def __init__(self) -> None: self.nonlocals: Set[str] = set()
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ visitor = NonLocals() visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores)
[docs] def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ visitor = NonLocals() visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores)
[docs] def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """ visitor = NonLocals() visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores)
[docs] class GetLambdaSource(ast.NodeVisitor): """Get the source code of a lambda function."""
[docs] def __init__(self) -> None: """Initialize the visitor.""" self.source: Optional[str] = None self.count = 0
[docs] def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """ self.count += 1 if hasattr(ast, "unparse"): self.source = ast.unparse(node)
[docs] def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: """Get the keys of the first argument of a function if it is a dict. Args: func: The function to check. Returns: Optional[List[str]]: The keys of the first argument if it is a dict, None otherwise. """ try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) visitor = IsFunctionArgDict() visitor.visit(tree) return list(visitor.keys) if visitor.keys else None except (SyntaxError, TypeError, OSError, SystemError): return None
[docs] def get_lambda_source(func: Callable) -> Optional[str]: """Get the source code of a lambda function. Args: func: a Callable that can be a lambda function. Returns: str: the source code of the lambda function. """ try: name = func.__name__ if func.__name__ != "<lambda>" else None except AttributeError: name = None try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) visitor = GetLambdaSource() visitor.visit(tree) return visitor.source if visitor.count == 1 else name except (SyntaxError, TypeError, OSError, SystemError): return name
[docs] def get_function_nonlocals(func: Callable) -> List[Any]: """Get the nonlocal variables accessed by a function. Args: func: The function to check. Returns: List[Any]: The nonlocal variables accessed by the function. """ try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) visitor = FunctionNonLocals() visitor.visit(tree) values: List[Any] = [] for k, v in inspect.getclosurevars(func).nonlocals.items(): if k in visitor.nonlocals: values.append(v) for kk in visitor.nonlocals: if "." in kk and kk.startswith(k): vv = v for part in kk.split(".")[1:]: if vv is None: break else: try: vv = getattr(vv, part) except AttributeError: break else: values.append(vv) return values except (SyntaxError, TypeError, OSError, SystemError): return []
[docs] def indent_lines_after_first(text: str, prefix: str) -> str: """Indent all lines of text after the first line. Args: text: The text to indent. prefix: Used to determine the number of spaces to indent. Returns: str: The indented text. """ n_spaces = len(prefix) spaces = " " * n_spaces lines = text.splitlines() return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
[docs] class AddableDict(Dict[str, Any]): """ Dictionary that can be added to another dictionary. """ def __add__(self, other: AddableDict) -> AddableDict: chunk = AddableDict(self) for key in other: if key not in chunk or chunk[key] is None: chunk[key] = other[key] elif other[key] is not None: try: added = chunk[key] + other[key] except TypeError: added = other[key] chunk[key] = added return chunk def __radd__(self, other: AddableDict) -> AddableDict: chunk = AddableDict(other) for key in self: if key not in chunk or chunk[key] is None: chunk[key] = self[key] elif self[key] is not None: try: added = chunk[key] + self[key] except TypeError: added = self[key] chunk[key] = added return chunk
_T_co = TypeVar("_T_co", covariant=True) _T_contra = TypeVar("_T_contra", contravariant=True)
[docs] class SupportsAdd(Protocol[_T_contra, _T_co]): """Protocol for objects that support addition.""" def __add__(self, __x: _T_contra) -> _T_co: ...
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
[docs] def add(addables: Iterable[Addable]) -> Optional[Addable]: """Add a sequence of addable objects together. Args: addables: The addable objects to add. Returns: Optional[Addable]: The result of adding the addable objects. """ final = None for chunk in addables: if final is None: final = chunk else: final = final + chunk return final
[docs] async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: """Asynchronously add a sequence of addable objects together. Args: addables: The addable objects to add. Returns: Optional[Addable]: The result of adding the addable objects. """ final = None async for chunk in addables: if final is None: final = chunk else: final = final + chunk return final
[docs] class ConfigurableField(NamedTuple): """Field that can be configured by the user. Parameters: id: The unique identifier of the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. annotation: The annotation of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """ id: str name: Optional[str] = None description: Optional[str] = None annotation: Optional[Any] = None is_shared: bool = False def __hash__(self) -> int: return hash((self.id, self.annotation))
[docs] class ConfigurableFieldSingleOption(NamedTuple): """Field that can be configured by the user with a default value. Parameters: id: The unique identifier of the field. options: The options for the field. default: The default value for the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """ id: str options: Mapping[str, Any] default: str name: Optional[str] = None description: Optional[str] = None is_shared: bool = False def __hash__(self) -> int: return hash((self.id, tuple(self.options.keys()), self.default))
[docs] class ConfigurableFieldMultiOption(NamedTuple): """Field that can be configured by the user with multiple default values. Parameters: id: The unique identifier of the field. options: The options for the field. default: The default values for the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """ id: str options: Mapping[str, Any] default: Sequence[str] name: Optional[str] = None description: Optional[str] = None is_shared: bool = False def __hash__(self) -> int: return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
AnyConfigurableField = Union[ ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption ]
[docs] class ConfigurableFieldSpec(NamedTuple): """Field that can be configured by the user. It is a specification of a field. Parameters: id: The unique identifier of the field. annotation: The annotation of the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. default: The default value for the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. dependencies: The dependencies of the field. Defaults to None. """ id: str annotation: Any name: Optional[str] = None description: Optional[str] = None default: Any = None is_shared: bool = False dependencies: Optional[List[str]] = None
[docs] def get_unique_config_specs( specs: Iterable[ConfigurableFieldSpec], ) -> List[ConfigurableFieldSpec]: """Get the unique config specs from a sequence of config specs. Args: specs: The config specs. Returns: List[ConfigurableFieldSpec]: The unique config specs. Raises: ValueError: If the runnable sequence contains conflicting config specs. """ grouped = groupby( sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id ) unique: List[ConfigurableFieldSpec] = [] for id, dupes in grouped: first = next(dupes) others = list(dupes) if len(others) == 0: unique.append(first) elif all(o == first for o in others): unique.append(first) else: raise ValueError( "RunnableSequence contains conflicting config specs" f"for {id}: {[first] + others}" ) return unique
class _RootEventFilter: def __init__( self, *, include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, exclude_names: Optional[Sequence[str]] = None, exclude_types: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None, ) -> None: """Utility to filter the root event in the astream_events implementation. This is simply binding the arguments to the namespace to make save on a bit of typing in the astream_events implementation. """ self.include_names = include_names self.include_types = include_types self.include_tags = include_tags self.exclude_names = exclude_names self.exclude_types = exclude_types self.exclude_tags = exclude_tags def include_event(self, event: StreamEvent, root_type: str) -> bool: """Determine whether to include an event.""" if ( self.include_names is None and self.include_types is None and self.include_tags is None ): include = True else: include = False event_tags = event.get("tags") or [] if self.include_names is not None: include = include or event["name"] in self.include_names if self.include_types is not None: include = include or root_type in self.include_types if self.include_tags is not None: include = include or any(tag in self.include_tags for tag in event_tags) if self.exclude_names is not None: include = include and event["name"] not in self.exclude_names if self.exclude_types is not None: include = include and root_type not in self.exclude_types if self.exclude_tags is not None: include = include and all( tag not in self.exclude_tags for tag in event_tags ) return include class _SchemaConfig(BaseConfig): arbitrary_types_allowed = True frozen = True
[docs] def create_model( __model_name: str, **field_definitions: Any, ) -> Type[BaseModel]: """Create a pydantic model with the given field definitions. Args: __model_name: The name of the model. **field_definitions: The field definitions for the model. Returns: Type[BaseModel]: The created model. """ try: return _create_model_cached(__model_name, **field_definitions) except TypeError: # something in field definitions is not hashable return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions )
@lru_cache(maxsize=256) def _create_model_cached( __model_name: str, **field_definitions: Any, ) -> Type[BaseModel]: return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions )
[docs] def is_async_generator( func: Any, ) -> TypeGuard[Callable[..., AsyncIterator]]: """Check if a function is an async generator. Args: func: The function to check. Returns: TypeGuard[Callable[..., AsyncIterator]: True if the function is an async generator, False otherwise. """ return ( inspect.isasyncgenfunction(func) or hasattr(func, "__call__") and inspect.isasyncgenfunction(func.__call__) )
[docs] def is_async_callable( func: Any, ) -> TypeGuard[Callable[..., Awaitable]]: """Check if a function is async. Args: func: The function to check. Returns: TypeGuard[Callable[..., Awaitable]: True if the function is async, False otherwise. """ return ( asyncio.iscoroutinefunction(func) or hasattr(func, "__call__") and asyncio.iscoroutinefunction(func.__call__) )