Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions src/classifai/indexers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import shutil
import time
import uuid
from typing import Literal

import numpy as np
import polars as pl
Expand Down Expand Up @@ -69,12 +70,12 @@ class VectorStore:
"""A class to model and create `VectorStore` objects for building and searching vector databases from CSV text files.

Attributes:
file_name (str): the data file contatining the knowledgebase to build the `VectorStore`
data_type (str): the data type of the data file (curently only csv supported)
vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Pacakge module
file_name (str | os.PathLike[str]): the data file contatining the knowledgebase to build the `VectorStore`
data_type (Literal["csv"]): the data type of the data file (curently only csv supported)
vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Package module
batch_size (int): the batch size to pass to the vectoriser when embedding
meta_data (dict): key-value pairs of metadata to extract from the input file and their correpsonding types
output_dir (str): the path to the output directory where the `VectorStore` will be saved
output_dir (str | os.PathLike[str]): the path to the output directory where the `VectorStore` will be saved
vectors (np.array): a numpy array of vectors for the vector database
vector_shape (int): the dimension of the vectors
num_vectors (int): the number of records saved in the `VectorStore`
Expand All @@ -84,22 +85,22 @@ class VectorStore:

def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
self,
file_name: str,
data_type: str,
file_name: str | os.PathLike[str],
data_type: Literal["csv"],
vectoriser: VectoriserBase,
batch_size: int = 8,
meta_data: dict | None = None,
output_dir: str | None = None,
output_dir: str | os.PathLike[str] | None = None,
overwrite: bool = False,
hooks: dict | None = None,
):
"""Initializes the `VectorStore` object by processing the input CSV file and generating
vector embeddings.

Args:
file_name (str): The name of the input CSV file.
file_name (str | os.PathLike): The name of the input CSV file.
data_type (str): The type of input data (currently supports only "csv").
vectoriser (object): The `Vectoriser` object used to transform text into
vectoriser (VectoriserBase): The `Vectoriser` object used to transform text into
vector embeddings.
batch_size (int): [optional] The batch size for processing the input file and batching to
vectoriser. Defaults to 8.
Expand All @@ -119,8 +120,10 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
`IndexBuildError`: If there are failures during index building or saving outputs.
"""
# ---- Input validation (caller mistakes) -> DataValidationError / ConfigurationError
if not isinstance(file_name, str) or not file_name.strip():
raise DataValidationError("file_name must be a non-empty string.", context={"file_name": file_name})
if not isinstance(file_name, (str, os.PathLike)) or not os.fspath(file_name).strip():
raise DataValidationError(
"file_name must be a non-empty string or os.PathLike.", context={"file_name": file_name}
)

if not os.path.exists(file_name):
raise DataValidationError("Input file does not exist.", context={"file_name": file_name})
Expand Down Expand Up @@ -149,17 +152,15 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__})

# ---- Assign fields
## all these fields are all initalised from inputs
self.file_name = file_name
self.data_type = data_type
self.vectoriser = vectoriser
self.batch_size = batch_size
self.meta_data = meta_data if meta_data is not None else {}
self.output_dir = output_dir
self.vectors = None
self.vector_shape = None
self.num_vectors = None
self.vectoriser_class = vectoriser.__class__.__name__
self.hooks = {} if hooks is None else hooks
self.vectoriser_class = vectoriser.__class__.__name__

# ---- Output directory handling (filesystem problems) -> ConfigurationError
try:
Expand All @@ -185,7 +186,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915

# ---- Build index (wrap every unexpected failure) -> IndexBuildError
try:
self._create_vector_store_index()
self._create_vector_store_index(os.fspath(self.file_name))
except ClassifaiError:
# preserve already-classified errors (e.g. vectoriser raised DataValidationError)
raise
Expand Down Expand Up @@ -260,14 +261,17 @@ def _save_metadata(self, path: str):
context={"path": path, "metadata": metadata, "cause_type": type(e).__name__, "cause_message": str(e)},
) from e

def _create_vector_store_index(self): # noqa: C901
def _create_vector_store_index(self, file_name: str): # noqa: C901
"""Processes text strings in batches, generates vector embeddings, and creates the
`VectorStore`.
Called from the constructor once other metadata has been set.
Iterates over data in batches, stores batch data and generated embeddings.
Creates a Polars DataFrame with the captured data and embeddings, and saves it as
a Parquet file in the output_dir attribute, and stores in the vectors attribute.

