Source code for langchain_text_splitters.base

from __future__ import annotations

import copy
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import (
    AbstractSet,
    Any,
    Callable,
    Collection,
    Iterable,
    List,
    Literal,
    Optional,
    Sequence,
    Type,
    TypeVar,
    Union,
)

from langchain_core.documents import BaseDocumentTransformer, Document

logger = logging.getLogger(__name__)

TS = TypeVar("TS", bound="TextSplitter")


[docs] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks."""
[docs] def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, keep_separator: Union[bool, Literal["start", "end"]] = False, add_start_index: bool = False, strip_whitespace: bool = True, ) -> None: """Create a new TextSplitter. Args: chunk_size: Maximum size of chunks to return chunk_overlap: Overlap in characters between chunks length_function: Function that measures the length of given chunks keep_separator: Whether to keep the separator and where to place it in each corresponding chunk (True='start') add_start_index: If `True`, includes chunk's start index in metadata strip_whitespace: If `True`, strips whitespace from the start and end of every document """ if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function self._keep_separator = keep_separator self._add_start_index = add_start_index self._strip_whitespace = strip_whitespace
[docs] @abstractmethod def split_text(self, text: str) -> List[str]: """Split text into multiple components."""
[docs] def create_documents( self, texts: List[str], metadatas: Optional[List[dict]] = None ) -> List[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): index = 0 previous_chunk_len = 0 for chunk in self.split_text(text): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: offset = index + previous_chunk_len - self._chunk_overlap index = text.find(chunk, max(0, offset)) metadata["start_index"] = index previous_chunk_len = len(chunk) new_doc = Document(page_content=chunk, metadata=metadata) documents.append(new_doc) return documents
[docs] def split_documents(self, documents: Iterable[Document]) -> List[Document]: """Split documents.""" texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) metadatas.append(doc.metadata) return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: text = separator.join(docs) if self._strip_whitespace: text = text.strip() if text == "": return None else: return text def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) docs = [] current_doc: List[str] = [] total = 0 for d in splits: _len = self._length_function(d) if ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size ): if total > self._chunk_size: logger.warning( f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): total -= self._length_function(current_doc[0]) + ( separator_len if len(current_doc) > 1 else 0 ) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs
[docs] @classmethod def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError( "Tokenizer received was not an instance of PreTrainedTokenizerBase" ) def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( "Could not import transformers python package. " "Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs)
[docs] @classmethod def from_tiktoken_encoder( cls: Type[TS], encoding_name: str = "gpt2", model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate max_tokens_for_prompt. " "Please install it with `pip install tiktoken`." ) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) else: enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str) -> int: return len( enc.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, ) ) if issubclass(cls, TokenTextSplitter): extra_kwargs = { "encoding_name": encoding_name, "model_name": model_name, "allowed_special": allowed_special, "disallowed_special": disallowed_special, } kwargs = {**kwargs, **extra_kwargs} return cls(length_function=_tiktoken_encoder, **kwargs)
[docs] def transform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents))
[docs] class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer."""
[docs] def __init__( self, encoding_name: str = "gpt2", model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." ) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc self._allowed_special = allowed_special self._disallowed_special = disallowed_special
[docs] def split_text(self, text: str) -> List[str]: def _encode(_text: str) -> List[int]: return self._tokenizer.encode( _text, allowed_special=self._allowed_special, disallowed_special=self._disallowed_special, ) tokenizer = Tokenizer( chunk_overlap=self._chunk_overlap, tokens_per_chunk=self._chunk_size, decode=self._tokenizer.decode, encode=_encode, ) return split_text_on_tokens(text=text, tokenizer=tokenizer)
[docs] class Language(str, Enum): """Enum of the programming languages.""" CPP = "cpp" GO = "go" JAVA = "java" KOTLIN = "kotlin" JS = "js" TS = "ts" PHP = "php" PROTO = "proto" PYTHON = "python" RST = "rst" RUBY = "ruby" RUST = "rust" SCALA = "scala" SWIFT = "swift" MARKDOWN = "markdown" LATEX = "latex" HTML = "html" SOL = "sol" CSHARP = "csharp" COBOL = "cobol" C = "c" LUA = "lua" PERL = "perl" HASKELL = "haskell" ELIXIR = "elixir" POWERSHELL = "powershell"
[docs] @dataclass(frozen=True) class Tokenizer: """Tokenizer data class.""" chunk_overlap: int """Overlap in tokens between chunks""" tokens_per_chunk: int """Maximum number of tokens per chunk""" decode: Callable[[List[int]], str] """ Function to decode a list of token ids to a string""" encode: Callable[[str], List[int]] """ Function to encode a string to a list of token ids"""
[docs] def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: """Split incoming text and return chunks using tokenizer.""" splits: List[str] = [] input_ids = tokenizer.encode(text) start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) if cur_idx == len(input_ids): break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] return splits