diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index 10e2d61b..d30073ab 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -89,7 +89,7 @@ jobs:
POSTGRES_USER: root
POSTGRES_PASSWORD: rootpassword
POSTGRES_DB: nextcloud
- options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5
+ options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 --name postgres --hostname postgres
steps:
- name: Checkout server
@@ -113,6 +113,8 @@ jobs:
repository: nextcloud/context_chat
path: apps/context_chat
persist-credentials: false
+ # todo: remove later
+ ref: feat/reverse-content-flow
- name: Checkout backend
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@@ -167,6 +169,10 @@ jobs:
cd ..
rm -rf documentation
+ - name: Run files scan
+ run: |
+ ./occ files:scan --all
+
- name: Setup python 3.11
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5
with:
@@ -195,28 +201,91 @@ jobs:
timeout 10 ./occ app_api:daemon:register --net host manual_install "Manual Install" manual-install http localhost http://localhost:8080
timeout 120 ./occ app_api:app:register context_chat_backend manual_install --json-info "{\"appid\":\"context_chat_backend\",\"name\":\"Context Chat Backend\",\"daemon_config_name\":\"manual_install\",\"version\":\"${{ fromJson(steps.appinfo.outputs.result).version }}\",\"secret\":\"12345\",\"port\":10034,\"scopes\":[],\"system_app\":0}" --force-scopes --wait-finish
ls -la context_chat_backend/persistent_storage/*
- sleep 30 # Wait for the em server to get ready
- - name: Scan files, baseline
- run: |
- ./occ files:scan admin
- ./occ context_chat:scan admin -m text/plain
-
- - name: Check python memory usage
+ - name: Initial memory usage check
run: |
ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem
ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt
- - name: Scan files
+ - name: Run cron jobs
run: |
- ./occ files:scan admin
- ./occ context_chat:scan admin -m text/markdown &
- ./occ context_chat:scan admin -m text/x-rst
+ # every 10 seconds indefinitely
+ while true; do
+ php cron.php
+ sleep 10
+ done &
+ sleep 30
+ # list all the bg jobs
+ ./occ background-job:list
+
+ - name: Initial dump of DB with context_chat_queue populated
+ run: |
+ docker exec postgres pg_dump nextcloud > /tmp/0_pgdump_nextcloud
- - name: Check python memory usage
+ - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files
run: |
- ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem
- ps -p $(cat pid.txt) -o %mem --no-headers > after_scan_mem.txt
+ success=0
+ echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files"
+ for i in {1..90}; do
+ echo "Checking stats, attempt $i..."
+
+ stats_err=$(mktemp)
+ stats=$(timeout 5 ./occ context_chat:stats --json 2>"$stats_err")
+ stats_exit=$?
+ echo "Stats output:"
+ echo "$stats"
+ if [ -s "$stats_err" ]; then
+ echo "Stderr:"
+ cat "$stats_err"
+ fi
+ echo "---"
+ rm -f "$stats_err"
+
+ # Check for critical errors in output
+ if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then
+ echo "Backend connection error detected (exit=$stats_exit), retrying..."
+ sleep 10
+ continue
+ fi
+
+ # Extract total eligible files
+ total_eligible_files=$(echo "$stats" | jq '.eligible_files_count' || echo "")
+
+ # Extract indexed documents count (files__default)
+ indexed_count=$(echo "$stats" | jq '.vectordb_document_counts.files__default' || echo "")
+
+ echo "Total eligible files: $total_eligible_files"
+ echo "Indexed documents (files__default): $indexed_count"
+
+ diff=$((total_eligible_files - indexed_count))
+ threshold=$((total_eligible_files * 3 / 100))
+
+ # Check if difference is within tolerance
+ if [ $diff -le $threshold ]; then
+ echo "Indexing within 3% tolerance (diff=$diff, threshold=$threshold)"
+ success=1
+ break
+ else
+ progress=$((diff * 100 / total_eligible_files))
+ echo "Outside 3% tolerance: diff=$diff (${progress}%), threshold=$threshold"
+ fi
+
+ # Check if backend is still alive
+ ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0")
+ if [ "$ccb_alive" -eq 0 ]; then
+ echo "Error: Context Chat Backend process is not running. Exiting."
+ exit 1
+ fi
+
+ sleep 10
+ done
+
+ echo "::endgroup::"
+
+ if [ $success -ne 1 ]; then
+ echo "Max attempts reached"
+ exit 1
+ fi
- name: Run the prompts
run: |
@@ -250,18 +319,9 @@ jobs:
echo "Memory usage during scan is stable. No memory leak detected."
fi
- - name: Compare memory usage and detect leak
+ - name: Final dump of DB with vectordb populated
run: |
- initial_mem=$(cat after_scan_mem.txt | tr -d ' ')
- final_mem=$(cat after_prompt_mem.txt | tr -d ' ')
- echo "Initial Memory Usage: $initial_mem%"
- echo "Memory Usage after prompt: $final_mem%"
-
- if (( $(echo "$final_mem > $initial_mem" | bc -l) )); then
- echo "Memory usage has increased during prompt. Possible memory leak detected!"
- else
- echo "Memory usage during prompt is stable. No memory leak detected."
- fi
+ docker exec postgres pg_dump nextcloud > /tmp/1_pgdump_nextcloud
- name: Show server logs
if: always()
@@ -298,6 +358,19 @@ jobs:
run: |
tail -v -n +1 context_chat_backend/persistent_storage/logs/em_server.log* || echo "No logs in logs directory"
+ - name: Upload database dumps
+ uses: actions/upload-artifact@v4
+ with:
+ name: database-dumps-${{ matrix.server-versions }}-php@${{ matrix.php-versions }}
+ path: |
+ /tmp/0_pgdump_nextcloud
+ /tmp/1_pgdump_nextcloud
+
+ - name: Final stats log
+ run: |
+ ./occ context_chat:stats
+ ./occ context_chat:stats --json
+
summary:
permissions:
contents: none
diff --git a/appinfo/info.xml b/appinfo/info.xml
index 9760cd29..30194baa 100644
--- a/appinfo/info.xml
+++ b/appinfo/info.xml
@@ -82,5 +82,19 @@ Setup background job workers as described here: https://docs.nextcloud.com/serve
Password to be used for authenticating requests to the OpenAI-compatible endpoint set in CC_EM_BASE_URL.
+
+
+ rp
+ Request Processing Mode
+ APP_ROLE=rp
+ true
+
+
+ indexing
+ Indexing Mode
+ APP_ROLE=indexing
+ false
+
+
diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py
index efb81b6d..832c8331 100644
--- a/context_chat_backend/chain/ingest/doc_loader.py
+++ b/context_chat_backend/chain/ingest/doc_loader.py
@@ -3,15 +3,13 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-import logging
import re
import tempfile
from collections.abc import Callable
-from typing import BinaryIO
+from io import BytesIO
import docx2txt
from epub2txt import epub2txt
-from fastapi import UploadFile
from langchain_unstructured import UnstructuredLoader
from odfdo import Document
from pandas import read_csv, read_excel
@@ -19,9 +17,10 @@
from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError
from striprtf import striprtf
-logger = logging.getLogger('ccb.doc_loader')
+from ...types import IndexingException, SourceItem
-def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str:
+
+def _temp_file_wrapper(file: BytesIO, loader: Callable, sep: str = '\n') -> str:
raw_bytes = file.read()
with tempfile.NamedTemporaryFile(mode='wb') as tmp:
tmp.write(raw_bytes)
@@ -35,49 +34,49 @@ def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str
# -- LOADERS -- #
-def _load_pdf(file: BinaryIO) -> str:
+def _load_pdf(file: BytesIO) -> str:
pdf_reader = PdfReader(file)
return '\n\n'.join([page.extract_text().strip() for page in pdf_reader.pages])
-def _load_csv(file: BinaryIO) -> str:
+def _load_csv(file: BytesIO) -> str:
return read_csv(file).to_string(header=False, na_rep='')
-def _load_epub(file: BinaryIO) -> str:
+def _load_epub(file: BytesIO) -> str:
return _temp_file_wrapper(file, epub2txt).strip()
-def _load_docx(file: BinaryIO) -> str:
+def _load_docx(file: BytesIO) -> str:
return docx2txt.process(file).strip()
-def _load_odt(file: BinaryIO) -> str:
+def _load_odt(file: BytesIO) -> str:
return _temp_file_wrapper(file, lambda fp: Document(fp).get_formatted_text()).strip()
-def _load_ppt_x(file: BinaryIO) -> str:
+def _load_ppt_x(file: BytesIO) -> str:
return _temp_file_wrapper(file, lambda fp: UnstructuredLoader(fp).load()).strip()
-def _load_rtf(file: BinaryIO) -> str:
+def _load_rtf(file: BytesIO) -> str:
return striprtf.rtf_to_text(file.read().decode('utf-8', 'ignore')).strip()
-def _load_xml(file: BinaryIO) -> str:
+def _load_xml(file: BytesIO) -> str:
data = file.read().decode('utf-8', 'ignore')
data = re.sub(r'', '', data)
return data.strip()
-def _load_xlsx(file: BinaryIO) -> str:
+def _load_xlsx(file: BytesIO) -> str:
return read_excel(file, na_filter=False).to_string(header=False, na_rep='')
-def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None:
+def _load_email(file: BytesIO, ext: str = 'eml') -> str:
# NOTE: msg format is not tested
if ext not in ['eml', 'msg']:
- return None
+ raise IndexingException(f'Unsupported email format: {ext}')
# TODO: implement attachment partitioner using unstructured.partition.partition_{email,msg}
# since langchain does not pass through the attachment_partitioner kwarg
@@ -115,30 +114,36 @@ def attachment_partitioner(
}
-def decode_source(source: UploadFile) -> str | None:
+def decode_source(source: SourceItem) -> str:
+ '''
+ Raises
+ ------
+ IndexingException
+ '''
+
+ io_obj: BytesIO | None = None
try:
# .pot files are powerpoint templates but also plain text files,
# so we skip them to prevent decoding errors
- if source.headers['title'].endswith('.pot'):
- return None
-
- mimetype = source.headers['type']
- if mimetype is None:
- return None
-
- if _loader_map.get(mimetype):
- result = _loader_map[mimetype](source.file)
- source.file.close()
- return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore')
-
- result = source.file.read().decode('utf-8', 'ignore')
- source.file.close()
- return result
- except PdfFileNotDecryptedError:
- logger.warning(f'PDF file ({source.filename}) is encrypted and cannot be read')
- return None
- except Exception:
- logger.exception(f'Error decoding source file ({source.filename})', stack_info=True)
- return None
+ if source.title.endswith('.pot'):
+ raise IndexingException('PowerPoint template files (.pot) are not supported')
+
+ if isinstance(source.content, str):
+ io_obj = BytesIO(source.content.encode('utf-8', 'ignore'))
+ else:
+ io_obj = source.content
+
+ if _loader_map.get(source.type):
+ result = _loader_map[source.type](io_obj)
+ return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore').strip()
+
+ return io_obj.read().decode('utf-8', 'ignore').strip()
+ except IndexingException:
+ raise
+ except PdfFileNotDecryptedError as e:
+ raise IndexingException('PDF file is encrypted and cannot be read') from e
+ except Exception as e:
+ raise IndexingException(f'Error decoding source file: {e}') from e
finally:
- source.file.close() # Ensure file is closed after processing
+ if io_obj is not None:
+ io_obj.close()
diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py
index 5871ebb8..190eebd4 100644
--- a/context_chat_backend/chain/ingest/injest.py
+++ b/context_chat_backend/chain/ingest/injest.py
@@ -2,65 +2,241 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import asyncio
import logging
import re
+from collections.abc import Mapping
+from io import BytesIO
+from time import perf_counter_ns
-from fastapi.datastructures import UploadFile
+import niquests
from langchain.schema import Document
+from nc_py_api import AsyncNextcloudApp
from ...dyn_loader import VectorDBLoader
-from ...types import TConfig
-from ...utils import is_valid_source_id, to_int
+from ...types import IndexingError, IndexingException, ReceivedFileItem, SourceItem, TConfig
from ...vectordb.base import BaseVectorDB
from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp
from ..types import InDocument
from .doc_loader import decode_source
from .doc_splitter import get_splitter_for
-from .mimetype_list import SUPPORTED_MIMETYPES
logger = logging.getLogger('ccb.injest')
-def _allowed_file(file: UploadFile) -> bool:
- return file.headers['type'] in SUPPORTED_MIMETYPES
+# max concurrent fetches to avoid overloading the NC server or hitting rate limits
+CONCURRENT_FILE_FETCHES = 10 # todo: config?
+MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, all loaded in RAM at once, todo: config?
+
+
+async def __fetch_file_content(
+ semaphore: asyncio.Semaphore,
+ file_id: int,
+ user_id: str,
+ _rlimit = 3,
+) -> BytesIO:
+ '''
+ Raises
+ ------
+ IndexingException
+ '''
+
+ async with semaphore:
+ nc = AsyncNextcloudApp()
+ try:
+ # a file pointer for storing the stream in memory until it is consumed
+ fp = BytesIO()
+ await nc._session.download2fp(
+ url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}',
+ fp=fp,
+ dav=False,
+ params={ 'userId': user_id },
+ )
+ fp.seek(0)
+ return fp
+ except niquests.exceptions.RequestException as e:
+ if e.response is None:
+ raise
+
+ if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue]
+ # todo: implement rate limits in php CC?
+ wait_for = int(e.response.headers.get('Retry-After', '30'))
+ if _rlimit <= 0:
+ raise IndexingException(
+ f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
+ ' max retries exceeded',
+ retryable=True,
+ ) from e
+ logger.warning(
+ f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
+ f' waiting {wait_for} before retrying',
+ exc_info=e,
+ )
+ await asyncio.sleep(wait_for)
+ return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1)
+
+ raise
+ except IndexingException:
+ raise
+ except Exception as e:
+ logger.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e)
+ raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e
+
+
+async def __fetch_files_content(
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> tuple[Mapping[int, SourceItem], Mapping[int, IndexingError]]:
+ source_items = {}
+ error_items = {}
+ semaphore = asyncio.Semaphore(CONCURRENT_FILE_FETCHES)
+ tasks = []
+ task_sources = {}
+
+ file_count = sum(1 for s in sources.values() if isinstance(s, ReceivedFileItem))
+ logger.debug('Fetching content for %d file(s) (max %d concurrent)', file_count, CONCURRENT_FILE_FETCHES)
+
+ for db_id, file in sources.items():
+ if isinstance(file, SourceItem):
+ continue
+
+ try:
+ # to detect any validation errors but it should not happen since file.reference is validated
+ file.file_id # noqa: B018
+ except ValueError as e:
+ logger.error(
+ f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}',
+ exc_info=e,
+ )
+ error_items[db_id] = IndexingError(
+ error=f'Invalid file reference format: {file.reference}',
+ retryable=False,
+ )
+ continue
+
+ if file.size > MAX_FILE_SIZE:
+ logger.info(
+ f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size'
+ f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB',
+ )
+ error_items[db_id] = IndexingError(
+ error=(
+ f'File size {(file.size/(1024*1024)):.2f} MiB'
+ f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB'
+ ),
+ retryable=False,
+ )
+ continue
+ # any user id from the list should have read access to the file
+ tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0])))
+ task_sources[db_id] = file
+
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ for (db_id, file), result in zip(task_sources.items(), results, strict=True):
+ if isinstance(result, str) or isinstance(result, BytesIO):
+ source_items[db_id] = SourceItem(
+ **{
+ **file.model_dump(),
+ 'content': result,
+ }
+ )
+ elif isinstance(result, IndexingException):
+ logger.error(
+ f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
+ f': {result}',
+ exc_info=result,
+ )
+ error_items[db_id] = IndexingError(
+ error=str(result),
+ retryable=result.retryable,
+ )
+ elif isinstance(result, BaseException):
+ logger.error(
+ f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},'
+ f' reference {file.reference}: {result}',
+ exc_info=result,
+ )
+ error_items[db_id] = IndexingError(
+ error=f'Unexpected error: {result}',
+ retryable=True,
+ )
+ else:
+ logger.error(
+ f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
+ f': {result}',
+ exc_info=True,
+ )
+ error_items[db_id] = IndexingError(
+ error='Unknown error',
+ retryable=True,
+ )
+
+ # add the content providers from the orginal "sources" to the result unprocessed
+ for db_id, source in sources.items():
+ if isinstance(source, SourceItem):
+ source_items[db_id] = source
+
+ return source_items, error_items
def _filter_sources(
vectordb: BaseVectorDB,
- sources: list[UploadFile]
-) -> tuple[list[UploadFile], list[UploadFile]]:
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]:
'''
Returns
-------
- tuple[list[str], list[UploadFile]]
+ tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]:
First value is a list of sources that already exist in the vectordb.
Second value is a list of sources that are new and should be embedded.
'''
try:
- existing_sources, new_sources = vectordb.check_sources(sources)
+ existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources)
except Exception as e:
- raise DbException('Error: Vectordb sources_to_embed error') from e
+ raise DbException('Error: Vectordb error while checking existing sources in indexing') from e
+
+ existing_sources = {}
+ to_embed_sources = {}
+
+ for db_id, source in sources.items():
+ if source.reference in existing_source_ids:
+ existing_sources[db_id] = source
+ elif source.reference in to_embed_source_ids:
+ to_embed_sources[db_id] = source
- return ([
- source for source in sources
- if source.filename in existing_sources
- ], [
- source for source in sources
- if source.filename in new_sources
- ])
+ return existing_sources, to_embed_sources
-def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]:
- indocuments = []
+def _sources_to_indocuments(
+ config: TConfig,
+ sources: Mapping[int, SourceItem]
+) -> tuple[Mapping[int, InDocument], Mapping[int, IndexingError]]:
+ indocuments = {}
+ errored_docs = {}
- for source in sources:
- logger.debug('processing source', extra={ 'source_id': source.filename })
+ for db_id, source in sources.items():
+ logger.debug('processing source', extra={ 'source_id': source.reference })
# transform the source to have text data
- content = decode_source(source)
+ try:
+ logger.debug('Decoding source %s (type: %s)', source.reference, source.type)
+ t0 = perf_counter_ns()
+ content = decode_source(source)
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.debug('Decoded source %s in %.2f ms (%d chars)', source.reference, elapsed_ms, len(content))
+ except IndexingException as e:
+ logger.error(f'Error decoding source ({source.reference}): {e}', exc_info=e)
+ errored_docs[db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
+ continue
- if content is None or (content := content.strip()) == '':
- logger.debug('decoded empty source', extra={ 'source_id': source.filename })
+ if content == '':
+ logger.debug('decoded empty source', extra={ 'source_id': source.reference })
+ errored_docs[db_id] = IndexingError(
+ error='Decoded content is empty',
+ retryable=False,
+ )
continue
# replace more than two newlines with two newlines (also blank spaces, more than 4)
@@ -68,97 +244,149 @@ def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[
# NOTE: do not use this with all docs when programming files are added
content = re.sub(r'(\s){5,}', r'\g<1>', content)
# filter out null bytes
- content = content.replace('\0', '')
-
- if content is None or content == '':
- logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.filename })
+ content = content.replace('\0', '').strip()
+
+ if content == '':
+ logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.reference })
+ errored_docs[db_id] = IndexingError(
+ error='Cleaned up content is empty',
+ retryable=False,
+ )
continue
- logger.debug('decoded non empty source', extra={ 'source_id': source.filename })
+ logger.debug('decoded non empty source', extra={ 'source_id': source.reference })
metadata = {
- 'source': source.filename,
- 'title': _decode_latin_1(source.headers['title']),
- 'type': source.headers['type'],
+ 'source': source.reference,
+ 'title': _decode_latin_1(source.title),
+ 'type': source.type,
}
doc = Document(page_content=content, metadata=metadata)
- splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type'])
+ splitter = get_splitter_for(config.embedding_chunk_size, source.type)
split_docs = splitter.split_documents([doc])
logger.debug('split document into chunks', extra={
- 'source_id': source.filename,
+ 'source_id': source.reference,
'len(split_docs)': len(split_docs),
})
- indocuments.append(InDocument(
+ indocuments[db_id] = InDocument(
documents=split_docs,
- userIds=list(map(_decode_latin_1, source.headers['userIds'].split(','))),
- source_id=source.filename, # pyright: ignore[reportArgumentType]
- provider=source.headers['provider'],
- modified=to_int(source.headers['modified']),
- ))
+ userIds=list(map(_decode_latin_1, source.userIds)),
+ source_id=source.reference,
+ provider=source.provider,
+ modified=source.modified, # pyright: ignore[reportArgumentType]
+ )
- return indocuments
+ return indocuments, errored_docs
+
+
+def _increase_access_for_existing_sources(
+ vectordb: BaseVectorDB,
+ existing_sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> Mapping[int, IndexingError | None]:
+ '''
+ update userIds for existing sources
+ allow the userIds as additional users, not as the only users
+ '''
+ if len(existing_sources) == 0:
+ return {}
+
+ results = {}
+ logger.debug('Increasing access for existing sources', extra={
+ 'source_ids': [source.reference for source in existing_sources.values()]
+ })
+ for db_id, source in existing_sources.items():
+ try:
+ vectordb.update_access(
+ UpdateAccessOp.ALLOW,
+ list(map(_decode_latin_1, source.userIds)),
+ source.reference,
+ )
+ results[db_id] = None
+ except SafeDbException as e:
+ logger.error(f'Failed to update access for source ({source.reference}): {e.args[0]}')
+ results[db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
+ continue
+ except Exception as e:
+ logger.error(f'Unexpected error while updating access for source ({source.reference}): {e}')
+ results[db_id] = IndexingError(
+ error='Unexpected error while updating access',
+ retryable=True,
+ )
+ continue
+ return results
def _process_sources(
vectordb: BaseVectorDB,
config: TConfig,
- sources: list[UploadFile],
-) -> tuple[list[str],list[str]]:
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> Mapping[int, IndexingError | None]:
'''
Processes the sources and adds them to the vectordb.
Returns the list of source ids that were successfully added and those that need to be retried.
'''
- existing_sources, filtered_sources = _filter_sources(vectordb, sources)
+ existing_sources, to_embed_sources = _filter_sources(vectordb, sources)
logger.debug('db filter source results', extra={
'len(existing_sources)': len(existing_sources),
'existing_sources': existing_sources,
- 'len(filtered_sources)': len(filtered_sources),
- 'filtered_sources': filtered_sources,
+ 'len(to_embed_sources)': len(to_embed_sources),
+ 'to_embed_sources': to_embed_sources,
})
- loaded_source_ids = [source.filename for source in existing_sources]
- # update userIds for existing sources
- # allow the userIds as additional users, not as the only users
- if len(existing_sources) > 0:
- logger.debug('Increasing access for existing sources', extra={
- 'source_ids': [source.filename for source in existing_sources]
- })
- for source in existing_sources:
- try:
- vectordb.update_access(
- UpdateAccessOp.allow,
- list(map(_decode_latin_1, source.headers['userIds'].split(','))),
- source.filename, # pyright: ignore[reportArgumentType]
- )
- except SafeDbException as e:
- logger.error(f'Failed to update access for source ({source.filename}): {e.args[0]}')
- continue
-
- if len(filtered_sources) == 0:
+ source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources)
+
+ logger.debug(
+ 'Fetching file contents for %d source(s) from Nextcloud',
+ len(to_embed_sources),
+ )
+ t0 = perf_counter_ns()
+ populated_to_embed_sources, errored_sources = asyncio.run(__fetch_files_content(to_embed_sources))
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.debug(
+ 'File content fetch complete in %.2f ms: %d fetched, %d errored',
+ elapsed_ms, len(populated_to_embed_sources), len(errored_sources),
+ )
+ source_proc_results.update(errored_sources) # pyright: ignore[reportAttributeAccessIssue]
+
+ if len(populated_to_embed_sources) == 0:
# no new sources to embed
logger.debug('Filtered all sources, nothing to embed')
- return loaded_source_ids, [] # pyright: ignore[reportReturnType]
+ return source_proc_results
logger.debug('Filtered sources:', extra={
- 'source_ids': [source.filename for source in filtered_sources]
+ 'source_ids': [source.reference for source in populated_to_embed_sources.values()]
})
# invalid/empty sources are filtered out here and not counted in loaded/retryable
- indocuments = _sources_to_indocuments(config, filtered_sources)
+ indocuments, errored_docs = _sources_to_indocuments(config, populated_to_embed_sources)
- logger.debug('Converted all sources to documents')
+ source_proc_results.update(errored_docs) # pyright: ignore[reportAttributeAccessIssue]
+ logger.debug('Converted sources to documents')
if len(indocuments) == 0:
# filtered document(s) were invalid/empty, not an error
logger.debug('All documents were found empty after being processed')
- return loaded_source_ids, [] # pyright: ignore[reportReturnType]
+ return source_proc_results
+
+ logger.debug('Adding documents to vectordb', extra={
+ 'source_ids': [indoc.source_id for indoc in indocuments.values()]
+ })
- added_source_ids, retry_source_ids = vectordb.add_indocuments(indocuments)
- loaded_source_ids.extend(added_source_ids)
+ t0 = perf_counter_ns()
+ doc_add_results = vectordb.add_indocuments(indocuments)
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.info(
+ 'vectordb.add_indocuments completed in %.2f ms for %d document(s)',
+ elapsed_ms, len(indocuments),
+ )
+ source_proc_results.update(doc_add_results) # pyright: ignore[reportAttributeAccessIssue]
logger.debug('Added documents to vectordb')
- return loaded_source_ids, retry_source_ids # pyright: ignore[reportReturnType]
+ return source_proc_results
def _decode_latin_1(s: str) -> str:
@@ -172,31 +400,15 @@ def _decode_latin_1(s: str) -> str:
def embed_sources(
vectordb_loader: VectorDBLoader,
config: TConfig,
- sources: list[UploadFile],
-) -> tuple[list[str],list[str]]:
- # either not a file or a file that is allowed
- sources_filtered = [
- source for source in sources
- if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
- or _allowed_file(source)
- ]
-
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> Mapping[int, IndexingError | None]:
logger.debug('Embedding sources:', extra={
'source_ids': [
- f'{source.filename} ({_decode_latin_1(source.headers["title"])})'
- for source in sources_filtered
- ],
- 'invalid_source_ids': [
- source.filename for source in sources
- if not is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
- ],
- 'not_allowed_file_ids': [
- source.filename for source in sources
- if not _allowed_file(source)
+ f'{source.reference} ({_decode_latin_1(source.title)})'
+ for source in sources.values()
],
- 'len(source_ids)': len(sources_filtered),
- 'len(total_source_ids)': len(sources),
+ 'len(source_ids)': len(sources),
})
vectordb = vectordb_loader.load()
- return _process_sources(vectordb, config, sources_filtered)
+ return _process_sources(vectordb, config, sources)
diff --git a/context_chat_backend/chain/one_shot.py b/context_chat_backend/chain/one_shot.py
index 1c0521bf..c79f272e 100644
--- a/context_chat_backend/chain/one_shot.py
+++ b/context_chat_backend/chain/one_shot.py
@@ -10,7 +10,7 @@
from ..types import TConfig
from .context import get_context_chunks, get_context_docs
from .query_proc import get_pruned_query
-from .types import ContextException, LLMOutput, ScopeType
+from .types import ContextException, LLMOutput, ScopeType, SearchResult
_LLM_TEMPLATE = '''Answer based only on this context and do not add any imaginative details. Make sure to use the same language as the question in your answer.
{context}
@@ -20,6 +20,7 @@
logger = logging.getLogger('ccb.chain')
+# todo: remove this maybe
def process_query(
user_id: str,
llm: LLM,
@@ -78,6 +79,9 @@ def process_context_query(
stop=[end_separator],
userid=user_id,
).strip()
- unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))})
+ unique_sources = [SearchResult(
+ source_id=source,
+ title=d.metadata.get('title', ''),
+ ) for d in context_docs if (source := d.metadata.get('source'))]
return LLMOutput(output=output, sources=unique_sources)
diff --git a/context_chat_backend/chain/types.py b/context_chat_backend/chain/types.py
index b006ad1a..3afdf297 100644
--- a/context_chat_backend/chain/types.py
+++ b/context_chat_backend/chain/types.py
@@ -33,12 +33,24 @@ class ContextException(Exception):
...
+class SearchResult(TypedDict):
+ source_id: str
+ title: str
+
+
class LLMOutput(TypedDict):
output: str
- sources: list[str]
- # todo: add "titles" field
+ sources: list[SearchResult]
-class SearchResult(TypedDict):
- source_id: str
- title: str
+class EnrichedSource(BaseModel):
+ id: str
+ label: str
+ icon: str
+ url: str
+
+class EnrichedSourceList(BaseModel):
+ sources: list[EnrichedSource]
+
+class ScopeList(BaseModel):
+ source_ids: list[str]
diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py
index c26b930a..9c3812e9 100644
--- a/context_chat_backend/controller.py
+++ b/context_chat_backend/controller.py
@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider
# isort: off
-from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult
+from .chain.types import ContextException
from .types import LoaderException, EmbeddingException
-from .vectordb.types import DbException, SafeDbException, UpdateAccessOp
+from .vectordb.types import DbException, SafeDbException
from .setup_functions import ensure_config_file, repair_run, setup_env_vars
# setup env vars before importing other modules
@@ -23,39 +24,29 @@
from collections.abc import Callable
from contextlib import asynccontextmanager
from functools import wraps
-from threading import Event, Thread
-from time import sleep
-from typing import Annotated, Any
-from fastapi import Body, FastAPI, Request, UploadFile
-from langchain.llms.base import LLM
+from fastapi import FastAPI, Request
from nc_py_api import AsyncNextcloudApp, NextcloudApp
from nc_py_api.ex_app import persistent_storage, set_handlers
-from pydantic import BaseModel, ValidationInfo, field_validator
from starlette.responses import FileResponse
-from .chain.context import do_doc_search
-from .chain.ingest.injest import embed_sources
-from .chain.one_shot import process_context_query, process_query
from .config_parser import get_config
-from .dyn_loader import LLMModelLoader, VectorDBLoader
+from .dyn_loader import VectorDBLoader
from .models.types import LlmException
from nc_py_api.ex_app import AppAPIAuthMiddleware
-from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of
-from .vectordb.service import (
- count_documents_by_provider,
- decl_update_access,
- delete_by_provider,
- delete_by_source,
- delete_user,
- update_access,
-)
+from .utils import JSONResponse, exec_in_proc
+from .task_fetcher import start_bg_threads, trigger_handler, wait_for_bg_threads
+from .vectordb.service import count_documents_by_provider
# setup
-repair_run()
-ensure_config_file()
+# only run once
+if mp.current_process().name == 'MainProcess':
+ repair_run()
+ ensure_config_file()
+
logger = logging.getLogger('ccb.controller')
+app_config = get_config(os.environ['CC_CONFIG_PATH'])
__download_models_from_hf = os.environ.get('CC_DOWNLOAD_MODELS_FROM_HF', 'true').lower() in ('1', 'true', 'yes')
models_to_fetch = {
@@ -70,13 +61,33 @@
'revision': '607a30d783dfa663caf39e06633721c8d4cfcd7e',
}
} if __download_models_from_hf else {}
-app_enabled = Event()
+app_enabled = threading.Event()
-def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str:
- if enabled:
- app_enabled.set()
- else:
- app_enabled.clear()
+def enabled_handler(enabled: bool, nc: NextcloudApp | AsyncNextcloudApp) -> str:
+ try:
+ if enabled:
+ provider = TaskProcessingProvider(
+ id="context_chat-context_chat_search",
+ name="Context Chat",
+ task_type="context_chat:context_chat_search",
+ expected_runtime=30,
+ )
+ nc.providers.task_processing.register(provider)
+ provider = TaskProcessingProvider(
+ id="context_chat-context_chat",
+ name="Context Chat",
+ task_type="context_chat:context_chat",
+ expected_runtime=30,
+ )
+ nc.providers.task_processing.register(provider)
+ app_enabled.set()
+ start_bg_threads(app_config, app_enabled)
+ else:
+ app_enabled.clear()
+ wait_for_bg_threads()
+ except Exception as e:
+ logger.exception('Error in enabled handler:', exc_info=e)
+ return f'Error in enabled handler: {e}'
logger.info(f'App {("disabled", "enabled")[enabled]}')
return ''
@@ -84,19 +95,17 @@ def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str:
@asynccontextmanager
async def lifespan(app: FastAPI):
- set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch)
+ set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch, trigger_handler=trigger_handler)
nc = NextcloudApp()
if nc.enabled_state:
app_enabled.set()
+ start_bg_threads(app_config, app_enabled)
logger.info(f'App enable state at startup: {app_enabled.is_set()}')
- t = Thread(target=background_thread_task, args=())
- t.start()
yield
vectordb_loader.offload()
- llm_loader.offload()
+ wait_for_bg_threads()
-app_config = get_config(os.environ['CC_CONFIG_PATH'])
app = FastAPI(debug=app_config.debug, lifespan=lifespan) # pyright: ignore[reportArgumentType]
app.extra['CONFIG'] = app_config
@@ -105,7 +114,6 @@ async def lifespan(app: FastAPI):
# loaders
vectordb_loader = VectorDBLoader(app_config)
-llm_loader = LLMModelLoader(app, app_config)
# locks and semaphores
@@ -117,22 +125,12 @@ async def lifespan(app: FastAPI):
index_lock = threading.Lock()
_indexing = {}
-# limit the number of concurrent document parsing
-doc_parse_semaphore = mp.Semaphore(app_config.doc_parser_worker_limit)
-
# middlewares
if not app_config.disable_aaa:
app.add_middleware(AppAPIAuthMiddleware)
-# logger background thread
-
-def background_thread_task():
- while(True):
- logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing})
- sleep(10)
-
# exception handlers
@app.exception_handler(DbException)
@@ -213,121 +211,6 @@ def _():
return JSONResponse(content={'enabled': app_enabled.is_set()}, status_code=200)
-@app.post('/updateAccessDeclarative')
-@enabled_guard(app)
-def _(
- userIds: Annotated[list[str], Body()],
- sourceId: Annotated[str, Body()],
-):
- logger.debug('Update access declarative request:', extra={
- 'user_ids': userIds,
- 'source_id': sourceId,
- })
-
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
-
- if not is_valid_source_id(sourceId):
- return JSONResponse('Invalid source id', 400)
-
- exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId))
-
- return JSONResponse('Access updated')
-
-
-@app.post('/updateAccess')
-@enabled_guard(app)
-def _(
- op: Annotated[UpdateAccessOp, Body()],
- userIds: Annotated[list[str], Body()],
- sourceId: Annotated[str, Body()],
-):
- logger.debug('Update access request', extra={
- 'op': op,
- 'user_ids': userIds,
- 'source_id': sourceId,
- })
-
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
-
- if not is_valid_source_id(sourceId):
- return JSONResponse('Invalid source id', 400)
-
- exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId))
-
- return JSONResponse('Access updated')
-
-
-@app.post('/updateAccessProvider')
-@enabled_guard(app)
-def _(
- op: Annotated[UpdateAccessOp, Body()],
- userIds: Annotated[list[str], Body()],
- providerId: Annotated[str, Body()],
-):
- logger.debug('Update access by provider request', extra={
- 'op': op,
- 'user_ids': userIds,
- 'provider_id': providerId,
- })
-
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
-
- if not is_valid_provider_id(providerId):
- return JSONResponse('Invalid provider id', 400)
-
- exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, providerId))
-
- return JSONResponse('Access updated')
-
-
-@app.post('/deleteSources')
-@enabled_guard(app)
-def _(sourceIds: Annotated[list[str], Body(embed=True)]):
- logger.debug('Delete sources request', extra={
- 'source_ids': sourceIds,
- })
-
- sourceIds = [source.strip() for source in sourceIds if source.strip() != '']
-
- if len(sourceIds) == 0:
- return JSONResponse('No sources provided', 400)
-
- res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds))
- if res is False:
- return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)
-
- return JSONResponse('All valid sources deleted')
-
-
-@app.post('/deleteProvider')
-@enabled_guard(app)
-def _(providerKey: str = Body(embed=True)):
- logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey })
-
- if value_of(providerKey) is None:
- return JSONResponse('Invalid provider key provided', 400)
-
- exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey))
-
- return JSONResponse('All valid sources deleted')
-
-
-@app.post('/deleteUser')
-@enabled_guard(app)
-def _(userId: str = Body(embed=True)):
- logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId })
-
- if value_of(userId) is None:
- return JSONResponse('Invalid userId provided', 400)
-
- exec_in_proc(target=delete_user, args=(vectordb_loader, userId))
-
- return JSONResponse('User deleted')
-
-
@app.post('/countIndexedDocuments')
@enabled_guard(app)
def _():
@@ -335,177 +218,6 @@ def _():
return JSONResponse(counts)
-@app.put('/loadSources')
-@enabled_guard(app)
-def _(sources: list[UploadFile]):
- global _indexing
-
- if len(sources) == 0:
- return JSONResponse('No sources provided', 400)
-
- filtered_sources = []
-
- for source in sources:
- if not value_of(source.filename):
- logger.warning('Skipping source with invalid source_id', extra={
- 'source_id': source.filename,
- 'title': source.headers.get('title'),
- })
- continue
-
- with index_lock:
- if source.filename in _indexing:
- # this request will be retried by the client
- return JSONResponse(
- f'This source ({source.filename}) is already being processed in another request, try again later',
- 503,
- headers={'cc-retry': 'true'},
- )
-
- if not (
- value_of(source.headers.get('userIds'))
- and source.headers.get('title', None) is not None
- and value_of(source.headers.get('type'))
- and value_of(source.headers.get('modified'))
- and source.headers['modified'].isdigit()
- and value_of(source.headers.get('provider'))
- ):
- logger.warning('Skipping source with invalid/missing headers', extra={
- 'source_id': source.filename,
- 'title': source.headers.get('title'),
- 'headers': source.headers,
- })
- continue
-
- filtered_sources.append(source)
-
- # wait for 10 minutes before failing the request
- semres = doc_parse_semaphore.acquire(block=True, timeout=10*60)
- if not semres:
- return JSONResponse(
- 'Document parser worker limit reached, try again in some time or consider increasing the limit',
- 503,
- headers={'cc-retry': 'true'}
- )
-
- with index_lock:
- for source in filtered_sources:
- _indexing[source.filename] = source.size
-
- try:
- loaded_sources, not_added_sources = exec_in_proc(
- target=embed_sources,
- args=(vectordb_loader, app.extra['CONFIG'], filtered_sources)
- )
- except (DbException, EmbeddingException):
- raise
- except Exception as e:
- raise DbException('Error: failed to load sources') from e
- finally:
- with index_lock:
- for source in filtered_sources:
- _indexing.pop(source.filename, None)
- doc_parse_semaphore.release()
-
- if len(loaded_sources) != len(filtered_sources):
- logger.debug('Some sources were not loaded', extra={
- 'Count of loaded sources': f'{len(loaded_sources)}/{len(filtered_sources)}',
- 'source_ids': loaded_sources,
- })
-
- # loaded sources include the existing sources that may only have their access updated
- return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources})
-
-
-class Query(BaseModel):
- userId: str
- query: str
- useContext: bool = True
- scopeType: ScopeType | None = None
- scopeList: list[str] | None = None
- ctxLimit: int = 20
-
- @field_validator('userId', 'query', 'ctxLimit')
- @classmethod
- def check_empty_values(cls, value: Any, info: ValidationInfo):
- if value_of(value) is None:
- raise ValueError('Empty value for field', info.field_name)
-
- return value
-
- @field_validator('ctxLimit')
- @classmethod
- def at_least_one_context(cls, value: int):
- if value < 1:
- raise ValueError('Invalid context chunk limit')
-
- return value
-
-
-def execute_query(query: Query, in_proc: bool = True) -> LLMOutput:
- llm: LLM = llm_loader.load()
- template = app.extra.get('LLM_TEMPLATE')
- no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
- # todo: array
- end_separator = app.extra.get('LLM_END_SEPARATOR', '')
-
- if query.useContext:
- target = process_context_query
- args=(
- query.userId,
- vectordb_loader,
- llm,
- app_config,
- query.query,
- query.ctxLimit,
- query.scopeType,
- query.scopeList,
- template,
- end_separator,
- )
- else:
- target=process_query
- args=(
- query.userId,
- llm,
- app_config,
- query.query,
- no_ctx_template,
- end_separator,
- )
-
- if in_proc:
- return exec_in_proc(target=target, args=args)
-
- return target(*args) # pyright: ignore
-
-
-@app.post('/query')
-@enabled_guard(app)
-def _(query: Query) -> LLMOutput:
- logger.debug('received query request', extra={ 'query': query.dict() })
-
- if app_config.llm[0] == 'nc_texttotext':
- return execute_query(query)
-
- with llm_lock:
- return execute_query(query, in_proc=False)
-
-
-@app.post('/docSearch')
-@enabled_guard(app)
-def _(query: Query) -> list[SearchResult]:
- # useContext from Query is not used here
- return exec_in_proc(target=do_doc_search, args=(
- query.userId,
- query.query,
- vectordb_loader,
- query.ctxLimit,
- query.scopeType,
- query.scopeList,
- ))
-
-
@app.get('/downloadLogs')
def download_logs() -> FileResponse:
with tempfile.NamedTemporaryFile('wb', delete=False) as tmp:
diff --git a/context_chat_backend/dyn_loader.py b/context_chat_backend/dyn_loader.py
index d67310ff..47b19575 100644
--- a/context_chat_backend/dyn_loader.py
+++ b/context_chat_backend/dyn_loader.py
@@ -7,11 +7,9 @@
import gc
import logging
from abc import ABC, abstractmethod
-from time import time
from typing import Any
import torch
-from fastapi import FastAPI
from langchain.llms.base import LLM
from .models.loader import init_model
@@ -54,19 +52,11 @@ def offload(self) -> None:
class LLMModelLoader(Loader):
- def __init__(self, app: FastAPI, config: TConfig) -> None:
+ def __init__(self, config: TConfig) -> None:
self.config = config
- self.app = app
def load(self) -> LLM:
- if self.app.extra.get('LLM_MODEL') is not None:
- self.app.extra['LLM_LAST_ACCESSED'] = time()
- return self.app.extra['LLM_MODEL']
-
llm_name, llm_config = self.config.llm
- self.app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '')
- self.app.extra['LLM_NO_CTX_TEMPLATE'] = llm_config.pop('no_ctx_template', '')
- self.app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '')
try:
model = init_model('llm', (llm_name, llm_config))
@@ -75,13 +65,9 @@ def load(self) -> LLM:
if not isinstance(model, LLM):
raise LoaderException(f'Error: {model} does not implement "llm" type or has returned an invalid object')
- self.app.extra['LLM_MODEL'] = model
- self.app.extra['LLM_LAST_ACCESSED'] = time()
return model
def offload(self) -> None:
- if self.app.extra.get('LLM_MODEL') is not None:
- del self.app.extra['LLM_MODEL']
clear_cache()
diff --git a/context_chat_backend/chain/ingest/mimetype_list.py b/context_chat_backend/mimetype_list.py
similarity index 100%
rename from context_chat_backend/chain/ingest/mimetype_list.py
rename to context_chat_backend/mimetype_list.py
diff --git a/context_chat_backend/network_em.py b/context_chat_backend/network_em.py
index 18bb11f4..ba1edc9e 100644
--- a/context_chat_backend/network_em.py
+++ b/context_chat_backend/network_em.py
@@ -3,12 +3,13 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
import logging
+import socket
from time import sleep
from typing import Literal, TypedDict
+from urllib.parse import urlparse
import niquests
from langchain_core.embeddings import Embeddings
-from pydantic import BaseModel
from .types import (
EmbeddingException,
@@ -20,6 +21,7 @@
)
logger = logging.getLogger('ccb.nextwork_em')
+TCP_CONNECT_TIMEOUT = 2.0 # seconds
# Copied from llama_cpp/llama_types.py
@@ -41,8 +43,35 @@ class CreateEmbeddingResponse(TypedDict):
usage: EmbeddingUsage
-class NetworkEmbeddings(Embeddings, BaseModel):
- app_config: TConfig
+class NetworkEmbeddings(Embeddings):
+ def __init__(self, app_config: TConfig):
+ self.app_config = app_config
+
+ def _get_host_and_port(self) -> tuple[str, int]:
+ parsed = urlparse(self.app_config.embedding.base_url)
+ host = parsed.hostname
+
+ if not host:
+ raise ValueError("Invalid URL: Missing hostname")
+
+ if parsed.port:
+ port = parsed.port
+ else:
+ port = 443 if parsed.scheme == "https" else 80
+
+ return host, port
+
+ def check_connection(self, check_origin: str) -> bool:
+ try:
+ host, port = self._get_host_and_port()
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(TCP_CONNECT_TIMEOUT)
+ sock.connect((host, port))
+ sock.close()
+ return True
+ except (ValueError, TimeoutError, ConnectionRefusedError, socket.gaierror) as e:
+ logger.warning(f'[{check_origin}] Embedding server is not reachable, retrying after some time: {e}')
+ return False
def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] | list[list[float]]:
emconf = self.app_config.embedding
@@ -79,6 +108,7 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
raise FatalEmbeddingException(response.text)
if response.status_code // 100 != 2:
raise EmbeddingException(response.text)
+ # todo: rework exception handling and their downstream interpretation
except FatalEmbeddingException as e:
logger.error('Fatal error while getting embeddings: %s', str(e), exc_info=e)
raise e
@@ -108,10 +138,14 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
logger.error('Unexpected error while getting embeddings', exc_info=e)
raise EmbeddingException('Error: unexpected error while getting embeddings') from e
- # converts TypedDict to a pydantic model
- resp = CreateEmbeddingResponse(**response.json())
- if isinstance(input_, str):
- return resp['data'][0]['embedding']
+ try:
+ # converts TypedDict to a pydantic model
+ resp = CreateEmbeddingResponse(**response.json())
+ if isinstance(input_, str):
+ return resp['data'][0]['embedding']
+ except Exception as e:
+ logger.error('Error parsing embedding response', exc_info=e)
+ raise EmbeddingException('Error: failed to parse embedding response') from e
# only one embedding in d['embedding'] since truncate is True
return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType]
diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py
new file mode 100644
index 00000000..be74b316
--- /dev/null
+++ b/context_chat_backend/task_fetcher.py
@@ -0,0 +1,744 @@
+#
+# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+#
+import logging
+import math
+import os
+from collections.abc import Mapping
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import suppress
+from enum import Enum
+from threading import Event, Thread
+from time import sleep
+from typing import Any
+
+import niquests
+from langchain.llms.base import LLM
+from nc_py_api import NextcloudApp, NextcloudException
+from niquests import JSONDecodeError, RequestException
+from pydantic import ValidationError
+
+from .chain.context import do_doc_search
+from .chain.ingest.injest import embed_sources
+from .chain.one_shot import process_context_query
+from .chain.types import ContextException, EnrichedSourceList, LLMOutput, ScopeList, SearchResult
+from .dyn_loader import LLMModelLoader, VectorDBLoader
+from .network_em import NetworkEmbeddings
+from .types import (
+ ActionsQueueItems,
+ ActionType,
+ AppRole,
+ FilesQueueItems,
+ IndexingError,
+ LoaderException,
+ ReceivedFileItem,
+ SourceItem,
+ TConfig,
+)
+from .utils import SubprocessKilledError, exec_in_proc, get_app_role
+from .vectordb.service import (
+ decl_update_access,
+ delete_by_provider,
+ delete_by_source,
+ delete_user,
+ update_access,
+ update_access_provider,
+)
+from .vectordb.types import DbException, SafeDbException
+
+APP_ROLE = get_app_role()
+THREADS = {}
+THREAD_STOP_EVENT = Event()
+LOGGER = logging.getLogger('ccb.task_fetcher')
+FILES_INDEXING_BATCH_SIZE = 16 # theoretical max RAM usage: 16 * 100 MiB, todo: config?
+MIN_FILES_PER_CPU = 4
+# divides the batch into these many chunks
+PARALLEL_FILE_PARSING_COUNT = max(1, (os.cpu_count() or 2) - 1) # todo: config?
+LOGGER.info(f'Using {PARALLEL_FILE_PARSING_COUNT} parallel file parsing workers')
+ACTIONS_BATCH_SIZE = 512 # todo: config?
+POLLING_COOLDOWN = 30
+TRIGGER = Event()
+CHECK_INTERVAL = 5
+CHECK_INTERVAL_WITH_TRIGGER = 5 * 60
+CHECK_INTERVAL_ON_ERROR = 15
+CONTEXT_LIMIT=20
+
+
+class ThreadType(Enum):
+ FILES_INDEXING = 'files_indexing'
+ UPDATES_PROCESSING = 'updates_processing'
+ REQUEST_PROCESSING = 'request_processing'
+
+
+def files_indexing_thread(app_config: TConfig, app_enabled: Event) -> None:
+ try:
+ network_em = NetworkEmbeddings(app_config)
+ vectordb_loader = VectorDBLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e)
+ return
+
+ def _load_sources(source_items: Mapping[int, SourceItem | ReceivedFileItem]) -> Mapping[int, IndexingError | None]:
+ source_refs = [s.reference for s in source_items.values()]
+ LOGGER.info('Starting embed_sources subprocess for %d source(s)', len(source_items), extra={
+ 'source_ids': source_refs,
+ })
+ try:
+ result = exec_in_proc(
+ target=embed_sources,
+ args=(vectordb_loader, app_config, source_items),
+ )
+ errors = {k: v for k, v in result.items() if isinstance(v, IndexingError)}
+ LOGGER.info(
+ 'embed_sources finished for %d source(s): %d succeeded, %d errored',
+ len(source_items), len(result) - len(errors), len(errors),
+ extra={'errors': errors},
+ )
+ return result
+ except SubprocessKilledError as e:
+ LOGGER.error(
+ 'embed_sources subprocess was killed for %d source(s) with exitcode %s',
+ len(source_items), e.exitcode, exc_info=e, extra={
+ 'source_ids': source_refs,
+ },
+ )
+ if len(source_items) == 1:
+ return dict.fromkeys(
+ source_items,
+ IndexingError(error=f'Subprocess killed with exitcode {e.exitcode}: {e}', retryable=False),
+ )
+
+ # Fall back to one-by-one to isolate the problematic file.
+ LOGGER.warning(
+ 'Falling back to individual processing for %d sources',
+ len(source_items),
+ )
+ fallback: dict[int, IndexingError | None] = {}
+ for db_id, item in source_items.items():
+ fallback.update(_load_sources({db_id: item}))
+ return fallback
+ except Exception as e:
+ err = IndexingError(
+ error=f'{e.__class__.__name__}: {e}',
+ retryable=True,
+ )
+ LOGGER.error(
+ 'embed_sources subprocess raised a %s error for %d sources, marking all as retryable',
+ e.__class__.__name__, len(source_refs), exc_info=e, extra={
+ 'source_ids': source_refs,
+ }
+ )
+ return dict.fromkeys(source_items, err)
+
+
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Files indexing thread is stopping due to stop event being set')
+ return
+
+ try:
+ if not network_em.check_connection(ThreadType.FILES_INDEXING.value):
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ nc = NextcloudApp()
+ q_items_res = nc.ocs(
+ 'GET',
+ '/ocs/v2.php/apps/context_chat/queues/documents',
+ params={ 'n': FILES_INDEXING_BATCH_SIZE }
+ )
+
+ try:
+ q_items: FilesQueueItems = FilesQueueItems.model_validate(q_items_res)
+ except ValidationError as e:
+ raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e
+
+ if not q_items.files and not q_items.content_providers:
+ LOGGER.debug('No documents to index')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ files_result = {}
+ providers_result = {}
+
+ # chunk file parsing for better file operation parallelism
+ file_chunk_size = max(MIN_FILES_PER_CPU, math.ceil(len(q_items.files) / PARALLEL_FILE_PARSING_COUNT))
+ file_chunks = [
+ dict(list(q_items.files.items())[i:i+file_chunk_size])
+ for i in range(0, len(q_items.files), file_chunk_size)
+ ]
+ provider_chunk_size = max(
+ MIN_FILES_PER_CPU,
+ math.ceil(len(q_items.content_providers) / PARALLEL_FILE_PARSING_COUNT),
+ )
+ provider_chunks = [
+ dict(list(q_items.content_providers.items())[i:i+provider_chunk_size])
+ for i in range(0, len(q_items.content_providers), provider_chunk_size)
+ ]
+
+ with ThreadPoolExecutor(
+ max_workers=PARALLEL_FILE_PARSING_COUNT,
+ thread_name_prefix='IndexingPool',
+ ) as executor:
+ LOGGER.info(
+ 'Dispatching %d file chunk(s) and %d provider chunk(s) to %d IndexingPool worker(s)',
+ len(file_chunks), len(provider_chunks), PARALLEL_FILE_PARSING_COUNT,
+ )
+ file_futures = [executor.submit(_load_sources, chunk) for chunk in file_chunks]
+ provider_futures = [executor.submit(_load_sources, chunk) for chunk in provider_chunks]
+
+ for i, future in enumerate(file_futures):
+ LOGGER.debug('Waiting for file chunk %d/%d future to complete', i + 1, len(file_futures))
+ files_result.update(future.result())
+ LOGGER.debug('File chunk %d/%d future completed', i + 1, len(file_futures))
+ for i, future in enumerate(provider_futures):
+ LOGGER.debug('Waiting for provider chunk %d/%d future to complete', i + 1, len(provider_futures))
+ providers_result.update(future.result())
+ LOGGER.debug('Provider chunk %d/%d future completed', i + 1, len(provider_futures))
+
+ if (
+ any(isinstance(res, IndexingError) for res in files_result.values())
+ or any(isinstance(res, IndexingError) for res in providers_result.values())
+ ):
+ LOGGER.error('Some sources failed to index', extra={
+ 'file_errors': {
+ db_id: error
+ for db_id, error in files_result.items()
+ if isinstance(error, IndexingError)
+ },
+ 'provider_errors': {
+ provider_id: error
+ for provider_id, error in providers_result.items()
+ if isinstance(error, IndexingError)
+ },
+ })
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error fetching documents to index, will retry:', exc_info=e)
+ sleep(5)
+ continue
+ except Exception as e:
+ LOGGER.exception('Error fetching documents to index:', exc_info=e)
+ sleep(5)
+ continue
+
+ # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable
+ to_delete_files_db_ids = [
+ db_id for db_id, result in files_result.items()
+ if result is None or (isinstance(result, IndexingError) and not result.retryable)
+ ]
+ to_delete_provider_db_ids = [
+ db_id for db_id, result in providers_result.items()
+ if result is None or (isinstance(result, IndexingError) and not result.retryable)
+ ]
+
+ try:
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/documents/',
+ json={
+ 'files': to_delete_files_db_ids,
+ 'content_providers': to_delete_provider_db_ids,
+ },
+ )
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error reporting indexing results, will retry:', exc_info=e)
+ sleep(5)
+ with suppress(Exception):
+ nc = NextcloudApp()
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/documents/',
+ json={
+ 'files': to_delete_files_db_ids,
+ 'content_providers': to_delete_provider_db_ids,
+ },
+ )
+ continue
+ except Exception as e:
+ LOGGER.exception('Error reporting indexing results:', exc_info=e)
+ sleep(5)
+ continue
+
+
+
+def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None:
+ try:
+ vectordb_loader = VectorDBLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e)
+ return
+
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Updates processing thread is stopping due to stop event being set')
+ return
+
+ try:
+ nc = NextcloudApp()
+ q_items_res = nc.ocs(
+ 'GET',
+ '/ocs/v2.php/apps/context_chat/queues/actions',
+ params={ 'n': ACTIONS_BATCH_SIZE }
+ )
+
+ try:
+ q_items: ActionsQueueItems = ActionsQueueItems.model_validate(q_items_res)
+ except ValidationError as e:
+ raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error fetching updates to process, will retry:', exc_info=e)
+ sleep(5)
+ continue
+ except Exception as e:
+ LOGGER.exception('Error fetching updates to process:', exc_info=e)
+ sleep(5)
+ continue
+
+ if not q_items.actions:
+ LOGGER.debug('No updates to process')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ processed_event_ids = []
+ errored_events = {}
+ for i, (db_id, action_item) in enumerate(q_items.actions.items()):
+ try:
+ match action_item.type:
+ case ActionType.DELETE_SOURCE_IDS:
+ exec_in_proc(target=delete_by_source, args=(vectordb_loader, action_item.payload.sourceIds))
+
+ case ActionType.DELETE_PROVIDER_ID:
+ exec_in_proc(target=delete_by_provider, args=(vectordb_loader, action_item.payload.providerId))
+
+ case ActionType.DELETE_USER_ID:
+ exec_in_proc(target=delete_user, args=(vectordb_loader, action_item.payload.userId))
+
+ case ActionType.UPDATE_ACCESS_SOURCE_ID:
+ exec_in_proc(
+ target=update_access,
+ args=(
+ vectordb_loader,
+ action_item.payload.op,
+ action_item.payload.userIds,
+ action_item.payload.sourceId,
+ ),
+ )
+
+ case ActionType.UPDATE_ACCESS_PROVIDER_ID:
+ exec_in_proc(
+ target=update_access_provider,
+ args=(
+ vectordb_loader,
+ action_item.payload.op,
+ action_item.payload.userIds,
+ action_item.payload.providerId,
+ ),
+ )
+
+ case ActionType.UPDATE_ACCESS_DECL_SOURCE_ID:
+ exec_in_proc(
+ target=decl_update_access,
+ args=(
+ vectordb_loader,
+ action_item.payload.userIds,
+ action_item.payload.sourceId,
+ ),
+ )
+
+ case _:
+ LOGGER.warning(
+ f'Unknown action type {action_item.type} for action id {db_id},'
+ f' type {action_item.type}, skipping and marking as processed',
+ extra={ 'action_item': action_item },
+ )
+ continue
+
+ processed_event_ids.append(db_id)
+ except SafeDbException as e:
+ LOGGER.debug(
+ f'Safe DB error thrown while processing action id {db_id}, type {action_item.type},'
+ " it's safe to ignore and mark as processed.",
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ processed_event_ids.append(db_id)
+ continue
+
+ except (LoaderException, DbException) as e:
+ LOGGER.error(
+ f'Error deleting source for action id {db_id}, type {action_item.type}: {e}',
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ errored_events[db_id] = str(e)
+ continue
+
+ except Exception as e:
+ LOGGER.error(
+ f'Unexpected error processing action id {db_id}, type {action_item.type}: {e}',
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ errored_events[db_id] = f'Unexpected error: {e}'
+ continue
+
+ if (i + 1) % 20 == 0:
+ LOGGER.debug(f'Processed {i + 1} updates, sleeping for a bit to allow other operations to proceed')
+ sleep(2)
+
+ LOGGER.info(f'Processed {len(processed_event_ids)} updates with {len(errored_events)} errors', extra={
+ 'errored_events': errored_events,
+ })
+
+ if len(processed_event_ids) == 0:
+ LOGGER.debug('No updates processed, skipping reporting to the server')
+ continue
+
+ try:
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/actions/',
+ json={ 'actions': processed_event_ids },
+ )
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error reporting processed updates, will retry:', exc_info=e)
+ sleep(5)
+ with suppress(Exception):
+ nc = NextcloudApp()
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/actions/',
+ json={ 'ids': processed_event_ids },
+ )
+ continue
+ except Exception as e:
+ LOGGER.exception('Error reporting processed updates:', exc_info=e)
+ sleep(5)
+ continue
+
+
+def resolve_scope_list(source_ids: list[str], userId: str) -> list[str]:
+ """
+
+ Parameters
+ ----------
+ source_ids
+
+ Returns
+ -------
+ source_ids with only files, no folders (or source_ids in case of non-file provider)
+ """
+ nc = NextcloudApp()
+ data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/resolve_scope_list', json={
+ 'source_ids': source_ids,
+ 'userId': userId,
+ })
+ return ScopeList.model_validate(data).source_ids
+
+
+def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None:
+ LOGGER.info('Starting task fetcher loop')
+
+ try:
+ network_em = NetworkEmbeddings(app_config)
+ vectordb_loader = VectorDBLoader(app_config)
+ llm_loader = LLMModelLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e)
+ return
+
+ nc = NextcloudApp()
+ llm: LLM = llm_loader.load()
+
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Updates processing thread is stopping due to stop event being set')
+ return
+
+ if not network_em.check_connection(ThreadType.REQUEST_PROCESSING.value):
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ try:
+ # Fetch pending task
+ try:
+ response = nc.providers.task_processing.next_task(
+ ['context_chat-context_chat', 'context_chat-context_chat_search'],
+ ['context_chat:context_chat', 'context_chat:context_chat_search'],
+ )
+ if not response:
+ wait_for_tasks()
+ continue
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ LOGGER.error(f"Network error fetching the next task {e}", exc_info=e)
+ wait_for_tasks(CHECK_INTERVAL_ON_ERROR)
+ continue
+
+ # Process task
+ task = response["task"]
+ userId = task['userId']
+
+ try:
+ LOGGER.debug(f'Processing task {task["id"]}')
+
+ if task['input'].get('scopeType') == 'source':
+ # Resolve scope list to only files, no folders
+ task['input']['scopeList'] = resolve_scope_list(task['input'].get('scopeList'), userId)
+
+ if task['type'] == 'context_chat:context_chat':
+ result: LLMOutput = process_normal_task(task, vectordb_loader, llm, app_config)
+ # Return result to Nextcloud
+ success = return_result_to_nextcloud(task['id'], userId, {
+ 'output': result['output'],
+ 'sources': enrich_sources(result['sources'], userId),
+ })
+ elif task['type'] == 'context_chat:context_chat_search':
+ search_result: list[SearchResult] = process_search_task(task, vectordb_loader)
+ # Return result to Nextcloud
+ success = return_result_to_nextcloud(task['id'], userId, {
+ 'sources': enrich_sources(search_result, userId),
+ })
+ else:
+ LOGGER.error(f'Unknown task type {task["type"]}')
+ success = return_error_to_nextcloud(task['id'], Exception(f'Unknown task type {task["type"]}'))
+
+ if success:
+ LOGGER.info(f'Task {task["id"]} completed successfully')
+ else:
+ LOGGER.error(f'Failed to return result for task {task["id"]}')
+
+ except ContextException as e:
+ LOGGER.warning(f'Context error for task {task["id"]}: {e}')
+ return_error_to_nextcloud(task['id'], e)
+ except ValueError as e:
+ LOGGER.warning(f'Validation error for task {task["id"]}: {e}')
+ return_error_to_nextcloud(task['id'], e)
+ except Exception as e:
+ LOGGER.exception(f'Unexpected error processing task {task["id"]}', exc_info=e)
+ return_error_to_nextcloud(task['id'], e)
+
+ except Exception as e:
+ LOGGER.exception('Error in task fetcher loop', exc_info=e)
+ wait_for_tasks(CHECK_INTERVAL_ON_ERROR)
+
+def trigger_handler(providerId: str):
+ global TRIGGER
+ print('TRIGGER called')
+ TRIGGER.set()
+
+def wait_for_tasks(interval = None):
+ global TRIGGER
+ global CHECK_INTERVAL
+ global CHECK_INTERVAL_WITH_TRIGGER
+ actual_interval = CHECK_INTERVAL if interval is None else interval
+ if TRIGGER.wait(timeout=actual_interval):
+ CHECK_INTERVAL = CHECK_INTERVAL_WITH_TRIGGER
+ TRIGGER.clear()
+
+
+def enrich_sources(results: list[SearchResult], userId: str) -> list[str]:
+ nc = NextcloudApp()
+ data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/enrich_sources', json={'sources': results, 'userId': userId})
+ sources = EnrichedSourceList.model_validate(data).sources
+ return [s.model_dump_json() for s in sources]
+
+
+def return_result_to_nextcloud(task_id: int, userId: str, result: dict[str, Any]) -> bool:
+ """
+ Return query result back to Nextcloud.
+
+ Args:
+ result: dict[str, Any]
+
+ Returns:
+ True if successful, False otherwise
+ """
+ LOGGER.debug('Returning result to Nextcloud', extra={
+ 'task_id': task_id,
+ 'result': result,
+ })
+
+ nc = NextcloudApp()
+
+ try:
+ nc.providers.task_processing.report_result(task_id, result)
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ LOGGER.error(f"Network error reporting task result {e}", exc_info=e)
+ return False
+
+ return True
+
+
+def return_error_to_nextcloud(task_id: int, e: Exception) -> bool:
+ """
+ Return error result back to Nextcloud.
+
+ Args:
+ task_id: Unique task identifier
+ e: error object
+
+ Returns:
+ True if successful, False otherwise
+ """
+ LOGGER.debug('Returning error to Nextcloud', exc_info=e)
+
+ nc = NextcloudApp()
+
+ if isinstance(e, ValueError):
+ message = "Validation error: " + str(e)
+ elif isinstance(e, ContextException):
+ message = "Context error" + str(e)
+ else:
+ message = "Unexpected error" + str(e)
+
+ try:
+ nc.providers.task_processing.report_result(task_id, None, message)
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ LOGGER.error(f"Network error reporting task result {e}", exc_info=e)
+ return False
+
+ return True
+
+
+def process_normal_task(
+ task: dict[str, Any],
+ vectordb_loader: VectorDBLoader,
+ llm: LLM,
+ app_config: TConfig,
+) -> LLMOutput:
+ """
+ Process a single query task.
+
+ Args:
+ task: Task dictionary from fetch_query_tasks_from_nextcloud
+ vectordb_loader: Vector database loader instance
+ llm: Language model instance
+ app_config: Application configuration
+
+ Returns:
+ LLMOutput with generated text and sources
+
+ Raises:
+ Various exceptions from query execution
+ """
+ user_id = task['userId']
+ task_input = task['input']
+ if task_input.get('scopeType') == 'none':
+ task_input['scopeType'] = None
+
+ # todo: document no template support
+ return exec_in_proc(target=process_context_query,
+ args=(
+ user_id,
+ vectordb_loader,
+ llm,
+ app_config,
+ task_input.get('prompt'),
+ CONTEXT_LIMIT,
+ task_input.get('scopeType'),
+ task_input.get('scopeList'),
+ )
+ )
+
+def process_search_task(
+ task: dict[str, Any],
+ vectordb_loader: VectorDBLoader,
+) -> list[SearchResult]:
+ """
+ Process a single search task.
+
+ Args:
+ task: Task dictionary from fetch_query_tasks_from_nextcloud
+ vectordb_loader: Vector database loader instance
+
+ Returns:
+ list of Search results
+
+ Raises:
+ Various exceptions from query execution
+ """
+ user_id = task['userId']
+ task_input = task['input']
+ if task_input.get('scopeType') == 'none':
+ task_input['scopeType'] = None
+
+ return exec_in_proc(target=do_doc_search,
+ args=(
+ user_id,
+ task_input.get('prompt'),
+ vectordb_loader,
+ CONTEXT_LIMIT,
+ task_input.get('scopeType'),
+ task_input.get('scopeList'),
+ )
+ )
+
+
+def start_bg_threads(app_config: TConfig, app_enabled: Event):
+ if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL:
+ if (
+ ThreadType.FILES_INDEXING in THREADS
+ or ThreadType.UPDATES_PROCESSING in THREADS
+ ):
+ LOGGER.info('Background threads already running, skipping start')
+ return
+
+ THREAD_STOP_EVENT.clear()
+ THREADS[ThreadType.FILES_INDEXING] = Thread(
+ target=files_indexing_thread,
+ args=(app_config, app_enabled),
+ name='FilesIndexingThread',
+ )
+ THREADS[ThreadType.UPDATES_PROCESSING] = Thread(
+ target=updates_processing_thread,
+ args=(app_config, app_enabled),
+ name='UpdatesProcessingThread',
+ )
+ THREADS[ThreadType.FILES_INDEXING].start()
+ THREADS[ThreadType.UPDATES_PROCESSING].start()
+
+ if APP_ROLE == AppRole.RP or APP_ROLE == AppRole.NORMAL:
+ if ThreadType.REQUEST_PROCESSING in THREADS:
+ LOGGER.info('Background threads already running, skipping start')
+ return
+
+ THREAD_STOP_EVENT.clear()
+ THREADS[ThreadType.REQUEST_PROCESSING] = Thread(
+ target=request_processing_thread,
+ args=(app_config, app_enabled),
+ name='RequestProcessingThread',
+ )
+ THREADS[ThreadType.REQUEST_PROCESSING].start()
+
+
+def wait_for_bg_threads():
+ if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL:
+ if (ThreadType.FILES_INDEXING not in THREADS or ThreadType.UPDATES_PROCESSING not in THREADS):
+ return
+
+ THREAD_STOP_EVENT.set()
+ THREADS[ThreadType.FILES_INDEXING].join()
+ THREADS[ThreadType.UPDATES_PROCESSING].join()
+ THREADS.pop(ThreadType.FILES_INDEXING)
+ THREADS.pop(ThreadType.UPDATES_PROCESSING)
+
+ if APP_ROLE == AppRole.RP or APP_ROLE == AppRole.NORMAL:
+ if (ThreadType.REQUEST_PROCESSING not in THREADS):
+ return
+
+ THREAD_STOP_EVENT.set()
+ THREADS[ThreadType.REQUEST_PROCESSING].join()
+ THREADS.pop(ThreadType.REQUEST_PROCESSING)
diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py
index 500a97d0..59d2568f 100644
--- a/context_chat_backend/types.py
+++ b/context_chat_backend/types.py
@@ -2,7 +2,16 @@
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-from pydantic import BaseModel
+import re
+from collections.abc import Mapping
+from enum import Enum
+from io import BytesIO
+from typing import Annotated, Literal, Self
+
+from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator, model_validator
+
+from .mimetype_list import SUPPORTED_MIMETYPES
+from .vectordb.types import UpdateAccessOp
__all__ = [
'DEFAULT_EM_MODEL_ALIAS',
@@ -15,6 +24,65 @@
]
DEFAULT_EM_MODEL_ALIAS = 'em_model'
+FILES_PROVIDER_ID = 'files__default'
+
+
+def is_valid_source_id(source_id: str) -> bool:
+ # note the ":" in the item id part
+ return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None
+
+
+def is_valid_provider_id(provider_id: str) -> bool:
+ return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None
+
+
+def _validate_source_ids(source_ids: list[str]) -> list[str]:
+ if (
+ not isinstance(source_ids, list)
+ or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids)
+ or len(source_ids) == 0
+ ):
+ raise ValueError('sourceIds must be a non-empty list of non-empty strings')
+ return [sid.strip() for sid in source_ids]
+
+
+def _validate_source_id(source_id: str) -> str:
+ return _validate_source_ids([source_id])[0]
+
+
+def _validate_provider_id(provider_id: str) -> str:
+ if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id):
+ raise ValueError('providerId must be a valid provider ID string')
+ return provider_id
+
+
+def _validate_user_ids(user_ids: list[str]) -> list[str]:
+ if (
+ not isinstance(user_ids, list)
+ or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids)
+ or len(user_ids) == 0
+ ):
+ raise ValueError('userIds must be a non-empty list of non-empty strings')
+ return [uid.strip() for uid in user_ids]
+
+
+def _validate_user_id(user_id: str) -> str:
+ return _validate_user_ids([user_id])[0]
+
+
+def _get_file_id_from_source_ref(source_ref: str) -> int:
+ '''
+ source reference is in the format "FILES_PROVIDER_ID: ".
+ '''
+ if not source_ref.startswith(f'{FILES_PROVIDER_ID}: '):
+ raise ValueError(f'Source reference does not start with expected prefix: {source_ref}')
+
+ try:
+ return int(source_ref[len(f'{FILES_PROVIDER_ID}: '):])
+ except ValueError as e:
+ raise ValueError(
+ f'Invalid source reference format for extracting file_id: {source_ref}'
+ ) from e
class TEmbeddingAuthApiKey(BaseModel):
@@ -71,3 +139,209 @@ class FatalEmbeddingException(EmbeddingException):
Either malformed request, authentication error, or other non-retryable error.
"""
+
+
+class AppRole(str, Enum):
+ NORMAL = 'normal'
+ INDEXING = 'indexing'
+ RP = 'rp'
+
+
+class CommonSourceItem(BaseModel):
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ # source_id of the form "appId__providerId: itemId"
+ reference: Annotated[str, AfterValidator(_validate_source_id)]
+ title: str
+ modified: int
+ type: str
+ provider: Annotated[str, AfterValidator(_validate_provider_id)]
+ size: float
+
+ @field_validator('modified', mode='before')
+ @classmethod
+ def validate_modified(cls, v):
+ if isinstance(v, int):
+ return v
+ if isinstance(v, str):
+ try:
+ return int(v)
+ except ValueError as e:
+ raise ValueError(f'Invalid modified value: {v}') from e
+ raise ValueError(f'Invalid modified type: {type(v)}')
+
+ @field_validator('reference', 'title', 'type', 'provider')
+ @classmethod
+ def validate_strings_non_empty(cls, v):
+ if not isinstance(v, str) or v.strip() == '':
+ raise ValueError('Must be a non-empty string')
+ return v.strip()
+
+ @field_validator('size')
+ @classmethod
+ def validate_size(cls, v):
+ if isinstance(v, int | float) and v >= 0:
+ return float(v)
+ raise ValueError(f'Invalid size value: {v}, must be a non-negative number')
+
+ @model_validator(mode='after')
+ def validate_type(self) -> Self:
+ if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES:
+ raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}')
+ return self
+
+
+class ReceivedFileItem(CommonSourceItem):
+ content: None
+
+ @computed_field
+ @property
+ def file_id(self) -> int:
+ return _get_file_id_from_source_ref(self.reference)
+
+
+class SourceItem(CommonSourceItem):
+ '''
+ Used for the unified queue of items to process, after fetching the content for files
+ and for directly fetched content providers.
+ '''
+ content: str | BytesIO
+
+ @field_validator('content')
+ @classmethod
+ def validate_content(cls, v):
+ if isinstance(v, str):
+ if v.strip() == '':
+ raise ValueError('Content must be a non-empty string')
+ return v.strip()
+ if isinstance(v, BytesIO):
+ if v.getbuffer().nbytes == 0:
+ raise ValueError('Content must be a non-empty BytesIO')
+ return v
+ raise ValueError('Content must be either a non-empty string or a non-empty BytesIO')
+
+ class Config:
+ # to allow BytesIO in content field
+ arbitrary_types_allowed = True
+
+
+class FilesQueueItems(BaseModel):
+ files: Mapping[int, ReceivedFileItem] # [db id]: FileItem
+ content_providers: Mapping[int, SourceItem] # [db id]: SourceItem
+
+
+class IndexingException(Exception):
+ retryable: bool = False
+
+ def __init__(self, message: str, retryable: bool = False):
+ super().__init__(message)
+ self.retryable = retryable
+
+
+class IndexingError(BaseModel):
+ error: str
+ retryable: bool = False
+
+
+# PHP equivalent for reference:
+
+# class ActionType {
+# // { sourceIds: array }
+# public const DELETE_SOURCE_IDS = 'delete_source_ids';
+# // { providerId: string }
+# public const DELETE_PROVIDER_ID = 'delete_provider_id';
+# // { userId: string }
+# public const DELETE_USER_ID = 'delete_user_id';
+# // { op: string, userIds: array, sourceId: string }
+# public const UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id';
+# // { op: string, userIds: array, providerId: string }
+# public const UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id';
+# // { userIds: array, sourceId: string }
+# public const UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id';
+# }
+
+
+class ActionPayloadDeleteSourceIds(BaseModel):
+ sourceIds: Annotated[list[str], AfterValidator(_validate_source_ids)]
+
+
+class ActionPayloadDeleteProviderId(BaseModel):
+ providerId: Annotated[str, AfterValidator(_validate_provider_id)]
+
+
+class ActionPayloadDeleteUserId(BaseModel):
+ userId: Annotated[str, AfterValidator(_validate_user_id)]
+
+
+class ActionPayloadUpdateAccessSourceId(BaseModel):
+ op: UpdateAccessOp
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ sourceId: Annotated[str, AfterValidator(_validate_source_id)]
+
+
+class ActionPayloadUpdateAccessProviderId(BaseModel):
+ op: UpdateAccessOp
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ providerId: Annotated[str, AfterValidator(_validate_provider_id)]
+
+
+class ActionPayloadUpdateAccessDeclSourceId(BaseModel):
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ sourceId: Annotated[str, AfterValidator(_validate_source_id)]
+
+
+class ActionType(str, Enum):
+ DELETE_SOURCE_IDS = 'delete_source_ids'
+ DELETE_PROVIDER_ID = 'delete_provider_id'
+ DELETE_USER_ID = 'delete_user_id'
+ UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id'
+ UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id'
+ UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id'
+
+
+class CommonActionsQueueItem(BaseModel):
+ id: int
+
+
+class ActionsQueueItemDeleteSourceIds(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_SOURCE_IDS]
+ payload: ActionPayloadDeleteSourceIds
+
+
+class ActionsQueueItemDeleteProviderId(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_PROVIDER_ID]
+ payload: ActionPayloadDeleteProviderId
+
+
+class ActionsQueueItemDeleteUserId(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_USER_ID]
+ payload: ActionPayloadDeleteUserId
+
+
+class ActionsQueueItemUpdateAccessSourceId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_SOURCE_ID]
+ payload: ActionPayloadUpdateAccessSourceId
+
+
+class ActionsQueueItemUpdateAccessProviderId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_PROVIDER_ID]
+ payload: ActionPayloadUpdateAccessProviderId
+
+
+class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_DECL_SOURCE_ID]
+ payload: ActionPayloadUpdateAccessDeclSourceId
+
+
+ActionsQueueItem = Annotated[
+ ActionsQueueItemDeleteSourceIds
+ | ActionsQueueItemDeleteProviderId
+ | ActionsQueueItemDeleteUserId
+ | ActionsQueueItemUpdateAccessSourceId
+ | ActionsQueueItemUpdateAccessProviderId
+ | ActionsQueueItemUpdateAccessDeclSourceId,
+ Discriminator('type'),
+]
+
+
+class ActionsQueueItems(BaseModel):
+ actions: Mapping[int, ActionsQueueItem]
diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py
index f6d6e672..4552e320 100644
--- a/context_chat_backend/utils.py
+++ b/context_chat_backend/utils.py
@@ -2,11 +2,15 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import faulthandler
+import io
import logging
import multiprocessing as mp
-import re
+import os
+import sys
import traceback
from collections.abc import Callable
+from contextlib import suppress
from functools import partial, wraps
from multiprocessing.connection import Connection
from time import perf_counter_ns
@@ -14,10 +18,11 @@
from fastapi.responses import JSONResponse as FastAPIJSONResponse
-from .types import TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig
+from .types import AppRole, TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig
T = TypeVar('T')
_logger = logging.getLogger('ccb.utils')
+_MAX_STD_CAPTURE_CHARS = 64 * 1024
def not_none(value: T | None) -> TypeGuard[T]:
@@ -69,19 +74,98 @@ def JSONResponse(
return FastAPIJSONResponse(content, status_code, **kwargs)
-def exception_wrap(fun: Callable | None, *args, resconn: Connection, **kwargs):
+class SubprocessKilledError(RuntimeError):
+ """Raised when a subprocess is terminated by a signal (for example SIGKILL)."""
+
+ def __init__(self, pid: int | None, target_name: str, exitcode: int):
+ super().__init__(
+ f'Subprocess PID {pid} for {target_name} exited with signal {abs(exitcode)} '
+ f'(raw exit code: {exitcode})'
+ )
+ self.exitcode = exitcode
+
+
+class SubprocessExecutionError(RuntimeError):
+ """Raised when a subprocess exits without a recoverable Python exception payload."""
+
+ def __init__(self, pid: int | None, target_name: str, exitcode: int, details: str = ''):
+ msg = f'Subprocess PID {pid} for {target_name} exited with exit code {exitcode}'
+ if details:
+ msg = f'{msg}: {details}'
+ super().__init__(msg)
+ self.exitcode = exitcode
+
+
+def _truncate_capture(text: str) -> str:
+ if len(text) <= _MAX_STD_CAPTURE_CHARS:
+ return text
+
+ head = _MAX_STD_CAPTURE_CHARS // 2
+ tail = _MAX_STD_CAPTURE_CHARS - head
+ omitted = len(text) - _MAX_STD_CAPTURE_CHARS
+ return (
+ f'[truncated {omitted} chars]\n'
+ f'{text[:head]}\n'
+ '[...snip...]\n'
+ f'{text[-tail:]}'
+ )
+
+
+def exception_wrap(fun: Callable | None, *args, resconn: Connection, stdconn: Connection, **kwargs):
+ # Preserve real stderr FD for faulthandler before we redirect sys.stderr.
+ _faulthandler_fd = os.dup(2)
+ with suppress(Exception):
+ faulthandler.enable(
+ file=os.fdopen(_faulthandler_fd, 'w', closefd=False),
+ all_threads=True,
+ )
+
+ stdout_capture = io.StringIO()
+ stderr_capture = io.StringIO()
+ orig_stdout = sys.stdout
+ orig_stderr = sys.stderr
+ sys.stdout = stdout_capture
+ sys.stderr = stderr_capture
+
try:
if fun is None:
- return resconn.send({ 'value': None, 'error': None })
- resconn.send({ 'value': fun(*args, **kwargs), 'error': None })
- except Exception as e:
+ resconn.send({ 'value': None, 'error': None })
+ else:
+ resconn.send({ 'value': fun(*args, **kwargs), 'error': None })
+ except BaseException as e:
tb = traceback.format_exc()
- resconn.send({ 'value': None, 'error': e, 'traceback': tb })
-
-
-def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): # noqa: B006
+ payload = {
+ 'value': None,
+ 'error': e,
+ 'traceback': tb,
+ }
+ try:
+ resconn.send(payload)
+ except Exception as send_err:
+ stderr_capture.write(f'Original error: {e}, pipe send error: {send_err}')
+ finally:
+ sys.stdout = orig_stdout
+ sys.stderr = orig_stderr
+ stdout_text = _truncate_capture(stdout_capture.getvalue())
+ stderr_text = _truncate_capture(stderr_capture.getvalue())
+ with suppress(Exception):
+ stdconn.send({
+ 'stdout': stdout_text,
+ 'stderr': stderr_text,
+ })
+ with suppress(Exception):
+ os.close(_faulthandler_fd)
+
+
+def exec_in_proc(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None):
+ if not kwargs:
+ kwargs = {}
+
+ # parent, child
pconn, cconn = mp.Pipe()
+ std_pconn, std_cconn = mp.Pipe()
kwargs['resconn'] = cconn
+ kwargs['stdconn'] = std_cconn
p = mp.Process(
group=group,
target=partial(exception_wrap, target),
@@ -90,24 +174,92 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem
kwargs=kwargs,
daemon=daemon,
)
+ target_name = getattr(target, '__name__', str(target))
+ start = perf_counter_ns()
p.start()
+ _logger.debug('Subprocess PID %d started for %s', p.pid, target_name)
+
+ result = None
+ stdobj = { 'stdout': '', 'stderr': '' }
+ got_result = False
+ got_std = False
+
+ # Drain result/std pipes while child is still alive to avoid deadlock on full pipe buffers.
+ # Pipe's buffer size is 64 KiB
+ while p.is_alive() and (not got_result or not got_std):
+ if not got_result and pconn.poll(0.1):
+ with suppress(EOFError, OSError, BrokenPipeError):
+ result = pconn.recv()
+ got_result = True
+ if not got_std and std_pconn.poll():
+ with suppress(EOFError, OSError, BrokenPipeError):
+ stdobj = std_pconn.recv()
+ got_std = True
+
p.join()
+ elapsed_ms = (perf_counter_ns() - start) / 1e6
+ _logger.debug(
+ 'Subprocess PID %d for %s finished in %.2f ms (exit code: %s)',
+ p.pid, target_name, elapsed_ms, p.exitcode,
+ )
- result = pconn.recv()
- if result['error'] is not None:
- _logger.error('original traceback: %s', result['traceback'])
+ if not got_std:
+ with suppress(EOFError, OSError, BrokenPipeError):
+ if std_pconn.poll():
+ stdobj = std_pconn.recv()
+ # no need to update got_std here
+ if stdobj.get('stdout') or stdobj.get('stderr'):
+ _logger.info('std info for %s', target_name, extra={
+ 'stdout': stdobj.get('stdout', ''),
+ 'stderr': stdobj.get('stderr', ''),
+ })
+
+ if not got_result:
+ with suppress(EOFError, OSError, BrokenPipeError):
+ if pconn.poll():
+ result = pconn.recv()
+ # no need to update got_result here
+
+ if result is not None and result.get('error') is not None:
+ _logger.error(
+ 'original traceback of %s (PID %d, exitcode: %s): %s',
+ target_name,
+ p.pid,
+ p.exitcode,
+ result.get('traceback', ''),
+ )
raise result['error']
- return result['value']
-
-
-def is_valid_source_id(source_id: str) -> bool:
- # note the ":" in the item id part
- return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None
+ if result is not None and 'value' in result:
+ if p.exitcode not in (None, 0):
+ _logger.warning(
+ 'Subprocess PID %d for %s exited with code %s after %.2f ms'
+ ' but returned a valid result',
+ p.pid, target_name, p.exitcode, elapsed_ms,
+ )
+ return result['value']
+ if p.exitcode and p.exitcode < 0:
+ _logger.warning(
+ 'Subprocess PID %d for %s exited due to signal %d, exitcode %d after %.2f ms',
+ p.pid, target_name, abs(p.exitcode), p.exitcode, elapsed_ms,
+ )
+ raise SubprocessKilledError(p.pid, target_name, p.exitcode)
+
+ if p.exitcode not in (None, 0):
+ raise SubprocessExecutionError(
+ p.pid,
+ target_name,
+ p.exitcode,
+ f'No structured exception payload received from child process: {result}',
+ )
-def is_valid_provider_id(provider_id: str) -> bool:
- return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None
+ raise SubprocessExecutionError(
+ p.pid,
+ target_name,
+ 0,
+ f'Subprocess exited successfully but returned no result payload: {result}',
+ )
def timed(func: Callable):
@@ -144,3 +296,13 @@ def redact_config(config: TConfig | TEmbeddingConfig) -> TConfig | TEmbeddingCon
em_conf.auth.password = '***REDACTED***' # noqa: S105
return config_copy
+
+
+def get_app_role() -> AppRole:
+ role = os.getenv('APP_ROLE', '').lower()
+ if role == '':
+ return AppRole.NORMAL
+ if role not in ['indexing', 'rp']:
+ _logger.warning(f'Invalid app role: {role}, defaulting to all roles')
+ return AppRole.NORMAL
+ return AppRole(role)
diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py
index 0bf10200..2b4aa35e 100644
--- a/context_chat_backend/vectordb/base.py
+++ b/context_chat_backend/vectordb/base.py
@@ -3,14 +3,15 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
from abc import ABC, abstractmethod
+from collections.abc import Mapping
from typing import Any
-from fastapi import UploadFile
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from ..chain.types import InDocument, ScopeType
+from ..types import IndexingError, ReceivedFileItem, SourceItem
from ..utils import timed
from .types import UpdateAccessOp
@@ -62,7 +63,7 @@ def get_instance(self) -> VectorStore:
'''
@abstractmethod
- def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list[str]]:
+ def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]:
'''
Adds the given indocuments to the vectordb and updates the docs + access tables.
@@ -79,10 +80,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list
@timed
@abstractmethod
- def check_sources(
- self,
- sources: list[UploadFile],
- ) -> tuple[list[str], list[str]]:
+ def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]:
'''
Checks the sources in the vectordb if they are already embedded
and are up to date.
@@ -91,8 +89,8 @@ def check_sources(
Args
----
- sources: list[UploadFile]
- List of source ids to check.
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+ Dict of sources to check.
Returns
-------
diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py
index 2b7fc060..d7b718dc 100644
--- a/context_chat_backend/vectordb/pgvector.py
+++ b/context_chat_backend/vectordb/pgvector.py
@@ -4,21 +4,29 @@
#
import logging
import os
+from collections.abc import Mapping
from datetime import datetime
+from time import perf_counter_ns
import psycopg
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as postgresql_dialects
import sqlalchemy.orm as orm
from dotenv import load_dotenv
-from fastapi import UploadFile
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from langchain_core.embeddings import Embeddings
from langchain_postgres.vectorstores import Base, PGVector
from ..chain.types import InDocument, ScopeType
-from ..types import EmbeddingException, RetryableEmbeddingException
+from ..types import (
+ EmbeddingException,
+ FatalEmbeddingException,
+ IndexingError,
+ ReceivedFileItem,
+ RetryableEmbeddingException,
+ SourceItem,
+)
from ..utils import timed
from .base import BaseVectorDB
from .types import DbException, SafeDbException, UpdateAccessOp
@@ -112,7 +120,15 @@ def __init__(self, embedding: Embeddings | None = None, **kwargs):
kwargs['connection'] = os.environ['CCB_DB_URL']
# setup langchain db + our access list table
- self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs)
+ try:
+ self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs)
+ except sa.exc.IntegrityError as ie: # pyright: ignore[reportAttributeAccessIssue]
+ if not isinstance(ie.orig, psycopg.errors.UniqueViolation):
+ raise
+
+ # tried to create the tables but it was already created in another process
+ # init the client again to detect it already exists, and continue from there
+ self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs)
def get_instance(self) -> VectorStore:
return self.client
@@ -130,24 +146,40 @@ def get_users(self) -> list[str]:
except Exception as e:
raise DbException('Error: getting a list of all users from access list') from e
- def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]:
+ def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]:
"""
Raises
EmbeddingException: if the embedding request definitively fails
"""
- added_sources = []
- retry_sources = []
+ results = {}
batch_size = PG_BATCH_SIZE // 5
with self.session_maker() as session:
- for indoc in indocuments:
+ for php_db_id, indoc in indocuments.items():
try:
# query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
# so we chunk the documents into (5 values * 10k) chunks
# change the chunk size when there are more inserted values per document
chunk_ids = []
- for i in range(0, len(indoc.documents), batch_size):
+ total_chunks = len(indoc.documents)
+ num_batches = max(1, -(-total_chunks // batch_size)) # ceiling division
+ logger.debug(
+ 'Embedding source %s: %d chunk(s) in %d batch(es)',
+ indoc.source_id, total_chunks, num_batches,
+ )
+ for i in range(0, total_chunks, batch_size):
+ batch_num = i // batch_size + 1
+ logger.debug(
+ 'Sending embedding batch %d/%d (%d chunk(s)) for source %s',
+ batch_num, num_batches, len(indoc.documents[i:i+batch_size]), indoc.source_id,
+ )
+ t0 = perf_counter_ns()
chunk_ids.extend(self.client.add_documents(indoc.documents[i:i+batch_size]))
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.debug(
+ 'Embedding batch %d/%d for source %s completed in %.2f ms',
+ batch_num, num_batches, indoc.source_id, elapsed_ms,
+ )
doc = DocumentsStore(
source_id=indoc.source_id,
@@ -170,7 +202,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis
)
self.decl_update_access(indoc.userIds, indoc.source_id, session)
- added_sources.append(indoc.source_id)
+ results[php_db_id] = None
session.commit()
except SafeDbException as e:
# for when the source_id is not found. This here can be an error in the DB
@@ -178,51 +210,62 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis
logger.exception('Error adding documents to vectordb', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=True,
+ )
continue
- except RetryableEmbeddingException as e:
+ except FatalEmbeddingException as e:
+ raise EmbeddingException(
+ f'Fatal error while embedding documents for source {indoc.source_id}: {e}'
+ ) from e
+ except (RetryableEmbeddingException, EmbeddingException) as e:
# temporary error, continue with the next document
logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=True,
+ )
continue
- except EmbeddingException as e:
- logger.exception('Error adding documents to vectordb', exc_info=e, extra={
- 'source_id': indoc.source_id,
- })
- raise
except Exception as e:
logger.exception('Error adding documents to vectordb', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error='An unexpected error occurred while adding documents to the database.',
+ retryable=True,
+ )
continue
- return added_sources, retry_sources
+ return results
@timed
- def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]]:
+ def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]:
+ '''
+ returns a tuple of (existing_source_ids, to_embed_source_ids)
+ '''
with self.session_maker() as session:
try:
stmt = (
sa.select(DocumentsStore.source_id)
- .filter(DocumentsStore.source_id.in_([source.filename for source in sources]))
+ .filter(DocumentsStore.source_id.in_([source.reference for source in sources.values()]))
.with_for_update()
)
results = session.execute(stmt).fetchall()
existing_sources = {r.source_id for r in results}
- to_embed = [source.filename for source in sources if source.filename not in existing_sources]
+ to_embed = [source.reference for source in sources.values() if source.reference not in existing_sources]
to_delete = []
- for source in sources:
+ for source in sources.values():
stmt = (
sa.select(DocumentsStore.source_id)
- .filter(DocumentsStore.source_id == source.filename)
+ .filter(DocumentsStore.source_id == source.reference)
.filter(DocumentsStore.modified < sa.cast(
- datetime.fromtimestamp(int(source.headers['modified'])),
+ datetime.fromtimestamp(int(source.modified)),
sa.DateTime,
))
)
@@ -239,14 +282,13 @@ def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]
session.rollback()
raise DbException('Error: checking sources in vectordb') from e
- still_existing_sources = [
- source
- for source in existing_sources
- if source not in to_delete
+ still_existing_source_ids = [
+ source_id
+ for source_id in existing_sources
+ if source_id not in to_delete
]
- # the pyright issue stems from source.filename, which has already been validated
- return list(still_existing_sources), to_embed # pyright: ignore[reportReturnType]
+ return list(still_existing_source_ids), to_embed
def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.Session | None = None):
session = session_ or self.session_maker()
@@ -325,7 +367,7 @@ def update_access(
)
match op:
- case UpdateAccessOp.allow:
+ case UpdateAccessOp.ALLOW:
for i in range(0, len(user_ids), PG_BATCH_SIZE):
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
stmt = (
@@ -342,7 +384,7 @@ def update_access(
session.execute(stmt)
session.commit()
- case UpdateAccessOp.deny:
+ case UpdateAccessOp.DENY:
for i in range(0, len(user_ids), PG_BATCH_SIZE):
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
stmt = (
@@ -435,15 +477,17 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
# entry from "AccessListStore" is deleted automatically due to the foreign key constraint
# batch the deletion to avoid hitting the query parameter limit
chunks_to_delete = []
+ deleted_source_ids = []
for i in range(0, len(source_ids), PG_BATCH_SIZE):
batched_ids = source_ids[i:i+PG_BATCH_SIZE]
stmt_doc = (
sa.delete(DocumentsStore)
.filter(DocumentsStore.source_id.in_(batched_ids))
- .returning(DocumentsStore.chunks)
+ .returning(DocumentsStore.chunks, DocumentsStore.source_id)
)
doc_result = session.execute(stmt_doc)
chunks_to_delete.extend(str(c) for res in doc_result for c in res.chunks)
+ deleted_source_ids.extend(str(res.source_id) for res in doc_result)
for i in range(0, len(chunks_to_delete), PG_BATCH_SIZE):
batched_chunks = chunks_to_delete[i:i+PG_BATCH_SIZE]
@@ -463,6 +507,14 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
if session_ is None:
session.close()
+ undeleted_source_ids = set(source_ids) - set(deleted_source_ids)
+ if len(undeleted_source_ids) > 0:
+ logger.info(
+ f'Source ids {undeleted_source_ids} were not deleted from documents store.'
+ ' This can be due to the source ids not existing in the documents store due to'
+ ' already being deleted or not having been added yet.'
+ )
+
def delete_provider(self, provider_key: str):
with self.session_maker() as session:
try:
@@ -506,7 +558,16 @@ def delete_user(self, user_id: str):
session.rollback()
raise DbException('Error: deleting user from access list') from e
- self._cleanup_if_orphaned(list(source_ids), session)
+ try:
+ self._cleanup_if_orphaned(list(source_ids), session)
+ except Exception as e:
+ session.rollback()
+ logger.error(
+ 'Error cleaning up orphaned source ids after deleting user, manual cleanup might be required',
+ exc_info=e,
+ extra={ 'source_ids': list(source_ids) },
+ )
+ raise DbException('Error: cleaning up orphaned source ids after deleting user') from e
def count_documents_by_provider(self) -> dict[str, int]:
try:
diff --git a/context_chat_backend/vectordb/service.py b/context_chat_backend/vectordb/service.py
index 620a0b39..06a8e19e 100644
--- a/context_chat_backend/vectordb/service.py
+++ b/context_chat_backend/vectordb/service.py
@@ -6,27 +6,42 @@
from ..dyn_loader import VectorDBLoader
from .base import BaseVectorDB
-from .types import DbException, UpdateAccessOp
+from .types import UpdateAccessOp
logger = logging.getLogger('ccb.vectordb')
-# todo: return source ids that were successfully deleted
+
def delete_by_source(vectordb_loader: VectorDBLoader, source_ids: list[str]):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('deleting sources by id', extra={ 'source_ids': source_ids })
- try:
- db.delete_source_ids(source_ids)
- except Exception as e:
- raise DbException('Error: Vectordb delete_source_ids error') from e
+ db.delete_source_ids(source_ids)
def delete_by_provider(vectordb_loader: VectorDBLoader, provider_key: str):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug(f'deleting sources by provider: {provider_key}')
db.delete_provider(provider_key)
def delete_user(vectordb_loader: VectorDBLoader, user_id: str):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug(f'deleting user from db: {user_id}')
db.delete_user(user_id)
@@ -38,6 +53,13 @@ def update_access(
user_ids: list[str],
source_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('updating access', extra={ 'op': op, 'user_ids': user_ids, 'source_id': source_id })
db.update_access(op, user_ids, source_id)
@@ -49,6 +71,13 @@ def update_access_provider(
user_ids: list[str],
provider_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('updating access by provider', extra={ 'op': op, 'user_ids': user_ids, 'provider_id': provider_id })
db.update_access_provider(op, user_ids, provider_id)
@@ -59,11 +88,24 @@ def decl_update_access(
user_ids: list[str],
source_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('decl update access', extra={ 'user_ids': user_ids, 'source_id': source_id })
db.decl_update_access(user_ids, source_id)
def count_documents_by_provider(vectordb_loader: VectorDBLoader):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('counting documents by provider')
return db.count_documents_by_provider()
diff --git a/context_chat_backend/vectordb/types.py b/context_chat_backend/vectordb/types.py
index df5c6dd7..30811797 100644
--- a/context_chat_backend/vectordb/types.py
+++ b/context_chat_backend/vectordb/types.py
@@ -14,5 +14,5 @@ class SafeDbException(Exception):
class UpdateAccessOp(Enum):
- allow = 'allow'
- deny = 'deny'
+ ALLOW = 'allow'
+ DENY = 'deny'
diff --git a/main.py b/main.py
index c4ffa1fd..076b7db0 100755
--- a/main.py
+++ b/main.py
@@ -3,9 +3,12 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+
import logging
-from os import getenv
+import multiprocessing as mp
+from os import cpu_count, getenv
+import psutil
import uvicorn
from nc_py_api.ex_app import run_app
@@ -48,6 +51,18 @@ def _setup_log_levels(debug: bool):
app_config: TConfig = app.extra['CONFIG']
_setup_log_levels(app_config.debug)
+ # do forks from a clean process that doesn't have any threads or locks
+ mp.set_start_method('forkserver')
+ mp.set_forkserver_preload([
+ 'context_chat_backend.chain.ingest.injest',
+ 'context_chat_backend.vectordb.pgvector',
+ 'langchain',
+ 'logging',
+ 'numpy',
+ 'sqlalchemy',
+ ])
+
+ print(f'CPU count: {cpu_count()}, Memory: {psutil.virtual_memory()}')
print('App config:\n' + redact_config(app_config).model_dump_json(indent=2), flush=True)
uv_log_config = uvicorn.config.LOGGING_CONFIG # pyright: ignore[reportAttributeAccessIssue]