From 2da5fc958c5d06e2973c6a6af3658d7a89dce2bf Mon Sep 17 00:00:00 2001 From: eugwne Date: Wed, 11 Feb 2026 22:49:59 +0100 Subject: [PATCH] v 0.3 --- ...ke-single-platform.yml => cmake-linux.yml} | 18 +- .github/workflows/cmake-macos.yml | 56 + .github/workflows/cmake-windows.yml | 94 + .github/workflows/docker-alpine.yml | 27 + .github/workflows/docker-ubuntu.yml | 27 + CMakeLists.txt | 367 +++- DockerAlpine | 25 +- DockerUbuntu | 10 +- README.md | 170 +- expected/aggregates/output_003.txt | 57 + expected/aggregates/output_004.txt | 15 + expected/aggregates/output_020.txt | 8 + expected/aggregates/output_021.txt | 8 + expected/aggregates/output_033.txt | 11 + expected/{ => basics}/output_001.txt | 13 +- expected/{ => basics}/output_002.txt | 1 - expected/basics/output_005.txt | 26 + expected/basics/output_026.txt | 4 + expected/basics/output_028.txt | 6 + expected/basics/output_029.txt | 11 + expected/errors/output_022.txt | 15 + expected/errors/output_027.txt | 2 + expected/errors/output_034.txt | 5 + expected/errors/output_036.txt | 14 + expected/lifetime/output_024.txt | 1 + expected/lifetime/output_025.txt | 1 + expected/limits/output_015.txt | 1 + expected/limits/output_016.txt | 1 + expected/limits/output_017.txt | 1 + expected/limits/output_018.txt | 1 + expected/{ => nesting}/output_008.txt | 1 - expected/nesting/output_009.txt | 14 + expected/output_003.txt | 5 - expected/output_004.txt | 13 - expected/output_005.txt | 13 - .../{ => performance}/output_timed001.txt | 0 expected/{ => types}/output_007.txt | 1 - expected/types/output_012.txt | 2 + expected/types/output_013.txt | 13 + expected/types/output_040.txt | 5 + expected/{ => vtables}/output_006.txt | 1 - expected/vtables/output_010.txt | 16 + expected/vtables/output_011.txt | 2 + expected/vtables/output_014.txt | 2 + expected/vtables/output_019.txt | 7 + expected/vtables/output_030.txt | 19 + expected/vtables/output_031.txt | 5 + expected/vtables/output_032.txt | 5 + expected/vtables/output_035.txt | 6 + expected/vtables/output_037.txt | 7 + expected/vtables/output_038.txt | 5 + expected/vtables/output_039.txt | 2 + sql/aggregates/input_003.sql | 137 ++ sql/{ => aggregates}/input_004.sql | 43 +- sql/aggregates/input_020.sql | 93 + sql/aggregates/input_021.sql | 78 + sql/aggregates/input_033.sql | 40 + sql/{ => basics}/input_001.sql | 43 +- sql/{ => basics}/input_002.sql | 6 +- sql/{ => basics}/input_005.sql | 24 +- sql/basics/input_026.sql | 8 + sql/basics/input_028.sql | 11 + sql/basics/input_029.sql | 65 + sql/errors/input_022.sql | 69 + sql/errors/input_027.sql | 12 + sql/errors/input_034.sql | 48 + sql/errors/input_036.sql | 85 + sql/input_003.sql | 36 - sql/lifetime/input_024.sql | 29 + sql/lifetime/input_025.sql | 29 + sql/limits/input_015.sql | 16 + sql/limits/input_016.sql | 23 + sql/limits/input_017.sql | 23 + sql/limits/input_018.sql | 26 + sql/{ => nesting}/input_008.sql | 10 +- sql/nesting/input_009.sql | 109 + sql/{ => performance}/input_timed001.sql | 0 sql/{ => types}/input_007.sql | 9 +- sql/types/input_012.sql | 12 + sql/types/input_013.sql | 40 + sql/types/input_040.sql | 16 + sql/{ => vtables}/input_006.sql | 6 +- sql/vtables/input_010.sql | 5 + sql/vtables/input_011.sql | 9 + sql/vtables/input_014.sql | 36 + sql/vtables/input_019.sql | 72 + sql/vtables/input_023.sql | 21 + sql/vtables/input_030.sql | 39 + sql/vtables/input_031.sql | 46 + sql/vtables/input_032.sql | 37 + sql/vtables/input_035.sql | 94 + sql/vtables/input_037.sql | 119 ++ sql/vtables/input_038.sql | 48 + sql/vtables/input_039.sql | 43 + src/plugin.c | 1163 ++++++++++- src/sqlite_capi.lua | 65 +- src/sqlite_lj.lua | 1756 +++++++++++++---- src/sqlite_shared.lua | 203 ++ test.py | 195 ++ test.sh | 43 - tests/multiconnection_isolation.py | 93 + tests/multithread_isolated_connections.cpp | 457 +++++ tests/multithread_vtable_queries.cpp | 278 +++ tests/thread_guard_same_connection.cpp | 116 ++ 104 files changed, 6514 insertions(+), 710 deletions(-) rename .github/workflows/{cmake-single-platform.yml => cmake-linux.yml} (83%) create mode 100644 .github/workflows/cmake-macos.yml create mode 100644 .github/workflows/cmake-windows.yml create mode 100644 .github/workflows/docker-alpine.yml create mode 100644 .github/workflows/docker-ubuntu.yml create mode 100644 expected/aggregates/output_003.txt create mode 100644 expected/aggregates/output_004.txt create mode 100644 expected/aggregates/output_020.txt create mode 100644 expected/aggregates/output_021.txt create mode 100644 expected/aggregates/output_033.txt rename expected/{ => basics}/output_001.txt (67%) rename expected/{ => basics}/output_002.txt (67%) create mode 100644 expected/basics/output_005.txt create mode 100644 expected/basics/output_026.txt create mode 100644 expected/basics/output_028.txt create mode 100644 expected/basics/output_029.txt create mode 100644 expected/errors/output_022.txt create mode 100644 expected/errors/output_027.txt create mode 100644 expected/errors/output_034.txt create mode 100644 expected/errors/output_036.txt create mode 100644 expected/lifetime/output_024.txt create mode 100644 expected/lifetime/output_025.txt create mode 100644 expected/limits/output_015.txt create mode 100644 expected/limits/output_016.txt create mode 100644 expected/limits/output_017.txt create mode 100644 expected/limits/output_018.txt rename expected/{ => nesting}/output_008.txt (96%) create mode 100644 expected/nesting/output_009.txt delete mode 100644 expected/output_003.txt delete mode 100644 expected/output_004.txt delete mode 100644 expected/output_005.txt rename expected/{ => performance}/output_timed001.txt (100%) rename expected/{ => types}/output_007.txt (98%) create mode 100644 expected/types/output_012.txt create mode 100644 expected/types/output_013.txt create mode 100644 expected/types/output_040.txt rename expected/{ => vtables}/output_006.txt (99%) create mode 100644 expected/vtables/output_010.txt create mode 100644 expected/vtables/output_011.txt create mode 100644 expected/vtables/output_014.txt create mode 100644 expected/vtables/output_019.txt create mode 100644 expected/vtables/output_030.txt create mode 100644 expected/vtables/output_031.txt create mode 100644 expected/vtables/output_032.txt create mode 100644 expected/vtables/output_035.txt create mode 100644 expected/vtables/output_037.txt create mode 100644 expected/vtables/output_038.txt create mode 100644 expected/vtables/output_039.txt create mode 100644 sql/aggregates/input_003.sql rename sql/{ => aggregates}/input_004.sql (61%) create mode 100644 sql/aggregates/input_020.sql create mode 100644 sql/aggregates/input_021.sql create mode 100644 sql/aggregates/input_033.sql rename sql/{ => basics}/input_001.sql (56%) rename sql/{ => basics}/input_002.sql (78%) rename sql/{ => basics}/input_005.sql (65%) create mode 100644 sql/basics/input_026.sql create mode 100644 sql/basics/input_028.sql create mode 100644 sql/basics/input_029.sql create mode 100644 sql/errors/input_022.sql create mode 100644 sql/errors/input_027.sql create mode 100644 sql/errors/input_034.sql create mode 100644 sql/errors/input_036.sql delete mode 100644 sql/input_003.sql create mode 100644 sql/lifetime/input_024.sql create mode 100644 sql/lifetime/input_025.sql create mode 100644 sql/limits/input_015.sql create mode 100644 sql/limits/input_016.sql create mode 100644 sql/limits/input_017.sql create mode 100644 sql/limits/input_018.sql rename sql/{ => nesting}/input_008.sql (85%) create mode 100644 sql/nesting/input_009.sql rename sql/{ => performance}/input_timed001.sql (100%) rename sql/{ => types}/input_007.sql (76%) create mode 100644 sql/types/input_012.sql create mode 100644 sql/types/input_013.sql create mode 100644 sql/types/input_040.sql rename sql/{ => vtables}/input_006.sql (89%) create mode 100644 sql/vtables/input_010.sql create mode 100644 sql/vtables/input_011.sql create mode 100644 sql/vtables/input_014.sql create mode 100644 sql/vtables/input_019.sql create mode 100644 sql/vtables/input_023.sql create mode 100644 sql/vtables/input_030.sql create mode 100644 sql/vtables/input_031.sql create mode 100644 sql/vtables/input_032.sql create mode 100644 sql/vtables/input_035.sql create mode 100644 sql/vtables/input_037.sql create mode 100644 sql/vtables/input_038.sql create mode 100644 sql/vtables/input_039.sql create mode 100644 src/sqlite_shared.lua create mode 100644 test.py delete mode 100755 test.sh create mode 100644 tests/multiconnection_isolation.py create mode 100644 tests/multithread_isolated_connections.cpp create mode 100644 tests/multithread_vtable_queries.cpp create mode 100644 tests/thread_guard_same_connection.cpp diff --git a/.github/workflows/cmake-single-platform.yml b/.github/workflows/cmake-linux.yml similarity index 83% rename from .github/workflows/cmake-single-platform.yml rename to .github/workflows/cmake-linux.yml index f38d022..2218422 100644 --- a/.github/workflows/cmake-single-platform.yml +++ b/.github/workflows/cmake-linux.yml @@ -1,6 +1,4 @@ -# This starter workflow is for a CMake project running on a single platform. There is a different starter workflow if you need cross-platform coverage. -# See: https://github.com/actions/starter-workflows/blob/main/ci/cmake-multi-platform.yml -name: CMake on a single platform +name: CMake Linux on: push: @@ -30,9 +28,10 @@ jobs: - name: Install libs run: | sudo apt-get update -qq - sudo apt-get install sqlite3 -y + sudo apt-get install sqlite3 libsqlite3-dev -y sudo apt-get install cmake -y sudo apt-get install gcc -y + sudo apt-get install python3 -y sudo apt-get install luajit -y sudo apt-get install libluajit-5.1-dev -y @@ -40,7 +39,7 @@ jobs: - name: Configure CMake # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type - run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} + run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DSQLITE_PLUGIN_LJ_PORTABLE_RPATH=ON - name: Build # Build your program with the given configuration @@ -52,7 +51,8 @@ jobs: # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail run: ctest -C ${{env.BUILD_TYPE}} -V - - uses: actions/upload-artifact@v4.3.3 + - if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 with: name: sqlite_plugin_lj_src path: | @@ -64,11 +64,11 @@ jobs: ${{github.workspace}}/DockerAlpine ${{github.workspace}}/DockerUbuntu ${{github.workspace}}/README.md - ${{github.workspace}}/test.sh + ${{github.workspace}}/test.py - - uses: actions/upload-artifact@v4.3.3 + - if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 with: name: sqlite_plugin_lj_lib path: ${{github.workspace}}/build/libsqlite_plugin_lj.so - diff --git a/.github/workflows/cmake-macos.yml b/.github/workflows/cmake-macos.yml new file mode 100644 index 0000000..8463eb2 --- /dev/null +++ b/.github/workflows/cmake-macos.yml @@ -0,0 +1,56 @@ +name: CMake on macOS + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + BUILD_TYPE: Release + +jobs: + build-macos: + runs-on: macos-latest + + steps: + - name: Ensure Node.js for JS-based actions + run: | + if ! command -v node >/dev/null 2>&1; then + brew update + brew install node + fi + node --version + + - uses: actions/checkout@v4 + + - name: Install dependencies + run: | + brew update + brew install cmake sqlite luajit python + + - name: Configure environment + run: | + SQLITE_PREFIX="$(brew --prefix sqlite)" + LUAJIT_PREFIX="$(brew --prefix luajit)" + echo "SQLITE3_INCLUDE_DIR=${SQLITE_PREFIX}/include" >> "$GITHUB_ENV" + echo "SQLITE3_BIN=${SQLITE_PREFIX}/bin/sqlite3" >> "$GITHUB_ENV" + echo "LUAJIT_INCLUDE_DIR=${LUAJIT_PREFIX}/include/luajit-2.1" >> "$GITHUB_ENV" + echo "LUAJIT_LIBRARY=${LUAJIT_PREFIX}/lib/libluajit-5.1.dylib" >> "$GITHUB_ENV" + echo "LUAJIT_BIN=${LUAJIT_PREFIX}/bin/luajit" >> "$GITHUB_ENV" + + - name: Configure CMake + run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DSQLITE_PLUGIN_LJ_PORTABLE_RPATH=ON + + - name: Build + run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}} + + - name: Test + working-directory: ${{github.workspace}}/build + run: ctest -C ${{env.BUILD_TYPE}} -V + + - if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 + with: + name: sqlite_plugin_lj_lib_macos + path: ${{github.workspace}}/build/libsqlite_plugin_lj.dylib diff --git a/.github/workflows/cmake-windows.yml b/.github/workflows/cmake-windows.yml new file mode 100644 index 0000000..1314761 --- /dev/null +++ b/.github/workflows/cmake-windows.yml @@ -0,0 +1,94 @@ +name: CMake on Windows + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + BUILD_TYPE: Release + VCPKG_ROOT: C:\vcpkg + +jobs: + build-windows: + runs-on: windows-latest + + steps: + - uses: actions/checkout@v4 + + - name: Ensure Python is available + shell: cmd + run: | + where python >nul 2>nul && (python --version && exit /b 0) + where py >nul 2>nul && (py -3 --version && exit /b 0) + echo Python 3 is required on the runner but was not found in PATH. + exit /b 1 + + - name: Install SQLite headers with vcpkg + shell: cmd + run: | + if not exist "%VCPKG_ROOT%" git clone https://github.com/microsoft/vcpkg "%VCPKG_ROOT%" + call "%VCPKG_ROOT%\bootstrap-vcpkg.bat" + "%VCPKG_ROOT%\vcpkg.exe" install sqlite3:x64-windows + + - name: Install sqlite3 CLI + shell: cmd + run: | + set "SQLITE_URL=https://sqlite.org/2026/sqlite-tools-win-x64-3510200.zip" + set "SQLITE_ZIP=%RUNNER_TEMP%\sqlite-tools.zip" + set "SQLITE_DIR=%RUNNER_TEMP%\sqlite-tools" + if not exist "%SQLITE_DIR%" mkdir "%SQLITE_DIR%" + curl -L "%SQLITE_URL%" -o "%SQLITE_ZIP%" + tar -xf "%SQLITE_ZIP%" -C "%SQLITE_DIR%" + set "SQLITE_EXE_DIR=" + for /r "%SQLITE_DIR%" %%F in (sqlite3.exe) do ( + set "SQLITE_EXE_DIR=%%~dpF" + goto :found_sqlite + ) + echo sqlite3.exe not found after extracting sqlite-tools archive + exit /b 1 + :found_sqlite + echo %SQLITE_EXE_DIR%>>"%GITHUB_PATH%" + "%SQLITE_EXE_DIR%sqlite3.exe" --version + + - name: Configure CMake (MSVC) + shell: cmd + run: > + cmake -S . -B build -G "Visual Studio 17 2022" -A x64 + -DCMAKE_TOOLCHAIN_FILE=%VCPKG_ROOT%\scripts\buildsystems\vcpkg.cmake + -DVCPKG_TARGET_TRIPLET=x64-windows + + - name: Build + shell: cmd + run: cmake --build build --config %BUILD_TYPE% + + - name: Test + shell: cmd + run: ctest --test-dir build -C %BUILD_TYPE% -V + + - name: Upload source artifact + if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 + with: + name: sqlite_plugin_lj_src_windows + path: | + ${{ github.workspace }}/expected + ${{ github.workspace }}/sql + ${{ github.workspace }}/src + ${{ github.workspace }}/LICENSE + ${{ github.workspace }}/CMakeLists.txt + ${{ github.workspace }}/DockerAlpine + ${{ github.workspace }}/DockerUbuntu + ${{ github.workspace }}/README.md + ${{ github.workspace }}/test.py + + - name: Upload windows build artifacts + if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 + with: + name: sqlite_plugin_lj_windows + path: | + ${{ github.workspace }}/build/Release/*sqlite_plugin_lj*.dll + ${{ github.workspace }}/build/Release/*sqlite_plugin_lj*.lib + ${{ github.workspace }}/build/libsqlite_plugin_lj.dll diff --git a/.github/workflows/docker-alpine.yml b/.github/workflows/docker-alpine.yml new file mode 100644 index 0000000..14a5d4f --- /dev/null +++ b/.github/workflows/docker-alpine.yml @@ -0,0 +1,27 @@ +name: Docker Alpine Build + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build-docker-alpine: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Build DockerAlpine binaries target + run: | + rm -rf lib + mkdir -p lib + docker build --output=lib --target=binaries -f DockerAlpine . + test -f lib/libsqlite_plugin_lj.so + + - if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 + with: + name: sqlite_plugin_lj_lib_alpine_docker + path: lib/libsqlite_plugin_lj.so diff --git a/.github/workflows/docker-ubuntu.yml b/.github/workflows/docker-ubuntu.yml new file mode 100644 index 0000000..55a904e --- /dev/null +++ b/.github/workflows/docker-ubuntu.yml @@ -0,0 +1,27 @@ +name: Docker Ubuntu Build + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build-docker-ubuntu: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Build DockerUbuntu binaries target + run: | + rm -rf lib + mkdir -p lib + docker build --output=lib --target=binaries -f DockerUbuntu . + test -f lib/libsqlite_plugin_lj.so + + - if: ${{ github.ref == 'refs/heads/main' }} + uses: actions/upload-artifact@v4.3.3 + with: + name: sqlite_plugin_lj_lib_ubuntu_docker + path: lib/libsqlite_plugin_lj.so diff --git a/CMakeLists.txt b/CMakeLists.txt index f0f1ff3..c2540cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,32 +1,101 @@ -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.14) -project(sqlite_plugin_lj LANGUAGES C) +project(sqlite_plugin_lj LANGUAGES C CXX) enable_testing() +find_package(Python3 COMPONENTS Interpreter REQUIRED) +find_package(SQLite3 REQUIRED) +find_package(Threads REQUIRED) add_library(${PROJECT_NAME} SHARED src/plugin.c) +option(SQLITE_PLUGIN_LJ_ENABLE_ASAN "Enable AddressSanitizer for sqlite_plugin_lj target only" OFF) +option(SQLITE_PLUGIN_LJ_PORTABLE_RPATH "Use origin-based runtime search path for distributable binaries" OFF) + +# Enforce strict warning policy for this project target. +if(MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /W4 /WX) +else() + target_compile_options(${PROJECT_NAME} PRIVATE -Wall -Wextra -Wpedantic -Werror) +endif() + +if(SQLITE_PLUGIN_LJ_ENABLE_ASAN) + message(STATUS "AddressSanitizer enabled for target: ${PROJECT_NAME}") + if(MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /fsanitize=address) + target_link_options(${PROJECT_NAME} PRIVATE /fsanitize=address) + else() + target_compile_options(${PROJECT_NAME} PRIVATE -fsanitize=address -fno-omit-frame-pointer) + target_link_options(${PROJECT_NAME} PRIVATE -fsanitize=address) + endif() +endif() + +# Optional SQLite headers (sqlite3ext.h) detection. +# Priority: SQLITE3_INCLUDE_DIR env var, then vcpkg (VCPKG_ROOT + triplet). +set(SQLITE3_INCLUDE_DIR "") +if(DEFINED ENV{SQLITE3_INCLUDE_DIR}) + set(SQLITE3_INCLUDE_DIR "$ENV{SQLITE3_INCLUDE_DIR}") +elseif(DEFINED ENV{VCPKG_ROOT}) + set(_vcpkg_root "$ENV{VCPKG_ROOT}") + if(DEFINED ENV{VCPKG_TARGET_TRIPLET}) + set(_vcpkg_triplet "$ENV{VCPKG_TARGET_TRIPLET}") + elseif(DEFINED ENV{VCPKG_DEFAULT_TRIPLET}) + set(_vcpkg_triplet "$ENV{VCPKG_DEFAULT_TRIPLET}") + else() + set(_vcpkg_triplet "x64-windows") + endif() + set(_vcpkg_inc "${_vcpkg_root}/installed/${_vcpkg_triplet}/include") + if(EXISTS "${_vcpkg_inc}/sqlite3ext.h") + set(SQLITE3_INCLUDE_DIR "${_vcpkg_inc}") + endif() +endif() +if(SQLITE3_INCLUDE_DIR) + message(STATUS "Using SQLite headers from: ${SQLITE3_INCLUDE_DIR}") + target_include_directories(${PROJECT_NAME} PRIVATE "${SQLITE3_INCLUDE_DIR}") +else() + message(WARNING "SQLite headers not found. Set SQLITE3_INCLUDE_DIR or VCPKG_ROOT so sqlite3ext.h is available.") +endif() + set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +if(SQLITE_PLUGIN_LJ_PORTABLE_RPATH AND NOT WIN32) + if(APPLE) + set(_portable_rpath "@loader_path") + else() + set(_portable_rpath "\$ORIGIN") + endif() + set_target_properties(${PROJECT_NAME} PROPERTIES + BUILD_WITH_INSTALL_RPATH TRUE + BUILD_RPATH "${_portable_rpath}" + INSTALL_RPATH "${_portable_rpath}" + INSTALL_RPATH_USE_LINK_PATH FALSE + ) +endif() set(LUA_SCRIPT_FILES src/sqlite_lj.lua src/sqlite_capi.lua + src/sqlite_shared.lua ) set(CMAKE_VERBOSE_MAKEFILE ON) -if(DEFINED ENV{LUAJIT_INCLUDE_DIR} AND DEFINED ENV{LUAJIT_LIBRARY}) - set(LUAJIT_INCLUDE_DIR $ENV{LUAJIT_INCLUDE_DIR}) - set(LUAJIT_LIBRARY $ENV{LUAJIT_LIBRARY}) +set(USE_SYSTEM_LUAJIT OFF) +if(DEFINED ENV{LUAJIT_INCLUDE_DIR} AND DEFINED ENV{LUAJIT_LIBRARY} AND DEFINED ENV{LUAJIT_BIN}) + set(LUAJIT_INCLUDE_DIR "$ENV{LUAJIT_INCLUDE_DIR}") + set(LUAJIT_LIBRARY "$ENV{LUAJIT_LIBRARY}") + set(LUAJIT_BIN "$ENV{LUAJIT_BIN}") - set(LUAJIT_BIN $ENV{LUAJIT_BIN}) + if(EXISTS "${LUAJIT_INCLUDE_DIR}" AND EXISTS "${LUAJIT_LIBRARY}" AND EXISTS "${LUAJIT_BIN}") + set(USE_SYSTEM_LUAJIT ON) + else() + message(WARNING "System LuaJIT paths are set but not found. Falling back to CMake-managed LuaJIT build.") + endif() +endif() - target_include_directories(${PROJECT_NAME} PRIVATE ${LUAJIT_INCLUDE_DIR}) - target_link_libraries(${PROJECT_NAME} PRIVATE ${LUAJIT_LIBRARY}) +if(USE_SYSTEM_LUAJIT) + target_include_directories(${PROJECT_NAME} PRIVATE "${LUAJIT_INCLUDE_DIR}") + target_link_libraries(${PROJECT_NAME} PRIVATE "${LUAJIT_LIBRARY}") else() - SET(CMAKE_SKIP_BUILD_RPATH FALSE) - SET(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) - SET(CMAKE_INSTALL_RPATH "\${ORIGIN}") # use github version include(FetchContent) @@ -38,34 +107,55 @@ else() FetchContent_MakeAvailable(luajit_sources) include(ExternalProject) - ExternalProject_Add(ext_luajit - SOURCE_DIR ${luajit_sources_SOURCE_DIR} - DOWNLOAD_COMMAND "" - CONFIGURE_COMMAND "" - BUILD_COMMAND make - BUILD_IN_SOURCE true - INSTALL_COMMAND make install PREFIX=${CMAKE_CURRENT_BINARY_DIR}/luajit - && cp -a "${CMAKE_CURRENT_BINARY_DIR}/luajit/lib/." "${CMAKE_CURRENT_BINARY_DIR}" - ) - - target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/luajit/include/luajit-2.1") + set(LUAJIT_SRC_DIR "${luajit_sources_SOURCE_DIR}/src") + if(MSVC) + ExternalProject_Add(ext_luajit + SOURCE_DIR ${luajit_sources_SOURCE_DIR} + DOWNLOAD_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND ${CMAKE_COMMAND} -E chdir ${LUAJIT_SRC_DIR} cmd /c msvcbuild.bat + BUILD_IN_SOURCE true + INSTALL_COMMAND "" + ) - add_dependencies(${PROJECT_NAME} ext_luajit) - target_link_libraries(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/libluajit-5.1.so") + target_include_directories(${PROJECT_NAME} PRIVATE "${LUAJIT_SRC_DIR}") + add_dependencies(${PROJECT_NAME} ext_luajit) + target_link_libraries(${PROJECT_NAME} PRIVATE "${LUAJIT_SRC_DIR}/lua51.lib") + set(LUAJIT_BIN "${LUAJIT_SRC_DIR}/luajit.exe") + else() + ExternalProject_Add(ext_luajit + SOURCE_DIR ${luajit_sources_SOURCE_DIR} + DOWNLOAD_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND make + BUILD_IN_SOURCE true + INSTALL_COMMAND make install PREFIX=${CMAKE_CURRENT_BINARY_DIR}/luajit + && cp -a "${CMAKE_CURRENT_BINARY_DIR}/luajit/lib/." "${CMAKE_CURRENT_BINARY_DIR}" + ) - set(LUAJIT_BIN ${CMAKE_CURRENT_BINARY_DIR}/luajit/bin/luajit) + target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/luajit/include/luajit-2.1") + add_dependencies(${PROJECT_NAME} ext_luajit) + target_link_libraries(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/libluajit-5.1${CMAKE_SHARED_LIBRARY_SUFFIX}") + set(LUAJIT_BIN "${CMAKE_CURRENT_BINARY_DIR}/luajit/bin/luajit") + endif() endif() -set(LUA_PATH "/usr/share/luajit-2.1.0-beta3/?.lua\;${CMAKE_CURRENT_BINARY_DIR}/luajit/share/luajit-2.1/?.lua\;\;") +if(USE_SYSTEM_LUAJIT) + set(LUA_PATH "${CMAKE_CURRENT_BINARY_DIR}/luajit/share/luajit-2.1/?.lua\;${CMAKE_CURRENT_BINARY_DIR}/luajit/share/luajit-2.1/?/?.lua\;\;") +elseif(MSVC) + set(LUA_PATH "${LUAJIT_SRC_DIR}/?.lua\;${LUAJIT_SRC_DIR}/?/?.lua\;\;") +else() + set(LUA_PATH "${CMAKE_CURRENT_BINARY_DIR}/luajit/share/luajit-2.1/?.lua\;${CMAKE_CURRENT_BINARY_DIR}/luajit/share/luajit-2.1/?/?.lua\;\;") +endif() set(BYTECODE_OBJECTS "") foreach(LUA_FILE ${LUA_SCRIPT_FILES}) get_filename_component(FILENAME_WE ${LUA_FILE} NAME_WE) - set(OBJECT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_WE}.obj) + set(OBJECT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_WE}.h) - if(DEFINED ENV{LUAJIT_LIBRARY}) + if(USE_SYSTEM_LUAJIT) add_custom_command( OUTPUT ${OBJECT_FILE} COMMAND ${CMAKE_COMMAND} -E env LUA_PATH="${LUA_PATH}" @@ -89,32 +179,219 @@ endforeach() add_custom_target(generate_bytecode_objects ALL DEPENDS ${BYTECODE_OBJECTS}) add_dependencies(${PROJECT_NAME} generate_bytecode_objects) -target_link_libraries(${PROJECT_NAME} PRIVATE ${BYTECODE_OBJECTS}) -set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "-Wl,--strip-all") - +target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +if(UNIX AND NOT APPLE) + target_link_options(${PROJECT_NAME} PRIVATE -Wl,--strip-all) +endif() +if(WIN32) + add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "$" + "${CMAKE_BINARY_DIR}/libsqlite_plugin_lj.dll" + ) + if(NOT USE_SYSTEM_LUAJIT) + add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${LUAJIT_SRC_DIR}/lua51.dll" + "${CMAKE_BINARY_DIR}/lua51.dll" + ) + endif() +endif() set(TEST_NAMES - 001 - 002 - 003 - 004 - 005 - 006 - 007 - 008 + 011 + 010 + 009 + 001 + 002 + 003 + 004 + 005 + 006 + 007 + 008 + 012 + 013 + 014 + 015 + 016 + 017 + 018 + 019 + 020 + 021 + 022 + 027 + 024 + 025 + 026 + 028 + 029 + 030 + 031 + 032 + 033 + 034 + 035 + 036 + 037 + 038 + 039 + 040 # timed001 ) +# set(TEST_NAMES 009) + +function(resolve_fixture out_var root_dir filename kind test_name) + file(GLOB_RECURSE MATCHES "${root_dir}/${filename}") + list(LENGTH MATCHES MATCH_COUNT) + if(MATCH_COUNT EQUAL 0) + message(FATAL_ERROR "Missing ${kind} fixture for test ${test_name}: ${filename}") + endif() + if(MATCH_COUNT GREATER 1) + message(FATAL_ERROR "Ambiguous ${kind} fixture for test ${test_name}: ${filename} -> ${MATCHES}") + endif() + list(GET MATCHES 0 MATCH_PATH) + set(${out_var} "${MATCH_PATH}" PARENT_SCOPE) +endfunction() + # Copy each file to the build directory foreach(TNAME ${TEST_NAMES}) - configure_file(sql/input_${TNAME}.sql ${CMAKE_BINARY_DIR}/sql/input_${TNAME}.sql COPYONLY) - configure_file(expected/output_${TNAME}.txt ${CMAKE_BINARY_DIR}/expected/output_${TNAME}.txt COPYONLY) + resolve_fixture(INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/sql" "input_${TNAME}.sql" "SQL" "${TNAME}") + resolve_fixture(EXPECTED_FILE "${CMAKE_CURRENT_SOURCE_DIR}/expected" "output_${TNAME}.txt" "expected" "${TNAME}") + + configure_file(${INPUT_FILE} ${CMAKE_BINARY_DIR}/sql/input_${TNAME}.sql COPYONLY) + configure_file(${EXPECTED_FILE} ${CMAKE_BINARY_DIR}/expected/output_${TNAME}.txt COPYONLY) endforeach() +foreach(TNAME ${TEST_NAMES}) + add_test( + NAME sqlite_plugin_test_${TNAME} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test.py --tests "${TNAME}" + ) + set_tests_properties(sqlite_plugin_test_${TNAME} PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +endforeach() + +# Vtable churn regression: ensure repeated create/drop remains stable. +add_test( + NAME sqlite_plugin_vtable_churn_repro + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test.py --sql-file ${CMAKE_CURRENT_SOURCE_DIR}/sql/vtables/input_023.sql +) +set_tests_properties(sqlite_plugin_vtable_churn_repro PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + +# Regression test: plugin must work independently on multiple SQLite connections +# in the same process. add_test( - NAME sqlite_plugin_tests - COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test.sh "${TEST_NAMES}" + NAME sqlite_plugin_multiconnection_isolation + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tests/multiconnection_isolation.py ) +set_tests_properties(sqlite_plugin_multiconnection_isolation PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} +) + +add_executable(sqlite_plugin_multithread_isolated_connections + ${CMAKE_CURRENT_SOURCE_DIR}/tests/multithread_isolated_connections.cpp +) +set_property(TARGET sqlite_plugin_multithread_isolated_connections PROPERTY CXX_STANDARD 23) +set_property(TARGET sqlite_plugin_multithread_isolated_connections PROPERTY CXX_STANDARD_REQUIRED ON) +set_property(TARGET sqlite_plugin_multithread_isolated_connections PROPERTY CXX_EXTENSIONS OFF) +set(SQLITE3_TEST_LINK_LIB "") +if(APPLE AND DEFINED ENV{SQLITE3_BIN}) + get_filename_component(_sqlite3_bin_dir "$ENV{SQLITE3_BIN}" DIRECTORY) + get_filename_component(_sqlite3_prefix "${_sqlite3_bin_dir}" DIRECTORY) + set(_sqlite3_brew_lib "${_sqlite3_prefix}/lib/libsqlite3.dylib") + if(EXISTS "${_sqlite3_brew_lib}") + set(SQLITE3_TEST_LINK_LIB "${_sqlite3_brew_lib}") + message(STATUS "Using Homebrew SQLite for multithread test: ${SQLITE3_TEST_LINK_LIB}") + endif() +endif() +if(SQLITE3_TEST_LINK_LIB) + target_link_libraries(sqlite_plugin_multithread_isolated_connections PRIVATE "${SQLITE3_TEST_LINK_LIB}" Threads::Threads) +else() + target_link_libraries(sqlite_plugin_multithread_isolated_connections PRIVATE SQLite::SQLite3 Threads::Threads) +endif() +if(MSVC) + target_compile_options(sqlite_plugin_multithread_isolated_connections PRIVATE /WX) +else() + target_compile_options(sqlite_plugin_multithread_isolated_connections PRIVATE -Werror) +endif() + +add_executable(sqlite_plugin_multithread_vtable_queries + ${CMAKE_CURRENT_SOURCE_DIR}/tests/multithread_vtable_queries.cpp +) +set_property(TARGET sqlite_plugin_multithread_vtable_queries PROPERTY CXX_STANDARD 23) +set_property(TARGET sqlite_plugin_multithread_vtable_queries PROPERTY CXX_STANDARD_REQUIRED ON) +set_property(TARGET sqlite_plugin_multithread_vtable_queries PROPERTY CXX_EXTENSIONS OFF) +if(SQLITE3_TEST_LINK_LIB) + target_link_libraries(sqlite_plugin_multithread_vtable_queries PRIVATE "${SQLITE3_TEST_LINK_LIB}" Threads::Threads) +else() + target_link_libraries(sqlite_plugin_multithread_vtable_queries PRIVATE SQLite::SQLite3 Threads::Threads) +endif() +if(MSVC) + target_compile_options(sqlite_plugin_multithread_vtable_queries PRIVATE /WX) +else() + target_compile_options(sqlite_plugin_multithread_vtable_queries PRIVATE -Werror) +endif() -# Set the working directory for the test -set_tests_properties(sqlite_plugin_tests PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +add_executable(sqlite_plugin_thread_guard_same_connection + ${CMAKE_CURRENT_SOURCE_DIR}/tests/thread_guard_same_connection.cpp +) +set_property(TARGET sqlite_plugin_thread_guard_same_connection PROPERTY CXX_STANDARD 23) +set_property(TARGET sqlite_plugin_thread_guard_same_connection PROPERTY CXX_STANDARD_REQUIRED ON) +set_property(TARGET sqlite_plugin_thread_guard_same_connection PROPERTY CXX_EXTENSIONS OFF) +if(SQLITE3_TEST_LINK_LIB) + target_link_libraries(sqlite_plugin_thread_guard_same_connection PRIVATE "${SQLITE3_TEST_LINK_LIB}" Threads::Threads) +else() + target_link_libraries(sqlite_plugin_thread_guard_same_connection PRIVATE SQLite::SQLite3 Threads::Threads) +endif() +if(MSVC) + target_compile_options(sqlite_plugin_thread_guard_same_connection PRIVATE /WX) +else() + target_compile_options(sqlite_plugin_thread_guard_same_connection PRIVATE -Werror) +endif() + +add_test( + NAME sqlite_plugin_multithread_isolated_connections + COMMAND sqlite_plugin_multithread_isolated_connections + $ + ${CMAKE_BINARY_DIR}/sql + ${CMAKE_BINARY_DIR}/expected +) +add_test( + NAME sqlite_plugin_multithread_vtable_queries + COMMAND sqlite_plugin_multithread_vtable_queries + $ +) +add_test( + NAME sqlite_plugin_thread_guard_same_connection + COMMAND sqlite_plugin_thread_guard_same_connection + $ +) +if(WIN32) + get_filename_component(_sqlite3_lib_dir "${SQLite3_LIBRARY}" DIRECTORY) + get_filename_component(_sqlite3_prefix "${_sqlite3_lib_dir}" DIRECTORY) + set(_sqlite3_bin_dir "${_sqlite3_prefix}/bin") + set_tests_properties(sqlite_plugin_multithread_isolated_connections PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ENVIRONMENT "PATH=$;${_sqlite3_bin_dir};$ENV{PATH}" + ) + set_tests_properties(sqlite_plugin_multithread_vtable_queries PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ENVIRONMENT "PATH=$;${_sqlite3_bin_dir};$ENV{PATH}" + ) + set_tests_properties(sqlite_plugin_thread_guard_same_connection PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ENVIRONMENT "PATH=$;${_sqlite3_bin_dir};$ENV{PATH}" + ) +else() + set_tests_properties(sqlite_plugin_multithread_isolated_connections PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ) + set_tests_properties(sqlite_plugin_multithread_vtable_queries PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ) + set_tests_properties(sqlite_plugin_thread_guard_same_connection PROPERTIES + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ) +endif() diff --git a/DockerAlpine b/DockerAlpine index 9f7e2d8..3f1a3a3 100644 --- a/DockerAlpine +++ b/DockerAlpine @@ -1,21 +1,24 @@ FROM alpine AS builder RUN apk add --update-cache \ -apk add bash \ -apk add luajit \ -apk add luajit-dev \ -apk add sqlite \ -apk add cmake \ -apk add build-base \ -rm -rf /var/cache/apk/* + bash \ + luajit \ + luajit-dev \ + sqlite \ + sqlite-dev \ + cmake \ + python3 \ + build-base \ + && rm -rf /var/cache/apk/* WORKDIR /app COPY . /app ENV LUAJIT_LIBRARY=/usr/lib/libluajit-5.1.so \ LUAJIT_INCLUDE_DIR=/usr/include/luajit-2.1 \ LUAJIT_BIN=/usr/bin/luajit -RUN mkdir build && cd build && USE_SYS_LJ=1 cmake -DCMAKE_BUILD_TYPE=Release .. && cmake --build . && cmake --install . && ctest -R sqlite_plugin_tests -V +RUN USE_SYS_LJ=1 cmake -S . -B /tmp/build -DCMAKE_BUILD_TYPE=Release -DSQLITE_PLUGIN_LJ_PORTABLE_RPATH=ON \ + && cmake --build /tmp/build \ + && cmake --install /tmp/build \ + && ctest --test-dir /tmp/build -V FROM scratch AS binaries -COPY --from=builder /app/build/libsqlite_plugin_lj.so /libsqlite_plugin_lj.so - - +COPY --from=builder /tmp/build/libsqlite_plugin_lj.so /libsqlite_plugin_lj.so diff --git a/DockerUbuntu b/DockerUbuntu index 8c51e65..ff5a7c9 100644 --- a/DockerUbuntu +++ b/DockerUbuntu @@ -1,8 +1,11 @@ FROM ubuntu AS builder RUN apt update \ && apt install sqlite3 -y \ +&& apt install libsqlite3-dev -y \ && apt install cmake -y \ && apt install gcc -y \ +&& apt install g++ -y \ +&& apt install python3 -y \ && apt install luajit -y \ && apt install libluajit-5.1-dev -y @@ -11,7 +14,10 @@ COPY . /app ENV LUAJIT_LIBRARY=/usr/lib/x86_64-linux-gnu/libluajit-5.1.so \ LUAJIT_INCLUDE_DIR=/usr/include/luajit-2.1 \ LUAJIT_BIN=/usr/bin/luajit -RUN mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && cmake --build . && cmake --install . && ctest -R sqlite_plugin_tests -V +RUN cmake -S . -B /tmp/build -DCMAKE_BUILD_TYPE=Release -DSQLITE_PLUGIN_LJ_PORTABLE_RPATH=ON \ + && cmake --build /tmp/build \ + && cmake --install /tmp/build \ + && ctest --test-dir /tmp/build -V FROM scratch AS binaries -COPY --from=builder /app/build/libsqlite_plugin_lj.so /libsqlite_plugin_lj.so +COPY --from=builder /tmp/build/libsqlite_plugin_lj.so /libsqlite_plugin_lj.so diff --git a/README.md b/README.md index f362385..1bc09e8 100644 --- a/README.md +++ b/README.md @@ -29,18 +29,39 @@ sqlite> select L('return arg[1] + arg[2] ', 1, 3); 4 ``` -Api functions stored in sqlite_lj module, make them available in lua global and add some functions to sqlite: +API functions are available in the `sqlite` module. + +Register all internal functions with default names: ```sql -sqlite> select L(' - _G.sqlite = require("sqlite_lj") - sqlite.create_function("make_fn", sqlite.make_fn, -1) - sqlite.create_function("make_int", sqlite.make_int, 2) - sqlite.create_function("make_chk", sqlite.make_chk, 3) - sqlite.create_function("make_stored_agg3", sqlite.make_function_agg_chk, -1) - sqlite.create_function("make_stored_aggc", sqlite.make_function_agg, -1) -'); +sqlite> select L('sqlite.register_internal_functions()'); ``` +Register one function (default name) or rename it: +```sql +sqlite> select L('sqlite.register_internal_function("make_chk")'); +sqlite> select L('sqlite.register_internal_function("make_function_agg_chk", "make_stored_agg3")'); +``` + +Internal API reference: + +| Function | Description | Arguments | Example | +| --- | --- | --- | --- | +| `make_fn` | Create SQL scalar function from Lua factory code. | `(name, code_text, argc)` | `select make_fn('inc', 'return function(a) return a + 1 end', 1);` | +| `make_chk` | Create SQL scalar function from Lua chunk using `arg`. | `(name, code_text, argc)` | `select make_chk('inc_c', 'return arg[1] + 1', 1);` | +| `make_int` | Create constant integer SQL function. | `(name, value)` | `select make_int('answer', 42);` | +| `make_str` | Create constant text SQL function. | `(name, value)` | `select make_str('hello_const', 'hello');` | +| `make_function_agg_chk` | Create aggregate from init/step/final chunks. | `(name, init, step, final, argc)` | `select make_function_agg_chk('sum_ac', 'n=0', 'n=n+arg[1]', 'return n', 1);` | +| `create_function_agg_coro_text` | Create aggregate from coroutine-based code. | `(name, code_text, argc)` | `select create_function_agg_coro_text('sum_a', 'return function() ... end', 1);` | +| `create_function_agg_chk` | Aggregate registration helper (chunk-based). | `(name, init, step, final, argc)` | `select create_function_agg_chk('sum_b', 'n=0', 'n=n+arg[1]', 'return n', 1);` | +| `create_function_agg_chk_window` | Window aggregate helper (with optional inverse). | `(name, init, step, final, argc)` or `(name, init, step, inverse, final, argc)` | `select create_function_agg_chk_window('sum_w', 'n=0', 'n=n+arg[1]', 'return n', 1);` | +| `make_vtable` | Create virtual table from Lua table descriptor. | `(table_name, spec)` | `select L('sqlite.make_vtable(\"table_a\", {columns={\"a\"}, rows={{1}}})');` | +| `run_sql` | Execute SQL text from Lua. | `(sql)` | `select L('sqlite.run_sql(\"create table t(a)\")');` | +| `fetch_first` | Run query and return first row as Lua table. | `(sql, [binds])` | `select L('return sqlite.fetch_first(\"select 1 as v\").v');` | +| `fetch_all` | Run query and return all rows. | `(sql, [binds])` | `select L('return #sqlite.fetch_all(\"select 1 as v\")');` | +| `nrows` | Iterator over query rows (name-keyed row table). | `(sql, [binds])` | `select L('for r in sqlite.nrows(\"select 1 as v\") do return r.v end');` | +| `rows` | Row iterator helper (array-like row table). | `(sql, [binds])` | `select L('for r in sqlite.rows(\"select 1 as v\") do return r[1] end');` | +| `urows` | Unpacked-row iterator helper. | `(sql, [binds])` | `select L('for v in sqlite.urows(\"select 1\") do return v end');` | + Create sqlite callable function: ```sql select make_fn('inc', 'return function(a) return a + 1 end', 1 /*expected arguments count*/); @@ -53,6 +74,12 @@ select make_chk('inc_c', 'return arg[1] + 1', 1); select inc_c(14); ``` +Create fast constant string function: +```sql +select make_str('hello_const', 'hello from sqlite_plugin_lj'); +select hello_const(); +``` + - Custom SQL aggregates @@ -85,8 +112,9 @@ select make_stored_aggc('sum_a', One field virtual table: ```sql --- store list_iterator in lua global -select L(' +-- query 1: initialize helper in vtable VM +SELECT * +FROM L(' _G.list_iterator = function(t) local i = 0 local n = #t @@ -95,8 +123,10 @@ select L(' if i <= n then return t[i] end end end + return function() end '); +-- query 2: use helper SELECT * , typeof(value) FROM L(' local tbl = {123, 324, math.pi, NULL, "test", -1377409902473561268LL} @@ -142,6 +172,10 @@ select * from table_a o1; - LuaJIT - SQLite +Windows notes: +- `vcpkg install sqlite3:x64-windows` provides SQLite headers (`sqlite3.h`, `sqlite3ext.h`) and libraries. +- Tests also require the SQLite CLI executable (`sqlite3.exe`). Install the official SQLite command-line tools from sqlite.org and add `sqlite3.exe` to `PATH`. + ### Building from Source Docker builds: @@ -155,17 +189,125 @@ Ubuntu docker build --output=lib --target=binaries -f DockerUbuntu . ``` -Local build with luajit repo: +GitHub Actions workflows: +- `cmake-linux.yml` (Linux CMake build/test) +- `cmake-macos.yml` (macOS CMake build/test) +- `cmake-windows.yml` (Windows/MSVC CMake build/test) +- `docker-alpine.yml` (Docker Alpine binary build) +- `docker-ubuntu.yml` (Docker Ubuntu binary build) + +Local build with LuaJIT fetched by CMake (Linux): ``` mkdir build cd build cmake -DCMAKE_BUILD_TYPE=Release .. cmake --build . -ctest -R sqlite_plugin_tests -V +ctest -V +``` + +Native Windows build sample (Developer Command Prompt for VS): +```bat +set VCPKG_ROOT=C:\git\vcpkg +cmake -S . -B build-vs -G "Visual Studio 17 2022" -A x64 +cmake --build build-vs --config Release +ctest --test-dir build-vs -C Release -V +``` + +Windows build sample with explicit SQLite headers directory (without `VCPKG_ROOT`): +```bat +set SQLITE3_INCLUDE_DIR=C:\path\to\sqlite\include +cmake -S . -B build-vs -G "Visual Studio 17 2022" -A x64 +cmake --build build-vs --config Release +ctest --test-dir build-vs -C Release -V +``` + +Cross-platform Python test runner (used by CTest): +```bash +cd build-vs # or build on Linux/macOS +python ../test.py --tests "011;010;009;001;002;003;004;005;006;007;008;012;013;014;015;016;017;018;019;020;021;022;024;025;026" +python ../test.py --sql-file ../sql/vtables/input_023.sql +``` + +All Lua scripts are embedded into the built library (`.so` / `.dll`). + +### Threading + +Use this extension in a single-threaded way per SQLite connection. Do not execute `L(...)`, custom Lua functions, aggregates, or virtual tables concurrently on the same connection. + +### VM Paths + +The extension uses separate Lua VM paths per SQLite connection: + +- `SELECT L('...')` executes in the function-callback VM. +- `FROM L('...')` / `FROM L10('...')` executes in the virtual-table VM. + +These are different Lua states. `_G` is not shared between them. + +That means this does **not** work: + +```sql +-- sets global in function-callback VM +SELECT L('_G.list_iterator = function(t) ... end'); + +-- runs in virtual-table VM, where list_iterator is nil +SELECT * FROM L10('return list_iterator({{1,2}})'); +``` + +Use this pattern instead: + +```sql +-- initialize global in virtual-table VM +SELECT * FROM L(' + _G.list_iterator = function(t) ... end + return function() return nil end +'); + +-- now virtual-table queries can use it +SELECT * FROM L10('return list_iterator({{1,2}})'); ``` -All lua scripts are included in the .so file. +Plain Lua globals are VM-local. If you need shared state across VM paths, pass values as function arguments or persist them in SQL tables. + +Nested virtual-table execution is also restricted. If a virtual-table callback is already running, trying to enter another virtual-table path from inside it is rejected (for example `L`/`L10` re-entry from within vtable-driven Lua code). Expected error text includes: + +`vtable cannot be used from nested context` + +Global access note for nested paths: nested virtual-table calls would target the same virtual-table VM (so globals would be the same VM globals), but the re-entry guard blocks execution before that nested call proceeds. + +Function-path nesting has a different caveat: nested function execution can run in a different depth VM. Runtime `_G` mutations made in one function call are not automatically propagated to other depth VMs, so nested calls (for example via `nrows("SELECT my_fn(...)")`) may observe a different value or `nil`. + +### Retention/Lifetime + +Shared bridge state (function contexts and stored Lua objects) is retained for the process/extension lifetime in the current implementation. Runtime function unregister/reclaim is currently unsupported as an operational path. Virtual table metadata remains the main reclaimable exception via vtable destroy callbacks. + +### Runtime Limits + +The extension uses a shared in-memory bridge between Lua VMs. + +- `shared object` + Serialized Lua data stored in bridge memory (for example function source chunks and vtable metadata). +- `shared buffer` + The backing byte buffer that holds all shared objects. +- `function context` + Stored callback metadata (`FunctionContext`) used by SQLite function/aggregate callbacks. + +Configure limits with environment variables before starting `sqlite3`: + +- `SQLITE_LJ_MAX_BUFFER_BYTES` + Total bytes available for the shared buffer. + Default: `67108864` (64 MiB). +- `SQLITE_LJ_MAX_OBJECT_BYTES` + Maximum size of a single shared object. + Default: `4194304` (4 MiB). +- `SQLITE_LJ_MAX_OBJECTS` + Maximum number of shared objects stored in the bridge. + Default: `200000`. +- `SQLITE_LJ_MAX_FUNCTION_CONTEXTS` + Maximum number of stored function contexts. + Default: `50000`. + +When any limit is exceeded, registration fails with an explicit Lua error. ## Examples diff --git a/expected/aggregates/output_003.txt b/expected/aggregates/output_003.txt new file mode 100644 index 0000000..276290b --- /dev/null +++ b/expected/aggregates/output_003.txt @@ -0,0 +1,57 @@ + +------------- + +7|7.0 +7|7.0 +8|8.0 +7 +7 +--- empty table aggregates --- + +0|0| +0 +21|21|21 +51|51 +--- window functions --- + +A|7 +B|8 +A|37 +B|28 +1|A|1 +2|A|3 +4|A|7 +3|B|3 +5|B|8 +1|A|11 +2|A|23 +4|A|37 +3|B|13 +5|B|28 +1|1 +2|3 +3|6 +4|10 +5|15 +1|11 +2|23 +3|36 +4|50 +5|65 +--- sliding window error --- +error caught: sliding windows not supported +--- sliding window with inverse --- + +1|1 +2|3 +3|5 +4|7 +5|9 +1|11 +2|23 +3|25 +4|27 +5|29 +--- nested aggregate --- + +21 \ No newline at end of file diff --git a/expected/aggregates/output_004.txt b/expected/aggregates/output_004.txt new file mode 100644 index 0000000..b25dba1 --- /dev/null +++ b/expected/aggregates/output_004.txt @@ -0,0 +1,15 @@ + +------------- + + +21|21 + +21|21 + +PASS: sum_error correctly reported indexing error +PASS: L table reports missing function error + + + | [value] 8LL + | [value] 10LL +table \ No newline at end of file diff --git a/expected/aggregates/output_020.txt b/expected/aggregates/output_020.txt new file mode 100644 index 0000000..f7be818 --- /dev/null +++ b/expected/aggregates/output_020.txt @@ -0,0 +1,8 @@ + +--- coroutine aggregate error propagation --- + +OK:init:query-failed + +OK:step:query-failed + +OK:final:query-failed \ No newline at end of file diff --git a/expected/aggregates/output_021.txt b/expected/aggregates/output_021.txt new file mode 100644 index 0000000..06bb3f2 --- /dev/null +++ b/expected/aggregates/output_021.txt @@ -0,0 +1,8 @@ + +--- coroutine aggregate error propagation expected --- + +OK:init:query-failed + +OK:step:query-failed + +OK:final:query-failed \ No newline at end of file diff --git a/expected/aggregates/output_033.txt b/expected/aggregates/output_033.txt new file mode 100644 index 0000000..bc56d4d --- /dev/null +++ b/expected/aggregates/output_033.txt @@ -0,0 +1,11 @@ + + +--- window on empty table ORDER BY --- +--- window on empty table PARTITION BY --- +--- window inverse on empty table --- +--- window on single row --- +42|1 +--- window inverse on single row --- +42|42 +--- window partition on single row --- +42|A|1 diff --git a/expected/output_001.txt b/expected/basics/output_001.txt similarity index 67% rename from expected/output_001.txt rename to expected/basics/output_001.txt index 406d586..6f37f7f 100644 --- a/expected/output_001.txt +++ b/expected/basics/output_001.txt @@ -17,14 +17,21 @@ 2 = 22LL 1 = 3LL 2 = 33LL - num1 = 1LL num2 = 11LL num1 = 2LL num2 = 22LL num1 = 3LL num2 = 33LL - 1LL 11LL 2LL 22LL -3LL 33LL \ No newline at end of file +3LL 33LL +1LL 1LL +1LL 2LL +1LL 3LL +2LL 1LL +2LL 2LL +2LL 3LL +3LL 1LL +3LL 2LL +3LL 3LL \ No newline at end of file diff --git a/expected/output_002.txt b/expected/basics/output_002.txt similarity index 67% rename from expected/output_002.txt rename to expected/basics/output_002.txt index c693b4d..2a8ec1a 100644 --- a/expected/output_002.txt +++ b/expected/basics/output_002.txt @@ -14,4 +14,3 @@ unable to delete/modify user-function due to active statements echo [21LL] -1377409902473561267|-1377409902473561267 -DROP FUNCTION error: unable to delete/modify user-function due to active statements \ No newline at end of file diff --git a/expected/basics/output_005.txt b/expected/basics/output_005.txt new file mode 100644 index 0000000..435a9a9 --- /dev/null +++ b/expected/basics/output_005.txt @@ -0,0 +1,26 @@ + +------------- + + +36|15|27 + +3|4||3|4| +5|8||5|8| + +1|2| +3|4| +5|8| +9|10| +--- query same vtable twice --- +1|2| +3|4| +5|8| +9|10| +3|4| +5|8| +9|10| +--- join two different vtables --- + +1|2|1|one +3|4|3|three +5|8|5|five \ No newline at end of file diff --git a/expected/basics/output_026.txt b/expected/basics/output_026.txt new file mode 100644 index 0000000..a82f9b6 --- /dev/null +++ b/expected/basics/output_026.txt @@ -0,0 +1,4 @@ + + +1 +hello_fast diff --git a/expected/basics/output_028.txt b/expected/basics/output_028.txt new file mode 100644 index 0000000..a0416f8 --- /dev/null +++ b/expected/basics/output_028.txt @@ -0,0 +1,6 @@ + + + +30 + +42 diff --git a/expected/basics/output_029.txt b/expected/basics/output_029.txt new file mode 100644 index 0000000..abffa49 --- /dev/null +++ b/expected/basics/output_029.txt @@ -0,0 +1,11 @@ + +PASS :param fetch_first +PASS $param fetch_first +PASS @param fetch_first +PASS multi named params +PASS fetch_all named +PASS nrows named +PASS urows named +PASS positional fallback +PASS NULL missing param +PASS positional with table diff --git a/expected/errors/output_022.txt b/expected/errors/output_022.txt new file mode 100644 index 0000000..eb2c2b0 --- /dev/null +++ b/expected/errors/output_022.txt @@ -0,0 +1,15 @@ + +--- malformed lua handling --- +PASS:L +PASS:L10 +PASS:make_fn +PASS:make_chk +PASS:make_stored_agg3:init +PASS:make_stored_agg3:step +PASS:make_stored_agg3:final +PASS:make_stored_aggc:code +PASS:make_stored_win:init +PASS:make_stored_win:step +PASS:make_stored_win:final +PASS:make_stored_win:inverse + diff --git a/expected/errors/output_027.txt b/expected/errors/output_027.txt new file mode 100644 index 0000000..0dd36de --- /dev/null +++ b/expected/errors/output_027.txt @@ -0,0 +1,2 @@ + +[string "temp_fn"]:3: api_bind_any: unsupported type table value table diff --git a/expected/errors/output_034.txt b/expected/errors/output_034.txt new file mode 100644 index 0000000..a837319 --- /dev/null +++ b/expected/errors/output_034.txt @@ -0,0 +1,5 @@ + +PASS L xFilter syntax error +PASS L xNext runtime error +PASS L xFilter runtime error +PASS L10 xFilter syntax error diff --git a/expected/errors/output_036.txt b/expected/errors/output_036.txt new file mode 100644 index 0000000..0a9c7b3 --- /dev/null +++ b/expected/errors/output_036.txt @@ -0,0 +1,14 @@ + + + +PASS abandon iterator returns correct value + +PASS repeated abandon stable + +PASS error after partial iteration propagates +PASS db usable after error cleanup + + +PASS nested functions with cleanup +PASS subsequent queries not corrupted + diff --git a/expected/lifetime/output_024.txt b/expected/lifetime/output_024.txt new file mode 100644 index 0000000..3451462 --- /dev/null +++ b/expected/lifetime/output_024.txt @@ -0,0 +1 @@ +PASS:bind_text_gc diff --git a/expected/lifetime/output_025.txt b/expected/lifetime/output_025.txt new file mode 100644 index 0000000..cf44bf0 --- /dev/null +++ b/expected/lifetime/output_025.txt @@ -0,0 +1 @@ +PASS:callback_gc diff --git a/expected/limits/output_015.txt b/expected/limits/output_015.txt new file mode 100644 index 0000000..c8f87da --- /dev/null +++ b/expected/limits/output_015.txt @@ -0,0 +1 @@ +PASS: shared object size ceiling enforced diff --git a/expected/limits/output_016.txt b/expected/limits/output_016.txt new file mode 100644 index 0000000..f6fdcf9 --- /dev/null +++ b/expected/limits/output_016.txt @@ -0,0 +1 @@ +PASS: shared object count ceiling enforced diff --git a/expected/limits/output_017.txt b/expected/limits/output_017.txt new file mode 100644 index 0000000..bef38c3 --- /dev/null +++ b/expected/limits/output_017.txt @@ -0,0 +1 @@ +PASS: function context ceiling enforced diff --git a/expected/limits/output_018.txt b/expected/limits/output_018.txt new file mode 100644 index 0000000..b618b79 --- /dev/null +++ b/expected/limits/output_018.txt @@ -0,0 +1 @@ +PASS: shared buffer ceiling enforced diff --git a/expected/output_008.txt b/expected/nesting/output_008.txt similarity index 96% rename from expected/output_008.txt rename to expected/nesting/output_008.txt index 69a113d..82fc587 100644 --- a/expected/output_008.txt +++ b/expected/nesting/output_008.txt @@ -3,7 +3,6 @@ 11 - 11 11 \ No newline at end of file diff --git a/expected/nesting/output_009.txt b/expected/nesting/output_009.txt new file mode 100644 index 0000000..96c973c --- /dev/null +++ b/expected/nesting/output_009.txt @@ -0,0 +1,14 @@ + +------------- + + +55LL +55 + +PASS: fib2 correctly rejected nested closure call + + +PASS: sub4 correctly rejected nested closure call + +6 +6 \ No newline at end of file diff --git a/expected/output_003.txt b/expected/output_003.txt deleted file mode 100644 index 5b0fdc0..0000000 --- a/expected/output_003.txt +++ /dev/null @@ -1,5 +0,0 @@ - -------------- - -7|7.0 -7|7.0 diff --git a/expected/output_004.txt b/expected/output_004.txt deleted file mode 100644 index 107d505..0000000 --- a/expected/output_004.txt +++ /dev/null @@ -1,13 +0,0 @@ - -------------- - - -21|21 - -21|21 - - - - | [value] 8LL - | [value] 10LL -table \ No newline at end of file diff --git a/expected/output_005.txt b/expected/output_005.txt deleted file mode 100644 index 553367d..0000000 --- a/expected/output_005.txt +++ /dev/null @@ -1,13 +0,0 @@ - -------------- - - -36|15|27 - -3|4||3|4| -5|8||5|8| - -1|2| -3|4| -5|8| -9|10| \ No newline at end of file diff --git a/expected/output_timed001.txt b/expected/performance/output_timed001.txt similarity index 100% rename from expected/output_timed001.txt rename to expected/performance/output_timed001.txt diff --git a/expected/output_007.txt b/expected/types/output_007.txt similarity index 98% rename from expected/output_007.txt rename to expected/types/output_007.txt index 417c982..2bc20fd 100644 --- a/expected/output_007.txt +++ b/expected/types/output_007.txt @@ -3,6 +3,5 @@ FF00FF FF00FF - F500F9|| F500F9 NULL 0 \ No newline at end of file diff --git a/expected/types/output_012.txt b/expected/types/output_012.txt new file mode 100644 index 0000000..d35e3fc --- /dev/null +++ b/expected/types/output_012.txt @@ -0,0 +1,2 @@ + +PASS: fetch_first bind error triggered diff --git a/expected/types/output_013.txt b/expected/types/output_013.txt new file mode 100644 index 0000000..4e6b98d --- /dev/null +++ b/expected/types/output_013.txt @@ -0,0 +1,13 @@ +PASS bind sqlite.int_t +PASS bind sqlite.int64_t +PASS bind sqlite.uint64_t +PASS bind sqlite.double_t +PASS bind NULL +PASS bind number int +PASS bind number real +PASS bind string +PASS bind boolean +PASS bind blob empty +PASS bind blob bytes +PASS bind invalid table +PASS bind invalid function diff --git a/expected/types/output_040.txt b/expected/types/output_040.txt new file mode 100644 index 0000000..280d31c --- /dev/null +++ b/expected/types/output_040.txt @@ -0,0 +1,5 @@ +PASS register inf functions +PASS return +inf stays float +PASS return -inf stays float +PASS bind +inf stays float +PASS bind -inf stays float diff --git a/expected/output_006.txt b/expected/vtables/output_006.txt similarity index 99% rename from expected/output_006.txt rename to expected/vtables/output_006.txt index 7391dcd..23c8b20 100644 --- a/expected/output_006.txt +++ b/expected/vtables/output_006.txt @@ -3,7 +3,6 @@ 1|2| ------------- - value typeof(value) -------------------- ------------- 123 integer diff --git a/expected/vtables/output_010.txt b/expected/vtables/output_010.txt new file mode 100644 index 0000000..6e4de01 --- /dev/null +++ b/expected/vtables/output_010.txt @@ -0,0 +1,16 @@ +count(*) +-------- +0 +value typeof(value) +----- ------------- +1 integer +2 integer + null +3 integer +value typeof(value) +----- ------------- +6 integer + null +7 integer +8 integer + diff --git a/expected/vtables/output_011.txt b/expected/vtables/output_011.txt new file mode 100644 index 0000000..f8105ef --- /dev/null +++ b/expected/vtables/output_011.txt @@ -0,0 +1,2 @@ + +1|2 diff --git a/expected/vtables/output_014.txt b/expected/vtables/output_014.txt new file mode 100644 index 0000000..b50273d --- /dev/null +++ b/expected/vtables/output_014.txt @@ -0,0 +1,2 @@ +PASS: make_vtable quoted identifiers +PASS: make_vtable rejects empty identifier diff --git a/expected/vtables/output_019.txt b/expected/vtables/output_019.txt new file mode 100644 index 0000000..47133c5 --- /dev/null +++ b/expected/vtables/output_019.txt @@ -0,0 +1,7 @@ +PASS rows nil means empty table +PASS rows empty means empty table +PASS rows table-empty-row gives null row +PASS duplicate columns +PASS rows must be table +PASS extra row values ignored +PASS garbage rows skipped diff --git a/expected/vtables/output_030.txt b/expected/vtables/output_030.txt new file mode 100644 index 0000000..03815aa --- /dev/null +++ b/expected/vtables/output_030.txt @@ -0,0 +1,19 @@ + +--- L10 all 10 columns --- +r0 r1 r2 r3 r4 r5 r6 r7 r8 r9 +-- -- --- -- -- --- -- -- ---- --- +1 a 1.1 10 x 100 1 -1 end1 999 +2 b 2.2 20 y 200 0 -2 end2 888 +('--- L10 filter on last columns ---') +-------------------------------------- +--- L10 filter on last columns --- +r0 r8 r9 +-- ---- --- +2 end2 888 +3 end3 777 +('--- L10 single row all columns ---') +-------------------------------------- +--- L10 single row all columns --- +r0 r1 r2 r3 r4 r5 r6 r7 r8 r9 +-- -- -- -- -- -- -- -- -- --- +10 20 30 40 50 60 70 80 90 100 diff --git a/expected/vtables/output_031.txt b/expected/vtables/output_031.txt new file mode 100644 index 0000000..f05650a --- /dev/null +++ b/expected/vtables/output_031.txt @@ -0,0 +1,5 @@ + +PASS initial create +PASS re-create different data +PASS re-create fewer rows +PASS re-create empty diff --git a/expected/vtables/output_032.txt b/expected/vtables/output_032.txt new file mode 100644 index 0000000..3550485 --- /dev/null +++ b/expected/vtables/output_032.txt @@ -0,0 +1,5 @@ + + +--- re-entry from L into make_vtable --- +PASS L into make_vtable blocked +PASS L into L blocked diff --git a/expected/vtables/output_035.txt b/expected/vtables/output_035.txt new file mode 100644 index 0000000..61a3651 --- /dev/null +++ b/expected/vtables/output_035.txt @@ -0,0 +1,6 @@ + +PASS L10 baseline no join +PASS L10 JOIN regular table +PASS L JOIN regular table +PASS L10 LEFT JOIN +PASS L10 IN subquery diff --git a/expected/vtables/output_037.txt b/expected/vtables/output_037.txt new file mode 100644 index 0000000..a1e3ccd --- /dev/null +++ b/expected/vtables/output_037.txt @@ -0,0 +1,7 @@ + +PASS make_vtable TEMP works +PASS repeated queries stable +PASS drop and re-create +PASS re-create different schema +PASS non-TEMP xCreate succeeds (query works) +PASS make_vtable works after xCreate test diff --git a/expected/vtables/output_038.txt b/expected/vtables/output_038.txt new file mode 100644 index 0000000..75e81d2 --- /dev/null +++ b/expected/vtables/output_038.txt @@ -0,0 +1,5 @@ + +PASS baseline table created +PASS re-create without drop errors +PASS original table preserved after failed re-create +PASS re-create works after explicit drop diff --git a/expected/vtables/output_039.txt b/expected/vtables/output_039.txt new file mode 100644 index 0000000..f871ae0 --- /dev/null +++ b/expected/vtables/output_039.txt @@ -0,0 +1,2 @@ + +PASS xColumn re-entry blocked diff --git a/sql/aggregates/input_003.sql b/sql/aggregates/input_003.sql new file mode 100644 index 0000000..55f61af --- /dev/null +++ b/sql/aggregates/input_003.sql @@ -0,0 +1,137 @@ +select load_extension('./libsqlite_plugin_lj'); + +select('-------------'); +CREATE TABLE data(val NUMERIC); +INSERT INTO data(val) VALUES (5), (7), (9); + +select L(' + + local avgCoroutine = [[ return function () + local acc = 0 + local n = 0 + while true do + local has_next, value = coroutine.yield() -- Receive values from the producer + if has_next then + acc = acc + value + n = n + 1 + else + acc = tonumber(acc) * (1/n) + break -- Break the loop when no more data is available + end + end + return acc + end ]] + + sqlite.create_function_agg_coro_text("agg_avg_coro", avgCoroutine, 1) + + + sqlite.create_function_agg_chk("agg_avg_chk", + "acc = 0; n = 0;", + "n = n + 1; acc = acc + tonumber(arg[1]);", + "return acc * (1/n);", + 1) + + sqlite.create_function_agg_chk("agg_avg_chk2", + "acc = 0; n = 0;", + "n = n + 1; acc = acc + tonumber(arg[1]) + tonumber(arg[2]);", + "return acc * (1/n);", + 2) +'); + +SELECT agg_avg_coro(val), avg(val) FROM data; +SELECT agg_avg_chk(val), avg(val) FROM data; +SELECT agg_avg_chk2(val, 1), avg(val) + 1 FROM data; + +select L(' + local out = {} + for value in sqlite.urows("SELECT agg_avg_coro(val), avg(val) FROM data") do + out[#out + 1] = tostring(tonumber(value)) + end + for value in sqlite.urows("SELECT agg_avg_chk(val), avg(val) FROM data") do + out[#out + 1] = tostring(tonumber(value)) + end + return table.concat(out, "\n") +'); + +-- Test aggregates on empty table +select('--- empty table aggregates ---'); +CREATE TABLE empty_data(val NUMERIC); + +select L(' + sqlite.create_function_agg_chk("sum_chk", "n = 0", "n = n + arg[1]", "return n", 1) + sqlite.create_function_agg_chk("sum_chk2", "n = 0", "n = n + arg[1] + arg[2]", "return n", 2) + + sqlite.create_function_agg_coro_text("sum_coro", [[ return function() + local acc = 0 + while true do + local has_next, value = coroutine.yield() + if has_next then + acc = acc + value + else + break + end + end + return acc + end ]], 1) +'); + +SELECT sum_chk(val), sum_coro(val), sum(val) FROM empty_data; +SELECT sum_chk2(val, 10) FROM empty_data; +SELECT sum_chk(val), sum_coro(val), sum(val) FROM data; +SELECT sum_chk2(val, 10), sum(val) + count(*) * 10 FROM data; + +-- Test window functions +select('--- window functions ---'); +CREATE TABLE win_data(x INT, cat TEXT); +INSERT INTO win_data VALUES (1,'A'), (2,'A'), (3,'B'), (4,'A'), (5,'B'); + +select L(' + sqlite.create_function_agg_chk_window("win_sum", "n = 0", "n = n + arg[1]", "return n", 1) + sqlite.create_function_agg_chk_window("win_sum2", "n = 0", "n = n + arg[1] + arg[2]", "return n", 2) +'); + +-- As regular aggregate (GROUP BY) +SELECT cat, win_sum(x) FROM win_data GROUP BY cat; +SELECT cat, win_sum2(x, 10) FROM win_data GROUP BY cat; + +-- As window function (PARTITION BY) +SELECT x, cat, win_sum(x) OVER (PARTITION BY cat ORDER BY x) FROM win_data ORDER BY cat, x; +SELECT x, cat, win_sum2(x, 10) OVER (PARTITION BY cat ORDER BY x) FROM win_data ORDER BY cat, x; + +-- As window function (ORDER BY only - running total) +SELECT x, win_sum(x) OVER (ORDER BY x) FROM win_data; +SELECT x, win_sum2(x, 10) OVER (ORDER BY x) FROM win_data; + +-- Test sliding window error (no inverse provided) +select('--- sliding window error ---'); +select L(' + local ok, err = pcall(function() + for row in sqlite.rows("SELECT x, win_sum(x) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM win_data") do + end + end) + if not ok and err:find("sliding windows not supported") then + return "error caught: sliding windows not supported" + end + return "error not caught" +'); + +-- Test sliding window with inverse (5-arg version) +select('--- sliding window with inverse ---'); +select L(' + sqlite.create_function_agg_chk_window("win_sum_inv", "n = 0", "n = n + arg[1]", "n = n - arg[1]", "return n", 1) + sqlite.create_function_agg_chk_window("win_sum_inv2", "n = 0", "n = n + arg[1] + arg[2]", "n = n - arg[1] - arg[2]", "return n", 2) +'); + +SELECT x, win_sum_inv(x) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM win_data; +SELECT x, win_sum_inv2(x, 10) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM win_data; + +-- Test nested aggregate call +select('--- nested aggregate ---'); +select L(' + sqlite.create_function_agg_chk("nested_sum", "n = 0", "n = n + arg[1]", "return n", 1) +'); + +select L(' + local result = sqlite.fetch_first("SELECT nested_sum(val) as s FROM data") + return result["s"] +'); diff --git a/sql/input_004.sql b/sql/aggregates/input_004.sql similarity index 61% rename from sql/input_004.sql rename to sql/aggregates/input_004.sql index 633f128..b5ff908 100644 --- a/sql/input_004.sql +++ b/sql/aggregates/input_004.sql @@ -4,9 +4,9 @@ CREATE TABLE data4(val NUMERIC); INSERT INTO data4(val) VALUES (5), (7), (9); select('-------------'); SELECT L(' - local sqlite = require("sqlite_lj") + local sqlite = sqlite sqlite.create_function("make_stored_agg3", sqlite.make_function_agg_chk, -1) - sqlite.create_function("make_stored_aggc", sqlite.make_function_agg, -1) + sqlite.create_function("make_stored_aggc", sqlite.create_function_agg_coro_text, -1) '); select make_stored_agg3('sum_ac', 'n = 0', 'n = n + arg[1]', 'return n', 1); SELECT sum_ac(val), sum(val) FROM data4; @@ -29,10 +29,32 @@ select make_stored_aggc('sum_a', SELECT sum_a(val), sum(val) FROM data4; select make_stored_agg3('sum_error', 'n = 0', 'n = n[1] + arg[1]', 'return n', 1); -SELECT sum_error(val), sum(val) FROM data4; +select L(' + sqlite.config.use_traceback = 0 + local ok, err = pcall(function() + return sqlite.fetch_first("SELECT sum_error(val) FROM data4") + end) + if not ok and tostring(err):match("attempt to index") then + return "PASS: sum_error correctly reported indexing error" + else + return "FAIL: expected indexing error, got: " .. tostring(ok) .. " " .. tostring(err) + end +'); + +select L(' + sqlite.config.use_traceback = 0 + local ok, err = pcall(function() + return sqlite.fetch_all([[select * from L(''return ter{1,2,NULL,3}'')]]) + end) + if not ok and tostring(err):match("attempt to call global ''ter''") then + return "PASS: L table reports missing function error" + else + return "FAIL: expected missing function error, got: " .. tostring(ok) .. " " .. tostring(err) + end +'); select L(' - local sqlite = require("sqlite_lj") + local sqlite = sqlite local fn = function(name, text_code, argc) sqlite.make_fn(name, text_code, argc) end @@ -40,8 +62,9 @@ select L(' '); select make_fn('rs2', ' -local sqlite = require("sqlite_lj") +local sqlite = sqlite return function () + local out = {} sqlite.run_sql(''create table test_table2(value)''); sqlite.run_sql(''INSERT INTO test_table2(value) VALUES (8), (10), (12);'') for row in sqlite.nrows(''select * from test_table2 where value < ?'', {11}) do @@ -49,7 +72,7 @@ return function () for k,v in pairs(row) do test_row = test_row .. '' | ['' .. (k) .. ''] '' .. tostring(v) end - print(test_row) + out[#out + 1] = test_row end --[[ for row in sqlite.nrows(''select name, file from PRAGMA_database_list;'') do @@ -57,12 +80,12 @@ return function () for _,v in ipairs({''name'', ''file''}) do test_row = test_row .. '' | ['' .. (v) .. ''] '' .. tostring(row[v]) end - print(test_row) + out[#out + 1] = test_row end ]] local database_list = sqlite.fetch_all(''select * from PRAGMA_database_list;'') - print(type(database_list)) - + out[#out + 1] = type(database_list) + return table.concat(out, "\n") end '); -select rs2(); \ No newline at end of file +select rs2(); diff --git a/sql/aggregates/input_020.sql b/sql/aggregates/input_020.sql new file mode 100644 index 0000000..515198b --- /dev/null +++ b/sql/aggregates/input_020.sql @@ -0,0 +1,93 @@ +select load_extension('./libsqlite_plugin_lj'); + +CREATE TABLE agg_err_data(v NUMERIC); +INSERT INTO agg_err_data(v) VALUES (1), (2); + +select('--- coroutine aggregate error propagation ---'); + +select L(' + sqlite.create_function_agg_coro_text("agg_init_error", [[ + return function () + error("initialization error") + end + ]], 1) +'); + +select L(' + local ok, res = pcall(function() + return sqlite.fetch_first("SELECT agg_init_error(v) AS s FROM agg_err_data") + end) + if ok then + local s = tostring(res.s) + if s:match("dead coroutine") then + return "ISSUE:init:resume-error-returned-as-value" + else + return "ISSUE:init:unexpected:" .. s + end + else + return "OK:init:query-failed" + end +'); + +select L(' + sqlite.create_function_agg_coro_text("agg_step_error", [[ + return function () + while true do + local has_next, value = coroutine.yield() + if has_next then + error("step error") + else + return 0 + end + end + end + ]], 1) +'); + +select L(' + local ok, res = pcall(function() + return sqlite.fetch_first("SELECT agg_step_error(v) AS s FROM agg_err_data") + end) + if ok then + local s = tostring(res.s) + if s:match("dead coroutine") then + return "ISSUE:step:resume-error-returned-as-value" + else + return "ISSUE:step:unexpected:" .. s + end + else + return "OK:step:query-failed" + end +'); + +select L(' + sqlite.create_function_agg_coro_text("agg_final_error", [[ + return function () + local acc = 0 + while true do + local has_next, value = coroutine.yield() + if has_next then + acc = acc + value + else + error("finalization error") + end + end + end + ]], 1) +'); + +select L(' + local ok, res = pcall(function() + return sqlite.fetch_first("SELECT agg_final_error(v) AS s FROM agg_err_data") + end) + if ok then + local s = tostring(res.s) + if s:match("finalization error") then + return "ISSUE:final:error-text-returned-as-value" + else + return "ISSUE:final:unexpected:" .. s + end + else + return "OK:final:query-failed" + end +'); diff --git a/sql/aggregates/input_021.sql b/sql/aggregates/input_021.sql new file mode 100644 index 0000000..89c35d0 --- /dev/null +++ b/sql/aggregates/input_021.sql @@ -0,0 +1,78 @@ +select load_extension('./libsqlite_plugin_lj'); + +CREATE TABLE agg_err_data_future(v NUMERIC); +INSERT INTO agg_err_data_future(v) VALUES (1), (2); + +select('--- coroutine aggregate error propagation expected ---'); + +select L(' + sqlite.create_function_agg_coro_text("agg_init_error_future", [[ + return function () + error("initialization error") + end + ]], 1) +'); + +select L(' + local ok, _ = pcall(function() + return sqlite.fetch_first("SELECT agg_init_error_future(v) AS s FROM agg_err_data_future") + end) + if ok then + return "ISSUE:init:query-succeeded" + else + return "OK:init:query-failed" + end +'); + +select L(' + sqlite.create_function_agg_coro_text("agg_step_error_future", [[ + return function () + while true do + local has_next, value = coroutine.yield() + if has_next then + error("step error") + else + return 0 + end + end + end + ]], 1) +'); + +select L(' + local ok, _ = pcall(function() + return sqlite.fetch_first("SELECT agg_step_error_future(v) AS s FROM agg_err_data_future") + end) + if ok then + return "ISSUE:step:query-succeeded" + else + return "OK:step:query-failed" + end +'); + +select L(' + sqlite.create_function_agg_coro_text("agg_final_error_future", [[ + return function () + local acc = 0 + while true do + local has_next, value = coroutine.yield() + if has_next then + acc = acc + value + else + error("finalization error") + end + end + end + ]], 1) +'); + +select L(' + local ok, _ = pcall(function() + return sqlite.fetch_first("SELECT agg_final_error_future(v) AS s FROM agg_err_data_future") + end) + if ok then + return "ISSUE:final:query-succeeded" + else + return "OK:final:query-failed" + end +'); diff --git a/sql/aggregates/input_033.sql b/sql/aggregates/input_033.sql new file mode 100644 index 0000000..3c8e9fa --- /dev/null +++ b/sql/aggregates/input_033.sql @@ -0,0 +1,40 @@ +select load_extension('./libsqlite_plugin_lj'); + +CREATE TABLE empty_win(x INT, cat TEXT); +CREATE TABLE single_win(x INT, cat TEXT); +INSERT INTO single_win VALUES (42, 'A'); + +select L(' + -- Window function without inverse (4-arg) + sqlite.create_function_agg_chk_window("win_cnt", + "n = 0", + "n = n + 1", + "return n", + 1) + + -- Window function with inverse (5-arg) + sqlite.create_function_agg_chk_window("win_total", + "n = 0", + "n = n + arg[1]", + "n = n - arg[1]", + "return n", + 1) +'); + +select('--- window on empty table ORDER BY ---'); +SELECT x, win_cnt(x) OVER (ORDER BY x) FROM empty_win; + +select('--- window on empty table PARTITION BY ---'); +SELECT x, win_cnt(x) OVER (PARTITION BY cat) FROM empty_win; + +select('--- window inverse on empty table ---'); +SELECT x, win_total(x) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM empty_win; + +select('--- window on single row ---'); +SELECT x, win_cnt(x) OVER (ORDER BY x) FROM single_win; + +select('--- window inverse on single row ---'); +SELECT x, win_total(x) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM single_win; + +select('--- window partition on single row ---'); +SELECT x, cat, win_cnt(x) OVER (PARTITION BY cat ORDER BY x) FROM single_win; diff --git a/sql/input_001.sql b/sql/basics/input_001.sql similarity index 56% rename from sql/input_001.sql rename to sql/basics/input_001.sql index 9c1b145..95883f1 100644 --- a/sql/input_001.sql +++ b/sql/basics/input_001.sql @@ -1,10 +1,9 @@ select load_extension('./libsqlite_plugin_lj'); select L(' - _G.sqlite = require("sqlite_lj") - sqlite.create_function("make_fn", sqlite.make_fn, -1) - sqlite.create_function("make_int", sqlite.make_int, 2) - sqlite.create_function("make_chk", sqlite.make_chk, 3) + sqlite.register_internal_function("make_fn") + sqlite.register_internal_function("make_int") + sqlite.register_internal_function("make_chk") '); select make_fn('Lua', ' return function (code_text, ...) @@ -50,7 +49,8 @@ select L('sqlite.run_sql[[ '); select L(' -_G.pprint_table = function (tbl, indent) +_G.pprint_table = function (tbl, out, indent) + out = out or {} indent = indent or 0 local keys = {} for k, _ in pairs(tbl) do @@ -61,24 +61,43 @@ _G.pprint_table = function (tbl, indent) for _, k in ipairs(keys) do local v = tbl[k] if type(v) == "table" then - print(string.rep(" ", indent) .. tostring(k) .. " = {") - pprint_table(v, indent + 4) - print(string.rep(" ", indent) .. "}") + out[#out + 1] = string.rep(" ", indent) .. tostring(k) .. " = {" + pprint_table(v, out, indent + 4) + out[#out + 1] = string.rep(" ", indent) .. "}" else - print(string.rep(" ", indent) .. tostring(k) .. " = " .. tostring(v)) + out[#out + 1] = string.rep(" ", indent) .. tostring(k) .. " = " .. tostring(v) end end + return out end '); select L(' -for a in sqlite.rows("SELECT * FROM numbers") do pprint_table(a) end +local out = {} +for a in sqlite.rows("SELECT * FROM numbers") do pprint_table(a, out) end +return table.concat(out, "\n") '); select L(' -for a in sqlite.nrows("SELECT * FROM numbers") do pprint_table(a) end +local out = {} +for a in sqlite.nrows("SELECT * FROM numbers") do pprint_table(a, out) end +return table.concat(out, "\n") '); select L(' -for num1, num2 in sqlite.urows("SELECT * FROM numbers") do print(num1,num2) end +local out = {} +for num1, num2 in sqlite.urows("SELECT * FROM numbers") do + out[#out + 1] = tostring(num1) .. "\t" .. tostring(num2) +end +return table.concat(out, "\n") +'); + +select L(' +local out = {} +for num1, num2 in sqlite.urows("SELECT * FROM numbers") do +for num3, num4 in sqlite.urows("SELECT * FROM numbers") do + out[#out + 1] = tostring(num1) .. "\t" .. tostring(num3) + end +end +return table.concat(out, "\n") '); diff --git a/sql/input_002.sql b/sql/basics/input_002.sql similarity index 78% rename from sql/input_002.sql rename to sql/basics/input_002.sql index 93297ff..6c144c6 100644 --- a/sql/input_002.sql +++ b/sql/basics/input_002.sql @@ -1,11 +1,9 @@ select load_extension('./libsqlite_plugin_lj'); select L(' - _G.sqlite = require("sqlite_lj") sqlite.make_int("const_x", 9999) sqlite.make_int("const_x2", 10000) - sqlite.create_function("make_fn", sqlite.make_fn, -1) - sqlite.create_function("drop_function", sqlite.drop_function, -1) + sqlite.register_internal_function("make_fn") '); select const_x(); select const_x2(); @@ -24,5 +22,3 @@ select L(' sqlite.create_function("inc_w", increment_w, 1) '); select inc(-1377409902473561268), inc_w(-1377409902473561268); -select drop_function('echo1', 1); - diff --git a/sql/input_005.sql b/sql/basics/input_005.sql similarity index 65% rename from sql/input_005.sql rename to sql/basics/input_005.sql index 88ec1d6..bce2022 100644 --- a/sql/input_005.sql +++ b/sql/basics/input_005.sql @@ -3,8 +3,7 @@ select load_extension('./libsqlite_plugin_lj'); select('-------------'); select L(' - local sqlite = require("sqlite_lj") - sqlite.create_function("make_fn", sqlite.make_fn, -1) + sqlite.register_internal_function("make_fn") '); select make_fn('Lua', ' return function (code_text, ...) @@ -27,7 +26,6 @@ select L('return arg[1] + arg[2]', 12, 24), Lua('return arg[1]', 15, 24), L('ret --select L('print (int64_t(1))'); --SELECT * FROM sqlite_master;-- WHERE type='table'; select L(' - local sqlite = require("sqlite_lj") sqlite.run_sql("DROP TABLE IF EXISTS TEMP.table_x") local data_vt = {columns = {"a", "b", "c"}, rows = {[0] = {1,2}, {3,4}, a = {5,8}}} sqlite.make_vtable("table_x", data_vt) @@ -38,7 +36,6 @@ inner join table_x o2 on o1.a = o2.a where o1.b > 2; select L(' - local sqlite = require("sqlite_lj") sqlite.run_sql("DROP TABLE IF EXISTS TEMP.table_a") sqlite.make_vtable("table_a", { @@ -50,4 +47,23 @@ select L(' select * from table_a o1; +-- Test: query the same vtable twice (separate queries) +select('--- query same vtable twice ---'); +select * from table_a; +select * from table_a where a > 1; +-- Test: join two different vtables +select('--- join two different vtables ---'); +select L(' + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.table_b") + sqlite.make_vtable("table_b", + { + columns = {"x", "y"}, + rows = {{1, "one"}, {3, "three"}, {5, "five"}} + } + ) + '); + +select table_a.a, table_a.b, table_b.x, table_b.y +from table_a +inner join table_b on table_a.a = table_b.x; diff --git a/sql/basics/input_026.sql b/sql/basics/input_026.sql new file mode 100644 index 0000000..4d544d7 --- /dev/null +++ b/sql/basics/input_026.sql @@ -0,0 +1,8 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.register_internal_function("make_str", "mkstr") +'); + +select mkstr('const_s', 'hello_fast'); +select const_s(); diff --git a/sql/basics/input_028.sql b/sql/basics/input_028.sql new file mode 100644 index 0000000..7e8ab68 --- /dev/null +++ b/sql/basics/input_028.sql @@ -0,0 +1,11 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.register_internal_function("make_chk") +'); + +select make_chk('sum2', 'return arg[1] + arg[2]', 2); +select sum2(10, 20); + +select make_chk('answer', 'return 42', 0); +select answer(); diff --git a/sql/basics/input_029.sql b/sql/basics/input_029.sql new file mode 100644 index 0000000..fba4c53 --- /dev/null +++ b/sql/basics/input_029.sql @@ -0,0 +1,65 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function check(label, ok) + out[#out + 1] = (ok and "PASS " or "FAIL ") .. label + end + + -- Setup test data + sqlite.run_sql("CREATE TABLE np_data(id INTEGER, name TEXT, val REAL)") + sqlite.run_sql("INSERT INTO np_data VALUES(1, ''alice'', 1.5)") + sqlite.run_sql("INSERT INTO np_data VALUES(2, ''bob'', 2.5)") + sqlite.run_sql("INSERT INTO np_data VALUES(3, ''charlie'', 3.5)") + + -- Test 1: :param style with fetch_first + local r = sqlite.fetch_first("SELECT id, name FROM np_data WHERE id = :id", {id = 2}) + check(":param fetch_first", r and tonumber(r.id) == 2 and r.name == "bob") + + -- Test 2: $param style with fetch_first + local r2 = sqlite.fetch_first("SELECT id, name FROM np_data WHERE name = $name", {name = "alice"}) + check("$param fetch_first", r2 and tonumber(r2.id) == 1) + + -- Test 3: @param style with fetch_first + local r3 = sqlite.fetch_first("SELECT id FROM np_data WHERE val = @val", {val = 3.5}) + check("@param fetch_first", r3 and tonumber(r3.id) == 3) + + -- Test 4: Multiple named params + local r4 = sqlite.fetch_first("SELECT id FROM np_data WHERE id = :id AND name = :name", {id = 1, name = "alice"}) + check("multi named params", r4 and tonumber(r4.id) == 1) + + -- Test 5: fetch_all with named params + local rows = sqlite.fetch_all("SELECT id FROM np_data WHERE val >= :min_val ORDER BY id", {min_val = 2.5}) + check("fetch_all named", #rows == 2 and tonumber(rows[1].id) == 2 and tonumber(rows[2].id) == 3) + + -- Test 6: nrows iterator with named params + local ids = {} + for row in sqlite.nrows("SELECT id FROM np_data WHERE id > :min_id ORDER BY id", {min_id = 1}) do + ids[#ids + 1] = tonumber(row.id) + end + check("nrows named", #ids == 2 and ids[1] == 2 and ids[2] == 3) + + -- Test 7: urows iterator with named params + local names = {} + for name in sqlite.urows("SELECT name FROM np_data WHERE val < :max_val ORDER BY name", {max_val = 3.0}) do + names[#names + 1] = name + end + check("urows named", #names == 2 and names[1] == "alice" and names[2] == "bob") + + -- Test 8: Positional fallback when named param not found in table + -- :b is param index 2, so params[2] = 20 is the positional fallback + local r5 = sqlite.fetch_first("SELECT :a as a, :b as b", {a = 10, [2] = 20}) + check("positional fallback", r5 and tonumber(r5.a) == 10 and tonumber(r5.b) == 20) + + -- Test 9: NULL for missing named params (key not in table, no positional fallback) + local r6 = sqlite.fetch_first("SELECT :x as x, :y as y", {x = 42}) + check("NULL missing param", r6 and tonumber(r6.x) == 42 and r6.y == nil) + + -- Test 10: All positional (? style) still works with table + local r7 = sqlite.fetch_first("SELECT ? as a, ? as b", {100, 200}) + check("positional with table", r7 and tonumber(r7.a) == 100 and tonumber(r7.b) == 200) + + return table.concat(out, "\n") +'); diff --git a/sql/errors/input_022.sql b/sql/errors/input_022.sql new file mode 100644 index 0000000..ffc39e3 --- /dev/null +++ b/sql/errors/input_022.sql @@ -0,0 +1,69 @@ +select load_extension('./libsqlite_plugin_lj'); + +select('--- malformed lua handling ---'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function expect_error(label, fn, pattern) + local _, res = pcall(fn) + local text = tostring(res) + if text:match(pattern) then + out[#out + 1] = "PASS:" .. label + else + out[#out + 1] = "FAIL:" .. label + out[#out + 1] = "DETAIL:" .. label .. ":" .. text + end + end + + expect_error("L", function() + return sqlite.fetch_first([[SELECT L(''return ('') AS v]]) + end, "Create failed %[temp_fn%]") + + expect_error("L10", function() + return sqlite.fetch_all([[SELECT * FROM L10(''return function('')]]) + end, "Create temporary function failed") + + expect_error("make_fn", function() + return sqlite.make_fn("bad_fn", "return function(a) return (a + ) end", 1) + end, "Create failed %[bad_fn%]") + + expect_error("make_chk", function() + return sqlite.make_chk("bad_chk", "return (arg[1] + )", 1) + end, "Create failed %[bad_chk%]") + + expect_error("make_stored_agg3:init", function() + return sqlite.make_function_agg_chk("bad_agg3_init", "n =", "n = n + arg[1]", "return n", 1) + end, "bad_agg3_init:init") + + expect_error("make_stored_agg3:step", function() + return sqlite.make_function_agg_chk("bad_agg3_step", "n = 0", "n = n +", "return n", 1) + end, "bad_agg3_step:step") + + expect_error("make_stored_agg3:final", function() + return sqlite.make_function_agg_chk("bad_agg3_final", "n = 0", "n = n + arg[1]", "return n +", 1) + end, "bad_agg3_final:final") + + expect_error("make_stored_aggc:code", function() + return sqlite.create_function_agg_coro_text("bad_aggc", "return function() local x = 1 + end", 1) + end, "Create failed %[bad_aggc%]") + + expect_error("make_stored_win:init", function() + return sqlite.create_function_agg_chk_window("bad_win_init", "n =", "n = n + arg[1]", "return n", 1) + end, "bad_win_init:init") + + expect_error("make_stored_win:step", function() + return sqlite.create_function_agg_chk_window("bad_win_step", "n = 0", "n = n +", "return n", 1) + end, "bad_win_step:step") + + expect_error("make_stored_win:final", function() + return sqlite.create_function_agg_chk_window("bad_win_final", "n = 0", "n = n + arg[1]", "return n +", 1) + end, "bad_win_final:final") + + expect_error("make_stored_win:inverse", function() + return sqlite.create_function_agg_chk_window("bad_win_inverse", "n = 0", "n = n + arg[1]", "n = n -", "return n", 1) + end, "bad_win_inverse:inverse") + + return table.concat(out, "\n") +'); diff --git a/sql/errors/input_027.sql b/sql/errors/input_027.sql new file mode 100644 index 0000000..1d5252f --- /dev/null +++ b/sql/errors/input_027.sql @@ -0,0 +1,12 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + local ok, err = pcall(function() + sqlite.fetch_first("SELECT :v as v", { v = { bad = 1 } }) + end) + if ok then + return "unexpected-success" + end + err = tostring(err):gsub("table: 0x%x+", "table") + return err +'); diff --git a/sql/errors/input_034.sql b/sql/errors/input_034.sql new file mode 100644 index 0000000..89d3146 --- /dev/null +++ b/sql/errors/input_034.sql @@ -0,0 +1,48 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function expect_error(label, fn, pattern) + local ok, err = pcall(fn) + local text = tostring(err) + if (not ok) and text:match(pattern) then + out[#out + 1] = "PASS " .. label + else + out[#out + 1] = "FAIL " .. label + if ok then + out[#out + 1] = "DETAIL:" .. label .. ":unexpected success" + else + out[#out + 1] = "DETAIL:" .. label .. ":" .. text + end + end + end + + -- Test 1: L vtable with code that errors on xFilter (bad Lua code) + expect_error("L xFilter syntax error", function() + for row in sqlite.nrows("SELECT * FROM L(''return (('')") do end + end, "Create temporary function failed") + + -- Test 2: L vtable with iterator function that errors mid-iteration + expect_error("L xNext runtime error", function() + local code = "local i = 0; return function() i = i + 1; if i == 1 then return 1 end; if i == 2 then error(\"deliberate iteration error\") end; return nil end" + local sql = "SELECT * FROM L(''" .. code .. "'')" + local results = {} + for row in sqlite.nrows(sql) do + results[#results + 1] = row + end + end, "deliberate iteration error") + + -- Test 3: L vtable with code that errors immediately (runtime error in xFilter) + expect_error("L xFilter runtime error", function() + for row in sqlite.nrows("SELECT * FROM L(''error(\"immediate xFilter fail\")'')") do end + end, "immediate xFilter fail") + + -- Test 4: L10 with bad Lua code + expect_error("L10 xFilter syntax error", function() + for row in sqlite.nrows("SELECT * FROM L10(''return (('')") do end + end, "Create temporary function failed") + + return table.concat(out, "\n") +'); diff --git a/sql/errors/input_036.sql b/sql/errors/input_036.sql new file mode 100644 index 0000000..cb7d05f --- /dev/null +++ b/sql/errors/input_036.sql @@ -0,0 +1,85 @@ +select load_extension('./libsqlite_plugin_lj'); + +-- Setup: create helper and test table +select L(' + sqlite.config.use_traceback = 0 + sqlite.register_internal_function("make_fn") + sqlite.register_internal_function("make_chk") + + sqlite.run_sql("CREATE TABLE IF NOT EXISTS t036(id INTEGER PRIMARY KEY, val TEXT)") + sqlite.run_sql("DELETE FROM t036") + sqlite.run_sql("INSERT INTO t036 VALUES(1, ''a'')") + sqlite.run_sql("INSERT INTO t036 VALUES(2, ''b'')") + sqlite.run_sql("INSERT INTO t036 VALUES(3, ''c'')") +'); + +-- Test 1: Function that abandons an iterator mid-iteration returns correct value +-- make_chk uses caller_chk which has close_unfinalized cleanup +select make_chk('test036_abandon', + 'local first = sqlite.fetch_first("SELECT id, val FROM t036 ORDER BY id") return tonumber(first.id)', + 0); +select L(' + local r = sqlite.fetch_first("SELECT test036_abandon() as v") + local ok = tonumber(r.v) == 1 + return (ok and "PASS" or "FAIL") .. " abandon iterator returns correct value" +'); + +-- Test 2: Repeated calls with abandoned iterators work correctly +select make_chk('test036_abandon2', + 'local first = sqlite.fetch_first("SELECT id FROM t036 ORDER BY id LIMIT 1") return tonumber(first.id)', + 0); +select L(' + local ok = true + for i = 1, 5 do + local r = sqlite.fetch_first("SELECT test036_abandon2() as v") + if tonumber(r.v) ~= 1 then ok = false end + end + return (ok and "PASS" or "FAIL") .. " repeated abandon stable" +'); + +-- Test 3: Function that errors after partial iteration propagates the error +select make_chk('test036_error_mid', + 'local first = sqlite.fetch_first("SELECT id FROM t036 ORDER BY id") error("deliberate error after partial read")', + 0); +select L(' + sqlite.config.use_traceback = 0 + local ok, err = pcall(function() + sqlite.fetch_first("SELECT test036_error_mid() as v") + end) + local err_ok = (not ok) and tostring(err):find("deliberate error after partial read") + return (err_ok and "PASS" or "FAIL") .. " error after partial iteration propagates" +'); + +-- Test 4: Database remains usable after error+cleanup +select L(' + local r = sqlite.fetch_first("SELECT count(*) as c FROM t036") + local ok = tonumber(r.c) == 3 + return (ok and "PASS" or "FAIL") .. " db usable after error cleanup" +'); + +-- Test 5: Nested function calls with cleanup +-- Inner function reads a value; outer calls inner and multiplies +select make_chk('test036_inner', + 'local r = sqlite.fetch_first("SELECT max(id) as m FROM t036") return tonumber(r.m)', + 0); +select make_chk('test036_outer', + 'local r = sqlite.fetch_first("SELECT test036_inner() as v") return tonumber(r.v) * 10', + 0); +select L(' + local r = sqlite.fetch_first("SELECT test036_outer() as v") + local ok = tonumber(r.v) == 30 + return (ok and "PASS" or "FAIL") .. " nested functions with cleanup" +'); + +-- Test 6: Subsequent queries not corrupted by cleanup +select L(' + local rows = sqlite.fetch_all("SELECT id, val FROM t036 ORDER BY id") + local ok = #rows == 3 + and tonumber(rows[1].id) == 1 and rows[1].val == "a" + and tonumber(rows[2].id) == 2 and rows[2].val == "b" + and tonumber(rows[3].id) == 3 and rows[3].val == "c" + return (ok and "PASS" or "FAIL") .. " subsequent queries not corrupted" +'); + +-- Cleanup +select L('sqlite.run_sql("DROP TABLE IF EXISTS t036")'); diff --git a/sql/input_003.sql b/sql/input_003.sql deleted file mode 100644 index 8592204..0000000 --- a/sql/input_003.sql +++ /dev/null @@ -1,36 +0,0 @@ -select load_extension('./libsqlite_plugin_lj'); - -select('-------------'); -CREATE TABLE data(val NUMERIC); -INSERT INTO data(val) VALUES (5), (7), (9); - -select L(' - local sqlite = require("sqlite_lj") - local avgCoroutine = function () - local acc = 0 - local n = 0 - while true do - local has_next, value = coroutine.yield() -- Receive values from the producer - if has_next then - acc = acc + value - n = n + 1 - else - acc = tonumber(acc) * (1/n) - break -- Break the loop when no more data is available - end - end - return acc - end - - sqlite.create_function_agg_coro("agg_avg_coro", avgCoroutine, 1) - - - sqlite.create_function_agg_chk("agg_avg_chk", - "acc = 0; n = 0;", - "n = n + 1; acc = acc + tonumber(arg[1]);", - "return acc * (1/n);", - 1) -'); - -SELECT agg_avg_coro(val), avg(val) FROM data; -SELECT agg_avg_chk(val), avg(val) FROM data; diff --git a/sql/lifetime/input_024.sql b/sql/lifetime/input_024.sql new file mode 100644 index 0000000..3322290 --- /dev/null +++ b/sql/lifetime/input_024.sql @@ -0,0 +1,29 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + + local function make_payload(i) + local head = string.char(65 + (i % 26)) + return head .. ":" .. tostring(i) .. ":" .. string.rep("z", 8192) + end + + local result = "PASS:bind_text_gc" + for i = 1, 2000 do + local payload = make_payload(i) + local iter = sqlite.rows("select ?1 as v", {payload}) + payload = nil + + collectgarbage("collect") + collectgarbage("collect") + + local row = iter() + local expected = make_payload(i) + if not row or row[1] ~= expected then + result = "FAIL:bind_text_gc:" .. tostring(i) + break + end + iter() + end + + return result +'); diff --git a/sql/lifetime/input_025.sql b/sql/lifetime/input_025.sql new file mode 100644 index 0000000..1f39d2f --- /dev/null +++ b/sql/lifetime/input_025.sql @@ -0,0 +1,29 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + + local result = "PASS:callback_gc" + for i = 1, 600 do + local fname = "w_gc_" .. tostring(i) + sqlite.create_function_agg_chk_window( + fname, + "n = 0", + "n = n + arg[1]", + "return n", + 1 + ) + + collectgarbage("collect") + collectgarbage("collect") + + local row = sqlite.fetch_first( + "WITH t(v) AS (VALUES(1),(2),(3)) SELECT " .. fname .. "(v) OVER () AS s FROM t LIMIT 1" + ) + if not row or tonumber(row.s) ~= 6 then + result = "FAIL:callback_gc:" .. tostring(i) + break + end + end + + return result +'); diff --git a/sql/limits/input_015.sql b/sql/limits/input_015.sql new file mode 100644 index 0000000..376c516 --- /dev/null +++ b/sql/limits/input_015.sql @@ -0,0 +1,16 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + local payload = string.rep("a", 5 * 1024 * 1024) + local chunk = "return function() return [[" .. payload .. "]] end" + + local ok, err = pcall(function() + sqlite.make_fn("too_big", chunk, 0) + end) + + if (not ok) and tostring(err):match("shared object storage limit exceeded") then + return "PASS: shared object size ceiling enforced" + else + return "FAIL: shared object size ceiling enforced" + end +'); diff --git a/sql/limits/input_016.sql b/sql/limits/input_016.sql new file mode 100644 index 0000000..22a74f7 --- /dev/null +++ b/sql/limits/input_016.sql @@ -0,0 +1,23 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + + local ok = true + local errtxt = "" + for i = 1, 20 do + local status, err = pcall(function() + sqlite.make_fn("obj_lim_" .. i, "return function() return " .. i .. " end", 0) + end) + if not status then + ok = false + errtxt = tostring(err) + break + end + end + + if (not ok) and errtxt:match("shared object storage limit exceeded") then + return "PASS: shared object count ceiling enforced" + else + return "FAIL: shared object count ceiling enforced" + end +'); diff --git a/sql/limits/input_017.sql b/sql/limits/input_017.sql new file mode 100644 index 0000000..6a29981 --- /dev/null +++ b/sql/limits/input_017.sql @@ -0,0 +1,23 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + + local ok = true + local errtxt = "" + for i = 1, 20 do + local status, err = pcall(function() + sqlite.create_function("ctx_lim_" .. i, function(a) return a end, 1) + end) + if not status then + ok = false + errtxt = tostring(err) + break + end + end + + if (not ok) and errtxt:match("function context storage limit exceeded") then + return "PASS: function context ceiling enforced" + else + return "FAIL: function context ceiling enforced" + end +'); diff --git a/sql/limits/input_018.sql b/sql/limits/input_018.sql new file mode 100644 index 0000000..46ca1ab --- /dev/null +++ b/sql/limits/input_018.sql @@ -0,0 +1,26 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + + local payload = string.rep("b", 50000) + local chunk = "return function() return [[" .. payload .. "]] end" + + local ok = true + local errtxt = "" + for i = 1, 20 do + local status, err = pcall(function() + sqlite.make_fn("buf_lim_" .. i, chunk, 0) + end) + if not status then + ok = false + errtxt = tostring(err) + break + end + end + + if (not ok) and errtxt:match("shared object storage limit exceeded") then + return "PASS: shared buffer ceiling enforced" + else + return "FAIL: shared buffer ceiling enforced" + end +'); diff --git a/sql/input_008.sql b/sql/nesting/input_008.sql similarity index 85% rename from sql/input_008.sql rename to sql/nesting/input_008.sql index 3363c49..925c511 100644 --- a/sql/input_008.sql +++ b/sql/nesting/input_008.sql @@ -2,7 +2,7 @@ select load_extension('./libsqlite_plugin_lj'); select('-------------'); select L(' -_G.sqlite = require("sqlite_lj") +sqlite.register_internal_function("make_fn") sqlite.run_sql[[ CREATE TABLE numbers008(num1,num2); INSERT INTO numbers008 VALUES(1,11); @@ -13,20 +13,18 @@ sqlite.run_sql[[ select L(' for value in sqlite.urows("select num2 from numbers008") do - print(tonumber(value)) - break; + return tonumber(value) end -- urows statement stored in unfinalized_statements internal map and finalize(stmt) called here '); -select L(' - local sub1 = function() +select make_fn('sub1', ' + return function() for value in sqlite.urows("select num2 from numbers008") do return tonumber(value) -- urows in unfinalized_statements end end - sqlite.create_function("sub1", sub1, 0) '); select sub1(); diff --git a/sql/nesting/input_009.sql b/sql/nesting/input_009.sql new file mode 100644 index 0000000..a26218d --- /dev/null +++ b/sql/nesting/input_009.sql @@ -0,0 +1,109 @@ +select load_extension('./libsqlite_plugin_lj'); +select('-------------'); +select L(' + sqlite.register_internal_function("make_fn") + sqlite.run_sql[[ + CREATE TABLE numbers009(num1,num2); + INSERT INTO numbers009 VALUES(1,11); + INSERT INTO numbers009 VALUES(2,22); + INSERT INTO numbers009 VALUES(3,33); +]] +'); + +select make_fn('fib', ' +return function (value) + if (value <= 1) then + return value + end + local value1 = sqlite.fetch_first("select fib(?) as f", {value - 1})["f"] + local value2 = sqlite.fetch_first("select fib(?) as f", {value - 2})["f"] + return value1 + value2; +end +'); + +select L(' +local res = sqlite.fetch_first("select fib(10) as result"); +return tostring(res["result"]) +'); +select fib(10); + + +-- select L(' +-- error("test error") +-- '); + +select L(' + sqlite.config.use_traceback = 0 + + local fib2 = function (value) + if (value <= 1) then + return value + end + + local value1 = sqlite.fetch_first("select fib2(?) as f", {value - 1})["f"] + local value2 = sqlite.fetch_first("select fib(?) as f", {value - 2})["f"] + return value1 + value2; + end + + sqlite.create_function("fib2", fib2, 1) +'); + +-- Test that closure-based function fails in nested call (expected behavior) +select L(' + local ok, err = pcall(function() + return sqlite.fetch_first("select fib2(2)") + end) + if not ok and tostring(err):match("max call depth") then + return "PASS: fib2 correctly rejected nested closure call" + else + return "FAIL: fib2 expected max call depth error, got: " .. tostring(ok) .. " " .. tostring(err) + end +'); + +-- check close of unfinalized statements in nested calls + +select L(' + local sub3 = function() + for value in sqlite.urows("select num2 from numbers009") do + return tonumber(value) + -- urows in unfinalized_statements + end + end + sqlite.create_function("sub3", sub3, 0) +'); + +select L(' + local sub4 = function() + for value in sqlite.urows("select sub3() from numbers009") do + return tonumber(value) + -- urows in unfinalized_statements + end + end + sqlite.create_function("sub4", sub4, 0) +'); + +-- Test that nested closure calls fail (expected behavior) +select L(' + local ok, err = pcall(function() + sqlite.fetch_first("select sub4()") + end) + if not ok and tostring(err):match("max call depth") then + return "PASS: sub4 correctly rejected nested closure call" + else + return "FAIL: expected max call depth error, got: " .. tostring(ok) .. " " .. tostring(err) + end +'); + +-- Test nested aggregate (create_function_agg_chk) +select L(' + sqlite.create_function_agg_chk("nested_sum", "n = 0", "n = n + arg[1]", "return n", 1) +'); + +-- Direct call +select nested_sum(num1) from numbers009; + +-- Nested call from L() +select L(' + local result = sqlite.fetch_first("SELECT nested_sum(num1) as s FROM numbers009") + return result["s"] +'); diff --git a/sql/input_timed001.sql b/sql/performance/input_timed001.sql similarity index 100% rename from sql/input_timed001.sql rename to sql/performance/input_timed001.sql diff --git a/sql/input_007.sql b/sql/types/input_007.sql similarity index 76% rename from sql/input_007.sql rename to sql/types/input_007.sql index 1a19b18..e648a71 100644 --- a/sql/input_007.sql +++ b/sql/types/input_007.sql @@ -5,7 +5,6 @@ select('-------------'); select hex(x'ff00ff'); select L(' -_G.sqlite = require("sqlite_lj") _G.blob_to_hex_string = function (blob) local hex_string = "" for i = 0, blob.size - 1 do @@ -17,7 +16,9 @@ end '); select L(' -for a in sqlite.urows [[select (x''ff00ff''); ]] do print( blob_to_hex_string(a) ) end +for a in sqlite.urows [[select (x''ff00ff''); ]] do + return blob_to_hex_string(a) +end '); select hex(L(' @@ -34,8 +35,8 @@ select L(' for a, b in sqlite.urows ("select ?1, ?2", { sqlite.make_blob ({0xF5, 0x00, 0xF9}), sqlite.make_blob () - }) do - print( blob_to_hex_string(a), b.data, b.size ) + }) do + return blob_to_hex_string(a) .. "\t" .. tostring(b.data) .. "\t" .. tostring(b.size) end '); diff --git a/sql/types/input_012.sql b/sql/types/input_012.sql new file mode 100644 index 0000000..ac91994 --- /dev/null +++ b/sql/types/input_012.sql @@ -0,0 +1,12 @@ +.load ./libsqlite_plugin_lj +select L('sqlite.config.use_traceback = 0'); +select L(' + local ok, err = pcall(function() + return sqlite.fetch_first("select ?1 as v", {[1] = function() end}) + end) + if not ok and tostring(err):match("api_bind_any: unsupported type function") then + return "PASS: fetch_first bind error triggered" + else + return "FAIL: unexpected fetch_first result" + end +'); diff --git a/sql/types/input_013.sql b/sql/types/input_013.sql new file mode 100644 index 0000000..ba6fb24 --- /dev/null +++ b/sql/types/input_013.sql @@ -0,0 +1,40 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function run_matrix(kind, cases, run_one) + for _, c in ipairs(cases) do + local ok, err = pcall(run_one, c.value) + local name = kind .. " " .. c.name + if c.expect == "ok" then + out[#out + 1] = (ok and "PASS " or "FAIL ") .. name + else + local matched = (not ok) and tostring(err):match(c.err_pat or "") + out[#out + 1] = (matched and "PASS " or "FAIL ") .. name + end + end + end + + local bind_cases = { + { name = "sqlite.int_t", value = sqlite.int_t(7), expect = "ok" }, + { name = "sqlite.int64_t", value = sqlite.int64_t(8), expect = "ok" }, + { name = "sqlite.uint64_t", value = sqlite.uint64_t(9), expect = "ok" }, + { name = "sqlite.double_t", value = sqlite.double_t(2.5), expect = "ok" }, + { name = "NULL", value = NULL, expect = "ok" }, + { name = "number int", value = 42, expect = "ok" }, + { name = "number real", value = 3.25, expect = "ok" }, + { name = "string", value = "abc", expect = "ok" }, + { name = "boolean", value = true, expect = "ok" }, + { name = "blob empty", value = sqlite.make_blob({}), expect = "ok" }, + { name = "blob bytes", value = sqlite.make_blob({65,66,67}),expect = "ok" }, + { name = "invalid table", value = {foo = 1}, expect = "err", err_pat = "unsupported type table" }, + { name = "invalid function",value = function() end, expect = "err", err_pat = "unsupported type function" } + } + + run_matrix("bind", bind_cases, function(v) + sqlite.fetch_first("select ?1 as v", {[1] = v}) + end) + + return table.concat(out, "\n") +'); diff --git a/sql/types/input_040.sql b/sql/types/input_040.sql new file mode 100644 index 0000000..ecb3dda --- /dev/null +++ b/sql/types/input_040.sql @@ -0,0 +1,16 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.create_function("ret_pos_inf", function() return math.huge end, 0) + sqlite.create_function("ret_neg_inf", function() return -math.huge end, 0) + return "PASS register inf functions" +'); +select case when ret_pos_inf() > 1e308 then 'PASS return +inf stays float' else 'FAIL return +inf stays float' end; +select case when ret_neg_inf() < -1e308 then 'PASS return -inf stays float' else 'FAIL return -inf stays float' end; +select L(' + local out = {} + local pos = sqlite.fetch_first("SELECT (?1 > 1e308) AS v", {math.huge}) + out[#out + 1] = (tonumber(pos.v) == 1 and "PASS bind +inf stays float" or ("FAIL bind +inf stays float: " .. tostring(pos.v))) + local neg = sqlite.fetch_first("SELECT (?1 < -1e308) AS v", {-math.huge}) + out[#out + 1] = (tonumber(neg.v) == 1 and "PASS bind -inf stays float" or ("FAIL bind -inf stays float: " .. tostring(neg.v))) + return table.concat(out, "\n") +'); diff --git a/sql/input_006.sql b/sql/vtables/input_006.sql similarity index 89% rename from sql/input_006.sql rename to sql/vtables/input_006.sql index 9b4530f..a99738f 100644 --- a/sql/input_006.sql +++ b/sql/vtables/input_006.sql @@ -4,7 +4,6 @@ select('-------------'); select L(' - local sqlite = require("sqlite_lj") sqlite.run_sql("DROP TABLE IF EXISTS TEMP.table_a") sqlite.make_vtable("table_a", { @@ -17,7 +16,8 @@ select L(' select * from table_a o1; select('-------------'); -select L(' +-- Init step: define list_iterator in vtable_vm scope (returns 0 rows) +SELECT * FROM L(' _G.list_iterator = function(t) local i = 0 local n = #t @@ -26,9 +26,9 @@ select L(' if i <= n then return t[i] end end end + return function() return nil end '); - .mode column SELECT * , typeof(value) FROM L(' diff --git a/sql/vtables/input_010.sql b/sql/vtables/input_010.sql new file mode 100644 index 0000000..75bfd70 --- /dev/null +++ b/sql/vtables/input_010.sql @@ -0,0 +1,5 @@ +.load ./libsqlite_plugin_lj +.mode column +SELECT count(*) FROM L; +SELECT * , typeof(value) from L('return {1,2,NULL,3}'); +SELECT * , typeof(value) from L('return 6,NULL,7,8'); \ No newline at end of file diff --git a/sql/vtables/input_011.sql b/sql/vtables/input_011.sql new file mode 100644 index 0000000..fe25bf5 --- /dev/null +++ b/sql/vtables/input_011.sql @@ -0,0 +1,9 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.crash_vt") + sqlite.make_vtable("crash_vt", { + columns = {"a", "b"}, + rows = {{1, 2}} + }) +'); +select * from crash_vt; diff --git a/sql/vtables/input_014.sql b/sql/vtables/input_014.sql new file mode 100644 index 0000000..2e9c4b8 --- /dev/null +++ b/sql/vtables/input_014.sql @@ -0,0 +1,36 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local ok1, err1 = pcall(function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.\"tbl weird\"") + sqlite.make_vtable("tbl weird", { + columns = {"select", "a\"b", "sp ace"}, + rows = {{11, 22, 33}} + }) + local row = sqlite.fetch_first("select \"select\", \"a\"\"b\", \"sp ace\" from \"tbl weird\"") + if row and row["select"] == 11 and row["a\"b"] == 22 and row["sp ace"] == 33 then + out[#out + 1] = "PASS: make_vtable quoted identifiers" + else + out[#out + 1] = "FAIL: make_vtable quoted identifiers mismatch" + end + end) + if not ok1 then + out[#out + 1] = "FAIL: make_vtable quoted identifiers error" + end + + local ok2, err2 = pcall(function() + sqlite.make_vtable("bad_table", { + columns = {""}, + rows = {{1}} + }) + end) + if (not ok2) and tostring(err2):match("identifier must be a non%-empty string") then + out[#out + 1] = "PASS: make_vtable rejects empty identifier" + else + out[#out + 1] = "FAIL: make_vtable rejects empty identifier" + end + + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_019.sql b/sql/vtables/input_019.sql new file mode 100644 index 0000000..671c0f2 --- /dev/null +++ b/sql/vtables/input_019.sql @@ -0,0 +1,72 @@ +.load ./libsqlite_plugin_lj +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function expect_fail(name, fn, pat) + local ok, err = pcall(fn) + local matched = (not ok) and tostring(err):match(pat) + out[#out + 1] = (matched and "PASS " or "FAIL ") .. name + end + + local function expect_ok(name, fn) + local ok = pcall(fn) + out[#out + 1] = (ok and "PASS " or "FAIL ") .. name + end + + expect_ok("rows nil means empty table", function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.rows_nil") + sqlite.make_vtable("rows_nil", { columns = {"a"}, rows = nil }) + local r = sqlite.fetch_first("select count(*) as c from rows_nil") + assert(r and tonumber(r.c) == 0) + end) + + expect_ok("rows empty means empty table", function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.rows_empty") + sqlite.make_vtable("rows_empty", { columns = {"a"}, rows = {} }) + local r = sqlite.fetch_first("select count(*) as c from rows_empty") + assert(r and tonumber(r.c) == 0) + end) + + expect_ok("rows table-empty-row gives null row", function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.rows_null") + sqlite.make_vtable("rows_null", { columns = {"a", "b"}, rows = {{}} }) + local r = sqlite.fetch_first("select count(*) as c from rows_null where a is null and b is null") + assert(r and tonumber(r.c) == 1) + end) + + expect_fail("duplicate columns", function() + sqlite.make_vtable("dup_cols", { + columns = {"a", "A"}, + rows = {{1, 2}} + }) + end, "duplicate column name") + + expect_fail("rows must be table", function() + sqlite.make_vtable("bad_rows", { + columns = {"a"}, + rows = 1 + }) + end, "rows must be a table") + + expect_ok("extra row values ignored", function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.wide_rows") + sqlite.make_vtable("wide_rows", { + columns = {"a"}, + rows = {{1, 2}} + }) + local r = sqlite.fetch_first("select a from wide_rows") + assert(r and tonumber(r.a) == 1) + end) + + expect_ok("garbage rows skipped", function() + sqlite.make_vtable("flat_rows", { + columns = {"a", "b"}, + rows = {1, 2} + }) + local r = sqlite.fetch_first("select count(*) as c from flat_rows") + assert(r and tonumber(r.c) == 0) + end) + + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_023.sql b/sql/vtables/input_023.sql new file mode 100644 index 0000000..bb0276e --- /dev/null +++ b/sql/vtables/input_023.sql @@ -0,0 +1,21 @@ +-- Regression stress test: repeated make_vtable/drop churn should remain stable. +-- This fixture is wired as a dedicated CTest target. + +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + for i = 1, 260 do + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.gc_vt") + sqlite.make_vtable("gc_vt", { + columns = {"a", "b"}, + rows = {{1, 2}, {3, 4}} + }) + + local r = sqlite.fetch_first("select count(*) as c from gc_vt") + if (not r) or tonumber(r.c) ~= 2 then + return "FAIL:count:" .. tostring(i) + end + end + return "PASS:vtable-churn-stable" +'); diff --git a/sql/vtables/input_030.sql b/sql/vtables/input_030.sql new file mode 100644 index 0000000..d8b5942 --- /dev/null +++ b/sql/vtables/input_030.sql @@ -0,0 +1,39 @@ +select load_extension('./libsqlite_plugin_lj'); + +-- Init list_iterator in vtable_vm scope +SELECT * FROM L(' + _G.list_iterator = function(t) + local i = 0 + local n = #t + return function () + i = i + 1 + if i <= n then return t[i] end + end + end + return function() return nil end + '); + +select('--- L10 all 10 columns ---'); +.mode column +SELECT * FROM L10(' + local tbl = { + {1, "a", 1.1, 10, "x", 100, true, -1, "end1", 999}, + {2, "b", 2.2, 20, "y", 200, false, -2, "end2", 888} + } + return list_iterator(tbl) +'); + +select('--- L10 filter on last columns ---'); +SELECT r0, r8, r9 FROM L10(' + local tbl = { + {1, "a", 1.1, 10, "x", 100, true, -1, "end1", 999}, + {2, "b", 2.2, 20, "y", 200, false, -2, "end2", 888}, + {3, "c", 3.3, 30, "z", 300, true, -3, "end3", 777} + } + return list_iterator(tbl) +') WHERE r9 < 900; + +select('--- L10 single row all columns ---'); +SELECT * FROM L10(' + return list_iterator({{10, 20, 30, 40, 50, 60, 70, 80, 90, 100}}) +'); diff --git a/sql/vtables/input_031.sql b/sql/vtables/input_031.sql new file mode 100644 index 0000000..3055698 --- /dev/null +++ b/sql/vtables/input_031.sql @@ -0,0 +1,46 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function check(label, ok) + out[#out + 1] = (ok and "PASS " or "FAIL ") .. label + end + + -- Test 1: Create vtable, query, drop, re-create with different data + sqlite.make_vtable("reuse_t", { + columns = {"a", "b"}, + rows = {{1, "x"}, {2, "y"}} + }) + local r1 = sqlite.fetch_all("SELECT a, b FROM reuse_t ORDER BY a") + check("initial create", #r1 == 2 and tonumber(r1[1].a) == 1 and r1[2].b == "y") + + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.reuse_t") + sqlite.make_vtable("reuse_t", { + columns = {"a", "b"}, + rows = {{10, "p"}, {20, "q"}, {30, "r"}} + }) + local r2 = sqlite.fetch_all("SELECT a, b FROM reuse_t ORDER BY a") + check("re-create different data", #r2 == 3 and tonumber(r2[1].a) == 10 and r2[3].b == "r") + + -- Test 2: Re-create with different row count (fewer rows) + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.reuse_t") + sqlite.make_vtable("reuse_t", { + columns = {"a", "b"}, + rows = {{99, "only"}} + }) + local r3 = sqlite.fetch_all("SELECT a, b FROM reuse_t") + check("re-create fewer rows", #r3 == 1 and tonumber(r3[1].a) == 99) + + -- Test 3: Re-create with empty rows + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.reuse_t") + sqlite.make_vtable("reuse_t", { + columns = {"a", "b"}, + rows = {} + }) + local r4 = sqlite.fetch_first("SELECT count(*) as c FROM reuse_t") + check("re-create empty", r4 and tonumber(r4.c) == 0) + + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_032.sql b/sql/vtables/input_032.sql new file mode 100644 index 0000000..d69a1c0 --- /dev/null +++ b/sql/vtables/input_032.sql @@ -0,0 +1,37 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + + sqlite.make_vtable("inner_t", { + columns = {"v"}, + rows = {{1}, {2}, {3}} + }) +'); + +select('--- re-entry from L into make_vtable ---'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function check(label, ok) + out[#out + 1] = (ok and "PASS " or "FAIL ") .. label + end + + -- Test 1: L callback tries to SELECT from a make_vtable -> should fail with re-entry error + local ok1, err1 = pcall(function() + for row in sqlite.nrows("SELECT * FROM L(''return function() return sqlite.fetch_first(\"SELECT count(*) as c FROM inner_t\").c end'')") do + end + end) + check("L into make_vtable blocked", not ok1 and tostring(err1):find("vtable cannot be used from nested context") ~= nil) + + -- Test 2: L callback tries to SELECT from L -> should also fail + local ok2, err2 = pcall(function() + for row in sqlite.nrows("SELECT * FROM L(''return function() return sqlite.fetch_first(\"SELECT * FROM L(''''return 1'''')\").value end'')") do + end + end) + check("L into L blocked", not ok2 and tostring(err2):find("vtable cannot be used from nested context") ~= nil) + + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_035.sql b/sql/vtables/input_035.sql new file mode 100644 index 0000000..0fe7382 --- /dev/null +++ b/sql/vtables/input_035.sql @@ -0,0 +1,94 @@ +select load_extension('./libsqlite_plugin_lj'); + +-- Init list_iterator in vtable_vm scope +SELECT * FROM L(' + _G.list_iterator = function(t) + local i = 0 + local n = #t + return function () + i = i + 1 + if i <= n then return t[i] end + end + end + return function() return nil end +'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function check(label, ok, detail) + if ok then + out[#out + 1] = "PASS " .. label + else + out[#out + 1] = "FAIL " .. label + if detail then + out[#out + 1] = "DETAIL:" .. label .. ":" .. tostring(detail) + end + end + end + + -- Setup: create a regular table to JOIN with + sqlite.run_sql("CREATE TABLE IF NOT EXISTS t035(id INTEGER PRIMARY KEY, val TEXT)") + sqlite.run_sql("DELETE FROM t035") + sqlite.run_sql("INSERT INTO t035 VALUES(1, ''alpha'')") + sqlite.run_sql("INSERT INTO t035 VALUES(2, ''beta'')") + sqlite.run_sql("INSERT INTO t035 VALUES(3, ''gamma'')") + + -- Test 1: Baseline - simple L10 query (no join, single constraint) + local ok1, err1 = pcall(function() + local rows = sqlite.fetch_all("SELECT r0, r1 FROM L10(''return list_iterator({{10, 20}, {30, 40}})'') ") + assert(#rows == 2, "expected 2 rows, got " .. #rows) + assert(tonumber(rows[1].r0) == 10, "expected r0=10, got " .. tostring(rows[1].r0)) + assert(tonumber(rows[2].r1) == 40, "expected r1=40, got " .. tostring(rows[2].r1)) + end) + check("L10 baseline no join", ok1, err1) + + -- Test 2: JOIN L10 with regular table + WHERE on L10 column + -- This creates non-usable constraints from the join condition + local ok2, err2 = pcall(function() + local sql = "SELECT t.id, t.val, v.r0, v.r1 FROM t035 t JOIN L10(''return list_iterator({{1, \"x\"}, {2, \"y\"}, {3, \"z\"}})'') v ON t.id = v.r0 ORDER BY t.id" + local rows = sqlite.fetch_all(sql) + assert(#rows == 3, "expected 3 rows, got " .. #rows) + assert(tonumber(rows[1].id) == 1, "expected id=1") + assert(rows[1].r1 == "x", "expected r1=x, got " .. tostring(rows[1].r1)) + assert(tonumber(rows[3].id) == 3, "expected id=3") + assert(rows[3].r1 == "z", "expected r1=z, got " .. tostring(rows[3].r1)) + end) + check("L10 JOIN regular table", ok2, err2) + + -- Test 3: JOIN L (single-column) with regular table + local ok3, err3 = pcall(function() + local sql = "SELECT t.id, t.val, v.value FROM t035 t JOIN L(''return list_iterator({1, 2, 3})'') v ON t.id = v.value ORDER BY t.id" + local rows = sqlite.fetch_all(sql) + assert(#rows == 3, "expected 3 rows, got " .. #rows) + assert(tonumber(rows[1].value) == 1, "expected value=1") + assert(rows[2].val == "beta", "expected val=beta") + end) + check("L JOIN regular table", ok3, err3) + + -- Test 4: LEFT JOIN with L10 as right side + local ok4, err4 = pcall(function() + local sql = "SELECT t.id, v.r0 FROM t035 t LEFT JOIN L10(''return list_iterator({{1, \"match\"}})'') v ON t.id = v.r0 ORDER BY t.id" + local rows = sqlite.fetch_all(sql) + assert(#rows == 3, "expected 3 rows, got " .. #rows) + assert(tonumber(rows[1].r0) == 1, "expected r0=1 for id=1") + assert(rows[2].r0 == nil or rows[2].r0 == "" or rows[2].r0 == NULL, "expected NULL r0 for id=2") + end) + check("L10 LEFT JOIN", ok4, err4) + + -- Test 5: Subquery with IN (SELECT ... FROM L10) + local ok5, err5 = pcall(function() + local sql = "SELECT id, val FROM t035 WHERE id IN (SELECT r0 FROM L10(''return list_iterator({{1}, {3}})'')) ORDER BY id" + local rows = sqlite.fetch_all(sql) + assert(#rows == 2, "expected 2 rows, got " .. #rows) + assert(tonumber(rows[1].id) == 1, "expected id=1") + assert(tonumber(rows[2].id) == 3, "expected id=3") + end) + check("L10 IN subquery", ok5, err5) + + -- Cleanup + sqlite.run_sql("DROP TABLE IF EXISTS t035") + + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_037.sql b/sql/vtables/input_037.sql new file mode 100644 index 0000000..a93827b --- /dev/null +++ b/sql/vtables/input_037.sql @@ -0,0 +1,119 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function check(label, ok, detail) + if ok then + out[#out + 1] = "PASS " .. label + else + out[#out + 1] = "FAIL " .. label + if detail then + out[#out + 1] = "DETAIL:" .. label .. ":" .. tostring(detail) + end + end + end + + -- Test 1: Standard make_vtable TEMP table works (exercises xConnect path) + local ok1, err1 = pcall(function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") + sqlite.make_vtable("vt037", { + columns = {"id", "name"}, + rows = {{1, "alice"}, {2, "bob"}} + }) + local rows = sqlite.fetch_all("SELECT id, name FROM vt037 ORDER BY id") + assert(#rows == 2, "expected 2 rows, got " .. #rows) + assert(tonumber(rows[1].id) == 1, "expected id=1") + assert(rows[2].name == "bob", "expected name=bob") + end) + check("make_vtable TEMP works", ok1, err1) + + -- Test 2: Repeated queries on same vtable + local ok2, err2 = pcall(function() + for i = 1, 3 do + local rows = sqlite.fetch_all("SELECT id, name FROM vt037 ORDER BY id") + assert(#rows == 2, "query " .. i .. ": expected 2 rows") + assert(tonumber(rows[1].id) == 1, "query " .. i .. ": expected id=1") + end + end) + check("repeated queries stable", ok2, err2) + + -- Test 3: Drop and re-create (xDestroy then new xCreate) + local ok3, err3 = pcall(function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") + sqlite.make_vtable("vt037", { + columns = {"id", "name"}, + rows = {{10, "charlie"}, {20, "diana"}, {30, "eve"}} + }) + local rows = sqlite.fetch_all("SELECT id, name FROM vt037 ORDER BY id") + assert(#rows == 3, "expected 3 rows after re-create") + assert(tonumber(rows[1].id) == 10, "expected id=10") + assert(rows[3].name == "eve", "expected name=eve") + end) + check("drop and re-create", ok3, err3) + + -- Test 4: Re-create with different column schema + local ok4, err4 = pcall(function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") + sqlite.make_vtable("vt037", { + columns = {"x", "y", "z"}, + rows = {{100, 200, 300}} + }) + local rows = sqlite.fetch_all("SELECT x, y, z FROM vt037") + assert(#rows == 1, "expected 1 row") + assert(tonumber(rows[1].x) == 100, "expected x=100") + assert(tonumber(rows[1].z) == 300, "expected z=300") + end) + check("re-create different schema", ok4, err4) + + -- Test 5: Non-TEMP vtable creation using module name directly + -- make_vtable registers a module named after the table. Creating a non-TEMP + -- virtual table exercises the xCreate path (not xConnect). + -- lua_vtable_module.xCreate just returns SQLITE_OK without setting ppVTab, + -- so this should fail (SQLite should detect the NULL vtab pointer). + local ok5, err5 = pcall(function() + -- Ensure temp version is dropped first so module is available + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") + -- Re-register the module by creating a temp table first + sqlite.make_vtable("vt037", { + columns = {"a"}, + rows = {{1}} + }) + -- Drop the temp table but module remains registered + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") + -- Now try creating a non-TEMP virtual table using the same module name + -- This goes through xCreate (not xConnect) which doesn''t set ppVTab + sqlite.run_sql("CREATE VIRTUAL TABLE main.vt037_main USING vt037()") + end) + -- Whether it errors or not, document the behavior + if ok5 then + -- If it succeeded, try to query it + local qok, qerr = pcall(function() + local rows = sqlite.fetch_all("SELECT * FROM main.vt037_main") + end) + check("non-TEMP xCreate succeeds (query " .. (qok and "works" or "fails") .. ")", true) + pcall(function() sqlite.run_sql("DROP TABLE IF EXISTS main.vt037_main") end) + else + -- Expected: xCreate doesn''t set ppVTab, so SQLite should error + check("non-TEMP xCreate errors as expected", true) + end + + -- Test 6: Verify make_vtable still works after the xCreate test + local ok6, err6 = pcall(function() + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") + sqlite.make_vtable("vt037", { + columns = {"v"}, + rows = {{42}} + }) + local r = sqlite.fetch_first("SELECT v FROM vt037") + assert(tonumber(r.v) == 42, "expected v=42, got " .. tostring(r.v)) + end) + check("make_vtable works after xCreate test", ok6, err6) + + -- Cleanup + pcall(function() sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt037") end) + pcall(function() sqlite.run_sql("DROP TABLE IF EXISTS main.vt037_main") end) + + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_038.sql b/sql/vtables/input_038.sql new file mode 100644 index 0000000..d140af7 --- /dev/null +++ b/sql/vtables/input_038.sql @@ -0,0 +1,48 @@ +select load_extension('./libsqlite_plugin_lj'); + +select L(' + sqlite.config.use_traceback = 0 + local out = {} + + local function check(label, ok, detail) + if ok then + out[#out + 1] = "PASS " .. label + else + out[#out + 1] = "FAIL " .. label + if detail then + out[#out + 1] = "DETAIL:" .. label .. ":" .. tostring(detail) + end + end + end + + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt038") + + sqlite.make_vtable("vt038", { + columns = {"v"}, + rows = {{1}, {2}} + }) + local rows1 = sqlite.fetch_all("SELECT v FROM vt038 ORDER BY v") + check("baseline table created", #rows1 == 2 and tonumber(rows1[1].v) == 1 and tonumber(rows1[2].v) == 2) + + local ok2, err2 = pcall(function() + sqlite.make_vtable("vt038", { + columns = {"v"}, + rows = {{10}, {20}, {30}} + }) + end) + check("re-create without drop errors", (not ok2) and tostring(err2):find("already exists") ~= nil, err2) + + local rows3 = sqlite.fetch_all("SELECT v FROM vt038 ORDER BY v") + check("original table preserved after failed re-create", #rows3 == 2 and tonumber(rows3[1].v) == 1 and tonumber(rows3[2].v) == 2) + + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt038") + sqlite.make_vtable("vt038", { + columns = {"v"}, + rows = {{7}} + }) + local rows4 = sqlite.fetch_all("SELECT v FROM vt038") + check("re-create works after explicit drop", #rows4 == 1 and tonumber(rows4[1].v) == 7) + + sqlite.run_sql("DROP TABLE IF EXISTS TEMP.vt038") + return table.concat(out, "\n") +'); diff --git a/sql/vtables/input_039.sql b/sql/vtables/input_039.sql new file mode 100644 index 0000000..33909a2 --- /dev/null +++ b/sql/vtables/input_039.sql @@ -0,0 +1,43 @@ +select load_extension('./libsqlite_plugin_lj'); + +-- Initialize list_iterator in vtable VM scope +SELECT * +FROM L(' + _G.list_iterator = function(t) + local i = 0 + local n = #t + return function () + i = i + 1 + if i <= n then return t[i] end + end + end + return function() return nil end +'); + +select L(' + sqlite.config.use_traceback = 0 + local ok, err = pcall(function() + sqlite.fetch_first([[ + SELECT r0 + FROM L10('' + local done = false + return function() + if done then return nil end + done = true + local row = setmetatable({}, { + __index = function(_, k) + local nested = sqlite.fetch_first("SELECT value FROM L(''''return list_iterator({42})'''')") + return tonumber(nested.value) + end + }) + return row + end + '') + ]]) + end) + + local msg = tostring(err) + local blocked = (not ok) + and msg:find("vtable cannot be used from nested context") ~= nil + return (blocked and "PASS" or "FAIL") .. " xColumn re-entry blocked" +'); diff --git a/src/plugin.c b/src/plugin.c index 07055b4..8570328 100644 --- a/src/plugin.c +++ b/src/plugin.c @@ -3,109 +3,1142 @@ #include #include +#include + #include +#include +#include +#include +#include +#if defined(_WIN32) +#include +#else +#include +#endif -static lua_State *L = NULL; -static int extension_init_ref = 0; -static int extension_deinit_ref = 0; +#if defined(_MSC_VER) +#define ATTR_HIDDEN +#define ATTR_CONSTRUCTOR +#define ATTR_DESTRUCTOR +#define EXT_EXPORT __declspec(dllexport) +#else +#define ATTR_HIDDEN __attribute__((visibility("hidden"))) +#define ATTR_CONSTRUCTOR __attribute__((constructor)) +#define ATTR_DESTRUCTOR __attribute__((destructor)) +#define EXT_EXPORT +#endif -void __attribute__((visibility("hidden"))) checkLuaError(int status); +int ATTR_HIDDEN checkLuaError(lua_State* L, int status); -void __attribute__((constructor)) before_main() -{ - L = lua_open(); - //LUAJIT_VERSION_SYM(); - lua_gc(L, LUA_GCSTOP, 0); - luaL_openlibs(L); - lua_gc(L, LUA_GCRESTART, -1); +static int loadBytecodeObject(lua_State* L, const unsigned char* bytecode, size_t size, const char* name) { + if (luaL_loadbuffer(L, (const char*)bytecode, size, name) != LUA_OK) { + const char* err = lua_tostring(L, -1); + printf("Load error: %s\n", err); + lua_pop(L, 1); + return -1; + } + + if (lua_pcall(L, 0, 1, 0) != LUA_OK) { + const char* err = lua_tostring(L, -1); + printf("Runtime error: %s\n", err); + lua_pop(L, 1); + return -1; + } + + lua_setglobal(L, name); + lua_settop(L, 0); + + return 0; } -void __attribute__((destructor)) after_main(); -void after_main() -{ - lua_rawgeti(L, LUA_REGISTRYINDEX, extension_deinit_ref); - int status = lua_pcall(L, 0, 0, 0); - checkLuaError(status); +static const sqlite3_api_routines *sqlite3_api; - lua_close(L); +static lua_State *L_shared; +static int bridge_ready = 0; +static char bridge_init_error[256] = {0}; + +static void set_bridge_error(const char* msg) { + if (!msg || bridge_init_error[0] != '\0') { + return; + } + snprintf(bridge_init_error, sizeof(bridge_init_error), "%s", msg); } +typedef struct FunctionContext { + void (*fn_ptr)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*fn_step_ptr)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*fn_final_ptr)(sqlite3_context *ctx); + void (*fn_destroy_ptr)(void*); + int64_t udata; + int allowed_nested; +} FunctionContext; -void checkLuaError(int status) -{ - if (status) - { - printf( "Lua error %d\n", status); - if( status == LUA_ERRRUN) { - if (lua_type(L, -1) == LUA_TSTRING){ - printf("Error: %s", lua_tostring((L), -1)); - lua_pop(L, lua_gettop(L)); - - }else { - printf("Lua error: [object]"); - } - } else if (status == LUA_ERRMEM) { - printf("%s %s","Memory error:",lua_tostring(L, -1)); - } else if (status == LUA_ERRERR) { - printf("%s %s","Error:",lua_tostring(L, -1)); - } +typedef int (*set_object_cb_t)(const char* data, size_t len); +typedef const char* (*get_object_cb_t)(int id, size_t* len); +typedef int (*set_fn_context_cb_t)(const FunctionContext* ctx); +typedef const FunctionContext* (*get_fn_context_cb_t)(int64_t id); +typedef struct { + set_object_cb_t set_object; + get_object_cb_t get_object; + set_fn_context_cb_t set_fn_context; + get_fn_context_cb_t get_fn_context; +} SharedBridge; +static SharedBridge bridge = {0}; + +#if defined(_WIN32) +typedef DWORD conn_thread_id_t; +static conn_thread_id_t current_thread_id(void) { + return GetCurrentThreadId(); +} +static int thread_id_equal(conn_thread_id_t a, conn_thread_id_t b) { + return a == b; +} +#else +typedef pthread_t conn_thread_id_t; +static conn_thread_id_t current_thread_id(void) { + return pthread_self(); +} +static int thread_id_equal(conn_thread_id_t a, conn_thread_id_t b) { + return pthread_equal(a, b); +} +#endif + +static const char* THREAD_GUARD_ERR = "sqlite_plugin_lj: connection used from multiple threads"; +static const char* VM_INIT_ERR = "sqlite_plugin_lj: Lua VM init failed or call depth exceeded"; +static const char* FN_CONTEXT_ERR = "sqlite_plugin_lj: function context callback is null"; + +#if defined(_WIN32) +static INIT_ONCE state_lock_once = INIT_ONCE_STATIC_INIT; +static CRITICAL_SECTION state_lock_cs; +static BOOL CALLBACK init_state_lock(PINIT_ONCE once, PVOID param, PVOID *ctx) { + (void)once; + (void)param; + (void)ctx; + InitializeCriticalSection(&state_lock_cs); + return TRUE; +} +static void state_lock(void) { + InitOnceExecuteOnce(&state_lock_once, init_state_lock, NULL, NULL); + EnterCriticalSection(&state_lock_cs); +} +static void state_unlock(void) { + LeaveCriticalSection(&state_lock_cs); +} +#else +static pthread_mutex_t state_lock_mutex = PTHREAD_MUTEX_INITIALIZER; +static void state_lock(void) { + pthread_mutex_lock(&state_lock_mutex); +} +static void state_unlock(void) { + pthread_mutex_unlock(&state_lock_mutex); +} +#endif + +static sqlite3_mutex* global_mutex(void) { + if (!sqlite3_api || !sqlite3_api->mutex_alloc) { + return NULL; + } + return sqlite3_api->mutex_alloc(SQLITE_MUTEX_STATIC_MAIN); +} + +static void global_lock(void) { + sqlite3_mutex *m = global_mutex(); + if (m && sqlite3_api->mutex_enter) { + sqlite3_api->mutex_enter(m); + } +} + +static void global_unlock(void) { + sqlite3_mutex *m = global_mutex(); + if (m && sqlite3_api->mutex_leave) { + sqlite3_api->mutex_leave(m); + } +} + +static int pushFunctionContext(FunctionContext ctx) { + if (!bridge.set_fn_context) { + return 0; + } + global_lock(); + int rc = bridge.set_fn_context(&ctx); + global_unlock(); + return rc; +} + +static const FunctionContext* getFunctionContext(int64_t index) { + if (!bridge.get_fn_context) { + return NULL; + } + global_lock(); + const FunctionContext* rc = bridge.get_fn_context(index); + global_unlock(); + return rc; +} + +static int bridge_set_object_locked(const char* data, size_t len) { + if (!bridge.set_object) { + return 0; + } + global_lock(); + int rc = bridge.set_object(data, len); + global_unlock(); + return rc; +} + +static const char* bridge_get_object_locked(int id, size_t* len) { + if (!bridge.get_object) { + return NULL; } + global_lock(); + const char* rc = bridge.get_object(id, len); + global_unlock(); + return rc; } -typedef struct sqlite3 sqlite3; -typedef struct sqlite3_api_routines sqlite3_api_routines; +#define SAVED_VM 10 typedef struct LJFunctionData { sqlite3 * db; char ** msg; const sqlite3_api_routines *api; + void (*callback)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*cb_context_fn)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*cb_context_step_fn)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*cb_context_final_fn)(sqlite3_context *ctx); + void (*cb_context_destroy_fn)(void*); + void (*sqlite_return_int_cb)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*sqlite_return_text_cb)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*sqlite_free_cb)(void*); + + // context functions + int (*pushFunctionContext)(FunctionContext ctx); + const FunctionContext* (*getFunctionContext)(int64_t index); + //shared bridge functions + int (*set_object)(const char* data, size_t len); + const char* (*get_object)(int id, size_t* len); + + + int call_depth; + int is_vtable_vm; + int conn_slot; + // lua write + void (*caller_fn)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + + // vtable callbacks - Lua implementations (set by Lua) + int (*vtab_xOpen_lua)(sqlite3_vtab*, sqlite3_vtab_cursor**); + int (*vtab_xClose_lua)(sqlite3_vtab_cursor*); + int (*vtab_xFilter_lua)(sqlite3_vtab_cursor*, int, const char*, int, sqlite3_value**); + int (*vtab_xNext_lua)(sqlite3_vtab_cursor*); + int (*vtab_xEof_lua)(sqlite3_vtab_cursor*); + int (*vtab_xColumn_lua)(sqlite3_vtab_cursor*, sqlite3_context*, int); + int (*vtab_xRowid_lua)(sqlite3_vtab_cursor*, sqlite3_int64*); + // vtable callbacks - C wrappers (provided by C, used by Lua module) + int (*cb_vtab_xOpen)(sqlite3_vtab*, sqlite3_vtab_cursor**); + int (*cb_vtab_xClose)(sqlite3_vtab_cursor*); + int (*cb_vtab_xFilter)(sqlite3_vtab_cursor*, int, const char*, int, sqlite3_value**); + int (*cb_vtab_xNext)(sqlite3_vtab_cursor*); + int (*cb_vtab_xEof)(sqlite3_vtab_cursor*); + int (*cb_vtab_xColumn)(sqlite3_vtab_cursor*, sqlite3_context*, int); + int (*cb_vtab_xRowid)(sqlite3_vtab_cursor*, sqlite3_int64*); } LJFunctionData; +typedef struct LJFunctionArgs { + sqlite3_context *ctx; + int argc; + sqlite3_value **argv; +} LJFunctionArgs; + +typedef struct Worker { + lua_State *L; + int extension_init_ref; + int extension_deinit_ref; + int extension_call_ref; + LJFunctionData* udata; +} Worker; + +typedef struct ConnState { + sqlite3 *db; + int slot; + bool cleanup_registered; + bool owner_thread_set; + conn_thread_id_t owner_thread_id; + Worker vm_stack[SAVED_VM]; + int call_depth; + Worker vtable_vm; + bool vtable_vm_busy; + bool vtable_vm_initialized; +} ConnState; -#define SQLITE_OK 0 -#define SQLITE_ERROR 1 +#define MAX_CONN_STATES 32 +static ConnState conn_states[MAX_CONN_STATES] = {0}; + +typedef struct lua_vtab_local { + sqlite3_vtab base; + uint32_t index; + int32_t conn_slot; +} lua_vtab_local; + +static ConnState* find_conn_state_unlocked(sqlite3 *db); +static ConnState* get_or_create_conn_state(sqlite3 *db); +static ConnState* find_conn_state_by_slot(int slot); +static ConnState* conn_state_from_context(sqlite3_context *ctx); +static ConnState* conn_state_from_vtab(sqlite3_vtab *pVtab); +static void cleanup_conn_state(ConnState *state); +static void sqlite_conn_cleanup_destroy_cb(void *p); +static void sqlite_conn_cleanup_noop_fn(sqlite3_context *ctx, int argc, sqlite3_value **argv); +static bool claim_or_validate_owner_thread(ConnState *state); +Worker ATTR_HIDDEN push_vm(ConnState *state); +void ATTR_HIDDEN pop_vm(ConnState *state, Worker w); + +static ConnState* find_conn_state_unlocked(sqlite3 *db) { + if (!db) { + return NULL; + } + for (int i = 0; i < MAX_CONN_STATES; ++i) { + if (conn_states[i].db == db) { + return &conn_states[i]; + } + } + return NULL; +} + +static ConnState* get_or_create_conn_state(sqlite3 *db) { + state_lock(); + ConnState *found = find_conn_state_unlocked(db); + if (found) { + state_unlock(); + return found; + } + for (int i = 0; i < MAX_CONN_STATES; ++i) { + if (conn_states[i].db == NULL) { + memset(&conn_states[i], 0, sizeof(conn_states[i])); + conn_states[i].db = db; + conn_states[i].slot = i + 1; + state_unlock(); + return &conn_states[i]; + } + } + state_unlock(); + return NULL; +} + +static ConnState* find_conn_state_by_slot(int slot) { + state_lock(); + if (slot <= 0 || slot > MAX_CONN_STATES) { + state_unlock(); + return NULL; + } + ConnState *state = &conn_states[slot - 1]; + if (state->db == NULL) { + state_unlock(); + return NULL; + } + state_unlock(); + return state; +} + +static ConnState* conn_state_from_context(sqlite3_context *ctx) { + if (!ctx || !sqlite3_api || !sqlite3_api->context_db_handle) { + return NULL; + } + sqlite3 *db = sqlite3_api->context_db_handle(ctx); + state_lock(); + ConnState *state = find_conn_state_unlocked(db); + state_unlock(); + return state; +} + +static ConnState* conn_state_from_vtab(sqlite3_vtab *pVtab) { + if (!pVtab) { + return NULL; + } + lua_vtab_local *local_vtab = (lua_vtab_local*)pVtab; + return find_conn_state_by_slot(local_vtab->conn_slot); +} + +static bool claim_or_validate_owner_thread(ConnState *state) { + if (!state) { + return false; + } + conn_thread_id_t tid = current_thread_id(); + bool ok = false; + state_lock(); + if (!state->owner_thread_set) { + state->owner_thread_id = tid; + state->owner_thread_set = true; + ok = true; + } else { + ok = thread_id_equal(state->owner_thread_id, tid) != 0; + } + state_unlock(); + return ok; +} + +static void cleanup_conn_state(ConnState *state) { + Worker vm_copy[SAVED_VM] = {0}; + Worker vtable_copy = {0}; + state_lock(); + if (!state || !state->db) { + state_unlock(); + return; + } + for (int j = 0; j < SAVED_VM; ++j) { + vm_copy[j] = state->vm_stack[j]; + memset(&state->vm_stack[j], 0, sizeof(state->vm_stack[j])); + } + if (state->vtable_vm.L) { + vtable_copy = state->vtable_vm; + memset(&state->vtable_vm, 0, sizeof(state->vtable_vm)); + } + int slot = state->slot; + memset(state, 0, sizeof(*state)); + state->slot = slot; + state_unlock(); -extern int sqlite3_extension_init( + for (int j = 0; j < SAVED_VM; ++j) { + Worker *w = &vm_copy[j]; + if (!w->L) { + continue; + } + if (w->extension_deinit_ref != LUA_NOREF) { + lua_rawgeti(w->L, LUA_REGISTRYINDEX, w->extension_deinit_ref); + int status = lua_pcall(w->L, 0, 0, 0); + checkLuaError(w->L, status); + } + free(w->udata); + w->udata = NULL; + lua_close(w->L); + } + if (vtable_copy.L) { + if (vtable_copy.extension_deinit_ref != LUA_NOREF) { + lua_rawgeti(vtable_copy.L, LUA_REGISTRYINDEX, vtable_copy.extension_deinit_ref); + int status = lua_pcall(vtable_copy.L, 0, 0, 0); + checkLuaError(vtable_copy.L, status); + } + free(vtable_copy.udata); + vtable_copy.udata = NULL; + lua_close(vtable_copy.L); + } +} + +static void sqlite_conn_cleanup_destroy_cb(void *p) { + cleanup_conn_state((ConnState*)p); +} + +static void sqlite_conn_cleanup_noop_fn(sqlite3_context *ctx, int argc, sqlite3_value **argv) { + (void)argc; + (void)argv; + sqlite3_api->result_null(ctx); +} + +static void sqlite_luajit_callback(sqlite3_context *ctx, int argc, sqlite3_value **argv) { + ConnState *state = conn_state_from_context(ctx); + if (!state) { + sqlite3_api->result_error(ctx, "sqlite_plugin_lj: connection state missing", -1); + return; + } + if (!claim_or_validate_owner_thread(state)) { + sqlite3_api->result_error(ctx, THREAD_GUARD_ERR, -1); + return; + } + Worker w = push_vm(state); + if (!w.L) { + // Handle error + sqlite3_api->result_error(ctx, VM_INIT_ERR, -1); + pop_vm(state, w); + return; + } + + w.udata->caller_fn(ctx, argc, argv); + + pop_vm(state, w); +} + +static void sqlite_luajit_callback_context_fn(sqlite3_context *ctx, int argc, sqlite3_value **argv) { + ConnState *state = conn_state_from_context(ctx); + if (!state) { + sqlite3_api->result_error(ctx, "sqlite_plugin_lj: connection state missing", -1); + return; + } + if (!claim_or_validate_owner_thread(state)) { + sqlite3_api->result_error(ctx, THREAD_GUARD_ERR, -1); + return; + } + Worker w = push_vm(state); + if (!w.L) { + sqlite3_api->result_error(ctx, VM_INIT_ERR, -1); + pop_vm(state, w); + return; + } + const FunctionContext* fn_context = getFunctionContext((int64_t)sqlite3_api->user_data(ctx)); + if (!fn_context) { + pop_vm(state, w); + return; + } + if (!fn_context->fn_ptr) { + sqlite3_api->result_error(ctx, FN_CONTEXT_ERR, -1); + pop_vm(state, w); + return; + } + if (!fn_context->allowed_nested && state && state->call_depth > 1) { + // Handle error + sqlite3_api->result_error(ctx, "max call depth[1] exceeded", -1); + pop_vm(state, w); + return; + } + + + fn_context->fn_ptr(ctx, argc, argv); + + pop_vm(state, w); +} + +static void sqlite_luajit_callback_context_step_fn(sqlite3_context *ctx, int argc, sqlite3_value **argv) { + ConnState *state = conn_state_from_context(ctx); + if (!state) { + sqlite3_api->result_error(ctx, "sqlite_plugin_lj: connection state missing", -1); + return; + } + if (!claim_or_validate_owner_thread(state)) { + sqlite3_api->result_error(ctx, THREAD_GUARD_ERR, -1); + return; + } + Worker w = push_vm(state); + if (!w.L) { + sqlite3_api->result_error(ctx, VM_INIT_ERR, -1); + pop_vm(state, w); + return; + } + const FunctionContext* fn_context = getFunctionContext((int64_t)sqlite3_api->user_data(ctx)); + if (!fn_context) { + pop_vm(state, w); + return; + } + if (!fn_context->fn_step_ptr) { + sqlite3_api->result_error(ctx, FN_CONTEXT_ERR, -1); + pop_vm(state, w); + return; + } + + if (!fn_context->allowed_nested && state && state->call_depth > 1) { + sqlite3_api->result_error(ctx, "max call depth[1] exceeded", -1); + pop_vm(state, w); + return; + } + + fn_context->fn_step_ptr(ctx, argc, argv); + pop_vm(state, w); +} + +static void sqlite_luajit_callback_context_final_fn(sqlite3_context *ctx) { + ConnState *state = conn_state_from_context(ctx); + if (!state) { + sqlite3_api->result_error(ctx, "sqlite_plugin_lj: connection state missing", -1); + return; + } + if (!claim_or_validate_owner_thread(state)) { + sqlite3_api->result_error(ctx, THREAD_GUARD_ERR, -1); + return; + } + Worker w = push_vm(state); + if (!w.L) { + sqlite3_api->result_error(ctx, VM_INIT_ERR, -1); + pop_vm(state, w); + return; + } + const FunctionContext* fn_context = getFunctionContext((int64_t)sqlite3_api->user_data(ctx)); + if (!fn_context) { + pop_vm(state, w); + return; + } + if (!fn_context->fn_final_ptr) { + sqlite3_api->result_error(ctx, FN_CONTEXT_ERR, -1); + pop_vm(state, w); + return; + } + + if (!fn_context->allowed_nested && state && state->call_depth > 1) { + sqlite3_api->result_error(ctx, "max call depth[1] exceeded", -1); + pop_vm(state, w); + return; + } + + fn_context->fn_final_ptr(ctx); + pop_vm(state, w); +} + +static void sqlite_luajit_callback_context_destroy_fn(void *unused) { + (void)unused; + // to be implemented if needed +} + + +static void sqlite_return_int_cb(sqlite3_context *ctx, int argc, sqlite3_value **argv) { + (void)argc; + (void)argv; + int64_t saved_constant = (int64_t) sqlite3_api->user_data(ctx); + sqlite3_api->result_int64(ctx, saved_constant); +} + +static void sqlite_return_text_cb(sqlite3_context *ctx, int argc, sqlite3_value **argv) { + (void)argc; + (void)argv; + const char* saved_constant = (const char*) sqlite3_api->user_data(ctx); + if (saved_constant == NULL) { + sqlite3_api->result_null(ctx); + return; + } + sqlite3_api->result_text(ctx, saved_constant, -1, SQLITE_STATIC); +} + +static void sqlite_user_data_free_cb(void* user_data) { + sqlite3_api->free(user_data); +} + + +// Forward declaration for get_vm with vtable flag +static Worker get_vm_internal(ConnState *state, int is_vtable); + +// Initialize the dedicated vtable VM +static void ensure_vtable_vm(ConnState *state) { + if (!state) { + return; + } + if (!state->vtable_vm_initialized) { + state->vtable_vm = get_vm_internal(state, 1); + state->vtable_vm_initialized = state->vtable_vm.L != NULL; + } +} + +// C wrapper for vtable xOpen +static int cb_vtab_xOpen(sqlite3_vtab* pVtab, sqlite3_vtab_cursor** ppCursor) { + ConnState *state = conn_state_from_vtab(pVtab); + if (!state) { + pVtab->zErrMsg = sqlite3_api->mprintf("vtable connection state missing"); + return SQLITE_ERROR; + } + if (!claim_or_validate_owner_thread(state)) { + pVtab->zErrMsg = sqlite3_api->mprintf("%s", THREAD_GUARD_ERR); + return SQLITE_ERROR; + } + if (state->vtable_vm_busy) { + pVtab->zErrMsg = sqlite3_api->mprintf("vtable cannot be used from nested context (vtable VM is busy)"); + return SQLITE_ERROR; + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xOpen_lua) { + pVtab->zErrMsg = sqlite3_api->mprintf("vtable VM not properly initialized"); + return SQLITE_ERROR; + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xOpen_lua(pVtab, ppCursor); + state->vtable_vm_busy = false; + + return result; +} + +// C wrapper for vtable xClose +static int cb_vtab_xClose(sqlite3_vtab_cursor* cursor) { + ConnState *state = conn_state_from_vtab(cursor ? cursor->pVtab : NULL); + if (!state) { + if (cursor && cursor->pVtab) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable connection state missing"); + } + return SQLITE_ERROR; + } + if (!claim_or_validate_owner_thread(state)) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("%s", THREAD_GUARD_ERR); + return SQLITE_ERROR; + } + if (state->vtable_vm_busy) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable cannot be used from nested context (vtable VM is busy)"); + return SQLITE_ERROR; + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xClose_lua) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable VM not properly initialized"); + return SQLITE_ERROR; + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xClose_lua(cursor); + state->vtable_vm_busy = false; + + return result; +} + +// C wrapper for vtable xFilter - checks busy flag before calling Lua +static int cb_vtab_xFilter(sqlite3_vtab_cursor* cursor, int idxNum, const char* idxStr, int argc, sqlite3_value** argv) { + ConnState *state = conn_state_from_vtab(cursor ? cursor->pVtab : NULL); + if (!state) { + if (cursor && cursor->pVtab) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable connection state missing"); + } + return SQLITE_ERROR; + } + if (!claim_or_validate_owner_thread(state)) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("%s", THREAD_GUARD_ERR); + return SQLITE_ERROR; + } + if (state->vtable_vm_busy) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable cannot be used from nested context (vtable VM is busy)"); + return SQLITE_ERROR; + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xFilter_lua) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable VM not properly initialized"); + return SQLITE_ERROR; + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xFilter_lua(cursor, idxNum, idxStr, argc, argv); + state->vtable_vm_busy = false; + + return result; +} + +// C wrapper for vtable xNext - checks busy flag before calling Lua +static int cb_vtab_xNext(sqlite3_vtab_cursor* cursor) { + ConnState *state = conn_state_from_vtab(cursor ? cursor->pVtab : NULL); + if (!state) { + if (cursor && cursor->pVtab) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable connection state missing"); + } + return SQLITE_ERROR; + } + if (!claim_or_validate_owner_thread(state)) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("%s", THREAD_GUARD_ERR); + return SQLITE_ERROR; + } + if (state->vtable_vm_busy) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable cannot be used from nested context (vtable VM is busy)"); + return SQLITE_ERROR; + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xNext_lua) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable VM not properly initialized"); + return SQLITE_ERROR; + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xNext_lua(cursor); + state->vtable_vm_busy = false; + + return result; +} + +// C wrapper for vtable xEof +static int cb_vtab_xEof(sqlite3_vtab_cursor* cursor) { + ConnState *state = conn_state_from_vtab(cursor ? cursor->pVtab : NULL); + if (!state) { + return 1; + } + if (!claim_or_validate_owner_thread(state)) { + return 1; + } + // xEof doesn't set error message, just returns true/false (non-zero = EOF) + if (state->vtable_vm_busy) { + return 1; // return EOF to stop iteration + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xEof_lua) { + return 1; // return EOF + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xEof_lua(cursor); + state->vtable_vm_busy = false; + + return result; +} + +// C wrapper for vtable xColumn +static int cb_vtab_xColumn(sqlite3_vtab_cursor* cursor, sqlite3_context* ctx, int col) { + ConnState *state = conn_state_from_vtab(cursor ? cursor->pVtab : NULL); + if (!state) { + sqlite3_api->result_error(ctx, "vtable connection state missing", -1); + return SQLITE_ERROR; + } + if (!claim_or_validate_owner_thread(state)) { + sqlite3_api->result_error(ctx, THREAD_GUARD_ERR, -1); + return SQLITE_ERROR; + } + if (state->vtable_vm_busy) { + sqlite3_api->result_error(ctx, "vtable cannot be used from nested context", -1); + return SQLITE_ERROR; + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xColumn_lua) { + sqlite3_api->result_error(ctx, "vtable VM not properly initialized", -1); + return SQLITE_ERROR; + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xColumn_lua(cursor, ctx, col); + state->vtable_vm_busy = false; + return result; +} + +// C wrapper for vtable xRowid +static int cb_vtab_xRowid(sqlite3_vtab_cursor* cursor, sqlite3_int64* pRowid) { + ConnState *state = conn_state_from_vtab(cursor ? cursor->pVtab : NULL); + if (!state) { + if (cursor && cursor->pVtab) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable connection state missing"); + } + return SQLITE_ERROR; + } + if (!claim_or_validate_owner_thread(state)) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("%s", THREAD_GUARD_ERR); + return SQLITE_ERROR; + } + if (state->vtable_vm_busy) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable cannot be used from nested context (vtable VM is busy)"); + return SQLITE_ERROR; + } + + ensure_vtable_vm(state); + if (!state->vtable_vm.L || !state->vtable_vm.udata || !state->vtable_vm.udata->vtab_xRowid_lua) { + cursor->pVtab->zErrMsg = sqlite3_api->mprintf("vtable VM not properly initialized"); + return SQLITE_ERROR; + } + + state->vtable_vm_busy = true; + int result = state->vtable_vm.udata->vtab_xRowid_lua(cursor, pRowid); + state->vtable_vm_busy = false; + return result; +} + +#include "sqlite_capi.h" +#include "sqlite_lj.h" +static Worker get_vm_internal(ConnState *state, int is_vtable) { + Worker w = (Worker){0}; + w.extension_init_ref = LUA_NOREF; + w.extension_deinit_ref = LUA_NOREF; + w.extension_call_ref = LUA_NOREF; + if (!state) { + return w; + } + w.L = lua_open(); + if (!w.L) { + return w; + } + //LUAJIT_VERSION_SYM(); + lua_gc(w.L, LUA_GCSTOP, 0); + luaL_openlibs(w.L); + lua_gc(w.L, LUA_GCRESTART, -1); + + if (loadBytecodeObject( + w.L, + luaJIT_BC_sqlite_capi, + luaJIT_BC_sqlite_capi_SIZE, "sqlite_capi") != 0) + { + lua_close(w.L); + w.L = NULL; + return w; + } + + if (loadBytecodeObject( + w.L, + luaJIT_BC_sqlite_lj, + luaJIT_BC_sqlite_lj_SIZE, "sqlite") != 0) + { + lua_close(w.L); + w.L = NULL; + return w; + } + + lua_getglobal(w.L, "sqlite"); + + { + lua_getfield(w.L, 1, "extension_init"); + w.extension_init_ref = luaL_ref(w.L, LUA_REGISTRYINDEX); + lua_getfield(w.L, 1, "extension_deinit"); + w.extension_deinit_ref = luaL_ref(w.L, LUA_REGISTRYINDEX); + lua_getfield(w.L, 1, "extension_call"); + w.extension_call_ref = luaL_ref(w.L, LUA_REGISTRYINDEX); + lua_settop(w.L, 0); + + + + lua_rawgeti(w.L, LUA_REGISTRYINDEX, w.extension_init_ref); + + w.udata = (LJFunctionData*)calloc(1, sizeof(LJFunctionData)); + if (!w.udata) { + lua_close(w.L); + w.L = NULL; + return w; + } + + w.udata->db = state->db; + w.udata->api = sqlite3_api; + w.udata->callback = sqlite_luajit_callback; + w.udata->cb_context_fn = sqlite_luajit_callback_context_fn; + w.udata->cb_context_step_fn = sqlite_luajit_callback_context_step_fn; + w.udata->cb_context_final_fn = sqlite_luajit_callback_context_final_fn; + w.udata->cb_context_destroy_fn = sqlite_luajit_callback_context_destroy_fn; + w.udata->sqlite_return_int_cb = sqlite_return_int_cb; + w.udata->sqlite_return_text_cb = sqlite_return_text_cb; + w.udata->sqlite_free_cb = sqlite_user_data_free_cb; + + // bridge functions + w.udata->set_object = bridge_set_object_locked; + w.udata->get_object = bridge_get_object_locked; + + w.udata->pushFunctionContext = pushFunctionContext; + w.udata->getFunctionContext = getFunctionContext; + + w.udata->call_depth = state->call_depth - 1; + w.udata->is_vtable_vm = is_vtable; + w.udata->conn_slot = state->slot; + w.udata->caller_fn = NULL; + + // vtable callbacks - Lua implementations set by Lua, C wrappers provided here + w.udata->vtab_xOpen_lua = NULL; + w.udata->vtab_xClose_lua = NULL; + w.udata->vtab_xFilter_lua = NULL; + w.udata->vtab_xNext_lua = NULL; + w.udata->vtab_xEof_lua = NULL; + w.udata->vtab_xColumn_lua = NULL; + w.udata->vtab_xRowid_lua = NULL; + w.udata->cb_vtab_xOpen = cb_vtab_xOpen; + w.udata->cb_vtab_xClose = cb_vtab_xClose; + w.udata->cb_vtab_xFilter = cb_vtab_xFilter; + w.udata->cb_vtab_xNext = cb_vtab_xNext; + w.udata->cb_vtab_xEof = cb_vtab_xEof; + w.udata->cb_vtab_xColumn = cb_vtab_xColumn; + w.udata->cb_vtab_xRowid = cb_vtab_xRowid; + + lua_pushlightuserdata(w.L, w.udata); + + int status = lua_pcall(w.L, 1, 0, 0); + if (!checkLuaError(w.L, status)) { + free(w.udata); + w.udata = NULL; + lua_close(w.L); + w.L = NULL; + w.extension_init_ref = LUA_NOREF; + w.extension_deinit_ref = LUA_NOREF; + w.extension_call_ref = LUA_NOREF; + return w; + } + lua_settop(w.L, 0); + } + return w; +} + +Worker push_vm(ConnState *state) { + Worker empty = (Worker){0}; + if (!state) { + return empty; + } + ++state->call_depth; + if(state->call_depth > SAVED_VM) { + return get_vm_internal(state, 0); + } else { + if (!state->vm_stack[state->call_depth-1].L) { + state->vm_stack[state->call_depth-1] = get_vm_internal(state, 0); + } + return state->vm_stack[state->call_depth-1]; + } + +} + +void pop_vm(ConnState *state, Worker w){ + if (!state) { + return; + } + if (!w.L) { + --state->call_depth; + return; + } + + if(state->call_depth > SAVED_VM) { + lua_State* L = w.L; + if (w.extension_deinit_ref != LUA_NOREF) { + lua_rawgeti(L, LUA_REGISTRYINDEX, w.extension_deinit_ref); + int status = lua_pcall(L, 0, 0, 0); + checkLuaError(L, status); + } + free(w.udata); + w.udata = NULL; + + lua_close(L); + } + --state->call_depth; +} + + + +#include "sqlite_shared.h" +void ATTR_CONSTRUCTOR before_main(void) +{ + if (L_shared || bridge_ready) { + return; + } + L_shared = lua_open(); + if (!L_shared) { + set_bridge_error("sqlite_plugin_lj: failed to create shared Lua state"); + return; + } + lua_gc(L_shared, LUA_GCSTOP, 0); + luaL_openlibs(L_shared); + lua_gc(L_shared, LUA_GCRESTART, -1); + + if (loadBytecodeObject( + L_shared, + luaJIT_BC_sqlite_shared, + luaJIT_BC_sqlite_shared_SIZE, "sqlite_shared") != 0) { + set_bridge_error("sqlite_plugin_lj: failed to load sqlite_shared bytecode"); + return; + } + + if (luaL_dostring(L_shared, + "return sqlite_shared\n" + + ) != LUA_OK) { + const char* err = lua_tostring(L_shared, -1); + printf("Lua load error: %s\n", err); + set_bridge_error("sqlite_plugin_lj: failed to load sqlite_shared module"); + lua_pop(L_shared, 1); + return; + } + lua_getfield(L_shared, -1, "init"); // plugin.init + if (!lua_isfunction(L_shared, -1)) { + printf("Error: init is not a function\n"); + set_bridge_error("sqlite_plugin_lj: sqlite_shared.init is not a function"); + lua_pop(L_shared, 2); // remove plugin and whatever is on top + return; + } + + lua_pushlightuserdata(L_shared, &bridge); + + if (lua_pcall(L_shared, 1, 0, 0) != LUA_OK) { + const char* err = lua_tostring(L_shared, -1); + printf("Lua init error: %s\n", err); + set_bridge_error("sqlite_plugin_lj: sqlite_shared.init failed"); + lua_pop(L_shared, 1); + return; + } + + if (bridge.set_fn_context && bridge.get_fn_context) { + bridge_ready = 1; + } else { + set_bridge_error("sqlite_plugin_lj: bridge callbacks not initialized"); + } +} + +void ATTR_DESTRUCTOR after_main(void); +void after_main(void) +{ + for (int i = 0; i < MAX_CONN_STATES; ++i) { + cleanup_conn_state(&conn_states[i]); + } + if (L_shared) { + lua_close(L_shared); + } + +} + + +int checkLuaError(lua_State* L, int status) +{ + if (status == LUA_OK) { + return 1; + } + const char* detail = NULL; + if (L && lua_gettop(L) > 0) { + detail = lua_tostring(L, -1); + } + if (!detail) { + detail = "[non-string lua error]"; + } + printf("Lua error %d: %s\n", status, detail); + if (L && lua_gettop(L) > 0) { + lua_pop(L, 1); + } + return 0; +} + +extern EXT_EXPORT int sqlite3_extension_init( sqlite3 *, char **, const sqlite3_api_routines *pApi); -extern int sqlite3_extension_init( +extern EXT_EXPORT int sqlite3_extension_init( sqlite3 * db, char **msg, const sqlite3_api_routines *api) { - lua_getglobal(L, "require"); - lua_pushstring(L, "sqlite_lj"); - int status = lua_pcall(L, 1, 1, 0); - if (status) - { - checkLuaError(status); + sqlite3_api = api; + ConnState *state = get_or_create_conn_state(db); + if (!state) { + if (msg) { + *msg = sqlite3_api->mprintf("%s", "sqlite_plugin_lj: too many active sqlite connections"); + } return SQLITE_ERROR; } - else - { - lua_getfield(L, 1, "extension_init"); - extension_init_ref = luaL_ref(L, LUA_REGISTRYINDEX); - lua_getfield(L, 1, "extension_deinit"); - extension_deinit_ref = luaL_ref(L, LUA_REGISTRYINDEX); - lua_settop(L, 0); + if (state->vm_stack[0].L || state->vtable_vm.L) { + cleanup_conn_state(state); + state->db = db; + } + +#if defined(_MSC_VER) + /* MSVC does not support GCC constructor attributes; initialize lazily. */ + state_lock(); + if (!L_shared && !bridge_ready) { + before_main(); + } + state_unlock(); +#endif - LJFunctionData* udata; + if (!bridge_ready || !bridge.set_fn_context || !bridge.get_fn_context) { + if (msg) { + const char* detail = bridge_init_error[0] ? bridge_init_error : "sqlite_plugin_lj: shared bridge not initialized"; + *msg = sqlite3_api->mprintf("%s", detail); + } + return SQLITE_ERROR; + } - lua_rawgeti(L, LUA_REGISTRYINDEX, extension_init_ref); + if (!state->cleanup_registered) { + int rc = sqlite3_api->create_function_v2( + db, + "__sqlite_lj_conn_cleanup__", + 0, + SQLITE_UTF8, + state, + sqlite_conn_cleanup_noop_fn, + NULL, + NULL, + sqlite_conn_cleanup_destroy_cb + ); + if (rc != SQLITE_OK) { + if (msg) { + *msg = sqlite3_api->mprintf("%s", sqlite3_api->errmsg(db)); + } + return SQLITE_ERROR; + } + state->cleanup_registered = true; + } - udata = (LJFunctionData*) lua_newuserdata(L, sizeof(LJFunctionData)); - udata->db = db; - udata->msg = msg; - udata->api = api; + pushFunctionContext((FunctionContext){0}); // index 0 is empty context - status = lua_pcall(L, 1, 0, 0); - checkLuaError(status); + Worker w = push_vm(state); + if (!w.L){ + if (msg) { + *msg = sqlite3_api->mprintf("%s", VM_INIT_ERR); + } + return SQLITE_ERROR; } + pop_vm(state, w); return SQLITE_OK; } - - diff --git a/src/sqlite_capi.lua b/src/sqlite_capi.lua index c151461..2099358 100644 --- a/src/sqlite_capi.lua +++ b/src/sqlite_capi.lua @@ -469,19 +469,81 @@ struct sqlite3_index_info { typedef struct lua_vtab { sqlite3_vtab base; /* Base class - must be first */ uint32_t index; /* (this) Derived class data */ + int32_t conn_slot; /* Connection slot bound at xConnect time */ } lua_vtab; +typedef struct LJFunctionData_old { + sqlite3 * db; + char ** msg; + const sqlite3_api_routines *api; +} LJFunctionData_old; + +typedef struct LJFunctionArgs { + sqlite3_context *ctx; + int argc; + sqlite3_value **argv; +} LJFunctionArgs; + +typedef struct { + void (*fn_ptr)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*fn_step_ptr)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*fn_final_ptr)(sqlite3_context *ctx); + void (*fn_destroy_ptr)(void*); + int64_t udata; + int allowed_nested; +} FunctionContext; + typedef struct LJFunctionData { sqlite3 * db; char ** msg; const sqlite3_api_routines *api; + void (*callback)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*cb_context_fn)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*cb_context_step_fn)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*cb_context_final_fn)(sqlite3_context *ctx); + void (*cb_context_destroy_fn)(void*); + void (*sqlite_return_int_cb)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*sqlite_return_text_cb)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + void (*sqlite_free_cb)(void*); + + // context functions + int (*pushFunctionContext)(FunctionContext ctx); + const FunctionContext* (*getFunctionContext)(int64_t index); + //shared bridge functions + int (*set_object)(const char* data, size_t len); + const char* (*get_object)(int id, size_t* len); + + int call_depth; + int is_vtable_vm; + int conn_slot; + // lua write + void (*caller_fn)(sqlite3_context *ctx, int argc, sqlite3_value **argv); + + // vtable callbacks - Lua implementations (set by Lua) + int (*vtab_xOpen_lua)(sqlite3_vtab*, sqlite3_vtab_cursor**); + int (*vtab_xClose_lua)(sqlite3_vtab_cursor*); + int (*vtab_xFilter_lua)(sqlite3_vtab_cursor*, int, const char*, int, sqlite3_value**); + int (*vtab_xNext_lua)(sqlite3_vtab_cursor*); + int (*vtab_xEof_lua)(sqlite3_vtab_cursor*); + int (*vtab_xColumn_lua)(sqlite3_vtab_cursor*, sqlite3_context*, int); + int (*vtab_xRowid_lua)(sqlite3_vtab_cursor*, sqlite_int64*); + // vtable callbacks - C wrappers (provided by C, used by Lua module) + int (*cb_vtab_xOpen)(sqlite3_vtab*, sqlite3_vtab_cursor**); + int (*cb_vtab_xClose)(sqlite3_vtab_cursor*); + int (*cb_vtab_xFilter)(sqlite3_vtab_cursor*, int, const char*, int, sqlite3_value**); + int (*cb_vtab_xNext)(sqlite3_vtab_cursor*); + int (*cb_vtab_xEof)(sqlite3_vtab_cursor*); + int (*cb_vtab_xColumn)(sqlite3_vtab_cursor*, sqlite3_context*, int); + int (*cb_vtab_xRowid)(sqlite3_vtab_cursor*, sqlite_int64*); } LJFunctionData; ]] local defines = { SQLITE_UTF8 = 1, SQLITE_OK = 0, + SQLITE_ERROR = 1, SQLITE_NOMEM = 7, --A malloc() failed + SQLITE_MISUSE = 21, SQLITE_DETERMINISTIC = 0x000000800, SQLITE_DIRECTONLY = 0x000080000, @@ -500,7 +562,8 @@ local defines = { SQLITE_NULL = 5, SQLITE_TEXT = 3, - SQLITE_VTAB_INNOCUOUS = 2 + SQLITE_VTAB_INNOCUOUS = 2, + SQLITE_INDEX_CONSTRAINT_EQ = 2 } local SQLITE = {} diff --git a/src/sqlite_lj.lua b/src/sqlite_lj.lua index baa7f84..46548ff 100644 --- a/src/sqlite_lj.lua +++ b/src/sqlite_lj.lua @@ -1,12 +1,13 @@ local plugin = {} plugin._DESCRIPTION = "LuaJIT FFI sqlite language extension" -plugin._VERSION = "sqlite lj 0.2" +plugin._VERSION = "sqlite lj 0.3" local ffi = require('ffi') -local SQLITE = require('sqlite_capi').SQLITE -local NULL = require('sqlite_capi').NULL +local sqlite_capi = sqlite_capi or require('sqlite_capi') +local SQLITE = sqlite_capi.SQLITE +local NULL = sqlite_capi.NULL local int_t = ffi.typeof("int") local int64_t = ffi.typeof("int64_t") @@ -19,22 +20,37 @@ local vtab_cursor_t = ffi.typeof('lua_vtab_cursor') local bor = require("bit").bor +local plugin_init_data local sqlite_db local sqlite_api local unfinalized_statements = setmetatable({}, { __mode = "k" }) -local close_unfinalized = function () - for stmt, isopen in pairs(unfinalized_statements) do - if not isopen then - return - end +local safe_finalize_stmt = function(stmt_pp) + if stmt_pp == nil then + return SQLITE.OK + end + local stmt_p = stmt_pp[0] + if stmt_p == nil then + return SQLITE.OK + end + stmt_pp[0] = nil + return sqlite_api.finalize(stmt_p) +end - local step_rc = sqlite_api.finalize(ffi.gc(stmt, nil)[0]) - if step_rc ~= SQLITE.OK then - print('close unfinalized statement failed', step_rc, stmt) +local close_unfinalized = function () + local close_error + for stmt, status in pairs(unfinalized_statements) do + if status.isopen and stmt[0] ~= nil then + status.isopen = false + unfinalized_statements[stmt] = nil + local step_rc = safe_finalize_stmt(stmt) + if step_rc ~= SQLITE.OK and close_error == nil then + close_error = 'close unfinalized statement failed: ' .. tostring(status.sql) + end end end + return close_error end local LJError = (function () @@ -73,6 +89,35 @@ public_env.double_t = double_t local config = {} public_env.config = config +local function error_xcall(err) + if type(err) == "table" then + if err.detail == nil then + err.detail = config.use_traceback == 0 and '' or debug.traceback() + end + return err + else + return { message = err, detail = config.use_traceback == 0 and '' or debug.traceback() } + end +end + +local buffer = require("string.buffer") +local function setObject(obj) + local encoded = buffer.encode(obj) + local key = tonumber(plugin_init_data.set_object(encoded, #encoded)) + if not key or key <= 0 then + return error("shared object storage limit exceeded") + end + return key +end + +local function getObject(id) + local size_out = ffi.new("size_t[1]") + local ptr = plugin_init_data.get_object(id, size_out) + if ptr == nil then return nil end + local encoded = ffi.string(ptr, size_out[0]) + return buffer.decode(encoded) +end + local function Storage() local self = {} local map = {} @@ -100,17 +145,190 @@ local function Storage() return self end + + +local function FNStorage() + local self = {} + local map = {} + + self.new_key = function (name, code_text) + if type(code_text) ~= 'string' then + return nil, 'code_text is not a string' + end + local fn, err = self.__make_fn(name, code_text) + if not fn then + return nil, 'Failed to create function: ' .. tostring(err) + end + local key = setObject({name = name, code_text = code_text}) + return tonumber(key) + end + + self.get = function (key) + local key = tonumber(key) + if not map[key] then + local obj = getObject(key) + if obj then + local fn_name = obj.name or tostring(key) + local fn, err = self.__make_fn(fn_name, ffi.string(obj.code_text), -1) + if not fn then + return nil, 'Failed to load function: ' .. tostring(err) + end + map[key] = fn + end + end + return map[key] + end + + self.remove = function (key) + key = tonumber(key) + map[key] = nil + end + + self.__make_fn = function(name, code_text) + if code_text == nil then return end + + local fn_env = {} + setmetatable(fn_env, public_env_mt) + + local f, err = loadstring(code_text, name, "t", fn_env) + if (f) then + local status, res = xpcall(f, error_xcall) + if not status then + local msg = 'Create failed ['.. name .. ']\n' ..tostring(res.message) .. '\n' .. res.detail + return nil, msg + end + + setfenv(res, fn_env) + return res, nil + + else + local msg = 'Create failed ['.. name .. ']\n' .. tostring(err) + return nil, msg + end + end + + return self +end + +local function AggChkStorage() + local self = {} + local map = {} + + self.new_key = function(name, init_text, step_text, final_text, inverse_text) + -- Validate and compile to check for errors early + local init_fn, err = loadstring(init_text, name..':init', "t") + if not init_fn then return nil, err end + local step_fn, err = loadstring(step_text, name..':step', "t") + if not step_fn then return nil, err end + local final_fn, err = loadstring(final_text, name..':final', "t") + if not final_fn then return nil, err end + if inverse_text then + local inverse_fn, err = loadstring(inverse_text, name..':inverse', "t") + if not inverse_fn then return nil, err end + end + + -- Store text in shared memory (include name for better error messages) + local key = setObject({name = name, init = init_text, step = step_text, final = final_text, inverse = inverse_text}) + return tonumber(key) + end + + self.get = function(key) + key = tonumber(key) + if key == nil then + return nil + end + if not map[key] then + local obj = getObject(key) + if obj then + local fn_name = obj.name or tostring(key) + local init_fn, err = loadstring(obj.init, fn_name..':init', "t") + if not init_fn then + error('aggregate load failed ['..fn_name..':init]: '..tostring(err)) + end + local step_fn, err = loadstring(obj.step, fn_name..':step', "t") + if not step_fn then + error('aggregate load failed ['..fn_name..':step]: '..tostring(err)) + end + local final_fn, err = loadstring(obj.final, fn_name..':final', "t") + if not final_fn then + error('aggregate load failed ['..fn_name..':final]: '..tostring(err)) + end + local inverse_fn = nil + if obj.inverse then + inverse_fn, err = loadstring(obj.inverse, fn_name..':inverse', "t") + if not inverse_fn then + error('aggregate load failed ['..fn_name..':inverse]: '..tostring(err)) + end + end + map[key] = { + init = init_fn, + step = step_fn, + final = final_fn, + inverse = inverse_fn + } + end + end + return map[key] + end + + self.remove = function(key) + key = tonumber(key) + if key == nil then + return + end + map[key] = nil + end + + return self +end + +local function_refs_text = FNStorage() +local agg_chk_refs_text = AggChkStorage() local function_refs = Storage() local agg_function_refs = Storage() -local vfunc_data = Storage() local vfunc_cur = Storage() +-- Shared storage for vtable metadata (accessible from vtable_vm) +local function SharedStorage() + local self = {} + local map = {} + + self.new_key = function(data) + -- Store in shared memory + local key = setObject(data) + return tonumber(key) + end + + self.get = function(key) + key = tonumber(key) + if not map[key] then + -- Retrieve from shared memory and cache locally + local obj = getObject(key) + if obj then + map[key] = obj + end + end + return map[key] + end + + self.remove = function(key) + key = tonumber(key) + map[key] = nil + end + + return self +end + +local shared_vtab_data = SharedStorage() +local make_vtable_modules = {} + local wrap_csafe +local wrap_csafe_cs local create_function_agg_chk -local create_function_agg_coro local run_sql local fetch_all +local fetch_first local nrows local map_sqlite_to_lj = { @@ -143,12 +361,10 @@ local sqlite_to_l = function(value) return ret_value end ---sqlite3_context *context, int argc, sqlite3_value **argv local api_get_args = function(context, argc, argv) local tmp = {} for i = 0, argc -1 do - local value = sqlite_to_l (argv[i]) - table.insert(tmp, value) + tmp[i + 1] = sqlite_to_l(argv[i]) end return unpack(tmp) end @@ -181,7 +397,7 @@ local return_handlers = { end, number = function(context, value) - if value == math.floor(value) then + if value ~= math.huge and value ~= -math.huge and value == math.floor(value) then sqlite_api.result_int64(context, value) else sqlite_api.result_double(context, value) @@ -239,7 +455,7 @@ local bind_handlers = { end, number = function(stmt, index, value) - if value == math.floor(value) then + if value ~= math.huge and value ~= -math.huge and value == math.floor(value) then sqlite_api.bind_int64(stmt, index, value) else sqlite_api.bind_double(stmt, index, value) @@ -247,7 +463,7 @@ local bind_handlers = { end, string = function(stmt, index, value) - sqlite_api.bind_text(stmt, index, value, -1, nil) + sqlite_api.bind_text(stmt, index, value, -1, SQLITE.TRANSIENT) end, boolean = function(stmt, index, value) @@ -271,7 +487,7 @@ local api_bind_any = function(stmt, index, value) elseif type(value.data) ~= 'nil' and value.size ~= nil then sqlite_api.bind_blob(stmt, index, value.data, value.size, SQLITE.TRANSIENT) else - local msg = "api_function_return_any: unsupported type " .. tostring(type(value)) .. " value " .. tostring(value) + local msg = "api_bind_any: unsupported type " .. tostring(type(value)) .. " value " .. tostring(value) error(msg) end else @@ -280,6 +496,71 @@ local api_bind_any = function(stmt, index, value) end end +local validate_bind_value = function(value) + local value_type = type(value) + if bind_handlers[value_type] then + if value_type == 'cdata' then + if ffi.istype(value, int_t) or ffi.istype(value, int64_t) then + return true + elseif ffi.istype(value, uint64_t) then + return true + elseif ffi.istype(value, float_t) or ffi.istype(value, double_t) then + return true + elseif value == NULL then + return true + else + return false, "api_bind_any: unsupported type " .. tostring(ffi.typeof(value)) .. " value " .. tostring(value) + end + elseif value_type == 'table' then + if value.size == 0 then + return true + elseif type(value.data) ~= 'nil' and value.size ~= nil then + return true + end + return false, "api_bind_any: unsupported type " .. tostring(type(value)) .. " value " .. tostring(value) + end + return true + elseif value == NULL then + return true + elseif is_error(value) then + return false, LJError.get() + elseif value_type == 'table' then + if value.size == 0 then + return true + elseif type(value.data) ~= 'nil' and value.size ~= nil then + return true + end + return false, "api_bind_any: unsupported type " .. tostring(type(value)) .. " value " .. tostring(value) + else + return false, "api_bind_any: unsupported type " .. value_type .. " value " .. tostring(value) + end +end + +local resolve_bind_keys = function(stmt_p) + local parameter_count = sqlite_api.bind_parameter_count(stmt_p) + local bind_keys = {} + for k = 1, parameter_count do + local c_param_name = sqlite_api.bind_parameter_name(stmt_p, k) + if c_param_name == nil then + bind_keys[k] = k + else + bind_keys[k] = ffi.string(c_param_name):sub(2) -- remove prefix + end + end + return parameter_count, bind_keys +end + +local resolve_bind_value = function(params, key, index) + if type(key) == "number" then + return params[key] + end + local value = params[key] + if value == nil then + value = params[index] + end + return value +end + local api_create_function_v2 = function(zFunctionName, nArg, eTextRep, pApp, xFunc, xStep, xFinal, xDestroy) local rc = sqlite_api.create_function_v2(sqlite_db, zFunctionName, nArg, eTextRep, pApp, @@ -295,18 +576,62 @@ local api_create_function_v2 = function(zFunctionName, nArg, eTextRep, pApp, xFu return true end ---sqlite3_context *context, int argc, sqlite3_value **argv -local return_const = function(context, argc, argv) - local saved_constant = ffi.cast('int64_t', sqlite_api.user_data(context)) - sqlite_api.result_int64(context, saved_constant) +local api_create_function_v2_c = function(zFunctionName, nArg, eTextRep, pApp, xFunc, xStep, xFinal, xDestroy) + local rc = sqlite_api.create_function_v2(sqlite_db, zFunctionName, nArg, eTextRep, pApp, + (xFunc), + (xStep), + (xFinal), + (xDestroy) + ); + + if (rc ~= SQLITE.OK) then + return false, ffi.string(sqlite_api.errmsg(sqlite_db)), rc + end + return true, nil, nil end local make_int = function(name, value) - local ok = api_create_function_v2(name, 0, bor(SQLITE.DETERMINISTIC, SQLITE.INNOCUOUS) , ffi.cast('void*', value), return_const, nil, nil, nil) + local ok = api_create_function_v2_c( + name, + 0, + bor(SQLITE.DETERMINISTIC, SQLITE.INNOCUOUS) , + ffi.cast('void*', value), + plugin_init_data.sqlite_return_int_cb, nil, nil, nil) return ok end public_env.make_int = make_int +local make_str = function(name, value) + if type(value) ~= "string" then + return false, "make_str: value must be a string", SQLITE.MISUSE + end + + local size = #value + 1 + local str_ptr = sqlite_api.malloc(size) + if str_ptr == nil then + return false, "make_str: out of memory", SQLITE.NOMEM + end + + ffi.copy(str_ptr, value, #value) + ffi.cast("char*", str_ptr)[#value] = 0 + + local ok, err, errcode = api_create_function_v2_c( + name, + 0, + bor(SQLITE.DETERMINISTIC, SQLITE.INNOCUOUS), + str_ptr, + plugin_init_data.sqlite_return_text_cb, nil, nil, plugin_init_data.sqlite_free_cb + ) + + if not ok then + sqlite_api.free(str_ptr) + return ok, err, errcode + end + + return ok, err, errcode +end +public_env.make_str = make_str + local exec_lua = function (code_text, ...) local fn_env = {arg = {...}} setmetatable(fn_env, public_env_mt) @@ -318,50 +643,83 @@ local exec_lua = function (code_text, ...) return fn() end -local function error_xcall(err) - if type(err) == "table" then - if err.detail == nil then - err.detail = config.use_traceback == 0 and '' or debug.traceback() + +wrap_csafe = (function () + local wrapped_functions = {} + return function (fn) + if fn then + if not wrapped_functions[fn] then + wrapped_functions[fn] = function (context, argc, argv) + local status, result = xpcall(fn, error_xcall, context, argc, argv) + if not status then + local msg = tostring(result.message) .. '\n' .. result.detail + api_function_return_any(context, LJError.set(msg)) + end + end + end + + return wrapped_functions[fn] end - return err - else - return { message = err, detail = config.use_traceback == 0 and '' or debug.traceback() } + return fn end -end - +end)() -local wrapped_functions = {} -wrap_csafe = function (fn) - if fn then - if not wrapped_functions[fn] then - --sqlite3_context *context, int argc, sqlite3_value **argv - wrapped_functions[fn] = function (context, argc, argv) - local status, result = xpcall(fn, error_xcall, context, argc, argv) - if not status then - local msg = tostring(result.message) .. '\n' .. result.detail - api_function_return_any(context, LJError.set(msg)) +wrap_csafe_cs = (function () + local wrapped_functions = {} + return function (fn) + if fn then + if not wrapped_functions[fn] then + wrapped_functions[fn] = function (context, argc, argv) + local outer_statements = unfinalized_statements + unfinalized_statements = setmetatable({}, { __mode = "k" }) + local status, result = xpcall(fn, error_xcall, context, argc, argv) + local error_msg = nil + if not status then + error_msg = tostring(result.message) .. '\n' .. result.detail + end + local close_error = close_unfinalized() + if close_error and error_msg then + error_msg = error_msg .. '\n' .. tostring(close_error) + elseif close_error then + -- Function may already have produced a result for this context. + -- Do not call result_* again; report cleanup failure as diagnostic. + io.stderr:write("sqlite_plugin_lj: cleanup warning: " .. tostring(close_error) .. "\n") + end + if error_msg then + api_function_return_any(context, LJError.set(error_msg)) + end + unfinalized_statements = outer_statements end end + + return wrapped_functions[fn] end + return fn + end +end)() + - return wrapped_functions[fn] +local get_fn_context = function(index) + return plugin_init_data.getFunctionContext(index) +end + +local allocate_function_context = function(ctx) + local fn_context = tonumber(plugin_init_data.pushFunctionContext(ctx)) + if not fn_context or fn_context <= 0 then + return error("function context storage limit exceeded") end - return fn + return fn_context end ---sqlite3_context *context, int argc, sqlite3_value **argv local caller_fn = function(context, argc, argv) - local outer_statements = unfinalized_statements - unfinalized_statements = setmetatable({}, { __mode = "k" }) + local fn_context = ffi.cast('int64_t', sqlite_api.user_data(context)) + local fc = get_fn_context(fn_context) + local saved_ref = fc and fc.udata or nil - local saved_ref = ffi.cast('int64_t', sqlite_api.user_data(context)) local fn = function_refs.get(saved_ref) local result = fn(api_get_args(context, argc, argv)) api_function_return_any(context, result) - - close_unfinalized() - unfinalized_statements = outer_statements end local create_function = function(name, fn, argc) @@ -369,9 +727,13 @@ local create_function = function(name, fn, argc) return error('create_function: "fn" is not a function') end - local ref = function_refs.new_key(fn) -- items count + local ref = function_refs.new_key(fn) + --local callback_fn = wrap_csafe(caller_fn_cs) + local callback_fn = wrap_csafe_cs(caller_fn) + + local fn_context = allocate_function_context(ffi.new("FunctionContext", {fn_ptr = callback_fn, udata = ref, allowed_nested = false})) - local status, err, errcode = api_create_function_v2(name, argc, SQLITE.UTF8, ffi.cast('void*', ref) , caller_fn, nil, nil, nil) + local status, err, errcode = api_create_function_v2_c(name, argc, SQLITE.UTF8, ffi.cast('void*', fn_context) , plugin_init_data.cb_context_fn, nil, nil, nil) if not status then return error(err) end @@ -379,7 +741,6 @@ local create_function = function(name, fn, argc) end public_env.create_function = create_function ---sqlite3_context *context, int argc, sqlite3_value **argv local caller_chk = function(context, argc, argv) local outer_statements = unfinalized_statements unfinalized_statements = setmetatable({}, { __mode = "k" }) @@ -395,7 +756,10 @@ local caller_chk = function(context, argc, argv) local result = fn() api_function_return_any(context, result) - close_unfinalized() + local close_error = close_unfinalized() + if close_error then + io.stderr:write("sqlite_plugin_lj: cleanup warning: " .. tostring(close_error) .. "\n") + end unfinalized_statements = outer_statements end @@ -420,37 +784,35 @@ local create_function_chk = function(name, code_text, argc, wrapper_function) end local make_fn = function(name, code_text, argc) - argc = argc or -1 - local fn_env = {} - setmetatable(fn_env, public_env_mt) - - if code_text == nil then return end - local f, err = loadstring(ffi.string(code_text), name, "t", fn_env) - if (f) then - local status, res = xpcall(f, error_xcall) + local create_fn = function(name, argc, fid) + local status, err, errcode = api_create_function_v2_c(name, argc, SQLITE.UTF8, ffi.cast('void*', fid) , plugin_init_data.callback, nil, nil, nil) if not status then - local msg = 'Create failed ['.. name .. ']\n' ..tostring(res.message) .. '\n' .. res.detail - return msg + return error(err) end + end - setfenv(res, fn_env) - local status, err = xpcall(create_function, error_xcall, name, res, argc) + argc = argc or -1 - if not status then - local msg = 'Create failed ['.. name .. ']\n' .. tostring(err.message) .. '\n' .. err.detail - return msg - end - - else - local msg = 'Create failed ['.. name .. ']\n' .. tostring(err) + local fid, err = function_refs_text.new_key(name, code_text) + if not fid then + return err + end + + local status, err = xpcall(create_fn, error_xcall, name, argc, fid) + + if not status then + local msg = 'Create failed ['.. name .. ']\n' .. tostring(err.message) .. '\n' .. err.detail return msg end + + end public_env.make_fn = make_fn local make_chk = function (name, chunk_text, argc) - local status, err = xpcall(create_function_chk, error_xcall, name, chunk_text, 1, caller_chk) + argc = argc or -1 + local status, err = xpcall(create_function_chk, error_xcall, name, chunk_text, argc, caller_chk) if not status then local msg = 'Create failed ['.. name .. ']\n' .. tostring(err.message) .. '\n' .. err.detail return msg @@ -464,160 +826,414 @@ local make_function_agg_chk = function(fname, finit, fstep, ffinal, fargc) end public_env.make_function_agg_chk = make_function_agg_chk -local make_function_agg = function(name, code_text, argc) - argc = argc or -1 - if code_text == nil then return end - local f, err = loadstring(ffi.string(code_text), name, "t", _G) - if (f) then - local status, res = xpcall(f, error_xcall) - if not status then - local msg = 'Create failed ['.. name .. ']\n' ..tostring(res.message) .. '\n' .. res.detail - return msg - end +-- Aggregate helper functions +local get_agg_storage = function(context) + return ffi.cast('int*', sqlite_api.aggregate_context(context, ffi.sizeof('int'))) +end - local status, err = xpcall(create_function_agg_coro, error_xcall, name, res, argc) - if not status then - local msg = 'Create failed ['.. name .. ']\n' .. tostring(err.message) .. '\n' .. err.detail - return msg - end - else - local msg = 'Create failed ['.. name .. ']\n' .. tostring(err) - return msg - end +local get_fn_udata = function(context) + local fn_context = ffi.cast('int64_t', sqlite_api.user_data(context)) + local fc = get_fn_context(fn_context) + return fc and fc.udata or nil end -public_env.make_function_agg = make_function_agg --- sqlite3_context *context, int argc, sqlite3_value **argv -local agg_cb_coro = function(context, argc, argv) - local storage = ffi.cast('int*', sqlite_api.aggregate_context(context, ffi.sizeof('int'))) - if (storage == nil) then - return sqlite_api.result_error_nomem(context); +local resume_or_raise = function(agg_fn, phase, ...) + local ok, result = coroutine.resume(agg_fn, ...) + if not ok then + error("aggregate coroutine " .. phase .. " failed: " .. tostring(result)) end + return result +end - if (storage[0] == 0) then - local saved_ref = ffi.cast('int64_t', sqlite_api.user_data(context)) - local agg_fn = coroutine.create(function_refs.get(saved_ref)) - - local key = agg_function_refs.new_key(agg_fn) - storage[0] = key +-- Ensure coroutine is initialized, returns the coroutine +local ensure_coro_init = function(context, storage) + if storage[0] == 0 then + local coro_fn = function_refs_text.get(get_fn_udata(context)) + local agg_fn = coroutine.create(coro_fn) + storage[0] = agg_function_refs.new_key(agg_fn) + resume_or_raise(agg_fn, "init") + end + return agg_function_refs.get(storage[0]) +end - coroutine.resume(agg_fn, api_get_args(context, argc, argv)) --run coro +local agg_cb_coro_text = function(context, argc, argv) + local storage = get_agg_storage(context) + if storage == nil then + return sqlite_api.result_error_nomem(context) end - local agg_fn = agg_function_refs.get(storage[0]) - coroutine.resume(agg_fn, true, api_get_args(context, argc, argv)) + local agg_fn = ensure_coro_init(context, storage) + resume_or_raise(agg_fn, "step", true, api_get_args(context, argc, argv)) end ---sqlite3_context *context -local agg_final_coro = function(context) - local storage = ffi.cast('int*', sqlite_api.aggregate_context(context, ffi.sizeof('int'))) - local agg_index = storage[0] - local agg_fn = agg_function_refs.get(agg_index) - local _, result = coroutine.resume(agg_fn, false) - agg_function_refs.remove(agg_index) +local agg_final_coro_text = function(context) + local storage = get_agg_storage(context) + local agg_fn = ensure_coro_init(context, storage) + agg_function_refs.remove(storage[0]) + local result = resume_or_raise(agg_fn, "final", false) api_function_return_any(context, result) end -create_function_agg_coro = function(name, fn, argc) - if type(fn) ~= 'function' then - return error('create_function: "fn" is not a function') +local create_function_agg_coro_text = function(name, code_text, argc) + local fid, err = function_refs_text.new_key(name, code_text) + if not fid then + return err end - local ref = function_refs.new_key(fn) - - local status, err, errcode = api_create_function_v2(name, argc, SQLITE.UTF8, ffi.cast('void*', ref) , nil, agg_cb_coro, agg_final_coro, nil) + local fn_context = allocate_function_context( + ffi.new("FunctionContext", { + fn_ptr = nil, + fn_step_ptr = wrap_csafe_cs(agg_cb_coro_text), + fn_final_ptr = wrap_csafe_cs(agg_final_coro_text), + fn_destroy_ptr = nil, + udata = fid, + allowed_nested = true + })) + + local status, err, errcode = api_create_function_v2_c(name, argc, SQLITE.UTF8, ffi.cast('void*', fn_context), + nil, plugin_init_data.cb_context_step_fn, plugin_init_data.cb_context_final_fn, nil) if not status then return error(err) end - end -public_env.create_function_agg_coro = create_function_agg_coro - ---sqlite3_context *context, int argc, sqlite3_value **argv -local agg_cb_chk = function(context, argc, argv) - local storage = ffi.cast('int*', sqlite_api.aggregate_context(context, ffi.sizeof('int'))) - if (storage == nil) then - return sqlite_api.result_error_nomem(context); - end - - local saved_ref = ffi.cast('int64_t', sqlite_api.user_data(context)) - local fn = function_refs.get(saved_ref) - - local args = {api_get_args(context, argc, argv)} - - if (storage[0] == 0) then - local chk_env = {} - setmetatable(chk_env, public_env_mt) - chk_env['arg'] = args - - local key = agg_function_refs.new_key(chk_env) - storage[0] = key +public_env.create_function_agg_coro_text = create_function_agg_coro_text + +-- Ensure chunk env is initialized, returns the env +local ensure_chk_init = function(context, storage) + if storage[0] == 0 then + local fn = agg_chk_refs_text.get(get_fn_udata(context)) + local chk_env = setmetatable({arg = {}}, public_env_mt) + storage[0] = agg_function_refs.new_key(chk_env) setfenv(fn.init, chk_env) fn.init() end + return agg_function_refs.get(storage[0]) +end + +local agg_cb_chk_text = function(context, argc, argv) + local storage = get_agg_storage(context) + if storage == nil then + return sqlite_api.result_error_nomem(context) + end - local chk_env = agg_function_refs.get(storage[0]) - chk_env['arg'] = args + local chk_env = ensure_chk_init(context, storage) + chk_env['arg'] = {api_get_args(context, argc, argv)} + local fn = agg_chk_refs_text.get(get_fn_udata(context)) setfenv(fn.step, chk_env) fn.step() end ---sqlite3_context *context -local agg_final_chk = function(context) - local storage = ffi.cast('int*', sqlite_api.aggregate_context(context, ffi.sizeof('int'))) - local agg_index = storage[0] - local chk_env = agg_function_refs.get(agg_index) - agg_function_refs.remove(agg_index) +local agg_final_chk_text = function(context) + local storage = get_agg_storage(context) + local chk_env = ensure_chk_init(context, storage) + agg_function_refs.remove(storage[0]) - local saved_ref = ffi.cast('int64_t', sqlite_api.user_data(context)) - local fn = function_refs.get(saved_ref) + local fn = agg_chk_refs_text.get(get_fn_udata(context)) setfenv(fn.final, chk_env) - local result = fn.final() api_function_return_any(context, result) end + create_function_agg_chk = function(name, init, step, final, argc) - local agg_code = { - init = loadstring(init, name..':init', "t"), - step = loadstring(step, name..':step', "t"), - final = loadstring(final, name..':final', "t") - } + argc = argc or -1 - local ref = function_refs.new_key(agg_code) + local fid, err = agg_chk_refs_text.new_key(name, init, step, final) + if not fid then + return error(err) + end - local status, err, errcode = api_create_function_v2(name, argc, SQLITE.UTF8, ffi.cast('void*', ref) , nil, agg_cb_chk, agg_final_chk, nil) + local fn_context = allocate_function_context( + ffi.new("FunctionContext", { + fn_ptr = nil, + fn_step_ptr = wrap_csafe_cs(agg_cb_chk_text), + fn_final_ptr = wrap_csafe_cs(agg_final_chk_text), + fn_destroy_ptr = nil, + udata = fid, + allowed_nested = true + })) + + local status, err, errcode = api_create_function_v2_c( + name, argc, SQLITE.UTF8, ffi.cast('void*', fn_context), + nil, + plugin_init_data.cb_context_step_fn, + plugin_init_data.cb_context_final_fn, + nil) if not status then return error(err) end end public_env.create_function_agg_chk = create_function_agg_chk -run_sql = function (sql) - sql = tostring(sql) - local rc = sqlite_api.exec(sqlite_db, sql, nil, nil, nil) - if rc ~= SQLITE.OK then - local msg = ffi.string(sqlite_api.errmsg(sqlite_db)) - return error("Failed to execute query: \n[" .. sql .. "] \n" .. msg) - end +-- Window function callbacks (use direct udata, not FunctionContext) +local get_udata_direct = function(context) + return tonumber(ffi.cast('int64_t', sqlite_api.user_data(context))) end -public_env.run_sql = run_sql -local finalize_stmt = function (stmt) - return sqlite_api.finalize(stmt[0]) +local ensure_chk_init_window = function(context, storage) + if storage[0] == 0 then + local fn = agg_chk_refs_text.get(get_udata_direct(context)) + local chk_env = setmetatable({arg = {}}, public_env_mt) + storage[0] = agg_function_refs.new_key(chk_env) + setfenv(fn.init, chk_env) + fn.init() + end + return agg_function_refs.get(storage[0]) end -fetch_all = function (...) - local rows = {} - for row in nrows(...) do - table.insert(rows, row); +local agg_cb_chk_window = function(context, argc, argv) + local storage = get_agg_storage(context) + if storage == nil then + return sqlite_api.result_error_nomem(context) end - return rows -end -public_env.fetch_all = fetch_all -local Statement = function (sql, params) - local self = {} + local chk_env = ensure_chk_init_window(context, storage) + chk_env['arg'] = {api_get_args(context, argc, argv)} + + local fn = agg_chk_refs_text.get(get_udata_direct(context)) + setfenv(fn.step, chk_env) + fn.step() +end + +local agg_final_chk_window = function(context) + local storage = get_agg_storage(context) + local chk_env = ensure_chk_init_window(context, storage) + agg_function_refs.remove(storage[0]) + + local fn = agg_chk_refs_text.get(get_udata_direct(context)) + setfenv(fn.final, chk_env) + local result = fn.final() + api_function_return_any(context, result) +end + +local agg_value_chk_window = function(context) + local storage = get_agg_storage(context) + local chk_env = ensure_chk_init_window(context, storage) + + local fn = agg_chk_refs_text.get(get_udata_direct(context)) + setfenv(fn.final, chk_env) + local result = fn.final() + api_function_return_any(context, result) +end + +local agg_inverse_chk_window = function(context, argc, argv) + local storage = get_agg_storage(context) + local chk_env = ensure_chk_init_window(context, storage) + + local fn = agg_chk_refs_text.get(get_udata_direct(context)) + if not fn.inverse then + return sqlite_api.result_error(context, "sliding windows not supported for this aggregate", -1) + end + chk_env['arg'] = {api_get_args(context, argc, argv)} + setfenv(fn.inverse, chk_env) + fn.inverse() +end + +local window_cb_anchors = nil +local get_window_cb_anchors = function() + if window_cb_anchors ~= nil then + return window_cb_anchors + end + window_cb_anchors = { + step = ffi.cast( + 'void (*)(sqlite3_context*, int, sqlite3_value**)', + wrap_csafe_cs(agg_cb_chk_window) + ), + final = ffi.cast( + 'void (*)(sqlite3_context*)', + wrap_csafe_cs(agg_final_chk_window) + ), + value = ffi.cast( + 'void (*)(sqlite3_context*)', + wrap_csafe_cs(agg_value_chk_window) + ), + inverse = ffi.cast( + 'void (*)(sqlite3_context*, int, sqlite3_value**)', + wrap_csafe_cs(agg_inverse_chk_window) + ), + } + return window_cb_anchors +end + +-- create_function_agg_chk_window(name, init, step, final, argc) - no sliding window support +-- create_function_agg_chk_window(name, init, step, inverse, final, argc) - full sliding window support +local create_function_agg_chk_window = function(name, init, step, arg4, arg5, arg6) + local inverse, final, argc + + if type(arg5) == "number" or arg5 == nil then + -- 4 string args: init, step, final, argc (no inverse) + inverse = nil + final = arg4 + argc = arg5 or -1 + else + -- 5 string args: init, step, inverse, final, argc + inverse = arg4 + final = arg5 + argc = arg6 or -1 + end + + local fid, err = agg_chk_refs_text.new_key(name, init, step, final, inverse) + if not fid then + return error(err) + end + + local cbs = get_window_cb_anchors() + local rc = sqlite_api.create_window_function( + sqlite_db, name, argc, SQLITE.UTF8, ffi.cast('void*', fid), + cbs.step, + cbs.final, + cbs.value, + cbs.inverse, + nil) + if rc ~= SQLITE.OK then + return error(ffi.string(sqlite_api.errmsg(sqlite_db))) + end +end +public_env.create_function_agg_chk_window = create_function_agg_chk_window + +local internal_sql_builtin_specs = { + make_fn = { fn = make_fn, argc = -1 }, + make_int = { fn = make_int, argc = 2 }, + make_str = { fn = make_str, argc = 2 }, + make_chk = { fn = make_chk, argc = 3 }, + make_function_agg_chk = { fn = make_function_agg_chk, argc = -1 }, + create_function_agg_coro_text = { fn = create_function_agg_coro_text, argc = -1 }, + create_function_agg_chk = { fn = create_function_agg_chk, argc = -1 }, + create_function_agg_chk_window = { fn = create_function_agg_chk_window, argc = -1 }, +} + +local register_internal_function = function(source_name, sql_name) + if type(source_name) ~= "string" or source_name == "" then + return error("register_internal_function: source_name must be a non-empty string") + end + + local spec = internal_sql_builtin_specs[source_name] + if not spec then + return error("register_internal_function: unknown source_name [" .. source_name .. "]") + end + + sql_name = sql_name or source_name + if type(sql_name) ~= "string" or sql_name == "" then + return error("register_internal_function: sql_name must be a non-empty string") + end + + return create_function(sql_name, spec.fn, spec.argc) +end +public_env.register_internal_function = register_internal_function + +local register_internal_functions = function(name_map) + if name_map == nil then + name_map = {} + end + if type(name_map) ~= "table" then + return error("register_internal_functions: name_map must be a table") + end + + for source_name, spec in pairs(internal_sql_builtin_specs) do + local mapped_name = name_map[source_name] + if mapped_name ~= false then + if mapped_name == nil then + mapped_name = source_name + end + create_function(mapped_name, spec.fn, spec.argc) + end + end +end +public_env.register_internal_functions = register_internal_functions + +run_sql = function (sql) + sql = tostring(sql) + local rc = sqlite_api.exec(sqlite_db, sql, nil, nil, nil) + if rc ~= SQLITE.OK then + local msg = ffi.string(sqlite_api.errmsg(sqlite_db)) + return error("Failed to execute query: \n[" .. sql .. "] \n" .. msg) + end +end +public_env.run_sql = run_sql + +fetch_all = function (...) + local rows = {} + local n = 0 + for row in nrows(...) do + n = n + 1 + rows[n] = row + end + return rows +end +public_env.fetch_all = fetch_all + +local fetchOneStatement = function (sql, params) + local stmt = ffi.new("sqlite3_stmt*[?]", 1) + local columns_count + local step_rc + + local self = {} + self.get_columns = function () + local columns = {} + for i = 0, columns_count - 1 do + columns[i] = ffi.string(sqlite_api.column_name(stmt[0], i)) + end + return columns + end + + self.to_table = function () + local row = {} + local cols = self.get_columns() + for i = 0, columns_count - 1 do + local value = sqlite_to_l(sqlite_api.column_value(stmt[0], i)) + row[cols[i]] = value + end + return row + end + + sql = tostring(sql) + local rc = sqlite_api.prepare_v2(sqlite_db, sql, -1, stmt, nil) + if rc ~= SQLITE.OK then + error("Failed to prepare query \n[" .. sql.. "]" ) + end + + local result = nil + local error_text + if type(params) == "table" then + local bind_error + local parameter_count, bind_keys = resolve_bind_keys(stmt[0]) + for k = 1, parameter_count do + local key = bind_keys[k] + local value = resolve_bind_value(params, key, k) + local can_bind, bind_err = validate_bind_value(value) + if not can_bind then + bind_error = bind_err + break + end + api_bind_any(stmt[0], k, value) + end + + if bind_error then + safe_finalize_stmt(stmt) + return error(bind_error) + end + end + columns_count = sqlite_api.column_count(stmt[0]) + step_rc = sqlite_api.step(stmt[0]) + if step_rc == SQLITE.ROW then + result = self.to_table() + elseif step_rc ~= SQLITE.DONE then + error_text = 'fetch one failed: \n\tmsg: ' .. ffi.string(sqlite_api.errmsg(sqlite_db)) .. '\n\tsql: ' .. sql + end + + local finalize_rc = safe_finalize_stmt(stmt) + if finalize_rc ~= SQLITE.OK and not error_text then + error('fetch one finalize failed: \n\tsql: ' .. sql) + end + if error_text then + return error(error_text) + end + return result +end + + +local Statement = function (sql, params) + local self = {} local stmt local columns_count @@ -625,31 +1241,25 @@ local Statement = function (sql, params) local step_rc sql = tostring(sql) - stmt = ffi.gc(ffi.new("sqlite3_stmt*[?]", 1), finalize_stmt) - unfinalized_statements[stmt] = true + stmt = ffi.new("sqlite3_stmt*[?]", 1) local rc = sqlite_api.prepare_v2(sqlite_db, sql, -1, stmt, nil) if rc ~= SQLITE.OK then error("Failed to prepare query \n[" .. sql.. "]" ) end + + unfinalized_statements[stmt] = {isopen = true, sql = sql} if type(params) == "table" then - local parameter_count = sqlite_api.bind_parameter_count(stmt[0]) - for k = 1, parameter_count do - local c_param_name = sqlite_api.bind_parameter_name(stmt[0], k) - if c_param_name == nil then - api_bind_any(stmt[0], k, params[k]) - else - local name = ffi.string(c_param_name):sub(2) -- remove prefix - local value = params[name] - if value == nil then - value = params[k] - end - api_bind_any(stmt[0], k, value) - end + local parameter_count, bind_keys = resolve_bind_keys(stmt[0]) + for k = 1, parameter_count do + local key = bind_keys[k] + local value = resolve_bind_value(params, key, k) + api_bind_any(stmt[0], k, value) end end columns_count = sqlite_api.column_count(stmt[0]) + local finalized = false self.get_columns = function () if not columns then @@ -663,10 +1273,12 @@ local Statement = function (sql, params) self.to_array = function() local row = {} + local out = 1 for i = 0, columns_count - 1 do local value = sqlite_to_l(sqlite_api.column_value(stmt[0], i)) - table.insert(row, value) + row[out] = value + out = out + 1 end return row end @@ -683,14 +1295,27 @@ local Statement = function (sql, params) end self.finalize_done = function () + if finalized then + return + end + finalized = true + local state = unfinalized_statements[stmt] + local error_text if step_rc ~= SQLITE.DONE then - error('iterate sql failed '.. tostring(step_rc)) + error_text = 'iterate sql failed: \n\tmsg: ' .. ffi.string(sqlite_api.errmsg(sqlite_db)) .. '\n\tsql: ' .. sql + end + if state ~= nil then + state.isopen = false end unfinalized_statements[stmt] = nil - step_rc = sqlite_api.finalize(ffi.gc(stmt, nil)[0]) + step_rc = safe_finalize_stmt(stmt) - if step_rc ~= SQLITE.OK then - error('finalize last failed '.. tostring(step_rc)) + -- if statement failed then closing it also returns error + if step_rc ~= SQLITE.OK and not error_text then + error('finalize last failed: '.. sql ) + end + if error_text then + error(error_text) end end @@ -709,6 +1334,11 @@ local Statement = function (sql, params) end +fetch_first = function (sql, params) + return fetchOneStatement(sql, params) +end +public_env.fetch_first = fetch_first + nrows = function (sql, params) local stmt = Statement(sql, params) @@ -741,60 +1371,64 @@ local urows = function (sql, params) return unpack(stmt.to_array()) end - end + end end public_env.urows = urows -local drop_function = function(name, argc) - local status, err, errcode = api_create_function_v2(name, argc, 0, nil , nil, nil, nil, nil) - if not status then - --TODO fix it - return "DROP FUNCTION error: ".. err +--local sqlite3_module_test = ffi.cast('sqlite3_module *', api.malloc(ffi.sizeof(sqlite3_module_t)))--sqlite3_module_t{} +local connect_make_vtable = function (db, pAux, argc, argv, ppVTab, pzErrUnused) + local shared_key_ptr = ffi.cast('int64_t*', pAux) + local shared_key = shared_key_ptr ~= nil and tonumber(shared_key_ptr[0]) or nil + if not shared_key then + if pzErrUnused ~= nil then + pzErrUnused[0] = sqlite_api.mprintf("lua_vtab: module context missing") + end + return SQLITE.ERROR + end + local vtab_data = shared_vtab_data.get(shared_key) + if not vtab_data or not vtab_data.xtable_text then + if pzErrUnused ~= nil then + pzErrUnused[0] = sqlite_api.mprintf("lua_vtab: shared metadata missing") + end + return SQLITE.ERROR end - return "" + local xtable_text = vtab_data.xtable_text + + local rc = sqlite_api.declare_vtab(db, xtable_text) --"CREATE TABLE x(...)"; + if rc == SQLITE.OK then + local pTable = ffi.cast('lua_vtab*', sqlite_api.malloc(ffi.sizeof('lua_vtab'))) + + if( pTable==nil ) then + return SQLITE.NOMEM + end + + pTable.base.pModule = nil; + pTable.base.nRef = 0 + pTable.base.zErrMsg = nil + pTable.index = shared_key -- store shared_key for vtable_vm to access + pTable.conn_slot = plugin_init_data.conn_slot + + ppVTab[0] = ffi.cast('sqlite3_vtab*', pTable) + end + return rc end -public_env.drop_function = drop_function ---local sqlite3_module_test = ffi.cast('sqlite3_module *', api.malloc(ffi.sizeof(sqlite3_module_t)))--sqlite3_module_t{} local lua_vtable_module = sqlite3_module_t{ iVersion = 1, - xCreate = function () - return SQLITE.OK - end, - xDestroy = function () + xCreate = connect_make_vtable, + xDestroy = function (pVtab) + local pTable = ffi.cast('lua_vtab*', pVtab) + local shared_key = tonumber(pTable.index) + shared_vtab_data.remove(shared_key) + sqlite_api.free(pVtab) return SQLITE.OK end, - xConnect = function (db, pAux, argc, argv, ppVTab, pzErrUnused) - local vfunc_data_index = ffi.cast('int64_t', pAux) - local xtable_text = vfunc_data.get(vfunc_data_index).xtable_text - - local rc = sqlite_api.declare_vtab(db, xtable_text) --"CREATE TABLE x(...)"; - if rc == SQLITE.OK then - local pTable = ffi.cast('lua_vtab*', sqlite_api.malloc(ffi.sizeof('lua_vtab'))) - - if( pTable==nil ) then - return SQLITE.NOMEM - end - - pTable.base.pModule = nil; - pTable.base.nRef = 0 - pTable.base.zErrMsg = nil - pTable.index = vfunc_data_index - - ppVTab[0] = ffi.cast('sqlite3_vtab*', pTable) - end - return rc - - end, + xConnect = connect_make_vtable, --sqlite3_vtab *pVtab xDisconnect = function(pVtab) - local pTable = ffi.cast('lua_vtab*', pVtab) - local vfunc_data_index = tonumber(pTable.index) - - vfunc_data.remove(vfunc_data_index) - + -- Don't remove from shared_vtab_data cache - data may be reused on next query sqlite_api.free(pVtab); return SQLITE.OK; end , @@ -803,104 +1437,102 @@ local lua_vtable_module = sqlite3_module_t{ return SQLITE.OK; end, - --sqlite3_vtab * pVtab, sqlite3_vtab_cursor **ppCursor - xOpen = function( pVtab, ppCursor) - local pTable = ffi.cast('lua_vtab*', pVtab) - local input = vfunc_data.get(pTable.index).input - local first_index = next(input.rows, nil) - - local cursor_index = vfunc_cur.new_key({input = input, iterator = first_index, rowid = 1}) - - local cur_sz = ffi.sizeof(vtab_cursor_t) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', sqlite_api.malloc(cur_sz)) - - if( cur_sqlptr==nil ) then - return SQLITE.NOMEM - end - - cur_sqlptr.index = cursor_index - cur_sqlptr.base.pVtab = pVtab - - ppCursor[0] = ffi.cast('sqlite3_vtab_cursor *', cur_sqlptr) - - return SQLITE.OK; - end, - - --sqlite3_vtab_cursor * - xClose = function(cur) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) - vfunc_cur.remove(tonumber(cur_sqlptr.index)) - sqlite_api.free(cur); - return SQLITE.OK; - end, - - xFilter = function (pVtabCursor, idxNum, idxStrUnused, argc, argv) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', pVtabCursor) - local cursor = vfunc_cur.get(cur_sqlptr.index) - - cursor.iterator = next(cursor.input.rows, nil) - cursor.rowid = 1 - - return SQLITE.OK; - end, - - --sqlite3_vtab_cursor * cur - xNext = function(cur) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) - local cursor = vfunc_cur.get(cur_sqlptr.index) - - cursor.iterator = next(cursor.input.rows, cursor.iterator) - cursor.rowid = cursor.rowid + 1 - - return SQLITE.OK; - end, - - --sqlite3_vtab_cursor * cur - xEof = function(cur) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) - local cursor = vfunc_cur.get(cur_sqlptr.index) + -- Cursor callbacks are set dynamically in create_vtable_functions + -- to use C wrappers that route through vtable_vm + xOpen = nil, + xClose = nil, + xFilter = nil, + xNext = nil, + xEof = nil, + xColumn = nil, + xRowid = nil +} - return not cursor.iterator - end, +local quote_ident_sqlite = function(name) + if type(name) ~= "string" or name == "" then + return error("identifier must be a non-empty string") + end + local escaped = sqlite_api.mprintf("%w", name) + if escaped == nil then + return error("identifier quote failed: out of memory") + end + local quoted = '"' .. ffi.string(escaped) .. '"' + sqlite_api.free(escaped) + return quoted +end - --sqlite3_vtab_cursor *cur, sqlite3_context *ctx, int col - xColumn = function(cur, ctx, col) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) - local cursor = vfunc_cur.get(cur_sqlptr.index) +local normalize_make_vtable_input = function(table_name, input) + if type(table_name) ~= "string" or table_name == "" then + return error("make_vtable: table_name must be a non-empty string") + end + if type(input) ~= "table" then + return error("make_vtable: input must be a table") + end + if type(input.columns) ~= "table" or #input.columns == 0 then + return error("make_vtable: columns must be a non-empty array") + end + local rows = input.rows + if rows == nil then + rows = {} + end + if type(rows) ~= "table" then + return error("make_vtable: rows must be a table") + end - -- zero index to 1 index - local data = cursor.input.rows[cursor.iterator][col + 1] - api_function_return_any(ctx, data) + local seen = {} + for i, col in ipairs(input.columns) do + if type(col) ~= "string" then + return error("make_vtable: column names must be non-empty strings") + end + if col == "" then + return error("identifier must be a non-empty string") + end + local normalized = string.lower(col) + if seen[normalized] then + return error("make_vtable: duplicate column name [" .. col .. "]") + end + seen[normalized] = true + end - return SQLITE.OK - end, + return { + columns = input.columns, + rows = rows + } +end - --sqlite3_vtab_cursor *cur, --sqlite_int64 *pRowid - xRowid = function(cur, pRowid) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) - local cursor = vfunc_cur.get(cur_sqlptr.index) +local make_vtable = function(table_name, input) + local normalized_input = normalize_make_vtable_input(table_name, input) - pRowid[0] = cursor.rowid - return SQLITE.OK; + local quoted_columns = {} + for i, col in ipairs(normalized_input.columns) do + quoted_columns[i] = quote_ident_sqlite(col) end -} - -local make_vtable = function(table_name, input) - local column_names = table.concat(input.columns, ', ') + local column_names = table.concat(quoted_columns, ', ') local xtable_text = string.format("CREATE TABLE x(%s)", column_names) - - local key = vfunc_data.new_key({ + local quoted_table_name = quote_ident_sqlite(table_name) + + -- Store vtab metadata in shared memory so vtable_vm can access it + local key = shared_vtab_data.new_key({ xtable_text = xtable_text, - input = input, - module = lua_vtable_module -- anchor to keep it alive + input = normalized_input, -- normalized input data in shared storage + module_type = "make_vtable" }) - local rc = sqlite_api.create_module_v2(sqlite_db, table_name, lua_vtable_module, ffi.cast('void*', key), nil); - if rc ~= SQLITE.OK then - return error("setup module failed") + local module_aux = make_vtable_modules[table_name] + if module_aux == nil then + module_aux = ffi.new("int64_t[1]", key) + local rc = sqlite_api.create_module_v2(sqlite_db, table_name, lua_vtable_module, ffi.cast('void*', module_aux), nil); + if rc ~= SQLITE.OK then + -- Shared storage is append-only; clear local cache entry at least. + shared_vtab_data.remove(key) + return error("setup module failed") + end + make_vtable_modules[table_name] = module_aux + else + module_aux[0] = key end - local create_vtable_text = "CREATE VIRTUAL TABLE TEMP." .. table_name .. " USING " .. table_name .. "();" + local create_vtable_text = "CREATE VIRTUAL TABLE TEMP." .. quoted_table_name .. " USING " .. quoted_table_name .. "();" run_sql(create_vtable_text); end public_env.make_vtable = make_vtable @@ -909,89 +1541,219 @@ local lua_vtable_module2 = sqlite3_module_t{ iVersion = 0, xConnect = function (db, pAux, argc, argv, ppVTab, pzErrUnused) - local vfunc_data_index = tonumber(ffi.cast('int64_t', pAux)) - local xtable_text = vfunc_data.get(vfunc_data_index).xtable_text + local shared_key = tonumber(ffi.cast('int64_t', pAux)) + local vtab_data = shared_vtab_data.get(shared_key) + if not vtab_data or not vtab_data.xtable_text then + if pzErrUnused ~= nil then + pzErrUnused[0] = sqlite_api.mprintf("lua_vtab: shared metadata missing") + end + return SQLITE.ERROR + end + local xtable_text = vtab_data.xtable_text local rc = sqlite_api.declare_vtab(db, xtable_text) --"CREATE TABLE x(...)"; if rc == SQLITE.OK then local pTable = ffi.cast('lua_vtab*', sqlite_api.malloc(ffi.sizeof('lua_vtab'))) - + if( pTable==nil ) then return SQLITE.NOMEM end - + pTable.base.pModule = nil; pTable.base.nRef = 0 pTable.base.zErrMsg = nil - pTable.index = vfunc_data_index - + pTable.index = shared_key -- store shared_key for vtable_vm to access + pTable.conn_slot = plugin_init_data.conn_slot + ppVTab[0] = ffi.cast('sqlite3_vtab*', pTable) sqlite_api.vtab_config(db, SQLITE.VTAB_INNOCUOUS); end return rc - + end, xDisconnect = function(pVTab) - local pTable = ffi.cast('lua_vtab*', pVTab) - local vfunc_data_index = tonumber(pTable.index) - - vfunc_data.remove(vfunc_data_index) - + -- Don't remove from shared_vtab_data cache - data may be reused on next query sqlite_api.free(pVTab); return SQLITE.OK; end, + xDestroy = function(pVTab) + local pTable = ffi.cast('lua_vtab*', pVTab) + local shared_key = tonumber(pTable.index) + shared_vtab_data.remove(shared_key) + sqlite_api.free(pVTab) + return SQLITE.OK + end, + xBestIndex = function (pVTab, pIdxInfo) local pTable = ffi.cast('lua_vtab*', pVTab) - local vfunc_data_index = tonumber(pTable.index) + local shared_key = tonumber(pTable.index) + local vtab_data = shared_vtab_data.get(shared_key) + if not vtab_data then + pVTab.zErrMsg = sqlite_api.mprintf("lua_vtab: shared metadata missing") + return SQLITE.ERROR + end local pConstraint = pIdxInfo.aConstraint; + local code_column_idx = vtab_data.code_column_idx + local has_code_eq = false - local code_column_idx = vfunc_data.get(vfunc_data_index).code_column_idx - - local constraint_code_idx = 0 for i = 0, pIdxInfo.nConstraint - 1 do local iColumn = pConstraint[i].iColumn - pIdxInfo.aConstraintUsage[i].argvIndex = i + 1 + pIdxInfo.aConstraintUsage[i].argvIndex = 0 pIdxInfo.aConstraintUsage[i].omit = 0 - if iColumn == code_column_idx then - constraint_code_idx = i + if (not has_code_eq) + and iColumn == code_column_idx + and pConstraint[i].usable == 1 + and pConstraint[i].op == SQLITE.INDEX_CONSTRAINT_EQ then + has_code_eq = true + pIdxInfo.aConstraintUsage[i].argvIndex = 1 pIdxInfo.aConstraintUsage[i].omit = 1 end end - pIdxInfo.idxNum = constraint_code_idx + pIdxInfo.idxNum = has_code_eq and 1 or 0 + pIdxInfo.estimatedCost = has_code_eq and 10 or 1000000 + pIdxInfo.estimatedRows = has_code_eq and 16 or 1048576 return SQLITE.OK; end, - xOpen = function( pVTab, ppCursor) - local pTable = ffi.cast('lua_vtab*', pVTab) - local code_column_idx = vfunc_data.get(pTable.index).code_column_idx + -- All cursor callbacks are set dynamically in create_vtable_functions + -- to use C wrappers that route through vtable_vm + xOpen = nil, + xClose = nil, + xFilter = nil, + xNext = nil, + xEof = nil, + xColumn = nil, + xRowid = nil +} + +-- Keep module structs strongly referenced for SQLite module lifetime. +plugin._vtable_module_anchors = { + lua_vtable_module, + lua_vtable_module2, +} + +local list_iterator = function(t) + local i = 0 + local n = #t + return function () + i = i + 1 + if i <= n then return t[i] end + end +end - local getter - if code_column_idx > 1 then - getter = function (self, col) - return self.value[1][col] +-- Setup vtable Lua callbacks (only needed in vtable_vm) +local setup_vtable_callbacks = function() + -- Define all cursor callbacks as standalone functions + -- These handle both lua_vtable_module (make_vtable) and lua_vtable_module2 (L/L10) + + local function set_vtab_error(cur_sqlptr, err) + local msg + if type(err) == 'table' then + msg = tostring(err.message) + if err.detail and err.detail ~= '' then + msg = msg .. '\n' .. err.detail end else - getter = function (self, _) - return self.value[1] + msg = tostring(err) + end + + cur_sqlptr.base.pVtab.zErrMsg = sqlite_api.mprintf("%s", msg) + return SQLITE.ERROR + end + + local function is_make_vtable_array_key(k, array_len) + return type(k) == "number" and k == math.floor(k) and k >= 1 and k <= array_len + end + + local function advance_make_vtable_extra(cursor_data) + while true do + local k, row = next(cursor_data.input.rows, cursor_data.extra_key) + cursor_data.extra_key = k + if k == nil then + cursor_data.current_row = nil + return + end + if (not is_make_vtable_array_key(k, cursor_data.array_len)) and type(row) == "table" then + cursor_data.current_row = row + return end end - - local cursor_index = vfunc_cur.new_key({rowid = 1, getter = getter}) + end + + local function reset_make_vtable_cursor(cursor_data) + cursor_data.array_len = #(cursor_data.input.rows) + cursor_data.array_idx = 1 + cursor_data.extra_key = nil + while cursor_data.array_idx <= cursor_data.array_len do + local row = cursor_data.input.rows[cursor_data.array_idx] + if type(row) == "table" then + cursor_data.current_row = row + return + end + cursor_data.array_idx = cursor_data.array_idx + 1 + end + advance_make_vtable_extra(cursor_data) + end + + local function advance_make_vtable_cursor(cursor_data) + cursor_data.array_idx = (cursor_data.array_idx or 1) + 1 + while cursor_data.array_idx <= cursor_data.array_len do + local row = cursor_data.input.rows[cursor_data.array_idx] + if type(row) == "table" then + cursor_data.current_row = row + return + end + cursor_data.array_idx = cursor_data.array_idx + 1 + end + advance_make_vtable_extra(cursor_data) + end + + local xOpen_lua = function(pVTab, ppCursor) + local pTable = ffi.cast('lua_vtab*', pVTab) + local shared_key = tonumber(pTable.index) + local vtab_data = shared_vtab_data.get(shared_key) + if not vtab_data then + pVTab.zErrMsg = sqlite_api.mprintf("lua_vtab: shared metadata missing") + return SQLITE.ERROR + end + + local cursor_data = {rowid = 1} + + if vtab_data.code_column_idx then + -- L/L10 module: set up getter for iterator results + local code_column_idx = vtab_data.code_column_idx + if code_column_idx > 1 then + cursor_data.getter = function(self, col) + return self.value[1][col] + end + else + cursor_data.getter = function(self, _) + return self.value[1] + end + end + else + -- make_vtable module: input is stored directly in shared data + local input = vtab_data.input + cursor_data.input = input + reset_make_vtable_cursor(cursor_data) + end + + local cursor_index = vfunc_cur.new_key(cursor_data) local cur_sz = ffi.sizeof(vtab_cursor_t) - local cur_sqlptr = ffi.cast('lua_vtab_cursor *', sqlite_api.malloc(cur_sz)) + local cur_sqlptr = ffi.cast('lua_vtab_cursor *', sqlite_api.malloc(cur_sz)) - if( cur_sqlptr==nil ) then - return SQLITE.NOMEM + if cur_sqlptr == nil then + return SQLITE.NOMEM end cur_sqlptr.index = cursor_index @@ -999,86 +1761,234 @@ local lua_vtable_module2 = sqlite3_module_t{ ppCursor[0] = ffi.cast('sqlite3_vtab_cursor *', cur_sqlptr) - return SQLITE.OK; - end, + return SQLITE.OK + end - xClose = function(cur) + local xClose_lua = function(cur) local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) vfunc_cur.remove(tonumber(cur_sqlptr.index)) - sqlite_api.free(cur); - return SQLITE.OK; - end, + sqlite_api.free(cur) + return SQLITE.OK + end - xFilter = function (pVtabCursor, idxNum, idxStrUnused, argc, argv) - local code_field = ffi.string(sqlite_api.value_text(argv[idxNum])) + local xFilter_lua = function(pVtabCursor, idxNum, idxStrUnused, argc, argv) local cur_sqlptr = ffi.cast('lua_vtab_cursor *', pVtabCursor) local cursor = vfunc_cur.get(cur_sqlptr.index) - cursor.fn = exec_lua(code_field) - cursor.value = {cursor.fn()} + if cursor == nil then + -- stale handle or double-close; treat as empty result + return SQLITE.OK + end + + if cursor.input then + -- make_vtable module: reset iterator + reset_make_vtable_cursor(cursor) + else + -- L/L10 module: execute code when provided; otherwise yield nothing + local missing = (idxNum ~= 1) or (argv == nil or argc < 1) + if missing then + cursor.fn = function() return nil end + else + local code_value = argv[0] + if code_value ~= nil and code_value ~= ffi.NULL then + local code_field = ffi.string(sqlite_api.value_text(code_value)) + local results + local ok, err = xpcall(function() + results = {exec_lua(code_field)} + end, error_xcall) + if not ok then + return set_vtab_error(cur_sqlptr, err) + end + + if type(results[1]) == "table" then + results = list_iterator(results[1]) + elseif type(results[1]) == "function" then + results = results[1] + else + results = list_iterator(results) + end + + cursor.fn = results + else + cursor.fn = function() return nil end + end + end + local ok, err = xpcall(function() + cursor.value = {cursor.fn()} + end, error_xcall) + if not ok then + return set_vtab_error(cur_sqlptr, err) + end + end cursor.rowid = 1 - return SQLITE.OK; - end, + return SQLITE.OK + end - xNext = function(cur) + local xNext_lua = function(cur) local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) local cursor = vfunc_cur.get(cur_sqlptr.index) - cursor.value = {cursor.fn()} - + if cursor == nil then + cur_sqlptr.base.pVtab.zErrMsg = sqlite_api.mprintf("lua_vtab: cursor missing") + return SQLITE.MISUSE + end + + if cursor.input then + -- make_vtable module: advance iterator + advance_make_vtable_cursor(cursor) + else + -- L/L10 module: call iterator function + local ok, err = xpcall(function() + cursor.value = {cursor.fn()} + end, error_xcall) + if not ok then + return set_vtab_error(cur_sqlptr, err) + end + end + cursor.rowid = cursor.rowid + 1 - return SQLITE.OK; - end, + return SQLITE.OK + end - xEof = function(cur) + local xEof_lua = function(cur) local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) local cursor = vfunc_cur.get(cur_sqlptr.index) - return #cursor.value == 0 - end, + if cursor == nil then + return 1 -- treat as exhausted instead of crashing + end - xColumn = function(cur, ctx, col) + if cursor.input then + -- make_vtable module: check if iterator exhausted + return cursor.current_row == nil and 1 or 0 + else + -- L/L10 module: check if value is empty + return #cursor.value == 0 and 1 or 0 + end + end + + local xColumn_lua = function(cur, ctx, col) local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) local cursor = vfunc_cur.get(cur_sqlptr.index) - api_function_return_any(ctx, cursor:getter(col + 1)) + if cursor == nil then + sqlite_api.result_error(ctx, "lua_vtab: cursor missing", -1) + return SQLITE.MISUSE + end + + if cursor.input then + -- make_vtable module: get column from current row + local row = cursor.current_row + local data + if type(row) == "table" then + data = row[col + 1] + else + data = nil + end + api_function_return_any(ctx, data) + else + -- L/L10 module: use getter + local ok, result = xpcall(function() + return cursor:getter(col + 1) + end, error_xcall) + if not ok then + return set_vtab_error(cur_sqlptr, result) + end + + api_function_return_any(ctx, result) + end - -- only when pIdxInfo.aConstraintUsage[...].omit = 0 - -- api_function_return_any(ctx, cursor.filter) return SQLITE.OK - end, + end - xRowid = function(cur, pRowid) + local xRowid_lua = function(cur, pRowid) local cur_sqlptr = ffi.cast('lua_vtab_cursor *', cur) local cursor = vfunc_cur.get(cur_sqlptr.index) + if cursor == nil then + cur_sqlptr.base.pVtab.zErrMsg = sqlite_api.mprintf("lua_vtab: cursor missing") + return SQLITE.MISUSE + end pRowid[0] = cursor.rowid - return SQLITE.OK; + return SQLITE.OK end -} -local create_vtable_functions = function() + -- Keep FFI callback cdata strongly referenced for VM lifetime. + local vtab_cbs = { + xOpen = ffi.cast('int (*)(sqlite3_vtab*, sqlite3_vtab_cursor**)', xOpen_lua), + xClose = ffi.cast('int (*)(sqlite3_vtab_cursor*)', xClose_lua), + xFilter = ffi.cast('int (*)(sqlite3_vtab_cursor*, int, const char*, int, sqlite3_value**)', xFilter_lua), + xNext = ffi.cast('int (*)(sqlite3_vtab_cursor*)', xNext_lua), + xEof = ffi.cast('int (*)(sqlite3_vtab_cursor*)', xEof_lua), + xColumn = ffi.cast('int (*)(sqlite3_vtab_cursor*, sqlite3_context*, int)', xColumn_lua), + xRowid = ffi.cast('int (*)(sqlite3_vtab_cursor*, sqlite_int64*)', xRowid_lua), + } + plugin._vtable_lua_callback_anchors = vtab_cbs + + -- Store Lua callbacks in plugin_init_data (C code will call these from vtable_vm) + plugin_init_data.vtab_xOpen_lua = vtab_cbs.xOpen + plugin_init_data.vtab_xClose_lua = vtab_cbs.xClose + plugin_init_data.vtab_xFilter_lua = vtab_cbs.xFilter + plugin_init_data.vtab_xNext_lua = vtab_cbs.xNext + plugin_init_data.vtab_xEof_lua = vtab_cbs.xEof + plugin_init_data.vtab_xColumn_lua = vtab_cbs.xColumn + plugin_init_data.vtab_xRowid_lua = vtab_cbs.xRowid +end + +-- Setup vtable module callbacks to use C wrappers (needed at depth 1 before create_module_v2) +local setup_vtable_module_callbacks = function() + -- Set both modules to use C wrappers (which route through vtable_vm) + lua_vtable_module.xOpen = plugin_init_data.cb_vtab_xOpen + lua_vtable_module.xClose = plugin_init_data.cb_vtab_xClose + lua_vtable_module.xFilter = plugin_init_data.cb_vtab_xFilter + lua_vtable_module.xNext = plugin_init_data.cb_vtab_xNext + lua_vtable_module.xEof = plugin_init_data.cb_vtab_xEof + lua_vtable_module.xColumn = plugin_init_data.cb_vtab_xColumn + lua_vtable_module.xRowid = plugin_init_data.cb_vtab_xRowid + + lua_vtable_module2.xOpen = plugin_init_data.cb_vtab_xOpen + lua_vtable_module2.xClose = plugin_init_data.cb_vtab_xClose + lua_vtable_module2.xFilter = plugin_init_data.cb_vtab_xFilter + lua_vtable_module2.xNext = plugin_init_data.cb_vtab_xNext + lua_vtable_module2.xEof = plugin_init_data.cb_vtab_xEof + lua_vtable_module2.xColumn = plugin_init_data.cb_vtab_xColumn + lua_vtable_module2.xRowid = plugin_init_data.cb_vtab_xRowid +end - local key = vfunc_data.new_key({ - module = lua_vtable_module2, -- anchor to keep it alive +-- Create L and L10 vtable modules (only needed once at depth 1) +local create_vtable_modules = function() + local key = shared_vtab_data.new_key({ xtable_text = "CREATE TABLE x(value, code hidden)", - code_column_idx = 1 + code_column_idx = 1, + module_type = "L" }) - local rc = sqlite_api.create_module_v2(sqlite_db, "L", lua_vtable_module2, ffi.cast('void*', key), nil); + local rc = sqlite_api.create_module_v2( + sqlite_db, + "L", + lua_vtable_module2, + ffi.cast('void*', key), + nil + ); if rc ~= SQLITE.OK then return error("setup L module failed") end - key = vfunc_data.new_key({ - module = lua_vtable_module2, -- anchor to keep it alive + key = shared_vtab_data.new_key({ xtable_text = "CREATE TABLE x(r0,r1,r2,r3,r4,r5,r6,r7,r8,r9, code hidden)", - code_column_idx = 10 + code_column_idx = 10, + module_type = "L10" }) - local rc = sqlite_api.create_module_v2(sqlite_db, 'L10', lua_vtable_module2, ffi.cast('void*', key), nil); + local rc = sqlite_api.create_module_v2( + sqlite_db, + "L10", + lua_vtable_module2, + ffi.cast('void*', key), + nil + ); if rc ~= SQLITE.OK then return error("setup L10 module failed") end @@ -1094,6 +2004,19 @@ local copy_to_plugin = function (map) end end +local __depth__ +plugin.extension_call = function () +end + + +local text_caller_fn = function(context, argc, argv) + --call function stored as text, no depth limit + local saved_ref = tonumber(ffi.cast('int64_t', sqlite_api.user_data(context))) + local fn = function_refs_text.get(saved_ref) + local result = fn(api_get_args(context, argc, argv)) + + api_function_return_any(context, result) +end plugin.extension_init = function ( ctx ) copy_to_global(public_env) @@ -1102,15 +2025,44 @@ plugin.extension_init = function ( ctx ) plugin.extension_init = nil plugin.extension_deinit = nil + plugin.extension_call = nil - local plugin_init_data = ffi.cast('LJFunctionData *', ctx) + plugin_init_data = ffi.cast('LJFunctionData *', ctx) sqlite_db = plugin_init_data.db sqlite_api = plugin_init_data.api + __depth__ = plugin_init_data.call_depth - create_vtable_functions() - create_function("L", exec_lua, -1) + plugin_init_data.caller_fn = wrap_csafe_cs(text_caller_fn) + if plugin_init_data.is_vtable_vm == 1 then + setup_vtable_callbacks() + return + end + + --package.loaded["sqlite_lj"] = plugin + if __depth__ > 1 then + return + end + + setup_vtable_module_callbacks() + create_vtable_modules() + + make_fn('L', [[ + return function (code_text, ...) + local name = "temp_fn" + local fn_env = {} + setmetatable(fn_env, { __index = _G }) + fn_env["arg"] = {...} + + local fn, err = loadstring(code_text, name, "t", fn_env) + if not fn then + local msg = "Create failed [".. name .. "]\n" .. tostring(err) + return error(msg) + end + return fn() + end + ]]); end plugin.extension_deinit = function () diff --git a/src/sqlite_shared.lua b/src/sqlite_shared.lua new file mode 100644 index 0000000..e91b34c --- /dev/null +++ b/src/sqlite_shared.lua @@ -0,0 +1,203 @@ +local ffi = require('ffi') +local C = ffi.C + +local plugin = {} +local cache_index = 1 + +local function read_limit(name, default) + local raw = os.getenv(name) + local n = raw and tonumber(raw) or nil + if n and n > 0 then + return math.floor(n) + end + return default +end + +local LIMITS = { + max_buffer_bytes = read_limit("SQLITE_LJ_MAX_BUFFER_BYTES", 64 * 1024 * 1024), + max_object_bytes = read_limit("SQLITE_LJ_MAX_OBJECT_BYTES", 4 * 1024 * 1024), + max_objects = read_limit("SQLITE_LJ_MAX_OBJECTS", 200000), + max_function_contexts = read_limit("SQLITE_LJ_MAX_FUNCTION_CONTEXTS", 50000), +} + +ffi.cdef[[ +typedef long long int int64_t; +typedef struct { + void* fn_ptr; + void* fn_step_ptr; + void* fn_final_ptr; + void* fn_destroy_ptr; + int64_t udata; + int allowed_nested; +} FunctionContext; +typedef int (*set_object_cb_t)(const char* data, size_t len); +typedef const char* (*get_object_cb_t)(int id, size_t* len); +typedef int (*set_fn_context_cb_t)(const FunctionContext* ctx); +typedef const FunctionContext* (*get_fn_context_cb_t)(int64_t id); +typedef struct { + set_object_cb_t set_object; + get_object_cb_t get_object; + set_fn_context_cb_t set_fn_context; + get_fn_context_cb_t get_fn_context; +} SharedBridge; +]] + +-- Extendable buffer manager +local BufferManager = {} +BufferManager.__index = BufferManager + +function BufferManager.new(size) + local self = setmetatable({}, BufferManager) + self.size = size or (1024*1024) + local ok, buf = pcall(ffi.new, "uint8_t[?]", self.size) + if not ok then + return nil, "shared buffer allocation failed" + end + self.buf = buf + self.offset = 0 + self.registry = {} + self.object_count = 0 + return self +end + +function BufferManager:ensure_capacity(additional) + local needed = self.offset + additional + if needed > LIMITS.max_buffer_bytes then + return nil, "shared buffer max bytes exceeded" + end + + if needed > self.size then + local new_size = math.max(self.size * 2, needed) + if new_size > LIMITS.max_buffer_bytes then + new_size = LIMITS.max_buffer_bytes + end + local ok, new_buf = pcall(ffi.new, "uint8_t[?]", new_size) + if not ok then + return nil, "shared buffer grow failed" + end + ffi.copy(new_buf, self.buf, self.offset) + self.buf = new_buf + self.size = new_size + end + return true +end + +function BufferManager:store(id, data) + local len = #data + if len > LIMITS.max_object_bytes then + return nil, "shared object exceeds max bytes" + end + if self.object_count >= LIMITS.max_objects then + return nil, "shared object count limit exceeded" + end + + local ok, err = self:ensure_capacity(len) + if not ok then + return nil, err + end + ffi.copy(self.buf + self.offset, data, len) + self.registry[id] = { start = self.offset, len = len } + self.offset = self.offset + len + self.object_count = self.object_count + 1 + return true +end + +function BufferManager:retrieve(id) + local e = self.registry[id] + if not e then return nil end + return ffi.string(self.buf + e.start, e.len) +end + +-- singleton manager +local manager, manager_err = BufferManager.new() +if not manager then + error(manager_err) +end + +-- Lua functions to expose +local function setObject(data, len) + if len > LIMITS.max_object_bytes then + return 0 + end + if cache_index > 0x7FFFFFFF or cache_index > LIMITS.max_objects then + return 0 + end + local id = cache_index + + local s = ffi.string(data, len) + local ok = manager:store(id, s) + if not ok then + return 0 + end + cache_index = cache_index + 1 + return id +end + +local function getObjectString(id, size_out) + if type(id) ~= 'number' or id <= 0 or id ~= math.floor(id) then + size_out[0] = 0 + return nil + end + local s = manager:retrieve(id) + if not s then + size_out[0] = 0 + return nil + end + size_out[0] = #s + return s +end + +local function_contexts = {} +local function_contexts_next_id = 1 + +local function setFunctionContext(ctx) + if function_contexts_next_id > LIMITS.max_function_contexts then + return 0 + end + local id = function_contexts_next_id + function_contexts_next_id = function_contexts_next_id + 1 + local fc = ffi.new("FunctionContext[1]") + ffi.copy(fc, ctx, ffi.sizeof("FunctionContext")) + function_contexts[id] = fc + return id +end + +local function getFunctionContext(id) + id = tonumber(id) + if not id then + return nil + end + local fc = function_contexts[id] + if not fc then + return nil + end + return fc +end + +-- -- Create C struct with function pointers +-- local bridge = ffi.new("SharedBridge") +-- bridge.set_object = ffi.cast("set_object_cb_t", setObject) +-- bridge.get_object = ffi.cast("get_object_cb_t", getObjectString) + +plugin.init = function(bridge_ptr) + local bridge = ffi.cast('SharedBridge*', bridge_ptr) + --print('Bridge set', bridge) + local cb_set_object = ffi.cast("set_object_cb_t", setObject) + local cb_get_object = ffi.cast("get_object_cb_t", getObjectString) + local cb_set_fn_context = ffi.cast("set_fn_context_cb_t", setFunctionContext) + local cb_get_fn_context = ffi.cast("get_fn_context_cb_t", getFunctionContext) + + bridge.set_object = cb_set_object + bridge.get_object = cb_get_object + bridge.set_fn_context = cb_set_fn_context + bridge.get_fn_context = cb_get_fn_context + + -- Keep callbacks alive + plugin._bridge_callbacks = { + set_object = cb_set_object, + get_object = cb_get_object, + set_fn_context = cb_set_fn_context, + get_fn_context = cb_get_fn_context + } +end +return plugin diff --git a/test.py b/test.py new file mode 100644 index 0000000..881829b --- /dev/null +++ b/test.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +import argparse +import os +import re +import subprocess +import sys +from pathlib import Path + + +LIMIT_ENV_BY_TEST = { + "016": { + "SQLITE_LJ_MAX_OBJECTS": "6", + "SQLITE_LJ_MAX_OBJECT_BYTES": "1048576", + "SQLITE_LJ_MAX_BUFFER_BYTES": "67108864", + "SQLITE_LJ_MAX_FUNCTION_CONTEXTS": "50000", + }, + "017": { + "SQLITE_LJ_MAX_FUNCTION_CONTEXTS": "5", + "SQLITE_LJ_MAX_OBJECTS": "200000", + "SQLITE_LJ_MAX_BUFFER_BYTES": "67108864", + "SQLITE_LJ_MAX_OBJECT_BYTES": "4194304", + }, + "018": { + "SQLITE_LJ_MAX_BUFFER_BYTES": "200000", + "SQLITE_LJ_MAX_OBJECT_BYTES": "120000", + "SQLITE_LJ_MAX_OBJECTS": "1000", + "SQLITE_LJ_MAX_FUNCTION_CONTEXTS": "50000", + }, +} + +SQLITE3_BIN = os.getenv("SQLITE3_BIN", "sqlite3") + + +GREEN = "\033[0;32m" +RED = "\033[0;31m" +NC = "\033[0m" + + +def mark(text: str, ok: bool) -> str: + prefix = "PASSED" if ok else "FAILED" + color = GREEN if ok else RED + return f"{color}{text}: {prefix}{NC}" + + +def normalize_text(text: str) -> str: + return text.replace("\r\n", "\n").replace("\r", "\n") + + +def shell_like_text(text: str) -> str: + # test.sh uses command substitution, which drops trailing newlines. + return normalize_text(text).rstrip("\n") + + +def extension_base_path() -> str: + return "./libsqlite_plugin_lj" + + +def strip_sql_load_statements(sql: str) -> str: + out_lines: list[str] = [] + for line in sql.splitlines(): + if re.match(r"^\s*\.load\b", line): + continue + if re.match(r"^\s*select\s+load_extension\s*\(", line, flags=re.IGNORECASE): + continue + out_lines.append(line) + return "\n".join(out_lines) + ("\n" if sql.endswith("\n") else "") + + +def execute_sqlite(sql: str, env: dict[str, str], extra_args: list[str] | None = None) -> tuple[int, str]: + args = [SQLITE3_BIN, ":memory:"] + if extra_args: + args.extend(extra_args) + proc = subprocess.run( + args, + input=sql, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + return proc.returncode, normalize_text(proc.stdout) + + +def resolve_fixture_path(root_dir: Path, flat_path: Path, filename: str) -> Path | None: + if flat_path.exists(): + return flat_path + + matches = sorted(root_dir.rglob(filename)) + if not matches: + return None + return matches[0] + + +def run_sqlite(input_file: Path, extra_env: dict[str, str] | None = None) -> tuple[int, str]: + sql = input_file.read_text(encoding="utf-8") + env = os.environ.copy() + if extra_env: + env.update(extra_env) + + code, output = execute_sqlite(sql, env) + if code == 0: + return code, output + + needs_load_fallback = ( + "no such function: load_extension" in output + or 'unknown command or invalid arguments: "load"' in output + ) + if not needs_load_fallback: + return code, output + + stripped_sql = strip_sql_load_statements(sql) + fallback_args = ["-cmd", f".load {extension_base_path()}"] + fb_code, fb_output = execute_sqlite(stripped_sql, env, fallback_args) + if fb_code == 0: + return fb_code, fb_output + return code, output + + +def run_suite(test_ids: str) -> int: + failed = 0 + root = Path.cwd() + + for test_id in test_ids.split(";"): + test_id = test_id.strip() + if not test_id: + continue + + print(f"Run test: {test_id}") + + input_file = resolve_fixture_path( + root / "sql", root / "sql" / f"input_{test_id}.sql", f"input_{test_id}.sql" + ) + if input_file is None: + print(f"Test [{test_id}]: FAILED") + print(f"Missing SQL fixture: input_{test_id}.sql") + failed += 1 + continue + + expected_file = resolve_fixture_path( + root / "expected", + root / "expected" / f"output_{test_id}.txt", + f"output_{test_id}.txt", + ) + if expected_file is None: + print(f"Test [{test_id}]: FAILED") + print(f"Missing expected fixture: output_{test_id}.txt") + failed += 1 + continue + + extra_env = LIMIT_ENV_BY_TEST.get(test_id, {}) + _, output_text = run_sqlite(input_file, extra_env) + expected_text = normalize_text(expected_file.read_text(encoding="utf-8")) + + output_cmp = shell_like_text(output_text) + expected_cmp = shell_like_text(expected_text) + + if output_cmp == expected_cmp: + print(mark(f"Test [{test_id}]", True)) + else: + print(mark(f"Test [{test_id}]", False)) + print(f"Expected: [{expected_cmp}]") + print(f"Actual: [{output_cmp}]") + failed += 1 + + if failed > 0: + print(f"{RED}{failed} test(s) failed.{NC}") + return 1 + return 0 + + +def run_sql_file(sql_file: Path) -> int: + code, out = run_sqlite(sql_file, None) + text = out.strip("\n") + if text: + print(text) + return code + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--tests", help="Semicolon-separated test ids") + parser.add_argument("--sql-file", help="SQL file to run as a single test") + args = parser.parse_args() + + if bool(args.tests) == bool(args.sql_file): + print("Use exactly one of --tests or --sql-file.", file=sys.stderr) + return 2 + + if args.tests: + return run_suite(args.tests) + return run_sql_file(Path(args.sql_file)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/test.sh b/test.sh deleted file mode 100755 index 62b2d07..0000000 --- a/test.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -# Define SQLite database file -DATABASE_FILE=":memory:" -DATABASE_FILE="test.db" - -rm -f $DATABASE_FILE - -GREEN='\033[0;32m' -RED='\033[0;31m' -NC='\033[0m' # No Color - -TESTS_TO_COMPARE=("$1") - -FAILED_TESTS=0 - -# Function to execute SQLite CLI with queries and compare outputs -function run_tests() { - IFS=';' - for item in $TESTS_TO_COMPARE; do - echo "Run test: $item" - OUTPUT=$(sqlite3 $DATABASE_FILE < sql/input_"$item".sql) - EXPECTED=$(cat expected/output_"$item".txt) - if [ "$OUTPUT" == "$EXPECTED" ]; then - echo -e "${GREEN}Test [$item]: PASSED${NC}" - else - echo -e "${RED}Test [$item]: FAILED${NC}" - echo "Expected: [$EXPECTED]" - echo "Actual: [$OUTPUT]" - ((FAILED_TESTS++)) - fi - - done - - # Check if any test failed and return error if so - if [ $FAILED_TESTS -gt 0 ]; then - echo "$FAILED_TESTS test(s) failed." - exit 1 - fi -} - -# Run tests -run_tests diff --git a/tests/multiconnection_isolation.py b/tests/multiconnection_isolation.py new file mode 100644 index 0000000..e6a5783 --- /dev/null +++ b/tests/multiconnection_isolation.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +import sqlite3 +import sys +from pathlib import Path + + +def resolve_extension_path(cwd: Path) -> Path: + candidates = [ + cwd / "libsqlite_plugin_lj.so", + cwd / "libsqlite_plugin_lj.dylib", + cwd / "libsqlite_plugin_lj.dll", + cwd / "sqlite_plugin_lj.dll", + ] + for path in candidates: + if path.exists(): + return path + raise FileNotFoundError( + "Could not find built extension in build dir. " + "Expected one of: " + ", ".join(str(p.name) for p in candidates) + ) + + +def scalar_int(conn: sqlite3.Connection, sql: str) -> int: + row = conn.execute(sql).fetchone() + if row is None or row[0] is None: + raise RuntimeError(f"Query returned no scalar value: {sql}") + return int(row[0]) + + +def require_l_function(conn: sqlite3.Connection, label: str, expected: int) -> None: + value = scalar_int(conn, f"select L('return {expected}')") + if value != expected: + raise AssertionError(f"{label}: expected {expected}, got {value}") + + +def main() -> int: + ext_path = resolve_extension_path(Path.cwd()).resolve() + + db1 = sqlite3.connect(":memory:") + db2 = sqlite3.connect(":memory:") + + try: + for db in (db1, db2): + db.enable_load_extension(True) + + db1.load_extension(str(ext_path)) + require_l_function(db1, "db1 initial", 1) + + db2.load_extension(str(ext_path)) + + # Both connections should keep independent usable plugin state. + require_l_function(db1, "db1 after db2 load", 2) + require_l_function(db2, "db2 after load", 3) + + # Verify DB isolation through plugin-executed SQL. + db1.execute( + "select L('sqlite.run_sql[[create table if not exists t_mc(v integer);" + "delete from t_mc; insert into t_mc(v) values(10);]]')" + ) + db2.execute( + "select L('sqlite.run_sql[[create table if not exists t_mc(v integer);" + "delete from t_mc; insert into t_mc(v) values(20);]]')" + ) + + c1 = scalar_int(db1, "select sum(v) from t_mc") + c2 = scalar_int(db2, "select sum(v) from t_mc") + if c1 != 10 or c2 != 20: + raise AssertionError(f"cross-db isolation failed: db1 sum={c1}, db2 sum={c2}") + + # Churn a second connection repeatedly while keeping db1 open. + for i in range(1, 34): + churn = sqlite3.connect(":memory:") + try: + churn.enable_load_extension(True) + churn.load_extension(str(ext_path)) + require_l_function(churn, f"churn db2 iter {i}", 1000 + i) + finally: + churn.close() + + require_l_function(db1, f"db1 after churn iter {i}", 2000 + i) + + print("PASS multiconnection isolation") + return 0 + except Exception as exc: # noqa: BLE001 + print(f"FAIL multiconnection isolation: {exc}") + return 1 + finally: + db1.close() + db2.close() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/multithread_isolated_connections.cpp b/tests/multithread_isolated_connections.cpp new file mode 100644 index 0000000..93bb89d --- /dev/null +++ b/tests/multithread_isolated_connections.cpp @@ -0,0 +1,457 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) +extern "C" int sqlite3_enable_load_extension(sqlite3* db, int onoff); +extern "C" int sqlite3_load_extension(sqlite3* db, const char* zFile, const char* zProc, char** pzErrMsg); +#endif + +namespace { + +constexpr int WORKER_COUNT = 6; +constexpr int MIN_LOOP_COUNT = 20; +constexpr auto DURATION = std::chrono::milliseconds(1500); + +struct SharedState { + bool has_error = false; + std::string message; + std::mutex lock; + + void set_error(std::string msg) { + std::scoped_lock guard(lock); + if (!has_error) { + has_error = true; + message = std::move(msg); + } + } +}; + +struct WorkerArgs { + int worker_id = 0; + std::string ext_path; + std::filesystem::path sql_root; + std::filesystem::path expected_root; + std::vector fixture_tests; + std::chrono::steady_clock::time_point deadline; + int min_loops = MIN_LOOP_COUNT; + SharedState* shared = nullptr; + int completed_loops = 0; + long long started_ms = 0; + long long finished_ms = 0; +}; + +long long monotonic_ms() { + using namespace std::chrono; + return duration_cast(steady_clock::now().time_since_epoch()).count(); +} + +std::string sqlite_error(sqlite3* db, const char* fallback = "sqlite error") { + if (!db) { + return fallback; + } + const char* msg = sqlite3_errmsg(db); + return msg ? msg : fallback; +} + +struct SqliteDb { + sqlite3* db = nullptr; + + ~SqliteDb() { + if (db) { + sqlite3_close(db); + db = nullptr; + } + } + + SqliteDb() = default; + SqliteDb(const SqliteDb&) = delete; + SqliteDb& operator=(const SqliteDb&) = delete; +}; + +int query_output_cb(void* udata, int argc, char** argv, char** /*colnames*/) { + auto* out = static_cast(udata); + for (int i = 0; i < argc; ++i) { + if (i > 0) { + out->push_back('|'); + } + if (argv[i]) { + out->append(argv[i]); + } + } + out->push_back('\n'); + return 0; +} + +bool read_text_file(const std::filesystem::path& path, std::string& out) { + std::ifstream in(path, std::ios::binary); + if (!in) { + return false; + } + std::ostringstream ss; + ss << in.rdbuf(); + out = ss.str(); + return true; +} + +std::string trim_right_newlines(std::string s) { + while (!s.empty() && (s.back() == '\n' || s.back() == '\r')) { + s.pop_back(); + } + return s; +} + +std::string normalize_for_compare(std::string s) { + std::replace(s.begin(), s.end(), '\r', '\n'); + + std::vector lines; + { + std::istringstream iss(s); + std::string line; + while (std::getline(iss, line)) { + if (!line.empty()) { + lines.push_back(line); + } + } + } + + std::string out; + for (size_t i = 0; i < lines.size(); ++i) { + if (i > 0) { + out.push_back('\n'); + } + out.append(lines[i]); + } + return trim_right_newlines(out); +} + +bool line_has_load_extension(std::string_view line) { + auto lower = std::string(line); + std::transform(lower.begin(), lower.end(), lower.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return lower.find("load_extension(") != std::string::npos; +} + +std::string strip_load_extension_lines(const std::string& sql) { + std::istringstream iss(sql); + std::string line; + std::string out; + bool first = true; + + while (std::getline(iss, line)) { + if (line_has_load_extension(line)) { + continue; + } + if (!first) { + out.push_back('\n'); + } + first = false; + out.append(line); + } + + if (!sql.empty() && sql.back() == '\n') { + out.push_back('\n'); + } + return out; +} + +bool scalar_int(sqlite3* db, const std::string& sql, int& out_value) { + sqlite3_stmt* stmt = nullptr; + int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr); + if (rc != SQLITE_OK) { + return false; + } + + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + sqlite3_finalize(stmt); + return false; + } + + out_value = sqlite3_column_int(stmt, 0); + rc = sqlite3_finalize(stmt); + return rc == SQLITE_OK; +} + +bool exec_sql(sqlite3* db, const std::string& sql, std::string& err) { + char* exec_err = nullptr; + const int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &exec_err); + if (rc == SQLITE_OK) { + return true; + } + + err = exec_err ? exec_err : sqlite_error(db); + if (exec_err) { + sqlite3_free(exec_err); + } + return false; +} + +bool load_extension(sqlite3* db, const std::string& ext_path, std::string& err) { + if (sqlite3_enable_load_extension(db, 1) != SQLITE_OK) { + err = "enable_load_extension failed: " + sqlite_error(db); + return false; + } + + char* load_err = nullptr; + const int rc = sqlite3_load_extension(db, ext_path.c_str(), nullptr, &load_err); + if (rc == SQLITE_OK) { + return true; + } + err = load_err ? std::string(load_err) : sqlite_error(db); + if (load_err) { + sqlite3_free(load_err); + } + return false; +} + +bool run_fixture_test(const WorkerArgs& arg, const std::string& test_id, std::string& err) { + const auto sql_path = arg.sql_root / ("input_" + test_id + ".sql"); + const auto expected_path = arg.expected_root / ("output_" + test_id + ".txt"); + + std::string sql_text; + std::string expected_text; + std::string actual_text; + + if (!read_text_file(sql_path, sql_text)) { + err = "cannot read SQL fixture: " + sql_path.string(); + return false; + } + if (!read_text_file(expected_path, expected_text)) { + err = "cannot read expected fixture: " + expected_path.string(); + return false; + } + + SqliteDb holder; + if (sqlite3_open(":memory:", &holder.db) != SQLITE_OK) { + err = "sqlite3_open failed for fixture " + test_id; + return false; + } + + if (!load_extension(holder.db, arg.ext_path, err)) { + err = "load_extension failed for fixture " + test_id + ": " + err; + return false; + } + + const std::string filtered_sql = strip_load_extension_lines(sql_text); + char* exec_err = nullptr; + const int exec_rc = sqlite3_exec(holder.db, filtered_sql.c_str(), query_output_cb, &actual_text, &exec_err); + if (exec_rc != SQLITE_OK) { + err = "fixture " + test_id + " SQL failed: " + (exec_err ? std::string(exec_err) : sqlite_error(holder.db)); + if (exec_err) { + sqlite3_free(exec_err); + } + return false; + } + + const std::string expected_cmp = normalize_for_compare(expected_text); + const std::string actual_cmp = normalize_for_compare(actual_text); + if (expected_cmp != actual_cmp) { + err = "fixture " + test_id + " mismatch"; + return false; + } + + return true; +} + +bool run_fixture_suite(const WorkerArgs& arg, std::string& err) { + for (const auto& test_id : arg.fixture_tests) { + if (!run_fixture_test(arg, test_id, err)) { + return false; + } + } + return true; +} + +void print_report(const std::vector& args, long long total_ms) { + int total_loops = 0; + int min_loops = 0; + int max_loops = 0; + bool initialized = false; + + std::printf("Run summary: total_time_ms=%lld workers=%d\n", total_ms, static_cast(args.size())); + + for (const auto& arg : args) { + const int loops = arg.completed_loops; + long long elapsed = 0; + if (arg.started_ms > 0 && arg.finished_ms >= arg.started_ms) { + elapsed = arg.finished_ms - arg.started_ms; + } + + std::printf( + " worker=%d loops=%d elapsed_ms=%lld mode=%s\n", + arg.worker_id, + loops, + elapsed, + arg.fixture_tests.empty() ? "churn" : "churn+fixture" + ); + + total_loops += loops; + if (!initialized) { + min_loops = loops; + max_loops = loops; + initialized = true; + } else { + min_loops = std::min(min_loops, loops); + max_loops = std::max(max_loops, loops); + } + } + + if (!initialized) { + min_loops = 0; + max_loops = 0; + } + + const double avg = args.empty() ? 0.0 : static_cast(total_loops) / static_cast(args.size()); + std::printf( + " Aggregate: total_loops=%d min_loops=%d max_loops=%d avg_loops=%.2f\n", + total_loops, + min_loops, + max_loops, + avg + ); +} + +void worker_main(WorkerArgs& arg) { + arg.started_ms = monotonic_ms(); + + int i = 1; + while (true) { + if (i > arg.min_loops && std::chrono::steady_clock::now() >= arg.deadline) { + break; + } + + SqliteDb holder; + if (sqlite3_open(":memory:", &holder.db) != SQLITE_OK) { + arg.shared->set_error( + "worker " + std::to_string(arg.worker_id) + + " iter " + std::to_string(i) + + ": open failed: " + sqlite_error(holder.db, "unknown") + ); + break; + } + + std::string load_err; + if (!load_extension(holder.db, arg.ext_path, load_err)) { + arg.shared->set_error( + "worker " + std::to_string(arg.worker_id) + + " iter " + std::to_string(i) + + ": load_extension failed: " + load_err + ); + break; + } + + const int expected = arg.worker_id * 100000 + i; + const std::string scalar_sql = "select L('return " + std::to_string(expected) + "')"; + int got = 0; + if (!scalar_int(holder.db, scalar_sql, got) || got != expected) { + arg.shared->set_error( + "worker " + std::to_string(arg.worker_id) + + " iter " + std::to_string(i) + + ": L() mismatch expected=" + std::to_string(expected) + + " got=" + std::to_string(got) + ); + break; + } + + std::string sql_err; + const std::string setup_sql = + "select L('sqlite.run_sql[[create table if not exists t(v integer);" + "delete from t; insert into t(v) values(" + std::to_string(arg.worker_id) + ");]]')"; + if (!exec_sql(holder.db, setup_sql, sql_err)) { + arg.shared->set_error( + "worker " + std::to_string(arg.worker_id) + + " iter " + std::to_string(i) + + ": setup sql failed: " + sql_err + ); + break; + } + + if (!scalar_int(holder.db, "select sum(v) from t", got) || got != arg.worker_id) { + arg.shared->set_error( + "worker " + std::to_string(arg.worker_id) + + " iter " + std::to_string(i) + + ": sum mismatch expected=" + std::to_string(arg.worker_id) + + " got=" + std::to_string(got) + ); + break; + } + + if (!arg.fixture_tests.empty() && i == 1) { + std::string fixture_err; + if (!run_fixture_suite(arg, fixture_err)) { + arg.shared->set_error( + "worker " + std::to_string(arg.worker_id) + + " iter " + std::to_string(i) + + ": fixture suite failed: " + fixture_err + ); + break; + } + } + + arg.completed_loops = i; + ++i; + } + + arg.finished_ms = monotonic_ms(); +} + +} // namespace + +int main(int argc, char** argv) { + if (argc < 4) { + std::fprintf(stderr, "Usage: %s \n", argv[0]); + return 2; + } + + SharedState shared; + const auto suite_start = monotonic_ms(); + const auto deadline = std::chrono::steady_clock::now() + DURATION; + + std::vector args(WORKER_COUNT); + for (int i = 0; i < WORKER_COUNT; ++i) { + args[i].worker_id = i + 1; + args[i].ext_path = argv[1]; + args[i].sql_root = argv[2]; + args[i].expected_root = argv[3]; + args[i].deadline = deadline; + args[i].min_loops = MIN_LOOP_COUNT; + args[i].shared = &shared; + } + + args[0].fixture_tests = {"035"}; + args[1].fixture_tests = {"037"}; + + std::vector threads; + threads.reserve(args.size()); + for (auto& arg : args) { + threads.emplace_back([&arg]() { worker_main(arg); }); + } + + for (auto& t : threads) { + t.join(); + } + + const auto suite_end = monotonic_ms(); + print_report(args, suite_end - suite_start); + + if (shared.has_error) { + std::printf("FAIL multithread isolated connections: %s\n", shared.message.c_str()); + return 1; + } + + std::printf("PASS multithread isolated connections\n"); + return 0; +} diff --git a/tests/multithread_vtable_queries.cpp b/tests/multithread_vtable_queries.cpp new file mode 100644 index 0000000..0d076c6 --- /dev/null +++ b/tests/multithread_vtable_queries.cpp @@ -0,0 +1,278 @@ +#include + +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) +extern "C" int sqlite3_enable_load_extension(sqlite3* db, int onoff); +extern "C" int sqlite3_load_extension(sqlite3* db, const char* zFile, const char* zProc, char** pzErrMsg); +#endif + +namespace { + +constexpr int WORKER_COUNT = 6; +constexpr int MIN_LOOP_COUNT = 30; +constexpr auto DURATION = std::chrono::milliseconds(1500); + +struct SharedState { + bool has_error = false; + std::string message; + std::mutex lock; + + void set_error(std::string msg) { + std::scoped_lock guard(lock); + if (!has_error) { + has_error = true; + message = std::move(msg); + } + } +}; + +struct WorkerArgs { + int worker_id = 0; + std::string ext_path; + std::chrono::steady_clock::time_point deadline; + int min_loops = MIN_LOOP_COUNT; + SharedState* shared = nullptr; + int completed_loops = 0; + long long started_ms = 0; + long long finished_ms = 0; +}; + +long long monotonic_ms() { + using namespace std::chrono; + return duration_cast(steady_clock::now().time_since_epoch()).count(); +} + +std::string sqlite_error(sqlite3* db, const char* fallback = "sqlite error") { + if (!db) { + return fallback; + } + const char* msg = sqlite3_errmsg(db); + return msg ? msg : fallback; +} + +struct SqliteDb { + sqlite3* db = nullptr; + + ~SqliteDb() { + if (db) { + sqlite3_close(db); + db = nullptr; + } + } + + SqliteDb() = default; + SqliteDb(const SqliteDb&) = delete; + SqliteDb& operator=(const SqliteDb&) = delete; +}; + +bool load_extension(sqlite3* db, const std::string& ext_path, std::string& err) { + if (sqlite3_enable_load_extension(db, 1) != SQLITE_OK) { + err = "enable_load_extension failed: " + sqlite_error(db); + return false; + } + char* load_err = nullptr; + const int rc = sqlite3_load_extension(db, ext_path.c_str(), nullptr, &load_err); + if (rc == SQLITE_OK) { + return true; + } + err = load_err ? std::string(load_err) : sqlite_error(db); + if (load_err) { + sqlite3_free(load_err); + } + return false; +} + +bool scalar_int(sqlite3* db, const std::string& sql, int& out_value, std::string& err) { + sqlite3_stmt* stmt = nullptr; + int rc = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr); + if (rc != SQLITE_OK) { + err = "prepare failed: " + sqlite_error(db); + return false; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + sqlite3_finalize(stmt); + err = "step failed: " + sqlite_error(db); + return false; + } + out_value = sqlite3_column_int(stmt, 0); + rc = sqlite3_finalize(stmt); + if (rc != SQLITE_OK) { + err = "finalize failed: " + sqlite_error(db); + return false; + } + return true; +} + +bool exec_sql(sqlite3* db, const std::string& sql, std::string& err) { + char* exec_err = nullptr; + const int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &exec_err); + if (rc == SQLITE_OK) { + return true; + } + err = exec_err ? std::string(exec_err) : sqlite_error(db); + if (exec_err) { + sqlite3_free(exec_err); + } + return false; +} + +void print_report(const std::vector& args, long long total_ms) { + int total_loops = 0; + int min_loops = 0; + int max_loops = 0; + bool initialized = false; + + std::printf("Run summary: total_time_ms=%lld workers=%d\n", total_ms, static_cast(args.size())); + for (const auto& arg : args) { + int loops = arg.completed_loops; + long long elapsed = 0; + if (arg.started_ms > 0 && arg.finished_ms >= arg.started_ms) { + elapsed = arg.finished_ms - arg.started_ms; + } + std::printf(" worker=%d loops=%d elapsed_ms=%lld\n", arg.worker_id, loops, elapsed); + total_loops += loops; + if (!initialized) { + min_loops = loops; + max_loops = loops; + initialized = true; + } else { + if (loops < min_loops) min_loops = loops; + if (loops > max_loops) max_loops = loops; + } + } + if (!initialized) { + min_loops = 0; + max_loops = 0; + } + const double avg = args.empty() ? 0.0 : static_cast(total_loops) / static_cast(args.size()); + std::printf( + " Aggregate: total_loops=%d min_loops=%d max_loops=%d avg_loops=%.2f\n", + total_loops, + min_loops, + max_loops, + avg + ); +} + +void worker_main(WorkerArgs& arg) { + arg.started_ms = monotonic_ms(); + + SqliteDb holder; + if (sqlite3_open(":memory:", &holder.db) != SQLITE_OK) { + arg.shared->set_error("worker " + std::to_string(arg.worker_id) + ": open failed: " + sqlite_error(holder.db, "unknown")); + arg.finished_ms = monotonic_ms(); + return; + } + + std::string load_err; + if (!load_extension(holder.db, arg.ext_path, load_err)) { + arg.shared->set_error("worker " + std::to_string(arg.worker_id) + ": load_extension failed: " + load_err); + arg.finished_ms = monotonic_ms(); + return; + } + + std::string sql_err; + if (!exec_sql( + holder.db, + "SELECT * FROM L('" + "_G.list_iterator = function(t) " + " local i = 0 " + " local n = #t " + " return function () " + " i = i + 1 " + " if i <= n then return t[i] end " + " end " + "end " + "return function() return nil end" + "')", + sql_err + )) { + arg.shared->set_error("worker " + std::to_string(arg.worker_id) + ": list_iterator init failed: " + sql_err); + arg.finished_ms = monotonic_ms(); + return; + } + + int i = 1; + while (true) { + if (i > arg.min_loops && std::chrono::steady_clock::now() >= arg.deadline) { + break; + } + + int got = 0; + if (!scalar_int(holder.db, "SELECT count(*) FROM L10('return list_iterator({{1,10},{2,20},{3,30}})')", got, sql_err) || got != 3) { + arg.shared->set_error("worker " + std::to_string(arg.worker_id) + " iter " + std::to_string(i) + ": q1 mismatch: " + sql_err); + break; + } + + if (!scalar_int( + holder.db, + "SELECT count(*) FROM (" + "SELECT t.id, v.r1 " + "FROM (SELECT 1 AS id UNION ALL SELECT 2 UNION ALL SELECT 3) t " + "JOIN L10('return list_iterator({{1,100},{2,200},{3,300}})') v ON t.id = v.r0" + ")", + got, + sql_err + ) || got != 3) { + arg.shared->set_error("worker " + std::to_string(arg.worker_id) + " iter " + std::to_string(i) + ": q2 mismatch: " + sql_err); + break; + } + + if (!scalar_int(holder.db, "SELECT count(*) FROM L('return list_iterator({1,2,3,4})') WHERE value >= 2", got, sql_err) || got != 3) { + arg.shared->set_error("worker " + std::to_string(arg.worker_id) + " iter " + std::to_string(i) + ": q3 mismatch: " + sql_err); + break; + } + + arg.completed_loops = i; + ++i; + } + + arg.finished_ms = monotonic_ms(); +} + +} // namespace + +int main(int argc, char** argv) { + if (argc < 2) { + std::fprintf(stderr, "Usage: %s \n", argv[0]); + return 2; + } + + SharedState shared; + const auto suite_start = monotonic_ms(); + const auto deadline = std::chrono::steady_clock::now() + DURATION; + + std::vector args(WORKER_COUNT); + for (int i = 0; i < WORKER_COUNT; ++i) { + args[i].worker_id = i + 1; + args[i].ext_path = argv[1]; + args[i].deadline = deadline; + args[i].shared = &shared; + } + + std::vector threads; + threads.reserve(args.size()); + for (auto& arg : args) { + threads.emplace_back([&arg]() { worker_main(arg); }); + } + for (auto& t : threads) { + t.join(); + } + + const auto suite_end = monotonic_ms(); + print_report(args, suite_end - suite_start); + + if (shared.has_error) { + std::printf("FAIL multithread vtable queries: %s\n", shared.message.c_str()); + return 1; + } + std::printf("PASS multithread vtable queries\n"); + return 0; +} diff --git a/tests/thread_guard_same_connection.cpp b/tests/thread_guard_same_connection.cpp new file mode 100644 index 0000000..a0974be --- /dev/null +++ b/tests/thread_guard_same_connection.cpp @@ -0,0 +1,116 @@ +#include + +#include +#include +#include + +#if defined(__APPLE__) +extern "C" int sqlite3_enable_load_extension(sqlite3* db, int onoff); +extern "C" int sqlite3_load_extension(sqlite3* db, const char* zFile, const char* zProc, char** pzErrMsg); +#endif + +namespace { + +constexpr const char* kExpectedErr = "sqlite_plugin_lj: connection used from multiple threads"; + +std::string sqlite_error(sqlite3* db, const char* fallback = "sqlite error") { + if (!db) { + return fallback; + } + const char* msg = sqlite3_errmsg(db); + return msg ? msg : fallback; +} + +bool load_extension(sqlite3* db, const std::string& ext_path, std::string& err) { + if (sqlite3_enable_load_extension(db, 1) != SQLITE_OK) { + err = "enable_load_extension failed: " + sqlite_error(db); + return false; + } + char* load_err = nullptr; + const int rc = sqlite3_load_extension(db, ext_path.c_str(), nullptr, &load_err); + if (rc == SQLITE_OK) { + return true; + } + err = load_err ? std::string(load_err) : sqlite_error(db); + if (load_err) { + sqlite3_free(load_err); + } + return false; +} + +bool scalar_int(sqlite3* db, const char* sql, int* out, std::string& err) { + sqlite3_stmt* stmt = nullptr; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr); + if (rc != SQLITE_OK) { + err = sqlite_error(db); + return false; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + err = sqlite_error(db); + sqlite3_finalize(stmt); + return false; + } + *out = sqlite3_column_int(stmt, 0); + rc = sqlite3_finalize(stmt); + if (rc != SQLITE_OK) { + err = sqlite_error(db); + return false; + } + return true; +} + +} // namespace + +int main(int argc, char** argv) { + if (argc < 2) { + std::fprintf(stderr, "Usage: %s \n", argv[0]); + return 2; + } + + sqlite3* db = nullptr; + if (sqlite3_open(":memory:", &db) != SQLITE_OK) { + std::printf("FAIL thread guard same connection: open failed\n"); + if (db) sqlite3_close(db); + return 1; + } + + std::string load_err; + if (!load_extension(db, argv[1], load_err)) { + std::printf("FAIL thread guard same connection: load failed: %s\n", load_err.c_str()); + sqlite3_close(db); + return 1; + } + + int v = 0; + std::string err; + if (!scalar_int(db, "select L('return 1')", &v, err) || v != 1) { + std::printf("FAIL thread guard same connection: owner setup failed: %s\n", err.c_str()); + sqlite3_close(db); + return 1; + } + + bool thread_ok = false; + std::string thread_err; + std::thread t([&]() { + int out = 0; + std::string e; + bool ok = scalar_int(db, "select L('return 2')", &out, e); + if (!ok && e.find(kExpectedErr) != std::string::npos) { + thread_ok = true; + return; + } + thread_err = ok ? "unexpected success" : e; + }); + t.join(); + + if (!thread_ok) { + std::printf("FAIL thread guard same connection: expected cross-thread error, got: %s\n", thread_err.c_str()); + sqlite3_close(db); + return 1; + } + + std::printf("PASS thread guard same connection\n"); + sqlite3_close(db); + return 0; +}