diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 10e2d61b..d30073ab 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -89,7 +89,7 @@ jobs: POSTGRES_USER: root POSTGRES_PASSWORD: rootpassword POSTGRES_DB: nextcloud - options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 + options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 --name postgres --hostname postgres steps: - name: Checkout server @@ -113,6 +113,8 @@ jobs: repository: nextcloud/context_chat path: apps/context_chat persist-credentials: false + # todo: remove later + ref: feat/reverse-content-flow - name: Checkout backend uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -167,6 +169,10 @@ jobs: cd .. rm -rf documentation + - name: Run files scan + run: | + ./occ files:scan --all + - name: Setup python 3.11 uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5 with: @@ -195,28 +201,91 @@ jobs: timeout 10 ./occ app_api:daemon:register --net host manual_install "Manual Install" manual-install http localhost http://localhost:8080 timeout 120 ./occ app_api:app:register context_chat_backend manual_install --json-info "{\"appid\":\"context_chat_backend\",\"name\":\"Context Chat Backend\",\"daemon_config_name\":\"manual_install\",\"version\":\"${{ fromJson(steps.appinfo.outputs.result).version }}\",\"secret\":\"12345\",\"port\":10034,\"scopes\":[],\"system_app\":0}" --force-scopes --wait-finish ls -la context_chat_backend/persistent_storage/* - sleep 30 # Wait for the em server to get ready - - name: Scan files, baseline - run: | - ./occ files:scan admin - ./occ context_chat:scan admin -m text/plain - - - name: Check python memory usage + - name: Initial memory usage check run: | ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt - - name: Scan files + - name: Run cron jobs run: | - ./occ files:scan admin - ./occ context_chat:scan admin -m text/markdown & - ./occ context_chat:scan admin -m text/x-rst + # every 10 seconds indefinitely + while true; do + php cron.php + sleep 10 + done & + sleep 30 + # list all the bg jobs + ./occ background-job:list + + - name: Initial dump of DB with context_chat_queue populated + run: | + docker exec postgres pg_dump nextcloud > /tmp/0_pgdump_nextcloud - - name: Check python memory usage + - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files run: | - ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem - ps -p $(cat pid.txt) -o %mem --no-headers > after_scan_mem.txt + success=0 + echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files" + for i in {1..90}; do + echo "Checking stats, attempt $i..." + + stats_err=$(mktemp) + stats=$(timeout 5 ./occ context_chat:stats --json 2>"$stats_err") + stats_exit=$? + echo "Stats output:" + echo "$stats" + if [ -s "$stats_err" ]; then + echo "Stderr:" + cat "$stats_err" + fi + echo "---" + rm -f "$stats_err" + + # Check for critical errors in output + if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then + echo "Backend connection error detected (exit=$stats_exit), retrying..." + sleep 10 + continue + fi + + # Extract total eligible files + total_eligible_files=$(echo "$stats" | jq '.eligible_files_count' || echo "") + + # Extract indexed documents count (files__default) + indexed_count=$(echo "$stats" | jq '.vectordb_document_counts.files__default' || echo "") + + echo "Total eligible files: $total_eligible_files" + echo "Indexed documents (files__default): $indexed_count" + + diff=$((total_eligible_files - indexed_count)) + threshold=$((total_eligible_files * 3 / 100)) + + # Check if difference is within tolerance + if [ $diff -le $threshold ]; then + echo "Indexing within 3% tolerance (diff=$diff, threshold=$threshold)" + success=1 + break + else + progress=$((diff * 100 / total_eligible_files)) + echo "Outside 3% tolerance: diff=$diff (${progress}%), threshold=$threshold" + fi + + # Check if backend is still alive + ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0") + if [ "$ccb_alive" -eq 0 ]; then + echo "Error: Context Chat Backend process is not running. Exiting." + exit 1 + fi + + sleep 10 + done + + echo "::endgroup::" + + if [ $success -ne 1 ]; then + echo "Max attempts reached" + exit 1 + fi - name: Run the prompts run: | @@ -250,18 +319,9 @@ jobs: echo "Memory usage during scan is stable. No memory leak detected." fi - - name: Compare memory usage and detect leak + - name: Final dump of DB with vectordb populated run: | - initial_mem=$(cat after_scan_mem.txt | tr -d ' ') - final_mem=$(cat after_prompt_mem.txt | tr -d ' ') - echo "Initial Memory Usage: $initial_mem%" - echo "Memory Usage after prompt: $final_mem%" - - if (( $(echo "$final_mem > $initial_mem" | bc -l) )); then - echo "Memory usage has increased during prompt. Possible memory leak detected!" - else - echo "Memory usage during prompt is stable. No memory leak detected." - fi + docker exec postgres pg_dump nextcloud > /tmp/1_pgdump_nextcloud - name: Show server logs if: always() @@ -298,6 +358,19 @@ jobs: run: | tail -v -n +1 context_chat_backend/persistent_storage/logs/em_server.log* || echo "No logs in logs directory" + - name: Upload database dumps + uses: actions/upload-artifact@v4 + with: + name: database-dumps-${{ matrix.server-versions }}-php@${{ matrix.php-versions }} + path: | + /tmp/0_pgdump_nextcloud + /tmp/1_pgdump_nextcloud + + - name: Final stats log + run: | + ./occ context_chat:stats + ./occ context_chat:stats --json + summary: permissions: contents: none diff --git a/appinfo/info.xml b/appinfo/info.xml index 9760cd29..30194baa 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -82,5 +82,19 @@ Setup background job workers as described here: https://docs.nextcloud.com/serve Password to be used for authenticating requests to the OpenAI-compatible endpoint set in CC_EM_BASE_URL. + + + rp + Request Processing Mode + APP_ROLE=rp + true + + + indexing + Indexing Mode + APP_ROLE=indexing + false + + diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py index efb81b6d..832c8331 100644 --- a/context_chat_backend/chain/ingest/doc_loader.py +++ b/context_chat_backend/chain/ingest/doc_loader.py @@ -3,15 +3,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # -import logging import re import tempfile from collections.abc import Callable -from typing import BinaryIO +from io import BytesIO import docx2txt from epub2txt import epub2txt -from fastapi import UploadFile from langchain_unstructured import UnstructuredLoader from odfdo import Document from pandas import read_csv, read_excel @@ -19,9 +17,10 @@ from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError from striprtf import striprtf -logger = logging.getLogger('ccb.doc_loader') +from ...types import IndexingException, SourceItem -def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str: + +def _temp_file_wrapper(file: BytesIO, loader: Callable, sep: str = '\n') -> str: raw_bytes = file.read() with tempfile.NamedTemporaryFile(mode='wb') as tmp: tmp.write(raw_bytes) @@ -35,49 +34,49 @@ def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str # -- LOADERS -- # -def _load_pdf(file: BinaryIO) -> str: +def _load_pdf(file: BytesIO) -> str: pdf_reader = PdfReader(file) return '\n\n'.join([page.extract_text().strip() for page in pdf_reader.pages]) -def _load_csv(file: BinaryIO) -> str: +def _load_csv(file: BytesIO) -> str: return read_csv(file).to_string(header=False, na_rep='') -def _load_epub(file: BinaryIO) -> str: +def _load_epub(file: BytesIO) -> str: return _temp_file_wrapper(file, epub2txt).strip() -def _load_docx(file: BinaryIO) -> str: +def _load_docx(file: BytesIO) -> str: return docx2txt.process(file).strip() -def _load_odt(file: BinaryIO) -> str: +def _load_odt(file: BytesIO) -> str: return _temp_file_wrapper(file, lambda fp: Document(fp).get_formatted_text()).strip() -def _load_ppt_x(file: BinaryIO) -> str: +def _load_ppt_x(file: BytesIO) -> str: return _temp_file_wrapper(file, lambda fp: UnstructuredLoader(fp).load()).strip() -def _load_rtf(file: BinaryIO) -> str: +def _load_rtf(file: BytesIO) -> str: return striprtf.rtf_to_text(file.read().decode('utf-8', 'ignore')).strip() -def _load_xml(file: BinaryIO) -> str: +def _load_xml(file: BytesIO) -> str: data = file.read().decode('utf-8', 'ignore') data = re.sub(r'', '', data) return data.strip() -def _load_xlsx(file: BinaryIO) -> str: +def _load_xlsx(file: BytesIO) -> str: return read_excel(file, na_filter=False).to_string(header=False, na_rep='') -def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None: +def _load_email(file: BytesIO, ext: str = 'eml') -> str: # NOTE: msg format is not tested if ext not in ['eml', 'msg']: - return None + raise IndexingException(f'Unsupported email format: {ext}') # TODO: implement attachment partitioner using unstructured.partition.partition_{email,msg} # since langchain does not pass through the attachment_partitioner kwarg @@ -115,30 +114,36 @@ def attachment_partitioner( } -def decode_source(source: UploadFile) -> str | None: +def decode_source(source: SourceItem) -> str: + ''' + Raises + ------ + IndexingException + ''' + + io_obj: BytesIO | None = None try: # .pot files are powerpoint templates but also plain text files, # so we skip them to prevent decoding errors - if source.headers['title'].endswith('.pot'): - return None - - mimetype = source.headers['type'] - if mimetype is None: - return None - - if _loader_map.get(mimetype): - result = _loader_map[mimetype](source.file) - source.file.close() - return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore') - - result = source.file.read().decode('utf-8', 'ignore') - source.file.close() - return result - except PdfFileNotDecryptedError: - logger.warning(f'PDF file ({source.filename}) is encrypted and cannot be read') - return None - except Exception: - logger.exception(f'Error decoding source file ({source.filename})', stack_info=True) - return None + if source.title.endswith('.pot'): + raise IndexingException('PowerPoint template files (.pot) are not supported') + + if isinstance(source.content, str): + io_obj = BytesIO(source.content.encode('utf-8', 'ignore')) + else: + io_obj = source.content + + if _loader_map.get(source.type): + result = _loader_map[source.type](io_obj) + return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore').strip() + + return io_obj.read().decode('utf-8', 'ignore').strip() + except IndexingException: + raise + except PdfFileNotDecryptedError as e: + raise IndexingException('PDF file is encrypted and cannot be read') from e + except Exception as e: + raise IndexingException(f'Error decoding source file: {e}') from e finally: - source.file.close() # Ensure file is closed after processing + if io_obj is not None: + io_obj.close() diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index 5871ebb8..190eebd4 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -2,65 +2,241 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +import asyncio import logging import re +from collections.abc import Mapping +from io import BytesIO +from time import perf_counter_ns -from fastapi.datastructures import UploadFile +import niquests from langchain.schema import Document +from nc_py_api import AsyncNextcloudApp from ...dyn_loader import VectorDBLoader -from ...types import TConfig -from ...utils import is_valid_source_id, to_int +from ...types import IndexingError, IndexingException, ReceivedFileItem, SourceItem, TConfig from ...vectordb.base import BaseVectorDB from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp from ..types import InDocument from .doc_loader import decode_source from .doc_splitter import get_splitter_for -from .mimetype_list import SUPPORTED_MIMETYPES logger = logging.getLogger('ccb.injest') -def _allowed_file(file: UploadFile) -> bool: - return file.headers['type'] in SUPPORTED_MIMETYPES +# max concurrent fetches to avoid overloading the NC server or hitting rate limits +CONCURRENT_FILE_FETCHES = 10 # todo: config? +MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, all loaded in RAM at once, todo: config? + + +async def __fetch_file_content( + semaphore: asyncio.Semaphore, + file_id: int, + user_id: str, + _rlimit = 3, +) -> BytesIO: + ''' + Raises + ------ + IndexingException + ''' + + async with semaphore: + nc = AsyncNextcloudApp() + try: + # a file pointer for storing the stream in memory until it is consumed + fp = BytesIO() + await nc._session.download2fp( + url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}', + fp=fp, + dav=False, + params={ 'userId': user_id }, + ) + fp.seek(0) + return fp + except niquests.exceptions.RequestException as e: + if e.response is None: + raise + + if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue] + # todo: implement rate limits in php CC? + wait_for = int(e.response.headers.get('Retry-After', '30')) + if _rlimit <= 0: + raise IndexingException( + f'Rate limited when fetching content for file id {file_id}, user id {user_id},' + ' max retries exceeded', + retryable=True, + ) from e + logger.warning( + f'Rate limited when fetching content for file id {file_id}, user id {user_id},' + f' waiting {wait_for} before retrying', + exc_info=e, + ) + await asyncio.sleep(wait_for) + return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1) + + raise + except IndexingException: + raise + except Exception as e: + logger.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e) + raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e + + +async def __fetch_files_content( + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> tuple[Mapping[int, SourceItem], Mapping[int, IndexingError]]: + source_items = {} + error_items = {} + semaphore = asyncio.Semaphore(CONCURRENT_FILE_FETCHES) + tasks = [] + task_sources = {} + + file_count = sum(1 for s in sources.values() if isinstance(s, ReceivedFileItem)) + logger.debug('Fetching content for %d file(s) (max %d concurrent)', file_count, CONCURRENT_FILE_FETCHES) + + for db_id, file in sources.items(): + if isinstance(file, SourceItem): + continue + + try: + # to detect any validation errors but it should not happen since file.reference is validated + file.file_id # noqa: B018 + except ValueError as e: + logger.error( + f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}', + exc_info=e, + ) + error_items[db_id] = IndexingError( + error=f'Invalid file reference format: {file.reference}', + retryable=False, + ) + continue + + if file.size > MAX_FILE_SIZE: + logger.info( + f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size' + f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB', + ) + error_items[db_id] = IndexingError( + error=( + f'File size {(file.size/(1024*1024)):.2f} MiB' + f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB' + ), + retryable=False, + ) + continue + # any user id from the list should have read access to the file + tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0]))) + task_sources[db_id] = file + + results = await asyncio.gather(*tasks, return_exceptions=True) + for (db_id, file), result in zip(task_sources.items(), results, strict=True): + if isinstance(result, str) or isinstance(result, BytesIO): + source_items[db_id] = SourceItem( + **{ + **file.model_dump(), + 'content': result, + } + ) + elif isinstance(result, IndexingException): + logger.error( + f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}' + f': {result}', + exc_info=result, + ) + error_items[db_id] = IndexingError( + error=str(result), + retryable=result.retryable, + ) + elif isinstance(result, BaseException): + logger.error( + f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},' + f' reference {file.reference}: {result}', + exc_info=result, + ) + error_items[db_id] = IndexingError( + error=f'Unexpected error: {result}', + retryable=True, + ) + else: + logger.error( + f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}' + f': {result}', + exc_info=True, + ) + error_items[db_id] = IndexingError( + error='Unknown error', + retryable=True, + ) + + # add the content providers from the orginal "sources" to the result unprocessed + for db_id, source in sources.items(): + if isinstance(source, SourceItem): + source_items[db_id] = source + + return source_items, error_items def _filter_sources( vectordb: BaseVectorDB, - sources: list[UploadFile] -) -> tuple[list[UploadFile], list[UploadFile]]: + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]: ''' Returns ------- - tuple[list[str], list[UploadFile]] + tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]: First value is a list of sources that already exist in the vectordb. Second value is a list of sources that are new and should be embedded. ''' try: - existing_sources, new_sources = vectordb.check_sources(sources) + existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources) except Exception as e: - raise DbException('Error: Vectordb sources_to_embed error') from e + raise DbException('Error: Vectordb error while checking existing sources in indexing') from e + + existing_sources = {} + to_embed_sources = {} + + for db_id, source in sources.items(): + if source.reference in existing_source_ids: + existing_sources[db_id] = source + elif source.reference in to_embed_source_ids: + to_embed_sources[db_id] = source - return ([ - source for source in sources - if source.filename in existing_sources - ], [ - source for source in sources - if source.filename in new_sources - ]) + return existing_sources, to_embed_sources -def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]: - indocuments = [] +def _sources_to_indocuments( + config: TConfig, + sources: Mapping[int, SourceItem] +) -> tuple[Mapping[int, InDocument], Mapping[int, IndexingError]]: + indocuments = {} + errored_docs = {} - for source in sources: - logger.debug('processing source', extra={ 'source_id': source.filename }) + for db_id, source in sources.items(): + logger.debug('processing source', extra={ 'source_id': source.reference }) # transform the source to have text data - content = decode_source(source) + try: + logger.debug('Decoding source %s (type: %s)', source.reference, source.type) + t0 = perf_counter_ns() + content = decode_source(source) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.debug('Decoded source %s in %.2f ms (%d chars)', source.reference, elapsed_ms, len(content)) + except IndexingException as e: + logger.error(f'Error decoding source ({source.reference}): {e}', exc_info=e) + errored_docs[db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue - if content is None or (content := content.strip()) == '': - logger.debug('decoded empty source', extra={ 'source_id': source.filename }) + if content == '': + logger.debug('decoded empty source', extra={ 'source_id': source.reference }) + errored_docs[db_id] = IndexingError( + error='Decoded content is empty', + retryable=False, + ) continue # replace more than two newlines with two newlines (also blank spaces, more than 4) @@ -68,97 +244,149 @@ def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[ # NOTE: do not use this with all docs when programming files are added content = re.sub(r'(\s){5,}', r'\g<1>', content) # filter out null bytes - content = content.replace('\0', '') - - if content is None or content == '': - logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.filename }) + content = content.replace('\0', '').strip() + + if content == '': + logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.reference }) + errored_docs[db_id] = IndexingError( + error='Cleaned up content is empty', + retryable=False, + ) continue - logger.debug('decoded non empty source', extra={ 'source_id': source.filename }) + logger.debug('decoded non empty source', extra={ 'source_id': source.reference }) metadata = { - 'source': source.filename, - 'title': _decode_latin_1(source.headers['title']), - 'type': source.headers['type'], + 'source': source.reference, + 'title': _decode_latin_1(source.title), + 'type': source.type, } doc = Document(page_content=content, metadata=metadata) - splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type']) + splitter = get_splitter_for(config.embedding_chunk_size, source.type) split_docs = splitter.split_documents([doc]) logger.debug('split document into chunks', extra={ - 'source_id': source.filename, + 'source_id': source.reference, 'len(split_docs)': len(split_docs), }) - indocuments.append(InDocument( + indocuments[db_id] = InDocument( documents=split_docs, - userIds=list(map(_decode_latin_1, source.headers['userIds'].split(','))), - source_id=source.filename, # pyright: ignore[reportArgumentType] - provider=source.headers['provider'], - modified=to_int(source.headers['modified']), - )) + userIds=list(map(_decode_latin_1, source.userIds)), + source_id=source.reference, + provider=source.provider, + modified=source.modified, # pyright: ignore[reportArgumentType] + ) - return indocuments + return indocuments, errored_docs + + +def _increase_access_for_existing_sources( + vectordb: BaseVectorDB, + existing_sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> Mapping[int, IndexingError | None]: + ''' + update userIds for existing sources + allow the userIds as additional users, not as the only users + ''' + if len(existing_sources) == 0: + return {} + + results = {} + logger.debug('Increasing access for existing sources', extra={ + 'source_ids': [source.reference for source in existing_sources.values()] + }) + for db_id, source in existing_sources.items(): + try: + vectordb.update_access( + UpdateAccessOp.ALLOW, + list(map(_decode_latin_1, source.userIds)), + source.reference, + ) + results[db_id] = None + except SafeDbException as e: + logger.error(f'Failed to update access for source ({source.reference}): {e.args[0]}') + results[db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue + except Exception as e: + logger.error(f'Unexpected error while updating access for source ({source.reference}): {e}') + results[db_id] = IndexingError( + error='Unexpected error while updating access', + retryable=True, + ) + continue + return results def _process_sources( vectordb: BaseVectorDB, config: TConfig, - sources: list[UploadFile], -) -> tuple[list[str],list[str]]: + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> Mapping[int, IndexingError | None]: ''' Processes the sources and adds them to the vectordb. Returns the list of source ids that were successfully added and those that need to be retried. ''' - existing_sources, filtered_sources = _filter_sources(vectordb, sources) + existing_sources, to_embed_sources = _filter_sources(vectordb, sources) logger.debug('db filter source results', extra={ 'len(existing_sources)': len(existing_sources), 'existing_sources': existing_sources, - 'len(filtered_sources)': len(filtered_sources), - 'filtered_sources': filtered_sources, + 'len(to_embed_sources)': len(to_embed_sources), + 'to_embed_sources': to_embed_sources, }) - loaded_source_ids = [source.filename for source in existing_sources] - # update userIds for existing sources - # allow the userIds as additional users, not as the only users - if len(existing_sources) > 0: - logger.debug('Increasing access for existing sources', extra={ - 'source_ids': [source.filename for source in existing_sources] - }) - for source in existing_sources: - try: - vectordb.update_access( - UpdateAccessOp.allow, - list(map(_decode_latin_1, source.headers['userIds'].split(','))), - source.filename, # pyright: ignore[reportArgumentType] - ) - except SafeDbException as e: - logger.error(f'Failed to update access for source ({source.filename}): {e.args[0]}') - continue - - if len(filtered_sources) == 0: + source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources) + + logger.debug( + 'Fetching file contents for %d source(s) from Nextcloud', + len(to_embed_sources), + ) + t0 = perf_counter_ns() + populated_to_embed_sources, errored_sources = asyncio.run(__fetch_files_content(to_embed_sources)) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.debug( + 'File content fetch complete in %.2f ms: %d fetched, %d errored', + elapsed_ms, len(populated_to_embed_sources), len(errored_sources), + ) + source_proc_results.update(errored_sources) # pyright: ignore[reportAttributeAccessIssue] + + if len(populated_to_embed_sources) == 0: # no new sources to embed logger.debug('Filtered all sources, nothing to embed') - return loaded_source_ids, [] # pyright: ignore[reportReturnType] + return source_proc_results logger.debug('Filtered sources:', extra={ - 'source_ids': [source.filename for source in filtered_sources] + 'source_ids': [source.reference for source in populated_to_embed_sources.values()] }) # invalid/empty sources are filtered out here and not counted in loaded/retryable - indocuments = _sources_to_indocuments(config, filtered_sources) + indocuments, errored_docs = _sources_to_indocuments(config, populated_to_embed_sources) - logger.debug('Converted all sources to documents') + source_proc_results.update(errored_docs) # pyright: ignore[reportAttributeAccessIssue] + logger.debug('Converted sources to documents') if len(indocuments) == 0: # filtered document(s) were invalid/empty, not an error logger.debug('All documents were found empty after being processed') - return loaded_source_ids, [] # pyright: ignore[reportReturnType] + return source_proc_results + + logger.debug('Adding documents to vectordb', extra={ + 'source_ids': [indoc.source_id for indoc in indocuments.values()] + }) - added_source_ids, retry_source_ids = vectordb.add_indocuments(indocuments) - loaded_source_ids.extend(added_source_ids) + t0 = perf_counter_ns() + doc_add_results = vectordb.add_indocuments(indocuments) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.info( + 'vectordb.add_indocuments completed in %.2f ms for %d document(s)', + elapsed_ms, len(indocuments), + ) + source_proc_results.update(doc_add_results) # pyright: ignore[reportAttributeAccessIssue] logger.debug('Added documents to vectordb') - return loaded_source_ids, retry_source_ids # pyright: ignore[reportReturnType] + return source_proc_results def _decode_latin_1(s: str) -> str: @@ -172,31 +400,15 @@ def _decode_latin_1(s: str) -> str: def embed_sources( vectordb_loader: VectorDBLoader, config: TConfig, - sources: list[UploadFile], -) -> tuple[list[str],list[str]]: - # either not a file or a file that is allowed - sources_filtered = [ - source for source in sources - if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType] - or _allowed_file(source) - ] - + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> Mapping[int, IndexingError | None]: logger.debug('Embedding sources:', extra={ 'source_ids': [ - f'{source.filename} ({_decode_latin_1(source.headers["title"])})' - for source in sources_filtered - ], - 'invalid_source_ids': [ - source.filename for source in sources - if not is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType] - ], - 'not_allowed_file_ids': [ - source.filename for source in sources - if not _allowed_file(source) + f'{source.reference} ({_decode_latin_1(source.title)})' + for source in sources.values() ], - 'len(source_ids)': len(sources_filtered), - 'len(total_source_ids)': len(sources), + 'len(source_ids)': len(sources), }) vectordb = vectordb_loader.load() - return _process_sources(vectordb, config, sources_filtered) + return _process_sources(vectordb, config, sources) diff --git a/context_chat_backend/chain/one_shot.py b/context_chat_backend/chain/one_shot.py index 1c0521bf..c79f272e 100644 --- a/context_chat_backend/chain/one_shot.py +++ b/context_chat_backend/chain/one_shot.py @@ -10,7 +10,7 @@ from ..types import TConfig from .context import get_context_chunks, get_context_docs from .query_proc import get_pruned_query -from .types import ContextException, LLMOutput, ScopeType +from .types import ContextException, LLMOutput, ScopeType, SearchResult _LLM_TEMPLATE = '''Answer based only on this context and do not add any imaginative details. Make sure to use the same language as the question in your answer. {context} @@ -20,6 +20,7 @@ logger = logging.getLogger('ccb.chain') +# todo: remove this maybe def process_query( user_id: str, llm: LLM, @@ -78,6 +79,9 @@ def process_context_query( stop=[end_separator], userid=user_id, ).strip() - unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))}) + unique_sources = [SearchResult( + source_id=source, + title=d.metadata.get('title', ''), + ) for d in context_docs if (source := d.metadata.get('source'))] return LLMOutput(output=output, sources=unique_sources) diff --git a/context_chat_backend/chain/types.py b/context_chat_backend/chain/types.py index b006ad1a..3afdf297 100644 --- a/context_chat_backend/chain/types.py +++ b/context_chat_backend/chain/types.py @@ -33,12 +33,24 @@ class ContextException(Exception): ... +class SearchResult(TypedDict): + source_id: str + title: str + + class LLMOutput(TypedDict): output: str - sources: list[str] - # todo: add "titles" field + sources: list[SearchResult] -class SearchResult(TypedDict): - source_id: str - title: str +class EnrichedSource(BaseModel): + id: str + label: str + icon: str + url: str + +class EnrichedSourceList(BaseModel): + sources: list[EnrichedSource] + +class ScopeList(BaseModel): + source_ids: list[str] diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index c26b930a..9c3812e9 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider # isort: off -from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult +from .chain.types import ContextException from .types import LoaderException, EmbeddingException -from .vectordb.types import DbException, SafeDbException, UpdateAccessOp +from .vectordb.types import DbException, SafeDbException from .setup_functions import ensure_config_file, repair_run, setup_env_vars # setup env vars before importing other modules @@ -23,39 +24,29 @@ from collections.abc import Callable from contextlib import asynccontextmanager from functools import wraps -from threading import Event, Thread -from time import sleep -from typing import Annotated, Any -from fastapi import Body, FastAPI, Request, UploadFile -from langchain.llms.base import LLM +from fastapi import FastAPI, Request from nc_py_api import AsyncNextcloudApp, NextcloudApp from nc_py_api.ex_app import persistent_storage, set_handlers -from pydantic import BaseModel, ValidationInfo, field_validator from starlette.responses import FileResponse -from .chain.context import do_doc_search -from .chain.ingest.injest import embed_sources -from .chain.one_shot import process_context_query, process_query from .config_parser import get_config -from .dyn_loader import LLMModelLoader, VectorDBLoader +from .dyn_loader import VectorDBLoader from .models.types import LlmException from nc_py_api.ex_app import AppAPIAuthMiddleware -from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of -from .vectordb.service import ( - count_documents_by_provider, - decl_update_access, - delete_by_provider, - delete_by_source, - delete_user, - update_access, -) +from .utils import JSONResponse, exec_in_proc +from .task_fetcher import start_bg_threads, trigger_handler, wait_for_bg_threads +from .vectordb.service import count_documents_by_provider # setup -repair_run() -ensure_config_file() +# only run once +if mp.current_process().name == 'MainProcess': + repair_run() + ensure_config_file() + logger = logging.getLogger('ccb.controller') +app_config = get_config(os.environ['CC_CONFIG_PATH']) __download_models_from_hf = os.environ.get('CC_DOWNLOAD_MODELS_FROM_HF', 'true').lower() in ('1', 'true', 'yes') models_to_fetch = { @@ -70,13 +61,33 @@ 'revision': '607a30d783dfa663caf39e06633721c8d4cfcd7e', } } if __download_models_from_hf else {} -app_enabled = Event() +app_enabled = threading.Event() -def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: - if enabled: - app_enabled.set() - else: - app_enabled.clear() +def enabled_handler(enabled: bool, nc: NextcloudApp | AsyncNextcloudApp) -> str: + try: + if enabled: + provider = TaskProcessingProvider( + id="context_chat-context_chat_search", + name="Context Chat", + task_type="context_chat:context_chat_search", + expected_runtime=30, + ) + nc.providers.task_processing.register(provider) + provider = TaskProcessingProvider( + id="context_chat-context_chat", + name="Context Chat", + task_type="context_chat:context_chat", + expected_runtime=30, + ) + nc.providers.task_processing.register(provider) + app_enabled.set() + start_bg_threads(app_config, app_enabled) + else: + app_enabled.clear() + wait_for_bg_threads() + except Exception as e: + logger.exception('Error in enabled handler:', exc_info=e) + return f'Error in enabled handler: {e}' logger.info(f'App {("disabled", "enabled")[enabled]}') return '' @@ -84,19 +95,17 @@ def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: @asynccontextmanager async def lifespan(app: FastAPI): - set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch) + set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch, trigger_handler=trigger_handler) nc = NextcloudApp() if nc.enabled_state: app_enabled.set() + start_bg_threads(app_config, app_enabled) logger.info(f'App enable state at startup: {app_enabled.is_set()}') - t = Thread(target=background_thread_task, args=()) - t.start() yield vectordb_loader.offload() - llm_loader.offload() + wait_for_bg_threads() -app_config = get_config(os.environ['CC_CONFIG_PATH']) app = FastAPI(debug=app_config.debug, lifespan=lifespan) # pyright: ignore[reportArgumentType] app.extra['CONFIG'] = app_config @@ -105,7 +114,6 @@ async def lifespan(app: FastAPI): # loaders vectordb_loader = VectorDBLoader(app_config) -llm_loader = LLMModelLoader(app, app_config) # locks and semaphores @@ -117,22 +125,12 @@ async def lifespan(app: FastAPI): index_lock = threading.Lock() _indexing = {} -# limit the number of concurrent document parsing -doc_parse_semaphore = mp.Semaphore(app_config.doc_parser_worker_limit) - # middlewares if not app_config.disable_aaa: app.add_middleware(AppAPIAuthMiddleware) -# logger background thread - -def background_thread_task(): - while(True): - logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing}) - sleep(10) - # exception handlers @app.exception_handler(DbException) @@ -213,121 +211,6 @@ def _(): return JSONResponse(content={'enabled': app_enabled.is_set()}, status_code=200) -@app.post('/updateAccessDeclarative') -@enabled_guard(app) -def _( - userIds: Annotated[list[str], Body()], - sourceId: Annotated[str, Body()], -): - logger.debug('Update access declarative request:', extra={ - 'user_ids': userIds, - 'source_id': sourceId, - }) - - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - - if not is_valid_source_id(sourceId): - return JSONResponse('Invalid source id', 400) - - exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId)) - - return JSONResponse('Access updated') - - -@app.post('/updateAccess') -@enabled_guard(app) -def _( - op: Annotated[UpdateAccessOp, Body()], - userIds: Annotated[list[str], Body()], - sourceId: Annotated[str, Body()], -): - logger.debug('Update access request', extra={ - 'op': op, - 'user_ids': userIds, - 'source_id': sourceId, - }) - - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - - if not is_valid_source_id(sourceId): - return JSONResponse('Invalid source id', 400) - - exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId)) - - return JSONResponse('Access updated') - - -@app.post('/updateAccessProvider') -@enabled_guard(app) -def _( - op: Annotated[UpdateAccessOp, Body()], - userIds: Annotated[list[str], Body()], - providerId: Annotated[str, Body()], -): - logger.debug('Update access by provider request', extra={ - 'op': op, - 'user_ids': userIds, - 'provider_id': providerId, - }) - - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - - if not is_valid_provider_id(providerId): - return JSONResponse('Invalid provider id', 400) - - exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, providerId)) - - return JSONResponse('Access updated') - - -@app.post('/deleteSources') -@enabled_guard(app) -def _(sourceIds: Annotated[list[str], Body(embed=True)]): - logger.debug('Delete sources request', extra={ - 'source_ids': sourceIds, - }) - - sourceIds = [source.strip() for source in sourceIds if source.strip() != ''] - - if len(sourceIds) == 0: - return JSONResponse('No sources provided', 400) - - res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds)) - if res is False: - return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400) - - return JSONResponse('All valid sources deleted') - - -@app.post('/deleteProvider') -@enabled_guard(app) -def _(providerKey: str = Body(embed=True)): - logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey }) - - if value_of(providerKey) is None: - return JSONResponse('Invalid provider key provided', 400) - - exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey)) - - return JSONResponse('All valid sources deleted') - - -@app.post('/deleteUser') -@enabled_guard(app) -def _(userId: str = Body(embed=True)): - logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId }) - - if value_of(userId) is None: - return JSONResponse('Invalid userId provided', 400) - - exec_in_proc(target=delete_user, args=(vectordb_loader, userId)) - - return JSONResponse('User deleted') - - @app.post('/countIndexedDocuments') @enabled_guard(app) def _(): @@ -335,177 +218,6 @@ def _(): return JSONResponse(counts) -@app.put('/loadSources') -@enabled_guard(app) -def _(sources: list[UploadFile]): - global _indexing - - if len(sources) == 0: - return JSONResponse('No sources provided', 400) - - filtered_sources = [] - - for source in sources: - if not value_of(source.filename): - logger.warning('Skipping source with invalid source_id', extra={ - 'source_id': source.filename, - 'title': source.headers.get('title'), - }) - continue - - with index_lock: - if source.filename in _indexing: - # this request will be retried by the client - return JSONResponse( - f'This source ({source.filename}) is already being processed in another request, try again later', - 503, - headers={'cc-retry': 'true'}, - ) - - if not ( - value_of(source.headers.get('userIds')) - and source.headers.get('title', None) is not None - and value_of(source.headers.get('type')) - and value_of(source.headers.get('modified')) - and source.headers['modified'].isdigit() - and value_of(source.headers.get('provider')) - ): - logger.warning('Skipping source with invalid/missing headers', extra={ - 'source_id': source.filename, - 'title': source.headers.get('title'), - 'headers': source.headers, - }) - continue - - filtered_sources.append(source) - - # wait for 10 minutes before failing the request - semres = doc_parse_semaphore.acquire(block=True, timeout=10*60) - if not semres: - return JSONResponse( - 'Document parser worker limit reached, try again in some time or consider increasing the limit', - 503, - headers={'cc-retry': 'true'} - ) - - with index_lock: - for source in filtered_sources: - _indexing[source.filename] = source.size - - try: - loaded_sources, not_added_sources = exec_in_proc( - target=embed_sources, - args=(vectordb_loader, app.extra['CONFIG'], filtered_sources) - ) - except (DbException, EmbeddingException): - raise - except Exception as e: - raise DbException('Error: failed to load sources') from e - finally: - with index_lock: - for source in filtered_sources: - _indexing.pop(source.filename, None) - doc_parse_semaphore.release() - - if len(loaded_sources) != len(filtered_sources): - logger.debug('Some sources were not loaded', extra={ - 'Count of loaded sources': f'{len(loaded_sources)}/{len(filtered_sources)}', - 'source_ids': loaded_sources, - }) - - # loaded sources include the existing sources that may only have their access updated - return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources}) - - -class Query(BaseModel): - userId: str - query: str - useContext: bool = True - scopeType: ScopeType | None = None - scopeList: list[str] | None = None - ctxLimit: int = 20 - - @field_validator('userId', 'query', 'ctxLimit') - @classmethod - def check_empty_values(cls, value: Any, info: ValidationInfo): - if value_of(value) is None: - raise ValueError('Empty value for field', info.field_name) - - return value - - @field_validator('ctxLimit') - @classmethod - def at_least_one_context(cls, value: int): - if value < 1: - raise ValueError('Invalid context chunk limit') - - return value - - -def execute_query(query: Query, in_proc: bool = True) -> LLMOutput: - llm: LLM = llm_loader.load() - template = app.extra.get('LLM_TEMPLATE') - no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE'] - # todo: array - end_separator = app.extra.get('LLM_END_SEPARATOR', '') - - if query.useContext: - target = process_context_query - args=( - query.userId, - vectordb_loader, - llm, - app_config, - query.query, - query.ctxLimit, - query.scopeType, - query.scopeList, - template, - end_separator, - ) - else: - target=process_query - args=( - query.userId, - llm, - app_config, - query.query, - no_ctx_template, - end_separator, - ) - - if in_proc: - return exec_in_proc(target=target, args=args) - - return target(*args) # pyright: ignore - - -@app.post('/query') -@enabled_guard(app) -def _(query: Query) -> LLMOutput: - logger.debug('received query request', extra={ 'query': query.dict() }) - - if app_config.llm[0] == 'nc_texttotext': - return execute_query(query) - - with llm_lock: - return execute_query(query, in_proc=False) - - -@app.post('/docSearch') -@enabled_guard(app) -def _(query: Query) -> list[SearchResult]: - # useContext from Query is not used here - return exec_in_proc(target=do_doc_search, args=( - query.userId, - query.query, - vectordb_loader, - query.ctxLimit, - query.scopeType, - query.scopeList, - )) - - @app.get('/downloadLogs') def download_logs() -> FileResponse: with tempfile.NamedTemporaryFile('wb', delete=False) as tmp: diff --git a/context_chat_backend/dyn_loader.py b/context_chat_backend/dyn_loader.py index d67310ff..47b19575 100644 --- a/context_chat_backend/dyn_loader.py +++ b/context_chat_backend/dyn_loader.py @@ -7,11 +7,9 @@ import gc import logging from abc import ABC, abstractmethod -from time import time from typing import Any import torch -from fastapi import FastAPI from langchain.llms.base import LLM from .models.loader import init_model @@ -54,19 +52,11 @@ def offload(self) -> None: class LLMModelLoader(Loader): - def __init__(self, app: FastAPI, config: TConfig) -> None: + def __init__(self, config: TConfig) -> None: self.config = config - self.app = app def load(self) -> LLM: - if self.app.extra.get('LLM_MODEL') is not None: - self.app.extra['LLM_LAST_ACCESSED'] = time() - return self.app.extra['LLM_MODEL'] - llm_name, llm_config = self.config.llm - self.app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '') - self.app.extra['LLM_NO_CTX_TEMPLATE'] = llm_config.pop('no_ctx_template', '') - self.app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '') try: model = init_model('llm', (llm_name, llm_config)) @@ -75,13 +65,9 @@ def load(self) -> LLM: if not isinstance(model, LLM): raise LoaderException(f'Error: {model} does not implement "llm" type or has returned an invalid object') - self.app.extra['LLM_MODEL'] = model - self.app.extra['LLM_LAST_ACCESSED'] = time() return model def offload(self) -> None: - if self.app.extra.get('LLM_MODEL') is not None: - del self.app.extra['LLM_MODEL'] clear_cache() diff --git a/context_chat_backend/chain/ingest/mimetype_list.py b/context_chat_backend/mimetype_list.py similarity index 100% rename from context_chat_backend/chain/ingest/mimetype_list.py rename to context_chat_backend/mimetype_list.py diff --git a/context_chat_backend/network_em.py b/context_chat_backend/network_em.py index 18bb11f4..ba1edc9e 100644 --- a/context_chat_backend/network_em.py +++ b/context_chat_backend/network_em.py @@ -3,12 +3,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # import logging +import socket from time import sleep from typing import Literal, TypedDict +from urllib.parse import urlparse import niquests from langchain_core.embeddings import Embeddings -from pydantic import BaseModel from .types import ( EmbeddingException, @@ -20,6 +21,7 @@ ) logger = logging.getLogger('ccb.nextwork_em') +TCP_CONNECT_TIMEOUT = 2.0 # seconds # Copied from llama_cpp/llama_types.py @@ -41,8 +43,35 @@ class CreateEmbeddingResponse(TypedDict): usage: EmbeddingUsage -class NetworkEmbeddings(Embeddings, BaseModel): - app_config: TConfig +class NetworkEmbeddings(Embeddings): + def __init__(self, app_config: TConfig): + self.app_config = app_config + + def _get_host_and_port(self) -> tuple[str, int]: + parsed = urlparse(self.app_config.embedding.base_url) + host = parsed.hostname + + if not host: + raise ValueError("Invalid URL: Missing hostname") + + if parsed.port: + port = parsed.port + else: + port = 443 if parsed.scheme == "https" else 80 + + return host, port + + def check_connection(self, check_origin: str) -> bool: + try: + host, port = self._get_host_and_port() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(TCP_CONNECT_TIMEOUT) + sock.connect((host, port)) + sock.close() + return True + except (ValueError, TimeoutError, ConnectionRefusedError, socket.gaierror) as e: + logger.warning(f'[{check_origin}] Embedding server is not reachable, retrying after some time: {e}') + return False def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] | list[list[float]]: emconf = self.app_config.embedding @@ -79,6 +108,7 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] raise FatalEmbeddingException(response.text) if response.status_code // 100 != 2: raise EmbeddingException(response.text) + # todo: rework exception handling and their downstream interpretation except FatalEmbeddingException as e: logger.error('Fatal error while getting embeddings: %s', str(e), exc_info=e) raise e @@ -108,10 +138,14 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] logger.error('Unexpected error while getting embeddings', exc_info=e) raise EmbeddingException('Error: unexpected error while getting embeddings') from e - # converts TypedDict to a pydantic model - resp = CreateEmbeddingResponse(**response.json()) - if isinstance(input_, str): - return resp['data'][0]['embedding'] + try: + # converts TypedDict to a pydantic model + resp = CreateEmbeddingResponse(**response.json()) + if isinstance(input_, str): + return resp['data'][0]['embedding'] + except Exception as e: + logger.error('Error parsing embedding response', exc_info=e) + raise EmbeddingException('Error: failed to parse embedding response') from e # only one embedding in d['embedding'] since truncate is True return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType] diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py new file mode 100644 index 00000000..be74b316 --- /dev/null +++ b/context_chat_backend/task_fetcher.py @@ -0,0 +1,744 @@ +# +# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors +# SPDX-License-Identifier: AGPL-3.0-or-later +# +import logging +import math +import os +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress +from enum import Enum +from threading import Event, Thread +from time import sleep +from typing import Any + +import niquests +from langchain.llms.base import LLM +from nc_py_api import NextcloudApp, NextcloudException +from niquests import JSONDecodeError, RequestException +from pydantic import ValidationError + +from .chain.context import do_doc_search +from .chain.ingest.injest import embed_sources +from .chain.one_shot import process_context_query +from .chain.types import ContextException, EnrichedSourceList, LLMOutput, ScopeList, SearchResult +from .dyn_loader import LLMModelLoader, VectorDBLoader +from .network_em import NetworkEmbeddings +from .types import ( + ActionsQueueItems, + ActionType, + AppRole, + FilesQueueItems, + IndexingError, + LoaderException, + ReceivedFileItem, + SourceItem, + TConfig, +) +from .utils import SubprocessKilledError, exec_in_proc, get_app_role +from .vectordb.service import ( + decl_update_access, + delete_by_provider, + delete_by_source, + delete_user, + update_access, + update_access_provider, +) +from .vectordb.types import DbException, SafeDbException + +APP_ROLE = get_app_role() +THREADS = {} +THREAD_STOP_EVENT = Event() +LOGGER = logging.getLogger('ccb.task_fetcher') +FILES_INDEXING_BATCH_SIZE = 16 # theoretical max RAM usage: 16 * 100 MiB, todo: config? +MIN_FILES_PER_CPU = 4 +# divides the batch into these many chunks +PARALLEL_FILE_PARSING_COUNT = max(1, (os.cpu_count() or 2) - 1) # todo: config? +LOGGER.info(f'Using {PARALLEL_FILE_PARSING_COUNT} parallel file parsing workers') +ACTIONS_BATCH_SIZE = 512 # todo: config? +POLLING_COOLDOWN = 30 +TRIGGER = Event() +CHECK_INTERVAL = 5 +CHECK_INTERVAL_WITH_TRIGGER = 5 * 60 +CHECK_INTERVAL_ON_ERROR = 15 +CONTEXT_LIMIT=20 + + +class ThreadType(Enum): + FILES_INDEXING = 'files_indexing' + UPDATES_PROCESSING = 'updates_processing' + REQUEST_PROCESSING = 'request_processing' + + +def files_indexing_thread(app_config: TConfig, app_enabled: Event) -> None: + try: + network_em = NetworkEmbeddings(app_config) + vectordb_loader = VectorDBLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + def _load_sources(source_items: Mapping[int, SourceItem | ReceivedFileItem]) -> Mapping[int, IndexingError | None]: + source_refs = [s.reference for s in source_items.values()] + LOGGER.info('Starting embed_sources subprocess for %d source(s)', len(source_items), extra={ + 'source_ids': source_refs, + }) + try: + result = exec_in_proc( + target=embed_sources, + args=(vectordb_loader, app_config, source_items), + ) + errors = {k: v for k, v in result.items() if isinstance(v, IndexingError)} + LOGGER.info( + 'embed_sources finished for %d source(s): %d succeeded, %d errored', + len(source_items), len(result) - len(errors), len(errors), + extra={'errors': errors}, + ) + return result + except SubprocessKilledError as e: + LOGGER.error( + 'embed_sources subprocess was killed for %d source(s) with exitcode %s', + len(source_items), e.exitcode, exc_info=e, extra={ + 'source_ids': source_refs, + }, + ) + if len(source_items) == 1: + return dict.fromkeys( + source_items, + IndexingError(error=f'Subprocess killed with exitcode {e.exitcode}: {e}', retryable=False), + ) + + # Fall back to one-by-one to isolate the problematic file. + LOGGER.warning( + 'Falling back to individual processing for %d sources', + len(source_items), + ) + fallback: dict[int, IndexingError | None] = {} + for db_id, item in source_items.items(): + fallback.update(_load_sources({db_id: item})) + return fallback + except Exception as e: + err = IndexingError( + error=f'{e.__class__.__name__}: {e}', + retryable=True, + ) + LOGGER.error( + 'embed_sources subprocess raised a %s error for %d sources, marking all as retryable', + e.__class__.__name__, len(source_refs), exc_info=e, extra={ + 'source_ids': source_refs, + } + ) + return dict.fromkeys(source_items, err) + + + while True: + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Files indexing thread is stopping due to stop event being set') + return + + try: + if not network_em.check_connection(ThreadType.FILES_INDEXING.value): + sleep(POLLING_COOLDOWN) + continue + + nc = NextcloudApp() + q_items_res = nc.ocs( + 'GET', + '/ocs/v2.php/apps/context_chat/queues/documents', + params={ 'n': FILES_INDEXING_BATCH_SIZE } + ) + + try: + q_items: FilesQueueItems = FilesQueueItems.model_validate(q_items_res) + except ValidationError as e: + raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + + if not q_items.files and not q_items.content_providers: + LOGGER.debug('No documents to index') + sleep(POLLING_COOLDOWN) + continue + + files_result = {} + providers_result = {} + + # chunk file parsing for better file operation parallelism + file_chunk_size = max(MIN_FILES_PER_CPU, math.ceil(len(q_items.files) / PARALLEL_FILE_PARSING_COUNT)) + file_chunks = [ + dict(list(q_items.files.items())[i:i+file_chunk_size]) + for i in range(0, len(q_items.files), file_chunk_size) + ] + provider_chunk_size = max( + MIN_FILES_PER_CPU, + math.ceil(len(q_items.content_providers) / PARALLEL_FILE_PARSING_COUNT), + ) + provider_chunks = [ + dict(list(q_items.content_providers.items())[i:i+provider_chunk_size]) + for i in range(0, len(q_items.content_providers), provider_chunk_size) + ] + + with ThreadPoolExecutor( + max_workers=PARALLEL_FILE_PARSING_COUNT, + thread_name_prefix='IndexingPool', + ) as executor: + LOGGER.info( + 'Dispatching %d file chunk(s) and %d provider chunk(s) to %d IndexingPool worker(s)', + len(file_chunks), len(provider_chunks), PARALLEL_FILE_PARSING_COUNT, + ) + file_futures = [executor.submit(_load_sources, chunk) for chunk in file_chunks] + provider_futures = [executor.submit(_load_sources, chunk) for chunk in provider_chunks] + + for i, future in enumerate(file_futures): + LOGGER.debug('Waiting for file chunk %d/%d future to complete', i + 1, len(file_futures)) + files_result.update(future.result()) + LOGGER.debug('File chunk %d/%d future completed', i + 1, len(file_futures)) + for i, future in enumerate(provider_futures): + LOGGER.debug('Waiting for provider chunk %d/%d future to complete', i + 1, len(provider_futures)) + providers_result.update(future.result()) + LOGGER.debug('Provider chunk %d/%d future completed', i + 1, len(provider_futures)) + + if ( + any(isinstance(res, IndexingError) for res in files_result.values()) + or any(isinstance(res, IndexingError) for res in providers_result.values()) + ): + LOGGER.error('Some sources failed to index', extra={ + 'file_errors': { + db_id: error + for db_id, error in files_result.items() + if isinstance(error, IndexingError) + }, + 'provider_errors': { + provider_id: error + for provider_id, error in providers_result.items() + if isinstance(error, IndexingError) + }, + }) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error fetching documents to index, will retry:', exc_info=e) + sleep(5) + continue + except Exception as e: + LOGGER.exception('Error fetching documents to index:', exc_info=e) + sleep(5) + continue + + # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable + to_delete_files_db_ids = [ + db_id for db_id, result in files_result.items() + if result is None or (isinstance(result, IndexingError) and not result.retryable) + ] + to_delete_provider_db_ids = [ + db_id for db_id, result in providers_result.items() + if result is None or (isinstance(result, IndexingError) and not result.retryable) + ] + + try: + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/documents/', + json={ + 'files': to_delete_files_db_ids, + 'content_providers': to_delete_provider_db_ids, + }, + ) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error reporting indexing results, will retry:', exc_info=e) + sleep(5) + with suppress(Exception): + nc = NextcloudApp() + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/documents/', + json={ + 'files': to_delete_files_db_ids, + 'content_providers': to_delete_provider_db_ids, + }, + ) + continue + except Exception as e: + LOGGER.exception('Error reporting indexing results:', exc_info=e) + sleep(5) + continue + + + +def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: + try: + vectordb_loader = VectorDBLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + while True: + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Updates processing thread is stopping due to stop event being set') + return + + try: + nc = NextcloudApp() + q_items_res = nc.ocs( + 'GET', + '/ocs/v2.php/apps/context_chat/queues/actions', + params={ 'n': ACTIONS_BATCH_SIZE } + ) + + try: + q_items: ActionsQueueItems = ActionsQueueItems.model_validate(q_items_res) + except ValidationError as e: + raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error fetching updates to process, will retry:', exc_info=e) + sleep(5) + continue + except Exception as e: + LOGGER.exception('Error fetching updates to process:', exc_info=e) + sleep(5) + continue + + if not q_items.actions: + LOGGER.debug('No updates to process') + sleep(POLLING_COOLDOWN) + continue + + processed_event_ids = [] + errored_events = {} + for i, (db_id, action_item) in enumerate(q_items.actions.items()): + try: + match action_item.type: + case ActionType.DELETE_SOURCE_IDS: + exec_in_proc(target=delete_by_source, args=(vectordb_loader, action_item.payload.sourceIds)) + + case ActionType.DELETE_PROVIDER_ID: + exec_in_proc(target=delete_by_provider, args=(vectordb_loader, action_item.payload.providerId)) + + case ActionType.DELETE_USER_ID: + exec_in_proc(target=delete_user, args=(vectordb_loader, action_item.payload.userId)) + + case ActionType.UPDATE_ACCESS_SOURCE_ID: + exec_in_proc( + target=update_access, + args=( + vectordb_loader, + action_item.payload.op, + action_item.payload.userIds, + action_item.payload.sourceId, + ), + ) + + case ActionType.UPDATE_ACCESS_PROVIDER_ID: + exec_in_proc( + target=update_access_provider, + args=( + vectordb_loader, + action_item.payload.op, + action_item.payload.userIds, + action_item.payload.providerId, + ), + ) + + case ActionType.UPDATE_ACCESS_DECL_SOURCE_ID: + exec_in_proc( + target=decl_update_access, + args=( + vectordb_loader, + action_item.payload.userIds, + action_item.payload.sourceId, + ), + ) + + case _: + LOGGER.warning( + f'Unknown action type {action_item.type} for action id {db_id},' + f' type {action_item.type}, skipping and marking as processed', + extra={ 'action_item': action_item }, + ) + continue + + processed_event_ids.append(db_id) + except SafeDbException as e: + LOGGER.debug( + f'Safe DB error thrown while processing action id {db_id}, type {action_item.type},' + " it's safe to ignore and mark as processed.", + exc_info=e, + extra={ 'action_item': action_item }, + ) + processed_event_ids.append(db_id) + continue + + except (LoaderException, DbException) as e: + LOGGER.error( + f'Error deleting source for action id {db_id}, type {action_item.type}: {e}', + exc_info=e, + extra={ 'action_item': action_item }, + ) + errored_events[db_id] = str(e) + continue + + except Exception as e: + LOGGER.error( + f'Unexpected error processing action id {db_id}, type {action_item.type}: {e}', + exc_info=e, + extra={ 'action_item': action_item }, + ) + errored_events[db_id] = f'Unexpected error: {e}' + continue + + if (i + 1) % 20 == 0: + LOGGER.debug(f'Processed {i + 1} updates, sleeping for a bit to allow other operations to proceed') + sleep(2) + + LOGGER.info(f'Processed {len(processed_event_ids)} updates with {len(errored_events)} errors', extra={ + 'errored_events': errored_events, + }) + + if len(processed_event_ids) == 0: + LOGGER.debug('No updates processed, skipping reporting to the server') + continue + + try: + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/actions/', + json={ 'actions': processed_event_ids }, + ) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error reporting processed updates, will retry:', exc_info=e) + sleep(5) + with suppress(Exception): + nc = NextcloudApp() + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/actions/', + json={ 'ids': processed_event_ids }, + ) + continue + except Exception as e: + LOGGER.exception('Error reporting processed updates:', exc_info=e) + sleep(5) + continue + + +def resolve_scope_list(source_ids: list[str], userId: str) -> list[str]: + """ + + Parameters + ---------- + source_ids + + Returns + ------- + source_ids with only files, no folders (or source_ids in case of non-file provider) + """ + nc = NextcloudApp() + data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/resolve_scope_list', json={ + 'source_ids': source_ids, + 'userId': userId, + }) + return ScopeList.model_validate(data).source_ids + + +def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None: + LOGGER.info('Starting task fetcher loop') + + try: + network_em = NetworkEmbeddings(app_config) + vectordb_loader = VectorDBLoader(app_config) + llm_loader = LLMModelLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + nc = NextcloudApp() + llm: LLM = llm_loader.load() + + while True: + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Updates processing thread is stopping due to stop event being set') + return + + if not network_em.check_connection(ThreadType.REQUEST_PROCESSING.value): + sleep(POLLING_COOLDOWN) + continue + + try: + # Fetch pending task + try: + response = nc.providers.task_processing.next_task( + ['context_chat-context_chat', 'context_chat-context_chat_search'], + ['context_chat:context_chat', 'context_chat:context_chat_search'], + ) + if not response: + wait_for_tasks() + continue + except (NextcloudException, RequestException, JSONDecodeError) as e: + LOGGER.error(f"Network error fetching the next task {e}", exc_info=e) + wait_for_tasks(CHECK_INTERVAL_ON_ERROR) + continue + + # Process task + task = response["task"] + userId = task['userId'] + + try: + LOGGER.debug(f'Processing task {task["id"]}') + + if task['input'].get('scopeType') == 'source': + # Resolve scope list to only files, no folders + task['input']['scopeList'] = resolve_scope_list(task['input'].get('scopeList'), userId) + + if task['type'] == 'context_chat:context_chat': + result: LLMOutput = process_normal_task(task, vectordb_loader, llm, app_config) + # Return result to Nextcloud + success = return_result_to_nextcloud(task['id'], userId, { + 'output': result['output'], + 'sources': enrich_sources(result['sources'], userId), + }) + elif task['type'] == 'context_chat:context_chat_search': + search_result: list[SearchResult] = process_search_task(task, vectordb_loader) + # Return result to Nextcloud + success = return_result_to_nextcloud(task['id'], userId, { + 'sources': enrich_sources(search_result, userId), + }) + else: + LOGGER.error(f'Unknown task type {task["type"]}') + success = return_error_to_nextcloud(task['id'], Exception(f'Unknown task type {task["type"]}')) + + if success: + LOGGER.info(f'Task {task["id"]} completed successfully') + else: + LOGGER.error(f'Failed to return result for task {task["id"]}') + + except ContextException as e: + LOGGER.warning(f'Context error for task {task["id"]}: {e}') + return_error_to_nextcloud(task['id'], e) + except ValueError as e: + LOGGER.warning(f'Validation error for task {task["id"]}: {e}') + return_error_to_nextcloud(task['id'], e) + except Exception as e: + LOGGER.exception(f'Unexpected error processing task {task["id"]}', exc_info=e) + return_error_to_nextcloud(task['id'], e) + + except Exception as e: + LOGGER.exception('Error in task fetcher loop', exc_info=e) + wait_for_tasks(CHECK_INTERVAL_ON_ERROR) + +def trigger_handler(providerId: str): + global TRIGGER + print('TRIGGER called') + TRIGGER.set() + +def wait_for_tasks(interval = None): + global TRIGGER + global CHECK_INTERVAL + global CHECK_INTERVAL_WITH_TRIGGER + actual_interval = CHECK_INTERVAL if interval is None else interval + if TRIGGER.wait(timeout=actual_interval): + CHECK_INTERVAL = CHECK_INTERVAL_WITH_TRIGGER + TRIGGER.clear() + + +def enrich_sources(results: list[SearchResult], userId: str) -> list[str]: + nc = NextcloudApp() + data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/enrich_sources', json={'sources': results, 'userId': userId}) + sources = EnrichedSourceList.model_validate(data).sources + return [s.model_dump_json() for s in sources] + + +def return_result_to_nextcloud(task_id: int, userId: str, result: dict[str, Any]) -> bool: + """ + Return query result back to Nextcloud. + + Args: + result: dict[str, Any] + + Returns: + True if successful, False otherwise + """ + LOGGER.debug('Returning result to Nextcloud', extra={ + 'task_id': task_id, + 'result': result, + }) + + nc = NextcloudApp() + + try: + nc.providers.task_processing.report_result(task_id, result) + except (NextcloudException, RequestException, JSONDecodeError) as e: + LOGGER.error(f"Network error reporting task result {e}", exc_info=e) + return False + + return True + + +def return_error_to_nextcloud(task_id: int, e: Exception) -> bool: + """ + Return error result back to Nextcloud. + + Args: + task_id: Unique task identifier + e: error object + + Returns: + True if successful, False otherwise + """ + LOGGER.debug('Returning error to Nextcloud', exc_info=e) + + nc = NextcloudApp() + + if isinstance(e, ValueError): + message = "Validation error: " + str(e) + elif isinstance(e, ContextException): + message = "Context error" + str(e) + else: + message = "Unexpected error" + str(e) + + try: + nc.providers.task_processing.report_result(task_id, None, message) + except (NextcloudException, RequestException, JSONDecodeError) as e: + LOGGER.error(f"Network error reporting task result {e}", exc_info=e) + return False + + return True + + +def process_normal_task( + task: dict[str, Any], + vectordb_loader: VectorDBLoader, + llm: LLM, + app_config: TConfig, +) -> LLMOutput: + """ + Process a single query task. + + Args: + task: Task dictionary from fetch_query_tasks_from_nextcloud + vectordb_loader: Vector database loader instance + llm: Language model instance + app_config: Application configuration + + Returns: + LLMOutput with generated text and sources + + Raises: + Various exceptions from query execution + """ + user_id = task['userId'] + task_input = task['input'] + if task_input.get('scopeType') == 'none': + task_input['scopeType'] = None + + # todo: document no template support + return exec_in_proc(target=process_context_query, + args=( + user_id, + vectordb_loader, + llm, + app_config, + task_input.get('prompt'), + CONTEXT_LIMIT, + task_input.get('scopeType'), + task_input.get('scopeList'), + ) + ) + +def process_search_task( + task: dict[str, Any], + vectordb_loader: VectorDBLoader, +) -> list[SearchResult]: + """ + Process a single search task. + + Args: + task: Task dictionary from fetch_query_tasks_from_nextcloud + vectordb_loader: Vector database loader instance + + Returns: + list of Search results + + Raises: + Various exceptions from query execution + """ + user_id = task['userId'] + task_input = task['input'] + if task_input.get('scopeType') == 'none': + task_input['scopeType'] = None + + return exec_in_proc(target=do_doc_search, + args=( + user_id, + task_input.get('prompt'), + vectordb_loader, + CONTEXT_LIMIT, + task_input.get('scopeType'), + task_input.get('scopeList'), + ) + ) + + +def start_bg_threads(app_config: TConfig, app_enabled: Event): + if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL: + if ( + ThreadType.FILES_INDEXING in THREADS + or ThreadType.UPDATES_PROCESSING in THREADS + ): + LOGGER.info('Background threads already running, skipping start') + return + + THREAD_STOP_EVENT.clear() + THREADS[ThreadType.FILES_INDEXING] = Thread( + target=files_indexing_thread, + args=(app_config, app_enabled), + name='FilesIndexingThread', + ) + THREADS[ThreadType.UPDATES_PROCESSING] = Thread( + target=updates_processing_thread, + args=(app_config, app_enabled), + name='UpdatesProcessingThread', + ) + THREADS[ThreadType.FILES_INDEXING].start() + THREADS[ThreadType.UPDATES_PROCESSING].start() + + if APP_ROLE == AppRole.RP or APP_ROLE == AppRole.NORMAL: + if ThreadType.REQUEST_PROCESSING in THREADS: + LOGGER.info('Background threads already running, skipping start') + return + + THREAD_STOP_EVENT.clear() + THREADS[ThreadType.REQUEST_PROCESSING] = Thread( + target=request_processing_thread, + args=(app_config, app_enabled), + name='RequestProcessingThread', + ) + THREADS[ThreadType.REQUEST_PROCESSING].start() + + +def wait_for_bg_threads(): + if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL: + if (ThreadType.FILES_INDEXING not in THREADS or ThreadType.UPDATES_PROCESSING not in THREADS): + return + + THREAD_STOP_EVENT.set() + THREADS[ThreadType.FILES_INDEXING].join() + THREADS[ThreadType.UPDATES_PROCESSING].join() + THREADS.pop(ThreadType.FILES_INDEXING) + THREADS.pop(ThreadType.UPDATES_PROCESSING) + + if APP_ROLE == AppRole.RP or APP_ROLE == AppRole.NORMAL: + if (ThreadType.REQUEST_PROCESSING not in THREADS): + return + + THREAD_STOP_EVENT.set() + THREADS[ThreadType.REQUEST_PROCESSING].join() + THREADS.pop(ThreadType.REQUEST_PROCESSING) diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 500a97d0..59d2568f 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -2,7 +2,16 @@ # SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # -from pydantic import BaseModel +import re +from collections.abc import Mapping +from enum import Enum +from io import BytesIO +from typing import Annotated, Literal, Self + +from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator, model_validator + +from .mimetype_list import SUPPORTED_MIMETYPES +from .vectordb.types import UpdateAccessOp __all__ = [ 'DEFAULT_EM_MODEL_ALIAS', @@ -15,6 +24,65 @@ ] DEFAULT_EM_MODEL_ALIAS = 'em_model' +FILES_PROVIDER_ID = 'files__default' + + +def is_valid_source_id(source_id: str) -> bool: + # note the ":" in the item id part + return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None + + +def is_valid_provider_id(provider_id: str) -> bool: + return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None + + +def _validate_source_ids(source_ids: list[str]) -> list[str]: + if ( + not isinstance(source_ids, list) + or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids) + or len(source_ids) == 0 + ): + raise ValueError('sourceIds must be a non-empty list of non-empty strings') + return [sid.strip() for sid in source_ids] + + +def _validate_source_id(source_id: str) -> str: + return _validate_source_ids([source_id])[0] + + +def _validate_provider_id(provider_id: str) -> str: + if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id): + raise ValueError('providerId must be a valid provider ID string') + return provider_id + + +def _validate_user_ids(user_ids: list[str]) -> list[str]: + if ( + not isinstance(user_ids, list) + or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids) + or len(user_ids) == 0 + ): + raise ValueError('userIds must be a non-empty list of non-empty strings') + return [uid.strip() for uid in user_ids] + + +def _validate_user_id(user_id: str) -> str: + return _validate_user_ids([user_id])[0] + + +def _get_file_id_from_source_ref(source_ref: str) -> int: + ''' + source reference is in the format "FILES_PROVIDER_ID: ". + ''' + if not source_ref.startswith(f'{FILES_PROVIDER_ID}: '): + raise ValueError(f'Source reference does not start with expected prefix: {source_ref}') + + try: + return int(source_ref[len(f'{FILES_PROVIDER_ID}: '):]) + except ValueError as e: + raise ValueError( + f'Invalid source reference format for extracting file_id: {source_ref}' + ) from e class TEmbeddingAuthApiKey(BaseModel): @@ -71,3 +139,209 @@ class FatalEmbeddingException(EmbeddingException): Either malformed request, authentication error, or other non-retryable error. """ + + +class AppRole(str, Enum): + NORMAL = 'normal' + INDEXING = 'indexing' + RP = 'rp' + + +class CommonSourceItem(BaseModel): + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + # source_id of the form "appId__providerId: itemId" + reference: Annotated[str, AfterValidator(_validate_source_id)] + title: str + modified: int + type: str + provider: Annotated[str, AfterValidator(_validate_provider_id)] + size: float + + @field_validator('modified', mode='before') + @classmethod + def validate_modified(cls, v): + if isinstance(v, int): + return v + if isinstance(v, str): + try: + return int(v) + except ValueError as e: + raise ValueError(f'Invalid modified value: {v}') from e + raise ValueError(f'Invalid modified type: {type(v)}') + + @field_validator('reference', 'title', 'type', 'provider') + @classmethod + def validate_strings_non_empty(cls, v): + if not isinstance(v, str) or v.strip() == '': + raise ValueError('Must be a non-empty string') + return v.strip() + + @field_validator('size') + @classmethod + def validate_size(cls, v): + if isinstance(v, int | float) and v >= 0: + return float(v) + raise ValueError(f'Invalid size value: {v}, must be a non-negative number') + + @model_validator(mode='after') + def validate_type(self) -> Self: + if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES: + raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}') + return self + + +class ReceivedFileItem(CommonSourceItem): + content: None + + @computed_field + @property + def file_id(self) -> int: + return _get_file_id_from_source_ref(self.reference) + + +class SourceItem(CommonSourceItem): + ''' + Used for the unified queue of items to process, after fetching the content for files + and for directly fetched content providers. + ''' + content: str | BytesIO + + @field_validator('content') + @classmethod + def validate_content(cls, v): + if isinstance(v, str): + if v.strip() == '': + raise ValueError('Content must be a non-empty string') + return v.strip() + if isinstance(v, BytesIO): + if v.getbuffer().nbytes == 0: + raise ValueError('Content must be a non-empty BytesIO') + return v + raise ValueError('Content must be either a non-empty string or a non-empty BytesIO') + + class Config: + # to allow BytesIO in content field + arbitrary_types_allowed = True + + +class FilesQueueItems(BaseModel): + files: Mapping[int, ReceivedFileItem] # [db id]: FileItem + content_providers: Mapping[int, SourceItem] # [db id]: SourceItem + + +class IndexingException(Exception): + retryable: bool = False + + def __init__(self, message: str, retryable: bool = False): + super().__init__(message) + self.retryable = retryable + + +class IndexingError(BaseModel): + error: str + retryable: bool = False + + +# PHP equivalent for reference: + +# class ActionType { +# // { sourceIds: array } +# public const DELETE_SOURCE_IDS = 'delete_source_ids'; +# // { providerId: string } +# public const DELETE_PROVIDER_ID = 'delete_provider_id'; +# // { userId: string } +# public const DELETE_USER_ID = 'delete_user_id'; +# // { op: string, userIds: array, sourceId: string } +# public const UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id'; +# // { op: string, userIds: array, providerId: string } +# public const UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id'; +# // { userIds: array, sourceId: string } +# public const UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id'; +# } + + +class ActionPayloadDeleteSourceIds(BaseModel): + sourceIds: Annotated[list[str], AfterValidator(_validate_source_ids)] + + +class ActionPayloadDeleteProviderId(BaseModel): + providerId: Annotated[str, AfterValidator(_validate_provider_id)] + + +class ActionPayloadDeleteUserId(BaseModel): + userId: Annotated[str, AfterValidator(_validate_user_id)] + + +class ActionPayloadUpdateAccessSourceId(BaseModel): + op: UpdateAccessOp + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + sourceId: Annotated[str, AfterValidator(_validate_source_id)] + + +class ActionPayloadUpdateAccessProviderId(BaseModel): + op: UpdateAccessOp + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + providerId: Annotated[str, AfterValidator(_validate_provider_id)] + + +class ActionPayloadUpdateAccessDeclSourceId(BaseModel): + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + sourceId: Annotated[str, AfterValidator(_validate_source_id)] + + +class ActionType(str, Enum): + DELETE_SOURCE_IDS = 'delete_source_ids' + DELETE_PROVIDER_ID = 'delete_provider_id' + DELETE_USER_ID = 'delete_user_id' + UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id' + UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id' + UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id' + + +class CommonActionsQueueItem(BaseModel): + id: int + + +class ActionsQueueItemDeleteSourceIds(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_SOURCE_IDS] + payload: ActionPayloadDeleteSourceIds + + +class ActionsQueueItemDeleteProviderId(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_PROVIDER_ID] + payload: ActionPayloadDeleteProviderId + + +class ActionsQueueItemDeleteUserId(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_USER_ID] + payload: ActionPayloadDeleteUserId + + +class ActionsQueueItemUpdateAccessSourceId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_SOURCE_ID] + payload: ActionPayloadUpdateAccessSourceId + + +class ActionsQueueItemUpdateAccessProviderId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_PROVIDER_ID] + payload: ActionPayloadUpdateAccessProviderId + + +class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_DECL_SOURCE_ID] + payload: ActionPayloadUpdateAccessDeclSourceId + + +ActionsQueueItem = Annotated[ + ActionsQueueItemDeleteSourceIds + | ActionsQueueItemDeleteProviderId + | ActionsQueueItemDeleteUserId + | ActionsQueueItemUpdateAccessSourceId + | ActionsQueueItemUpdateAccessProviderId + | ActionsQueueItemUpdateAccessDeclSourceId, + Discriminator('type'), +] + + +class ActionsQueueItems(BaseModel): + actions: Mapping[int, ActionsQueueItem] diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py index f6d6e672..4552e320 100644 --- a/context_chat_backend/utils.py +++ b/context_chat_backend/utils.py @@ -2,11 +2,15 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +import faulthandler +import io import logging import multiprocessing as mp -import re +import os +import sys import traceback from collections.abc import Callable +from contextlib import suppress from functools import partial, wraps from multiprocessing.connection import Connection from time import perf_counter_ns @@ -14,10 +18,11 @@ from fastapi.responses import JSONResponse as FastAPIJSONResponse -from .types import TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig +from .types import AppRole, TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig T = TypeVar('T') _logger = logging.getLogger('ccb.utils') +_MAX_STD_CAPTURE_CHARS = 64 * 1024 def not_none(value: T | None) -> TypeGuard[T]: @@ -69,19 +74,98 @@ def JSONResponse( return FastAPIJSONResponse(content, status_code, **kwargs) -def exception_wrap(fun: Callable | None, *args, resconn: Connection, **kwargs): +class SubprocessKilledError(RuntimeError): + """Raised when a subprocess is terminated by a signal (for example SIGKILL).""" + + def __init__(self, pid: int | None, target_name: str, exitcode: int): + super().__init__( + f'Subprocess PID {pid} for {target_name} exited with signal {abs(exitcode)} ' + f'(raw exit code: {exitcode})' + ) + self.exitcode = exitcode + + +class SubprocessExecutionError(RuntimeError): + """Raised when a subprocess exits without a recoverable Python exception payload.""" + + def __init__(self, pid: int | None, target_name: str, exitcode: int, details: str = ''): + msg = f'Subprocess PID {pid} for {target_name} exited with exit code {exitcode}' + if details: + msg = f'{msg}: {details}' + super().__init__(msg) + self.exitcode = exitcode + + +def _truncate_capture(text: str) -> str: + if len(text) <= _MAX_STD_CAPTURE_CHARS: + return text + + head = _MAX_STD_CAPTURE_CHARS // 2 + tail = _MAX_STD_CAPTURE_CHARS - head + omitted = len(text) - _MAX_STD_CAPTURE_CHARS + return ( + f'[truncated {omitted} chars]\n' + f'{text[:head]}\n' + '[...snip...]\n' + f'{text[-tail:]}' + ) + + +def exception_wrap(fun: Callable | None, *args, resconn: Connection, stdconn: Connection, **kwargs): + # Preserve real stderr FD for faulthandler before we redirect sys.stderr. + _faulthandler_fd = os.dup(2) + with suppress(Exception): + faulthandler.enable( + file=os.fdopen(_faulthandler_fd, 'w', closefd=False), + all_threads=True, + ) + + stdout_capture = io.StringIO() + stderr_capture = io.StringIO() + orig_stdout = sys.stdout + orig_stderr = sys.stderr + sys.stdout = stdout_capture + sys.stderr = stderr_capture + try: if fun is None: - return resconn.send({ 'value': None, 'error': None }) - resconn.send({ 'value': fun(*args, **kwargs), 'error': None }) - except Exception as e: + resconn.send({ 'value': None, 'error': None }) + else: + resconn.send({ 'value': fun(*args, **kwargs), 'error': None }) + except BaseException as e: tb = traceback.format_exc() - resconn.send({ 'value': None, 'error': e, 'traceback': tb }) - - -def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): # noqa: B006 + payload = { + 'value': None, + 'error': e, + 'traceback': tb, + } + try: + resconn.send(payload) + except Exception as send_err: + stderr_capture.write(f'Original error: {e}, pipe send error: {send_err}') + finally: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + stdout_text = _truncate_capture(stdout_capture.getvalue()) + stderr_text = _truncate_capture(stderr_capture.getvalue()) + with suppress(Exception): + stdconn.send({ + 'stdout': stdout_text, + 'stderr': stderr_text, + }) + with suppress(Exception): + os.close(_faulthandler_fd) + + +def exec_in_proc(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None): + if not kwargs: + kwargs = {} + + # parent, child pconn, cconn = mp.Pipe() + std_pconn, std_cconn = mp.Pipe() kwargs['resconn'] = cconn + kwargs['stdconn'] = std_cconn p = mp.Process( group=group, target=partial(exception_wrap, target), @@ -90,24 +174,92 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem kwargs=kwargs, daemon=daemon, ) + target_name = getattr(target, '__name__', str(target)) + start = perf_counter_ns() p.start() + _logger.debug('Subprocess PID %d started for %s', p.pid, target_name) + + result = None + stdobj = { 'stdout': '', 'stderr': '' } + got_result = False + got_std = False + + # Drain result/std pipes while child is still alive to avoid deadlock on full pipe buffers. + # Pipe's buffer size is 64 KiB + while p.is_alive() and (not got_result or not got_std): + if not got_result and pconn.poll(0.1): + with suppress(EOFError, OSError, BrokenPipeError): + result = pconn.recv() + got_result = True + if not got_std and std_pconn.poll(): + with suppress(EOFError, OSError, BrokenPipeError): + stdobj = std_pconn.recv() + got_std = True + p.join() + elapsed_ms = (perf_counter_ns() - start) / 1e6 + _logger.debug( + 'Subprocess PID %d for %s finished in %.2f ms (exit code: %s)', + p.pid, target_name, elapsed_ms, p.exitcode, + ) - result = pconn.recv() - if result['error'] is not None: - _logger.error('original traceback: %s', result['traceback']) + if not got_std: + with suppress(EOFError, OSError, BrokenPipeError): + if std_pconn.poll(): + stdobj = std_pconn.recv() + # no need to update got_std here + if stdobj.get('stdout') or stdobj.get('stderr'): + _logger.info('std info for %s', target_name, extra={ + 'stdout': stdobj.get('stdout', ''), + 'stderr': stdobj.get('stderr', ''), + }) + + if not got_result: + with suppress(EOFError, OSError, BrokenPipeError): + if pconn.poll(): + result = pconn.recv() + # no need to update got_result here + + if result is not None and result.get('error') is not None: + _logger.error( + 'original traceback of %s (PID %d, exitcode: %s): %s', + target_name, + p.pid, + p.exitcode, + result.get('traceback', ''), + ) raise result['error'] - return result['value'] - - -def is_valid_source_id(source_id: str) -> bool: - # note the ":" in the item id part - return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None + if result is not None and 'value' in result: + if p.exitcode not in (None, 0): + _logger.warning( + 'Subprocess PID %d for %s exited with code %s after %.2f ms' + ' but returned a valid result', + p.pid, target_name, p.exitcode, elapsed_ms, + ) + return result['value'] + if p.exitcode and p.exitcode < 0: + _logger.warning( + 'Subprocess PID %d for %s exited due to signal %d, exitcode %d after %.2f ms', + p.pid, target_name, abs(p.exitcode), p.exitcode, elapsed_ms, + ) + raise SubprocessKilledError(p.pid, target_name, p.exitcode) + + if p.exitcode not in (None, 0): + raise SubprocessExecutionError( + p.pid, + target_name, + p.exitcode, + f'No structured exception payload received from child process: {result}', + ) -def is_valid_provider_id(provider_id: str) -> bool: - return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None + raise SubprocessExecutionError( + p.pid, + target_name, + 0, + f'Subprocess exited successfully but returned no result payload: {result}', + ) def timed(func: Callable): @@ -144,3 +296,13 @@ def redact_config(config: TConfig | TEmbeddingConfig) -> TConfig | TEmbeddingCon em_conf.auth.password = '***REDACTED***' # noqa: S105 return config_copy + + +def get_app_role() -> AppRole: + role = os.getenv('APP_ROLE', '').lower() + if role == '': + return AppRole.NORMAL + if role not in ['indexing', 'rp']: + _logger.warning(f'Invalid app role: {role}, defaulting to all roles') + return AppRole.NORMAL + return AppRole(role) diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py index 0bf10200..2b4aa35e 100644 --- a/context_chat_backend/vectordb/base.py +++ b/context_chat_backend/vectordb/base.py @@ -3,14 +3,15 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # from abc import ABC, abstractmethod +from collections.abc import Mapping from typing import Any -from fastapi import UploadFile from langchain.schema import Document from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore from ..chain.types import InDocument, ScopeType +from ..types import IndexingError, ReceivedFileItem, SourceItem from ..utils import timed from .types import UpdateAccessOp @@ -62,7 +63,7 @@ def get_instance(self) -> VectorStore: ''' @abstractmethod - def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list[str]]: + def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]: ''' Adds the given indocuments to the vectordb and updates the docs + access tables. @@ -79,10 +80,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list @timed @abstractmethod - def check_sources( - self, - sources: list[UploadFile], - ) -> tuple[list[str], list[str]]: + def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]: ''' Checks the sources in the vectordb if they are already embedded and are up to date. @@ -91,8 +89,8 @@ def check_sources( Args ---- - sources: list[UploadFile] - List of source ids to check. + sources: Mapping[int, SourceItem | ReceivedFileItem] + Dict of sources to check. Returns ------- diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py index 2b7fc060..d7b718dc 100644 --- a/context_chat_backend/vectordb/pgvector.py +++ b/context_chat_backend/vectordb/pgvector.py @@ -4,21 +4,29 @@ # import logging import os +from collections.abc import Mapping from datetime import datetime +from time import perf_counter_ns import psycopg import sqlalchemy as sa import sqlalchemy.dialects.postgresql as postgresql_dialects import sqlalchemy.orm as orm from dotenv import load_dotenv -from fastapi import UploadFile from langchain.schema import Document from langchain.vectorstores import VectorStore from langchain_core.embeddings import Embeddings from langchain_postgres.vectorstores import Base, PGVector from ..chain.types import InDocument, ScopeType -from ..types import EmbeddingException, RetryableEmbeddingException +from ..types import ( + EmbeddingException, + FatalEmbeddingException, + IndexingError, + ReceivedFileItem, + RetryableEmbeddingException, + SourceItem, +) from ..utils import timed from .base import BaseVectorDB from .types import DbException, SafeDbException, UpdateAccessOp @@ -112,7 +120,15 @@ def __init__(self, embedding: Embeddings | None = None, **kwargs): kwargs['connection'] = os.environ['CCB_DB_URL'] # setup langchain db + our access list table - self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs) + try: + self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs) + except sa.exc.IntegrityError as ie: # pyright: ignore[reportAttributeAccessIssue] + if not isinstance(ie.orig, psycopg.errors.UniqueViolation): + raise + + # tried to create the tables but it was already created in another process + # init the client again to detect it already exists, and continue from there + self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs) def get_instance(self) -> VectorStore: return self.client @@ -130,24 +146,40 @@ def get_users(self) -> list[str]: except Exception as e: raise DbException('Error: getting a list of all users from access list') from e - def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]: + def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]: """ Raises EmbeddingException: if the embedding request definitively fails """ - added_sources = [] - retry_sources = [] + results = {} batch_size = PG_BATCH_SIZE // 5 with self.session_maker() as session: - for indoc in indocuments: + for php_db_id, indoc in indocuments.items(): try: # query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html) # so we chunk the documents into (5 values * 10k) chunks # change the chunk size when there are more inserted values per document chunk_ids = [] - for i in range(0, len(indoc.documents), batch_size): + total_chunks = len(indoc.documents) + num_batches = max(1, -(-total_chunks // batch_size)) # ceiling division + logger.debug( + 'Embedding source %s: %d chunk(s) in %d batch(es)', + indoc.source_id, total_chunks, num_batches, + ) + for i in range(0, total_chunks, batch_size): + batch_num = i // batch_size + 1 + logger.debug( + 'Sending embedding batch %d/%d (%d chunk(s)) for source %s', + batch_num, num_batches, len(indoc.documents[i:i+batch_size]), indoc.source_id, + ) + t0 = perf_counter_ns() chunk_ids.extend(self.client.add_documents(indoc.documents[i:i+batch_size])) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.debug( + 'Embedding batch %d/%d for source %s completed in %.2f ms', + batch_num, num_batches, indoc.source_id, elapsed_ms, + ) doc = DocumentsStore( source_id=indoc.source_id, @@ -170,7 +202,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis ) self.decl_update_access(indoc.userIds, indoc.source_id, session) - added_sources.append(indoc.source_id) + results[php_db_id] = None session.commit() except SafeDbException as e: # for when the source_id is not found. This here can be an error in the DB @@ -178,51 +210,62 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error=str(e), + retryable=True, + ) continue - except RetryableEmbeddingException as e: + except FatalEmbeddingException as e: + raise EmbeddingException( + f'Fatal error while embedding documents for source {indoc.source_id}: {e}' + ) from e + except (RetryableEmbeddingException, EmbeddingException) as e: # temporary error, continue with the next document logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error=str(e), + retryable=True, + ) continue - except EmbeddingException as e: - logger.exception('Error adding documents to vectordb', exc_info=e, extra={ - 'source_id': indoc.source_id, - }) - raise except Exception as e: logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error='An unexpected error occurred while adding documents to the database.', + retryable=True, + ) continue - return added_sources, retry_sources + return results @timed - def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]]: + def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]: + ''' + returns a tuple of (existing_source_ids, to_embed_source_ids) + ''' with self.session_maker() as session: try: stmt = ( sa.select(DocumentsStore.source_id) - .filter(DocumentsStore.source_id.in_([source.filename for source in sources])) + .filter(DocumentsStore.source_id.in_([source.reference for source in sources.values()])) .with_for_update() ) results = session.execute(stmt).fetchall() existing_sources = {r.source_id for r in results} - to_embed = [source.filename for source in sources if source.filename not in existing_sources] + to_embed = [source.reference for source in sources.values() if source.reference not in existing_sources] to_delete = [] - for source in sources: + for source in sources.values(): stmt = ( sa.select(DocumentsStore.source_id) - .filter(DocumentsStore.source_id == source.filename) + .filter(DocumentsStore.source_id == source.reference) .filter(DocumentsStore.modified < sa.cast( - datetime.fromtimestamp(int(source.headers['modified'])), + datetime.fromtimestamp(int(source.modified)), sa.DateTime, )) ) @@ -239,14 +282,13 @@ def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str] session.rollback() raise DbException('Error: checking sources in vectordb') from e - still_existing_sources = [ - source - for source in existing_sources - if source not in to_delete + still_existing_source_ids = [ + source_id + for source_id in existing_sources + if source_id not in to_delete ] - # the pyright issue stems from source.filename, which has already been validated - return list(still_existing_sources), to_embed # pyright: ignore[reportReturnType] + return list(still_existing_source_ids), to_embed def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.Session | None = None): session = session_ or self.session_maker() @@ -325,7 +367,7 @@ def update_access( ) match op: - case UpdateAccessOp.allow: + case UpdateAccessOp.ALLOW: for i in range(0, len(user_ids), PG_BATCH_SIZE): batched_uids = user_ids[i:i+PG_BATCH_SIZE] stmt = ( @@ -342,7 +384,7 @@ def update_access( session.execute(stmt) session.commit() - case UpdateAccessOp.deny: + case UpdateAccessOp.DENY: for i in range(0, len(user_ids), PG_BATCH_SIZE): batched_uids = user_ids[i:i+PG_BATCH_SIZE] stmt = ( @@ -435,15 +477,17 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None # entry from "AccessListStore" is deleted automatically due to the foreign key constraint # batch the deletion to avoid hitting the query parameter limit chunks_to_delete = [] + deleted_source_ids = [] for i in range(0, len(source_ids), PG_BATCH_SIZE): batched_ids = source_ids[i:i+PG_BATCH_SIZE] stmt_doc = ( sa.delete(DocumentsStore) .filter(DocumentsStore.source_id.in_(batched_ids)) - .returning(DocumentsStore.chunks) + .returning(DocumentsStore.chunks, DocumentsStore.source_id) ) doc_result = session.execute(stmt_doc) chunks_to_delete.extend(str(c) for res in doc_result for c in res.chunks) + deleted_source_ids.extend(str(res.source_id) for res in doc_result) for i in range(0, len(chunks_to_delete), PG_BATCH_SIZE): batched_chunks = chunks_to_delete[i:i+PG_BATCH_SIZE] @@ -463,6 +507,14 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None if session_ is None: session.close() + undeleted_source_ids = set(source_ids) - set(deleted_source_ids) + if len(undeleted_source_ids) > 0: + logger.info( + f'Source ids {undeleted_source_ids} were not deleted from documents store.' + ' This can be due to the source ids not existing in the documents store due to' + ' already being deleted or not having been added yet.' + ) + def delete_provider(self, provider_key: str): with self.session_maker() as session: try: @@ -506,7 +558,16 @@ def delete_user(self, user_id: str): session.rollback() raise DbException('Error: deleting user from access list') from e - self._cleanup_if_orphaned(list(source_ids), session) + try: + self._cleanup_if_orphaned(list(source_ids), session) + except Exception as e: + session.rollback() + logger.error( + 'Error cleaning up orphaned source ids after deleting user, manual cleanup might be required', + exc_info=e, + extra={ 'source_ids': list(source_ids) }, + ) + raise DbException('Error: cleaning up orphaned source ids after deleting user') from e def count_documents_by_provider(self) -> dict[str, int]: try: diff --git a/context_chat_backend/vectordb/service.py b/context_chat_backend/vectordb/service.py index 620a0b39..06a8e19e 100644 --- a/context_chat_backend/vectordb/service.py +++ b/context_chat_backend/vectordb/service.py @@ -6,27 +6,42 @@ from ..dyn_loader import VectorDBLoader from .base import BaseVectorDB -from .types import DbException, UpdateAccessOp +from .types import UpdateAccessOp logger = logging.getLogger('ccb.vectordb') -# todo: return source ids that were successfully deleted + def delete_by_source(vectordb_loader: VectorDBLoader, source_ids: list[str]): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('deleting sources by id', extra={ 'source_ids': source_ids }) - try: - db.delete_source_ids(source_ids) - except Exception as e: - raise DbException('Error: Vectordb delete_source_ids error') from e + db.delete_source_ids(source_ids) def delete_by_provider(vectordb_loader: VectorDBLoader, provider_key: str): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug(f'deleting sources by provider: {provider_key}') db.delete_provider(provider_key) def delete_user(vectordb_loader: VectorDBLoader, user_id: str): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug(f'deleting user from db: {user_id}') db.delete_user(user_id) @@ -38,6 +53,13 @@ def update_access( user_ids: list[str], source_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('updating access', extra={ 'op': op, 'user_ids': user_ids, 'source_id': source_id }) db.update_access(op, user_ids, source_id) @@ -49,6 +71,13 @@ def update_access_provider( user_ids: list[str], provider_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('updating access by provider', extra={ 'op': op, 'user_ids': user_ids, 'provider_id': provider_id }) db.update_access_provider(op, user_ids, provider_id) @@ -59,11 +88,24 @@ def decl_update_access( user_ids: list[str], source_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('decl update access', extra={ 'user_ids': user_ids, 'source_id': source_id }) db.decl_update_access(user_ids, source_id) def count_documents_by_provider(vectordb_loader: VectorDBLoader): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('counting documents by provider') return db.count_documents_by_provider() diff --git a/context_chat_backend/vectordb/types.py b/context_chat_backend/vectordb/types.py index df5c6dd7..30811797 100644 --- a/context_chat_backend/vectordb/types.py +++ b/context_chat_backend/vectordb/types.py @@ -14,5 +14,5 @@ class SafeDbException(Exception): class UpdateAccessOp(Enum): - allow = 'allow' - deny = 'deny' + ALLOW = 'allow' + DENY = 'deny' diff --git a/main.py b/main.py index c4ffa1fd..076b7db0 100755 --- a/main.py +++ b/main.py @@ -3,9 +3,12 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # + import logging -from os import getenv +import multiprocessing as mp +from os import cpu_count, getenv +import psutil import uvicorn from nc_py_api.ex_app import run_app @@ -48,6 +51,18 @@ def _setup_log_levels(debug: bool): app_config: TConfig = app.extra['CONFIG'] _setup_log_levels(app_config.debug) + # do forks from a clean process that doesn't have any threads or locks + mp.set_start_method('forkserver') + mp.set_forkserver_preload([ + 'context_chat_backend.chain.ingest.injest', + 'context_chat_backend.vectordb.pgvector', + 'langchain', + 'logging', + 'numpy', + 'sqlalchemy', + ]) + + print(f'CPU count: {cpu_count()}, Memory: {psutil.virtual_memory()}') print('App config:\n' + redact_config(app_config).model_dump_json(indent=2), flush=True) uv_log_config = uvicorn.config.LOGGING_CONFIG # pyright: ignore[reportAttributeAccessIssue]