Source code for langchain.chains.query_constructor.base

"""LLM Chain for turning a user text query into a structured query."""

from __future__ import annotations

import json
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast

from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers.json import parse_and_check_json_markdown
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.structured_query import (
    Comparator,
    Comparison,
    FilterDirective,
    Operation,
    Operator,
    StructuredQuery,
)

from langchain.chains.llm import LLMChain
from langchain.chains.query_constructor.parser import get_parser
from langchain.chains.query_constructor.prompt import (
    DEFAULT_EXAMPLES,
    DEFAULT_PREFIX,
    DEFAULT_SCHEMA_PROMPT,
    DEFAULT_SUFFIX,
    EXAMPLE_PROMPT,
    EXAMPLES_WITH_LIMIT,
    PREFIX_WITH_DATA_SOURCE,
    SCHEMA_WITH_LIMIT_PROMPT,
    SUFFIX_WITHOUT_DATA_SOURCE,
    USER_SPECIFIED_EXAMPLE_PROMPT,
)
from langchain.chains.query_constructor.schema import AttributeInfo


[docs] class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): """Output parser that parses a structured query.""" ast_parse: Callable """Callable that parses dict into internal representation of query language."""
[docs] def parse(self, text: str) -> StructuredQuery: try: expected_keys = ["query", "filter"] allowed_keys = ["query", "filter", "limit"] parsed = parse_and_check_json_markdown(text, expected_keys) if parsed["query"] is None or len(parsed["query"]) == 0: parsed["query"] = " " if parsed["filter"] == "NO_FILTER" or not parsed["filter"]: parsed["filter"] = None else: parsed["filter"] = self.ast_parse(parsed["filter"]) if not parsed.get("limit"): parsed.pop("limit", None) return StructuredQuery( **{k: v for k, v in parsed.items() if k in allowed_keys} ) except Exception as e: raise OutputParserException( f"Parsing text\n{text}\n raised following error:\n{e}" )
[docs] @classmethod def from_components( cls, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, allowed_attributes: Optional[Sequence[str]] = None, fix_invalid: bool = False, ) -> StructuredQueryOutputParser: """ Create a structured query output parser from components. Args: allowed_comparators: allowed comparators allowed_operators: allowed operators Returns: a structured query output parser """ ast_parse: Callable if fix_invalid: def ast_parse(raw_filter: str) -> Optional[FilterDirective]: filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter)) fixed = fix_filter_directive( filter, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ) return fixed else: ast_parse = get_parser( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ).parse return cls(ast_parse=ast_parse)
[docs] def fix_filter_directive( filter: Optional[FilterDirective], *, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, allowed_attributes: Optional[Sequence[str]] = None, ) -> Optional[FilterDirective]: """Fix invalid filter directive. Args: filter: Filter directive to fix. allowed_comparators: allowed comparators. Defaults to all comparators. allowed_operators: allowed operators. Defaults to all operators. allowed_attributes: allowed attributes. Defaults to all attributes. Returns: Fixed filter directive. """ if ( not (allowed_comparators or allowed_operators or allowed_attributes) ) or not filter: return filter elif isinstance(filter, Comparison): if allowed_comparators and filter.comparator not in allowed_comparators: return None if allowed_attributes and filter.attribute not in allowed_attributes: return None return filter elif isinstance(filter, Operation): if allowed_operators and filter.operator not in allowed_operators: return None args = [ cast( FilterDirective, fix_filter_directive( arg, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ), ) for arg in filter.arguments if arg is not None ] if not args: return None elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): return args[0] else: return Operation( operator=filter.operator, arguments=args, ) else: return filter
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str: info_dicts = {} for i in info: i_dict = dict(i) info_dicts[i_dict.pop("name")] = i_dict return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
[docs] def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]: """Construct examples from input-output pairs. Args: input_output_pairs: Sequence of input-output pairs. Returns: List of examples. """ examples = [] for i, (_input, output) in enumerate(input_output_pairs): structured_request = ( json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}") ) example = { "i": i + 1, "user_query": _input, "structured_request": structured_request, } examples.append(example) return examples
[docs] def get_query_constructor_prompt( document_contents: str, attribute_info: Sequence[Union[AttributeInfo, dict]], *, examples: Optional[Sequence] = None, allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, schema_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> BasePromptTemplate: """Create query construction prompt. Args: document_contents: The contents of the document to be queried. attribute_info: A list of AttributeInfo objects describing the attributes of the document. examples: Optional list of examples to use for the chain. allowed_comparators: Sequence of allowed comparators. allowed_operators: Sequence of allowed operators. enable_limit: Whether to enable the limit operator. Defaults to False. schema_prompt: Prompt for describing query schema. Should have string input variables allowed_comparators and allowed_operators. kwargs: Additional named params to pass to FewShotPromptTemplate init. Returns: A prompt template that can be used to construct queries. """ default_schema_prompt = ( SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT ) schema_prompt = schema_prompt or default_schema_prompt attribute_str = _format_attribute_info(attribute_info) schema = schema_prompt.format( allowed_comparators=" | ".join(allowed_comparators), allowed_operators=" | ".join(allowed_operators), ) if examples and isinstance(examples[0], tuple): examples = construct_examples(examples) example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT prefix = PREFIX_WITH_DATA_SOURCE.format( schema=schema, content=document_contents, attributes=attribute_str ) suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1) else: examples = examples or ( EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES ) example_prompt = EXAMPLE_PROMPT prefix = DEFAULT_PREFIX.format(schema=schema) suffix = DEFAULT_SUFFIX.format( i=len(examples) + 1, content=document_contents, attributes=attribute_str ) return FewShotPromptTemplate( examples=list(examples), example_prompt=example_prompt, input_variables=["query"], suffix=suffix, prefix=prefix, **kwargs, )
[docs] def load_query_constructor_chain( llm: BaseLanguageModel, document_contents: str, attribute_info: Sequence[Union[AttributeInfo, dict]], examples: Optional[List] = None, allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, schema_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> LLMChain: """Load a query constructor chain. Args: llm: BaseLanguageModel to use for the chain. document_contents: The contents of the document to be queried. attribute_info: Sequence of attributes in the document. examples: Optional list of examples to use for the chain. allowed_comparators: Sequence of allowed comparators. Defaults to all Comparators. allowed_operators: Sequence of allowed operators. Defaults to all Operators. enable_limit: Whether to enable the limit operator. Defaults to False. schema_prompt: Prompt for describing query schema. Should have string input variables allowed_comparators and allowed_operators. **kwargs: Arbitrary named params to pass to LLMChain. Returns: A LLMChain that can be used to construct queries. """ prompt = get_query_constructor_prompt( document_contents, attribute_info, examples=examples, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, enable_limit=enable_limit, schema_prompt=schema_prompt, ) allowed_attributes = [] for ainfo in attribute_info: allowed_attributes.append( ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] ) output_parser = StructuredQueryOutputParser.from_components( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ) # For backwards compatibility. prompt.output_parser = output_parser return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs)
[docs] def load_query_constructor_runnable( llm: BaseLanguageModel, document_contents: str, attribute_info: Sequence[Union[AttributeInfo, dict]], *, examples: Optional[Sequence] = None, allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, schema_prompt: Optional[BasePromptTemplate] = None, fix_invalid: bool = False, **kwargs: Any, ) -> Runnable: """Load a query constructor runnable chain. Args: llm: BaseLanguageModel to use for the chain. document_contents: Description of the page contents of the document to be queried. attribute_info: Sequence of attributes in the document. examples: Optional list of examples to use for the chain. allowed_comparators: Sequence of allowed comparators. Defaults to all Comparators. allowed_operators: Sequence of allowed operators. Defaults to all Operators. enable_limit: Whether to enable the limit operator. Defaults to False. schema_prompt: Prompt for describing query schema. Should have string input variables allowed_comparators and allowed_operators. fix_invalid: Whether to fix invalid filter directives by ignoring invalid operators, comparators and attributes. kwargs: Additional named params to pass to FewShotPromptTemplate init. Returns: A Runnable that can be used to construct queries. """ prompt = get_query_constructor_prompt( document_contents, attribute_info, examples=examples, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, enable_limit=enable_limit, schema_prompt=schema_prompt, **kwargs, ) allowed_attributes = [] for ainfo in attribute_info: allowed_attributes.append( ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] ) output_parser = StructuredQueryOutputParser.from_components( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, fix_invalid=fix_invalid, ) return prompt | llm | output_parser