Args:
file_name (str): The filename of csv to read in

Raises:
`DataValidationError`: If there are issues reading or validating the input file.
`IndexBuildError`: If there are failures during embedding or building the vectors table.
Expand All @@ -276,9 +280,9 @@ def _create_vector_store_index(self): # noqa: C901
try:
if self.data_type == "csv":
self.vectors = pl.read_csv(
self.file_name,
file_name,
columns=["label", "text", *self.meta_data.keys()],
dtypes=self.meta_data | {"label": str, "text": str},
schema_overrides=self.meta_data | {"label": str, "text": str},
)
self.vectors = self.vectors.with_columns(
pl.Series("uuid", [str(uuid.uuid4()) for _ in range(self.vectors.height)])
Expand Down Expand Up @@ -740,7 +744,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
return result_df

@classmethod
def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915
def from_filespace(cls, folder_path: str | os.PathLike[str], vectoriser: VectoriserBase, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915
"""Creates a `VectorStore` instance from stored metadata and Parquet files.
This method reads the metadata and vectors from the specified folder,
validates the contents, and initializes a `VectorStore` object with the
Expand All @@ -752,8 +756,8 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
needing to reprocess the original text data.

Args:
folder_path (str): The folder path containing the metadata and Parquet files.
vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings.
folder_path (str | os.PathLike): The folder path containing the metadata and Parquet files.
vectoriser (VectoriserBase): The `Vectoriser` object used to transform text into vector embeddings.
hooks (dict): [optional] A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None.

Returns:
Expand All @@ -765,7 +769,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
`IndexBuildError`: If there are failures during loading or parsing the files.
"""
# ---- Validate arguments (caller mistakes) -> DataValidationError / ConfigurationError
if not isinstance(folder_path, str) or not folder_path.strip():
if not isinstance(folder_path, (str, os.PathLike)) or not os.fspath(folder_path).strip():
raise DataValidationError("folder_path must be a non-empty string.", context={"folder_path": folder_path})

if not os.path.isdir(folder_path):
Expand Down
11 changes: 4 additions & 7 deletions src/classifai/servers/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Pydantic Classes to model request and response data for ClassifAI FastAPI RESTful API."""

import pandas as pd
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, ConfigDict, Field


class SearchRequestEntry(BaseModel):
Expand Down Expand Up @@ -33,8 +33,7 @@ class SearchResponseEntry(BaseModel):
rank: int = Field(description="The rank of the result entry for the given query, with 1 being the most relevant.")
score: float = Field(description="The similarity score of the result entry for the given query.")

class Config:
extra = Extra.allow # Allow extra keys (e.g., metadata columns)å
model_config = ConfigDict(extra="allow")


class SearchResponseSet(BaseModel):
Expand Down Expand Up @@ -81,8 +80,7 @@ class ReverseSearchResponseEntry(BaseModel):
doc_label: str
doc_text: str

class Config:
extra = Extra.allow # Allow extra keys (e.g., metadata columns)
model_config = ConfigDict(extra="allow")


class ReverseSearchResponseSet(BaseModel):
Expand Down Expand Up @@ -135,8 +133,7 @@ class EmbedResponseEntry(BaseModel):
description="The vector embedding result for the input text string, represented as a list of floats."
)

class Config:
extra = Extra.allow # Allow extra keys (e.g., metadata columns)
model_config = ConfigDict(extra="allow")


class EmbedResponseBody(BaseModel):
Expand Down
Loading