diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml new file mode 100644 index 00000000..40b173c5 --- /dev/null +++ b/.github/workflows/check.yml @@ -0,0 +1,169 @@ +name: Check + +on: + pull_request: + paths: + - ".cargo" + - ".github/workflows/check.yml" + - "crates/**" + - "src/**" + - "tests/**" + - "build.rs" + - "Cargo.lock" + - "Cargo.toml" + - "rust-toolchain.toml" + - "rustfmt.toml" + - "taplo.toml" + - "vchord.control" + push: + paths: + - ".cargo" + - ".github/workflows/check.yml" + - "crates/**" + - "src/**" + - "tests/**" + - "build.rs" + - "Cargo.lock" + - "Cargo.toml" + - "rust-toolchain.toml" + - "rustfmt.toml" + - "taplo.toml" + - "vchord.control" + merge_group: + workflow_dispatch: + +concurrency: + group: ${{ github.ref }}-${{ github.workflow }} + cancel-in-progress: true + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + SCCACHE_GHA_ENABLED: true + RUSTC_WRAPPER: sccache + RUSTFLAGS: "-Dwarnings" + +jobs: + style: + runs-on: "ubuntu-24.04" + + steps: + - name: Set up Environment + run: | + curl -fsSL https://github.com/tamasfe/taplo/releases/latest/download/taplo-full-linux-$(uname -m).gz | gzip -d - | install -m 755 /dev/stdin /usr/local/bin/taplo + + - name: Checkout + uses: actions/checkout@v4 + + - name: Typos + uses: crate-ci/typos@master + + - name: Taplo + run: taplo fmt --check + + - name: Ruff + uses: astral-sh/ruff-action@v1 + + - name: Rustfmt + run: cargo fmt --check + + lint: + strategy: + matrix: + runner: ["ubuntu-24.04", "ubicloud-standard-8-arm-ubuntu-2404"] + runs-on: ${{ matrix.runner }} + + steps: + - name: Set up Environment + run: | + sudo apt-get update + + if [ "$(uname -m)" == "x86_64" ]; then + wget https://downloadmirror.intel.com/843185/sde-external-9.48.0-2024-11-25-lin.tar.xz -O /tmp/sde-external.tar.xz + sudo tar -xf /tmp/sde-external.tar.xz -C /opt + sudo mv /opt/sde-external-9.48.0-2024-11-25-lin /opt/sde + fi + + if [ "$(uname -m)" == "aarch64" ]; then + sudo apt-get install -y qemu-user-static + fi + + - name: Set up Sccache + uses: mozilla-actions/sccache-action@v0.0.7 + + - name: Checkout + uses: actions/checkout@v4 + + - name: Clippy + run: cargo clippy --workspace --exclude vchord + + - name: Cargo Test + run: cargo test --workspace --exclude vchord --exclude simd --no-fail-fast + + - name: Cargo Test (simd) + run: | + cargo test -p simd --release -- --nocapture + + if [ "$(uname -m)" == "x86_64" ]; then + cargo \ + --config 'target.'\''cfg(all())'\''.runner = ["/opt/sde/sde64", "-spr", "--"]' \ + test -p simd --release -- --nocapture + fi + if [ "$(uname -m)" == "aarch64" ]; then + cargo \ + --config 'target.'\''cfg(all())'\''.runner = ["qemu-aarch64-static", "-cpu", "max"]' \ + test -p simd --release -- --nocapture + fi + + psql: + strategy: + matrix: + version: ["13", "14", "15", "16", "17"] + runner: ["ubuntu-24.04", "ubicloud-standard-8-arm-ubuntu-2404"] + runs-on: ${{ matrix.runner }} + + steps: + - name: Set up Environment + run: | + sudo apt-get update + + sudo apt-get remove -y '^postgres.*' '^libpq.*' + sudo apt-get purge -y '^postgres.*' '^libpq.*' + + sudo update-alternatives --install /usr/bin/clang clang $(which clang-18) 255 + + sudo apt-get install -y postgresql-common + sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y + sudo apt-get install -y postgresql-server-dev-${{ matrix.version }} + + sudo apt-get install -y postgresql-${{ matrix.version }} postgresql-${{ matrix.version }}-pgvector + echo "local all all trust" | sudo tee /etc/postgresql/${{ matrix.version }}/main/pg_hba.conf + echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/${{ matrix.version }}/main/pg_hba.conf + echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/${{ matrix.version }}/main/pg_hba.conf + sudo -iu postgres createuser -s -r $USER + sudo -iu postgres createdb -O $USER $USER + sudo -iu postgres psql -c 'ALTER SYSTEM SET shared_preload_libraries = "vchord.so"' + sudo systemctl stop postgresql + + curl -fsSL https://github.com/tensorchord/pgrx/releases/download/v0.12.9/cargo-pgrx-v0.12.9-$(uname -m)-unknown-linux-musl.tar.gz | tar -xOzf - ./cargo-pgrx | install -m 755 /dev/stdin /usr/local/bin/cargo-pgrx + cargo pgrx init --pg${{ matrix.version }}=$(which pg_config) + + curl -fsSL https://github.com/risinglightdb/sqllogictest-rs/releases/download/v0.26.4/sqllogictest-bin-v0.26.4-$(uname -m)-unknown-linux-musl.tar.gz | tar -xOzf - ./sqllogictest | install -m 755 /dev/stdin /usr/local/bin/sqllogictest + + - name: Set up Sccache + uses: mozilla-actions/sccache-action@v0.0.7 + + - name: Checkout + uses: actions/checkout@v4 + + - name: Clippy + run: cargo clippy -p vchord --features pg${{ matrix.version }} -- --no-deps + + - name: Install + run: cargo pgrx install -p vchord --features pg${{ matrix.version }} --release --sudo + + - name: Sqllogictest + run: | + sudo systemctl start postgresql + psql -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' + sqllogictest --db $USER --user $USER './tests/**/*.slt' diff --git a/.github/workflows/pgrx.yml b/.github/workflows/pgrx.yml deleted file mode 100644 index 9bd8489b..00000000 --- a/.github/workflows/pgrx.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: Build pgrx Image - -on: - workflow_dispatch: - inputs: - version: - description: 'pgrx version' - required: true - type: string - toolchain: - description: 'additional rust toolchain' - required: true - type: string - -concurrency: - group: ${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -env: - IMAGE_NAME: "ghcr.io/tensorchord/vectorchord-pgrx" - -permissions: - contents: write - packages: write - -jobs: - build: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v4 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Login to ghcr.io - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push - uses: docker/build-push-action@v6 - with: - context: . - file: ./docker/pgrx.Dockerfile - push: true - cache-from: type=gha - cache-to: type=gha,mode=max - platforms: "linux/amd64,linux/arm64" - build-args: | - PGRX_VERSION=${{ github.event.inputs.version }} - RUST_TOOLCHAIN=${{ github.event.inputs.toolchain }} - tags: ${{ env.IMAGE_NAME }}:${{ github.event.inputs.version }}-${{ github.event.inputs.toolchain }} diff --git a/.github/workflows/psql.yml b/.github/workflows/psql.yml deleted file mode 100644 index 42c108c0..00000000 --- a/.github/workflows/psql.yml +++ /dev/null @@ -1,101 +0,0 @@ -name: PostgresSQL - -on: - pull_request: - paths: - - '.github/workflows/psql.yml' - - 'src/**' - - 'Cargo.lock' - - 'Cargo.toml' - - '*.control' - - 'rust-toolchain.toml' - - 'tests/**' - - 'tools/**' - push: - branches: - - main - paths: - - '.github/workflows/psql.yml' - - 'src/**' - - 'Cargo.lock' - - 'Cargo.toml' - - '*.control' - - 'rust-toolchain.toml' - - 'tests/**' - - 'tools/**' - merge_group: - workflow_dispatch: - -concurrency: - group: ${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -env: - CARGO_TERM_COLOR: always - RUST_BACKTRACE: 1 - RUSTFLAGS: "-Dwarnings" - -jobs: - test: - runs-on: ${{ matrix.runner }} - strategy: - matrix: - version: ["14", "15", "16", "17"] - runner: ["ubuntu-22.04", "ubuntu-22.04-arm"] - env: - PGRX_IMAGE: "ghcr.io/tensorchord/vectorchord-pgrx:0.12.9-nightly-2024-12-25" - SQLLOGICTEST: "0.25.0" - ARCH: ${{ matrix.runner == 'ubuntu-22.04' && 'x86_64' || 'aarch64' }} - PLATFORM: ${{ matrix.runner == 'ubuntu-22.04' && 'amd64' || 'arm64' }} - - steps: - - name: Set up Environment - run: | - sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - sudo apt-get purge -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - - sudo apt-get install -y postgresql-common - sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y - sudo apt-get install -y postgresql-client-17 - - uses: actions/checkout@v4 - - name: Configure sccache - uses: actions/github-script@v7 - with: - script: | - const url = process.env.ACTIONS_CACHE_URL || ''; - const token = process.env.ACTIONS_RUNTIME_TOKEN || ''; - core.exportVariable( - 'CACHE_ENVS', - `-e CARGO_INCREMENTAL=0 -e SCCACHE_GHA_ENABLED=true -e RUSTC_WRAPPER=sccache -e ACTIONS_CACHE_URL=${url} -e ACTIONS_RUNTIME_TOKEN=${token}`, - ); - - name: Set up pgrx docker images and permissions - run: | - docker pull $PGRX_IMAGE - echo "Default user: $(id -u):$(id -g)" - sudo chmod -R 777 . - - - name: Build - run: | - docker run --rm -v .:/workspace $CACHE_ENVS \ - -e SEMVER=0.0.0 \ - -e VERSION=${{ matrix.version }} \ - -e ARCH=$ARCH \ - -e PLATFORM=$PLATFORM \ - $PGRX_IMAGE ./tools/package.sh - docker build -t vchord:pg${{ matrix.version }} --build-arg PG_VERSION=${{ matrix.version }} --build-arg SEMVER=0.0.0 -f ./docker/Dockerfile . - - - name: Setup SQL Logic Test - run: | - curl -fsSL -o sqllogictest.tar.gz https://github.com/risinglightdb/sqllogictest-rs/releases/download/v${SQLLOGICTEST}/sqllogictest-bin-v${SQLLOGICTEST}-$ARCH-unknown-linux-musl.tar.gz - tar -xzf sqllogictest.tar.gz - mv sqllogictest /usr/local/bin/ - - - name: SQL Test - env: - PGPASSWORD: postgres - run: | - docker run --rm --name test -d -e POSTGRES_PASSWORD=${PGPASSWORD} -p 5432:5432 vchord:pg${{ matrix.version }} - sleep 5 - psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' - sqllogictest './tests/**/*.slt' - docker stop test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1662bbd0..c057afad 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,7 @@ on: workflow_dispatch: inputs: tag: - description: 'tag name (semver without v-prefix)' + description: "tag name (semver without v-prefix)" required: true type: string @@ -16,12 +16,12 @@ concurrency: jobs: semver: - runs-on: ubuntu-latest - outputs: - SEMVER: ${{ steps.semver.outputs.SEMVER }} + runs-on: "ubuntu-latest" + steps: - - uses: actions/github-script@v7 + - name: Semver id: semver + uses: actions/github-script@v7 with: script: | const tag = "${{ github.event.inputs.tag }}" || "${{ github.event.release.tag_name }}"; @@ -32,58 +32,92 @@ jobs: } core.setOutput('SEMVER', tag); + outputs: + SEMVER: ${{ steps.semver.outputs.SEMVER }} + build: - runs-on: ${{ matrix.runner }} needs: ["semver"] strategy: matrix: - version: ["14", "15", "16", "17"] + version: ["13", "14", "15", "16", "17"] runner: ["ubuntu-22.04", "ubuntu-22.04-arm"] + runs-on: ${{ matrix.runner }} + env: - PGRX_IMAGE: "ghcr.io/tensorchord/vectorchord-pgrx:0.12.9-nightly-2024-12-25" - SEMVER: ${{ needs.semver.outputs.SEMVER }} - ARCH: ${{ matrix.runner == 'ubuntu-22.04' && 'x86_64' || 'aarch64' }} - PLATFORM: ${{ matrix.runner == 'ubuntu-22.04' && 'amd64' || 'arm64' }} + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + RUSTFLAGS: "-Dwarnings" steps: - name: Set up Environment run: | - sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - sudo apt-get purge -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' + sudo apt-get update + + sudo apt-get remove -y '^postgres.*' '^libpq.*' + sudo apt-get purge -y '^postgres.*' '^libpq.*' + + curl --proto '=https' --tlsv1.2 -sSf https://apt.llvm.org/llvm.sh | sudo bash -s -- 18 + sudo update-alternatives --install /usr/bin/clang clang $(which clang-18) 255 sudo apt-get install -y postgresql-common sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y - sudo apt-get install -y postgresql-client-17 - - uses: actions/checkout@v4 - - name: Configure sccache - uses: actions/github-script@v7 - with: - script: | - const url = process.env.ACTIONS_CACHE_URL || ''; - const token = process.env.ACTIONS_RUNTIME_TOKEN || ''; - core.exportVariable( - 'CACHE_ENVS', - `-e CARGO_INCREMENTAL=0 -e SCCACHE_GHA_ENABLED=true -e RUSTC_WRAPPER=sccache -e ACTIONS_CACHE_URL=${url} -e ACTIONS_RUNTIME_TOKEN=${token}`, - ); - - name: Set up pgrx docker images and permissions - run: | - docker pull $PGRX_IMAGE - echo "Default user: $(id -u):$(id -g)" - sudo chmod -R 777 . + sudo apt-get install -y postgresql-server-dev-${{ matrix.version }} + + sudo apt-get install -y postgresql-${{ matrix.version }} postgresql-${{ matrix.version }}-pgvector + + curl -fsSL https://github.com/tensorchord/pgrx/releases/download/v0.12.9/cargo-pgrx-v0.12.9-$(uname -m)-unknown-linux-musl.tar.gz | tar -xOzf - ./cargo-pgrx | install -m 755 /dev/stdin /usr/local/bin/cargo-pgrx + cargo pgrx init --pg${{ matrix.version }}=$(which pg_config) + + - name: Checkout + uses: actions/checkout@v4 - name: Build env: + SEMVER: ${{ needs.semver.outputs.SEMVER }} + VERSION: ${{ matrix.version }} + ARCH: ${{ matrix.runner == 'ubuntu-22.04' && 'x86_64' || 'aarch64' }} + PLATFORM: ${{ matrix.runner == 'ubuntu-22.04' && 'amd64' || 'arm64' }} GH_TOKEN: ${{ github.token }} run: | - docker run --rm -v .:/workspace $CACHE_ENVS \ - -e SEMVER=$SEMVER \ - -e VERSION=${{ matrix.version }} \ - -e ARCH=$ARCH \ - -e PLATFORM=$PLATFORM \ - $PGRX_IMAGE ./tools/package.sh + cargo build --lib --features pg$VERSION --release + + mkdir -p ./build/zip + cp -a ./sql/upgrade/. ./build/zip/ + cp ./sql/install/vchord--$SEMVER.sql ./build/zip/vchord--$SEMVER.sql + sed -e "s/@CARGO_VERSION@/$SEMVER/g" < ./vchord.control > ./build/zip/vchord.control + cp ./target/release/libvchord.so ./build/zip/vchord.so + zip ./build/postgresql-${VERSION}-vchord_${SEMVER}_${ARCH}-linux-gnu.zip -j ./build/zip/* + + mkdir -p ./build/deb + mkdir -p ./build/deb/DEBIAN + mkdir -p ./build/deb/usr/share/postgresql/$VERSION/extension/ + mkdir -p ./build/deb/usr/lib/postgresql/$VERSION/lib/ + for file in $(ls ./build/zip/*.sql | xargs -n 1 basename); do + cp ./build/zip/$file ./build/deb/usr/share/postgresql/$VERSION/extension/$file + done + for file in $(ls ./build/zip/*.control | xargs -n 1 basename); do + cp ./build/zip/$file ./build/deb/usr/share/postgresql/$VERSION/extension/$file + done + for file in $(ls ./build/zip/*.so | xargs -n 1 basename); do + cp ./build/zip/$file ./build/deb/usr/lib/postgresql/$VERSION/lib/$file + done + echo "Package: postgresql-${VERSION}-vchord + Version: ${SEMVER}-1 + Section: database + Priority: optional + Architecture: ${PLATFORM} + Maintainer: Tensorchord + Description: Vector database plugin for Postgres, written in Rust, specifically designed for LLM + Homepage: https://vectorchord.ai/ + License: AGPL-3 or Elastic-2" \ + > ./build/deb/DEBIAN/control + (cd ./build/deb && md5sum usr/share/postgresql/$VERSION/extension/* usr/lib/postgresql/$VERSION/lib/*) > ./build/deb/DEBIAN/md5sums + dpkg-deb --root-owner-group -Zxz --build ./build/deb/ ./build/postgresql-${VERSION}-vchord_${SEMVER}-1_${PLATFORM}.deb + ls ./build - gh release upload --clobber $SEMVER ./build/postgresql-${{ matrix.version }}-vchord_${SEMVER}-1_${PLATFORM}.deb - gh release upload --clobber $SEMVER ./build/postgresql-${{ matrix.version }}-vchord_${SEMVER}_${ARCH}-linux-gnu.zip + + gh release upload --clobber $SEMVER ./build/postgresql-${VERSION}-vchord_${SEMVER}-1_${PLATFORM}.deb + gh release upload --clobber $SEMVER ./build/postgresql-${VERSION}-vchord_${SEMVER}_${ARCH}-linux-gnu.zip docker: runs-on: ubuntu-latest @@ -156,7 +190,7 @@ jobs: TARGETARCH=amd64 PGVECTOR=0.8.0 tags: modelzai/vchord-cnpg:${{ matrix.version }}-v${{ env.SEMVER }} - + test: name: Run tests runs-on: @@ -190,5 +224,4 @@ jobs: sleep 5 curl https://registry.pgtrunk.io/extensions/all | jq -r ".[] | .name" > /tmp/extensions.txt trunk-install.sh | tee /tmp/output.txt - cat /tmp/output.txt - + cat /tmp/output.txt diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index 8b1b4e2d..00000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Rust - -on: - pull_request: - paths: - - '.github/workflows/rust.yml' - - 'src/**' - - 'Cargo.lock' - - 'Cargo.toml' - - '*.control' - - 'rust-toolchain.toml' - push: - branches: - - main - paths: - - '.github/workflows/rust.yml' - - 'src/**' - - 'Cargo.lock' - - 'Cargo.toml' - - '*.control' - - 'rust-toolchain.toml' - merge_group: - workflow_dispatch: - -concurrency: - group: ${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -jobs: - test: - strategy: - matrix: - include: - - runner: "ubuntu-22.04" - arch: "x86_64" - - runner: "ubuntu-22.04-arm" - arch: "aarch64" - runs-on: ${{ matrix.runner }} - env: - PGRX_IMAGE: "ghcr.io/tensorchord/vectorchord-pgrx:0.12.9-nightly-2024-12-25" - - steps: - - name: Set up Environment - run: | - sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - sudo apt-get purge -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - - sudo apt-get install -y postgresql-common - sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y - sudo apt-get install -y postgresql-client-17 - - uses: actions/checkout@v4 - - name: Configure sccache - uses: actions/github-script@v7 - with: - script: | - const url = process.env.ACTIONS_CACHE_URL || ''; - const token = process.env.ACTIONS_RUNTIME_TOKEN || ''; - core.exportVariable( - 'CACHE_ENVS', - `-e CARGO_INCREMENTAL=0 -e SCCACHE_GHA_ENABLED=true -e RUSTC_WRAPPER=sccache -e ACTIONS_CACHE_URL=${url} -e ACTIONS_RUNTIME_TOKEN=${token}`, - ); - - name: Set up docker images and permissions - run: | - docker pull $PGRX_IMAGE - echo "Default user: $(id -u):$(id -g)" - sudo chown -R 1000:1000 . - - - name: Clippy - run: | - for v in {14..17}; do - docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo clippy --target ${{ matrix.arch }}-unknown-linux-gnu --features "pg$v" -- -D warnings - done - - name: Build - run: | - for v in {14..17}; do - docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo build --lib --target ${{ matrix.arch }}-unknown-linux-gnu --features "pg$v" - done - - name: Test - run: | - # pg agnostic tests - docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo test --no-fail-fast --target ${{ matrix.arch }}-unknown-linux-gnu --features pg17 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml deleted file mode 100644 index ac5f385e..00000000 --- a/.github/workflows/style.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Style - -on: - pull_request: - branches: - - main - push: - branches: - - main - merge_group: - workflow_dispatch: - -concurrency: - group: ${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -env: - CARGO_TERM_COLOR: always - -jobs: - lint: - runs-on: ubuntu-latest - timeout-minutes: 5 - steps: - - uses: actions/checkout@v4 - - - name: Typos - uses: crate-ci/typos@master - - - name: TOML lint - run: | - curl -fsSL https://github.com/tamasfe/taplo/releases/latest/download/taplo-full-linux-x86_64.gz | gzip -d - | install -m 755 /dev/stdin /usr/local/bin/taplo - taplo fmt --check - - - name: Cargo Lint - run: | - cargo fmt --check - - - name: Python lint - uses: astral-sh/ruff-action@v1 diff --git a/.taplo.toml b/.taplo.toml deleted file mode 100644 index d9a9fdac..00000000 --- a/.taplo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[formatting] -indent_string = " " - -[[rule]] -keys = ["dependencies", "*-denpendencies", "lints", "patch.*", "profile.*"] - -[rule.formatting] -reorder_keys = true -reorder_arrays = true -align_comments = true diff --git a/Cargo.lock b/Cargo.lock index f5214d8e..a5a23736 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "algorithm" +version = "0.0.0" +dependencies = [ + "always_equal", + "distance", + "half 2.4.1", + "heapify", + "k_means", + "paste", + "rabitq", + "rand", + "random_orthogonal_matrix", + "serde", + "simd", + "turboselect", + "validator", + "vector", + "zerocopy 0.8.17", + "zerocopy-derive 0.8.17", +] + [[package]] name = "always_equal" version = "0.0.0" @@ -77,9 +99,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "bitvec" @@ -117,9 +139,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.9" +version = "1.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" dependencies = [ "shlex", ] @@ -196,9 +218,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "darling" @@ -321,13 +343,14 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "getrandom" -version = "0.2.15" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" dependencies = [ "cfg-if", "libc", "wasi", + "windows-targets", ] [[package]] @@ -350,8 +373,8 @@ dependencies = [ "cfg-if", "crunchy", "serde", - "zerocopy 0.8.14", - "zerocopy-derive 0.8.14", + "zerocopy 0.8.17", + "zerocopy-derive 0.8.17", ] [[package]] @@ -369,6 +392,12 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "heapify" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0049b265b7f201ca9ab25475b22b47fe444060126a51abe00f77d986fc5cc52e" + [[package]] name = "heapless" version = "0.8.0" @@ -391,7 +420,7 @@ version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.59.0", + "windows-sys", ] [[package]] @@ -547,9 +576,9 @@ checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown", @@ -557,13 +586,13 @@ dependencies = [ [[package]] name = "is-terminal" -version = "0.4.13" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys", ] [[package]] @@ -587,6 +616,17 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +[[package]] +name = "k_means" +version = "0.0.0" +dependencies = [ + "half 2.4.1", + "rabitq", + "rand", + "rayon", + "simd", +] + [[package]] name = "libc" version = "0.2.169" @@ -615,12 +655,6 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" -[[package]] -name = "log" -version = "0.4.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" - [[package]] name = "matrixmultiply" version = "0.3.9" @@ -731,9 +765,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.2" +version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" [[package]] name = "owo-colors" @@ -966,20 +1000,20 @@ checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" [[package]] name = "rand" -version = "0.8.5" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ - "libc", "rand_chacha", "rand_core", + "zerocopy 0.8.17", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", "rand_core", @@ -987,18 +1021,19 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" dependencies = [ "getrandom", + "zerocopy 0.8.17", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "ddc3b5afe4c995c44540865b8ca5c52e6a59fa362da96c5d30886930ddc8da1c" dependencies = [ "num-traits", "rand", @@ -1086,9 +1121,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "safe_arch" @@ -1164,9 +1199,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -1209,8 +1244,8 @@ dependencies = [ "cc", "half 2.4.1", "rand", - "serde", "simd_macros", + "zerocopy 0.8.17", ] [[package]] @@ -1267,9 +1302,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.96" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", @@ -1345,9 +1380,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" dependencies = [ "serde", "serde_spanned", @@ -1366,9 +1401,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ "indexmap", "serde", @@ -1377,6 +1412,11 @@ dependencies = [ "winnow", ] +[[package]] +name = "turboselect" +version = "0.1.0" +source = "git+https://github.com/tensorchord/turboselect.git?rev=d8753c4ffe5b47f28670fea21d56cf3658d51b9b#d8753c4ffe5b47f28670fea21d56cf3658d51b9b" + [[package]] name = "typenum" version = "1.17.0" @@ -1397,9 +1437,9 @@ checksum = "ccb97dac3243214f8d8507998906ca3e2e0b900bf9bf4870477f125b82e68f6e" [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-segmentation" @@ -1438,18 +1478,18 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.11.1" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" +checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0" dependencies = [ "getrandom", ] [[package]] name = "validator" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0b4a29d8709210980a09379f27ee31549b73292c87ab9899beee1c0d3be6303" +checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa" dependencies = [ "idna", "once_cell", @@ -1463,9 +1503,9 @@ dependencies = [ [[package]] name = "validator_derive" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bac855a2ce6f843beb229757e6e570a42e837bcb15e5f449dd48d5747d41bf77" +checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca" dependencies = [ "darling", "once_cell", @@ -1479,24 +1519,22 @@ dependencies = [ name = "vchord" version = "0.0.0" dependencies = [ - "always_equal", + "algorithm", "distance", "half 2.4.1", - "log", + "k_means", "paste", "pgrx", "pgrx-catalog", - "rabitq", "rand", "random_orthogonal_matrix", - "rayon", "serde", "simd", "toml", "validator", "vector", - "zerocopy 0.8.14", - "zerocopy-derive 0.8.14", + "zerocopy 0.8.17", + "zerocopy-derive 0.8.17", ] [[package]] @@ -1505,7 +1543,6 @@ version = "0.0.0" dependencies = [ "distance", "half 2.4.1", - "serde", "simd", ] @@ -1521,9 +1558,12 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.13.3+wasi-0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] [[package]] name = "wide" @@ -1557,7 +1597,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys", ] [[package]] @@ -1566,15 +1606,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets", -] - [[package]] name = "windows-sys" version = "0.59.0" @@ -1650,13 +1681,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.24" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "write16" version = "1.0.0" @@ -1723,11 +1763,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.14" +version = "0.8.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a367f292d93d4eab890745e75a778da40909cab4d6ff8173693812f79c4a2468" +checksum = "aa91407dacce3a68c56de03abe2760159582b846c6a4acd2f456618087f12713" dependencies = [ - "zerocopy-derive 0.8.14", + "zerocopy-derive 0.8.17", ] [[package]] @@ -1743,9 +1783,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.14" +version = "0.8.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3931cb58c62c13adec22e38686b559c86a30565e16ad6e8510a337cedc611e1" +checksum = "06718a168365cad3d5ff0bb133aad346959a2074bd4a85c121255a11304a8626" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 3583ae26..395f98a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vchord" version.workspace = true -edition.workspace = true +edition = "2021" [lib] name = "vchord" @@ -20,25 +20,23 @@ pg16 = ["pgrx/pg16", "pgrx-catalog/pg16"] pg17 = ["pgrx/pg17", "pgrx-catalog/pg17"] [dependencies] -always_equal = { path = "./crates/always_equal" } +algorithm = { path = "./crates/algorithm" } distance = { path = "./crates/distance" } -rabitq = { path = "./crates/rabitq" } +k_means = { path = "./crates/k_means" } random_orthogonal_matrix = { path = "./crates/random_orthogonal_matrix" } simd = { path = "./crates/simd" } vector = { path = "./crates/vector" } half.workspace = true -log = "0.4.25" -paste = "1" +paste.workspace = true pgrx = { version = "=0.12.9", default-features = false, features = ["cshim"] } pgrx-catalog = "0.1.0" rand.workspace = true -rayon = "1.10.0" serde.workspace = true -toml = "0.8.19" -validator = { version = "0.19.0", features = ["derive"] } -zerocopy = "0.8.14" -zerocopy-derive = "0.8.14" +toml = "0.8.20" +validator.workspace = true +zerocopy.workspace = true +zerocopy-derive.workspace = true [patch.crates-io] half = { git = "https://github.com/tensorchord/half-rs.git", rev = "3f9a8843d6722bd1833de2289347640ad8770146" } @@ -56,14 +54,19 @@ edition = "2021" [workspace.dependencies] half = { version = "2.4.1", features = ["serde", "zerocopy"] } -rand = "0.8.5" +paste = "1" +rand = "0.9.0" serde = "1" +validator = { version = "0.20.0", features = ["derive"] } +zerocopy = "0.8.17" +zerocopy-derive = "0.8.17" [workspace.lints] clippy.identity_op = "allow" clippy.int_plus_one = "allow" clippy.needless_range_loop = "allow" clippy.nonminimal_bool = "allow" +rust.unsafe_code = "deny" rust.unsafe_op_in_unsafe_fn = "deny" rust.unused_lifetimes = "warn" rust.unused_qualifications = "warn" diff --git a/crates/algorithm/Cargo.toml b/crates/algorithm/Cargo.toml new file mode 100644 index 00000000..83adc488 --- /dev/null +++ b/crates/algorithm/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "algorithm" +version.workspace = true +edition.workspace = true + +[dependencies] +always_equal = { path = "../always_equal" } +distance = { path = "../distance" } +k_means = { path = "../k_means" } +rabitq = { path = "../rabitq" } +random_orthogonal_matrix = { path = "../random_orthogonal_matrix" } +simd = { path = "../simd" } +vector = { path = "../vector" } + +half.workspace = true +heapify = "0.2.0" +paste.workspace = true +rand.workspace = true +serde.workspace = true +turboselect = { git = "https://github.com/tensorchord/turboselect.git", rev = "d8753c4ffe5b47f28670fea21d56cf3658d51b9b" } +validator.workspace = true +zerocopy.workspace = true +zerocopy-derive.workspace = true + +[lints] +workspace = true diff --git a/crates/algorithm/src/build.rs b/crates/algorithm/src/build.rs new file mode 100644 index 00000000..d8c9b62d --- /dev/null +++ b/crates/algorithm/src/build.rs @@ -0,0 +1,102 @@ +use crate::RelationWrite; +use crate::operator::{Accessor2, Operator, Vector}; +use crate::tape::*; +use crate::tuples::*; +use crate::types::*; +use vector::VectorOwned; + +pub fn build( + vector_options: VectorOptions, + vchordrq_options: VchordrqIndexOptions, + index: impl RelationWrite, + structures: Vec>, +) { + let dims = vector_options.dims; + let is_residual = vchordrq_options.residual_quantization && O::SUPPORTS_RESIDUAL; + let mut meta = TapeWriter::<_, _, MetaTuple>::create(|| index.extend(false)); + assert_eq!(meta.first(), 0); + let freepage = TapeWriter::<_, _, FreepageTuple>::create(|| index.extend(false)); + let mut vectors = TapeWriter::<_, _, VectorTuple>::create(|| index.extend(true)); + let mut pointer_of_means = Vec::>::new(); + for i in 0..structures.len() { + let mut level = Vec::new(); + for j in 0..structures[i].len() { + let vector = structures[i].means[j].as_borrowed(); + let (metadata, slices) = O::Vector::vector_split(vector); + let mut chain = Ok(metadata); + for i in (0..slices.len()).rev() { + chain = Err(vectors.push(match chain { + Ok(metadata) => VectorTuple::_0 { + payload: None, + elements: slices[i].to_vec(), + metadata, + }, + Err(pointer) => VectorTuple::_1 { + payload: None, + elements: slices[i].to_vec(), + pointer, + }, + })); + } + level.push(chain.err().unwrap()); + } + pointer_of_means.push(level); + } + let mut pointer_of_firsts = Vec::>::new(); + for i in 0..structures.len() { + let mut level = Vec::new(); + for j in 0..structures[i].len() { + if i == 0 { + let tape = TapeWriter::<_, _, H0Tuple>::create(|| index.extend(false)); + let mut jump = TapeWriter::<_, _, JumpTuple>::create(|| index.extend(false)); + jump.push(JumpTuple { + first: tape.first(), + }); + level.push(jump.first()); + } else { + let mut tape = H1TapeWriter::<_, _>::create(|| index.extend(false)); + let h2_mean = structures[i].means[j].as_borrowed(); + let h2_children = structures[i].children[j].as_slice(); + for child in h2_children.iter().copied() { + let h1_mean = structures[i - 1].means[child as usize].as_borrowed(); + let code = if is_residual { + let mut residual_accessor = O::ResidualAccessor::default(); + residual_accessor.push( + O::Vector::elements_and_metadata(h1_mean).0, + O::Vector::elements_and_metadata(h2_mean).0, + ); + let residual = residual_accessor.finish( + O::Vector::elements_and_metadata(h1_mean).1, + O::Vector::elements_and_metadata(h2_mean).1, + ); + O::Vector::code(residual.as_borrowed()) + } else { + O::Vector::code(h1_mean) + }; + tape.push(H1Branch { + mean: pointer_of_means[i - 1][child as usize], + dis_u_2: code.dis_u_2, + factor_ppc: code.factor_ppc, + factor_ip: code.factor_ip, + factor_err: code.factor_err, + signs: code.signs, + first: pointer_of_firsts[i - 1][child as usize], + }); + } + let tape = tape.into_inner(); + level.push(tape.first()); + } + } + pointer_of_firsts.push(level); + } + meta.push(MetaTuple { + dims, + height_of_root: structures.len() as u32, + is_residual, + rerank_in_heap: vchordrq_options.rerank_in_table, + vectors_first: vectors.first(), + root_mean: pointer_of_means.last().unwrap()[0], + root_first: pointer_of_firsts.last().unwrap()[0], + freepage_first: freepage.first(), + }); +} diff --git a/crates/algorithm/src/bulkdelete.rs b/crates/algorithm/src/bulkdelete.rs new file mode 100644 index 00000000..524ec908 --- /dev/null +++ b/crates/algorithm/src/bulkdelete.rs @@ -0,0 +1,167 @@ +use crate::operator::Operator; +use crate::pipe::Pipe; +use crate::tuples::*; +use crate::{Page, RelationWrite}; +use std::num::NonZeroU64; + +pub fn bulkdelete( + index: impl RelationWrite, + check: impl Fn(), + callback: impl Fn(NonZeroU64) -> bool, +) { + let meta_guard = index.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let height_of_root = meta_tuple.height_of_root(); + let root_first = meta_tuple.root_first(); + let vectors_first = meta_tuple.vectors_first(); + drop(meta_guard); + { + type State = Vec; + let mut state: State = vec![root_first]; + let step = |state: State| { + let mut results = Vec::new(); + for first in state { + let mut current = first; + while current != u32::MAX { + let h1_guard = index.read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + for first in h1_tuple.first().iter().copied() { + results.push(first); + } + } + H1TupleReader::_1(_) => (), + } + } + current = h1_guard.get_opaque().next; + } + } + results + }; + for _ in (1..height_of_root).rev() { + state = step(state); + } + for first in state { + let jump_guard = index.read(first); + let jump_tuple = jump_guard + .get(1) + .expect("data corruption") + .pipe(read_tuple::); + let first = jump_tuple.first(); + let mut current = first; + while current != u32::MAX { + check(); + let read = index.read(current); + let flag = 'flag: { + for i in 1..=read.len() { + let h0_tuple = read + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h0_tuple { + H0TupleReader::_0(h0_tuple) => { + let p = h0_tuple.payload(); + if let Some(payload) = p { + if callback(payload) { + break 'flag true; + } + } + } + H0TupleReader::_1(h0_tuple) => { + let p = h0_tuple.payload(); + for j in 0..32 { + if let Some(payload) = p[j] { + if callback(payload) { + break 'flag true; + } + } + } + } + H0TupleReader::_2(_) => (), + } + } + false + }; + if flag { + drop(read); + let mut write = index.write(current, false); + for i in 1..=write.len() { + let h0_tuple = write + .get_mut(i) + .expect("data corruption") + .pipe(write_tuple::); + match h0_tuple { + H0TupleWriter::_0(mut h0_tuple) => { + let p = h0_tuple.payload(); + if let Some(payload) = *p { + if callback(payload) { + *p = None; + } + } + } + H0TupleWriter::_1(mut h0_tuple) => { + let p = h0_tuple.payload(); + for j in 0..32 { + if let Some(payload) = p[j] { + if callback(payload) { + p[j] = None; + } + } + } + } + H0TupleWriter::_2(_) => (), + } + } + current = write.get_opaque().next; + } else { + current = read.get_opaque().next; + } + } + } + } + { + let first = vectors_first; + let mut current = first; + while current != u32::MAX { + check(); + let read = index.read(current); + let flag = 'flag: { + for i in 1..=read.len() { + if let Some(vector_bytes) = read.get(i) { + let vector_tuple = vector_bytes.pipe(read_tuple::>); + let p = vector_tuple.payload(); + if let Some(payload) = p { + if callback(payload) { + break 'flag true; + } + } + } + } + false + }; + if flag { + drop(read); + let mut write = index.write(current, true); + for i in 1..=write.len() { + if let Some(vector_bytes) = write.get(i) { + let vector_tuple = vector_bytes.pipe(read_tuple::>); + let p = vector_tuple.payload(); + if let Some(payload) = p { + if callback(payload) { + write.free(i); + } + } + }; + } + current = write.get_opaque().next; + } else { + current = read.get_opaque().next; + } + } + } +} diff --git a/crates/algorithm/src/cache.rs b/crates/algorithm/src/cache.rs new file mode 100644 index 00000000..0405708e --- /dev/null +++ b/crates/algorithm/src/cache.rs @@ -0,0 +1,51 @@ +use crate::pipe::Pipe; +use crate::tuples::{H1Tuple, H1TupleReader, MetaTuple, read_tuple}; +use crate::{Page, RelationRead}; + +pub fn cache(index: impl RelationRead) -> Vec { + let mut trace = Vec::::new(); + let mut read = |id| { + let result = index.read(id); + trace.push(id); + result + }; + let meta_guard = read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let height_of_root = meta_tuple.height_of_root(); + let root_first = meta_tuple.root_first(); + drop(meta_guard); + type State = Vec; + let mut state: State = vec![root_first]; + let mut step = |state: State| { + let mut results = Vec::new(); + for first in state { + let mut current = first; + while current != u32::MAX { + let h1_guard = read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + for first in h1_tuple.first().iter().copied() { + results.push(first); + } + } + H1TupleReader::_1(_) => (), + } + } + current = h1_guard.get_opaque().next; + } + } + results + }; + for _ in (1..height_of_root).rev() { + state = step(state); + } + for first in state { + let _ = read(first); + } + trace +} diff --git a/src/algorithm/freepages.rs b/crates/algorithm/src/freepages.rs similarity index 80% rename from src/algorithm/freepages.rs rename to crates/algorithm/src/freepages.rs index 8984d3ce..6dc505c2 100644 --- a/src/algorithm/freepages.rs +++ b/crates/algorithm/src/freepages.rs @@ -1,9 +1,9 @@ -use crate::algorithm::tuples::*; -use crate::algorithm::*; -use crate::utils::pipe::Pipe; +use crate::pipe::Pipe; +use crate::tuples::*; +use crate::*; use std::cmp::Reverse; -pub fn mark(relation: impl RelationWrite, freepage_first: u32, pages: &[u32]) { +pub fn mark(index: impl RelationWrite, freepage_first: u32, pages: &[u32]) { let mut pages = pages.to_vec(); pages.sort_by_key(|x| Reverse(*x)); pages.dedup(); @@ -18,7 +18,7 @@ pub fn mark(relation: impl RelationWrite, freepage_first: u32, pages: &[u32]) { } local }; - let mut freespace_guard = relation.write(current, false); + let mut freespace_guard = index.write(current, false); if freespace_guard.len() == 0 { freespace_guard.alloc(&serialize(&FreepageTuple {})); } @@ -30,19 +30,19 @@ pub fn mark(relation: impl RelationWrite, freepage_first: u32, pages: &[u32]) { freespace_tuple.mark(local as _); } if freespace_guard.get_opaque().next == u32::MAX { - let extend = relation.extend(false); + let extend = index.extend(false); freespace_guard.get_opaque_mut().next = extend.id(); } (current, offset) = (freespace_guard.get_opaque().next, offset + 32768); } } -pub fn fetch(relation: impl RelationWrite, freepage_first: u32) -> Option { +pub fn fetch(index: impl RelationWrite, freepage_first: u32) -> Option { let first = freepage_first; assert!(first != u32::MAX); let (mut current, mut offset) = (first, 0_u32); loop { - let mut freespace_guard = relation.write(current, false); + let mut freespace_guard = index.write(current, false); if freespace_guard.len() == 0 { return None; } diff --git a/src/algorithm/insert.rs b/crates/algorithm/src/insert.rs similarity index 60% rename from src/algorithm/insert.rs rename to crates/algorithm/src/insert.rs index f2b8cfbc..9b71bae4 100644 --- a/src/algorithm/insert.rs +++ b/crates/algorithm/src/insert.rs @@ -1,27 +1,24 @@ -use crate::algorithm::operator::*; -use crate::algorithm::tape::read_h1_tape; -use crate::algorithm::tuples::*; -use crate::algorithm::vectors::{self}; -use crate::algorithm::{Page, PageGuard, RelationWrite}; -use crate::utils::pipe::Pipe; +use crate::linked_vec::LinkedVec; +use crate::operator::*; +use crate::pipe::Pipe; +use crate::select_heap::SelectHeap; +use crate::tape::{access_1, append}; +use crate::tuples::*; +use crate::vectors::{self}; +use crate::{Page, RelationWrite}; use always_equal::AlwaysEqual; use distance::Distance; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::num::NonZeroU64; -use vector::VectorBorrowed; -use vector::VectorOwned; +use vector::{VectorBorrowed, VectorOwned}; -pub fn insert( - relation: impl RelationWrite + Clone, - payload: NonZeroU64, - vector: O::Vector, -) { - let vector = O::Vector::random_projection(vector.as_borrowed()); - let meta_guard = relation.read(0); +pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vector: O::Vector) { + let meta_guard = index.read(0); let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); let dims = meta_tuple.dims(); let is_residual = meta_tuple.is_residual(); + let rerank_in_heap = meta_tuple.rerank_in_heap(); let height_of_root = meta_tuple.height_of_root(); assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); let root_mean = meta_tuple.root_mean(); @@ -35,19 +32,18 @@ pub fn insert( None }; - let mean = vectors::vector_append::( - relation.clone(), - vectors_first, - vector.as_borrowed(), - payload, - ); + let mean = if !rerank_in_heap { + vectors::append::(index.clone(), vectors_first, vector.as_borrowed(), payload) + } else { + IndexPointer::default() + }; type State = (u32, Option<::Vector>); let mut state: State = { let mean = root_mean; if is_residual { - let residual_u = vectors::vector_access_1::( - relation.clone(), + let residual_u = vectors::access_1::( + index.clone(), mean, LAccess::new( O::Vector::elements_and_metadata(vector.as_borrowed()), @@ -60,7 +56,7 @@ pub fn insert( } }; let step = |state: State| { - let mut results = Vec::new(); + let mut results = LinkedVec::new(); { let (first, residual) = state; let lut = if let Some(residual) = residual { @@ -68,8 +64,8 @@ pub fn insert( } else { default_lut_block.as_ref().unwrap() }; - read_h1_tape( - relation.clone(), + access_1( + index.clone(), first, || { RAccess::new( @@ -82,14 +78,14 @@ pub fn insert( }, ); } - let mut heap = BinaryHeap::from(results); + let mut heap = SelectHeap::from_vec(results.into_vec()); let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); if is_residual { - let (dis_u, residual_u) = vectors::vector_access_1::( - relation.clone(), + let (dis_u, residual_u) = vectors::access_1::( + index.clone(), mean, LAccess::new( O::Vector::elements_and_metadata(vector.as_borrowed()), @@ -105,8 +101,8 @@ pub fn insert( AlwaysEqual(Some(residual_u)), )); } else { - let dis_u = vectors::vector_access_1::( - relation.clone(), + let dis_u = vectors::access_1::( + index.clone(), mean, LAccess::new( O::Vector::elements_and_metadata(vector.as_borrowed()), @@ -140,7 +136,7 @@ pub fn insert( elements: rabitq::pack_to_u64(&code.signs), }); - let jump_guard = relation.read(first); + let jump_guard = index.read(first); let jump_tuple = jump_guard .get(1) .expect("data corruption") @@ -148,43 +144,5 @@ pub fn insert( let first = jump_tuple.first(); - assert!(first != u32::MAX); - let mut current = first; - loop { - let read = relation.read(current); - if read.get_opaque().next == u32::MAX { - drop(read); - let mut write = relation.write(current, false); - if write.get_opaque().next == u32::MAX { - if write.alloc(&bytes).is_some() { - return; - } - let mut extend = relation.extend(false); - write.get_opaque_mut().next = extend.id(); - drop(write); - let fresh = extend.id(); - if extend.alloc(&bytes).is_some() { - drop(extend); - let mut past = relation.write(first, false); - past.get_opaque_mut().skip = std::cmp::max(past.get_opaque_mut().skip, fresh); - drop(past); - return; - } else { - panic!("a tuple cannot even be fit in a fresh page"); - } - } else { - if current == first && write.get_opaque().skip != first { - current = write.get_opaque().skip; - } else { - current = write.get_opaque().next; - } - } - } else { - if current == first && read.get_opaque().skip != first { - current = read.get_opaque().skip; - } else { - current = read.get_opaque().next; - } - } - } + append(index.clone(), first, &bytes, false); } diff --git a/src/algorithm/mod.rs b/crates/algorithm/src/lib.rs similarity index 58% rename from src/algorithm/mod.rs rename to crates/algorithm/src/lib.rs index 7a9c3ff2..372774f3 100644 --- a/src/algorithm/mod.rs +++ b/crates/algorithm/src/lib.rs @@ -1,17 +1,40 @@ -pub mod build; -pub mod freepages; -pub mod insert; +#![allow(clippy::collapsible_else_if)] +#![allow(clippy::type_complexity)] +#![allow(clippy::len_without_is_empty)] + +mod build; +mod bulkdelete; +mod cache; +mod freepages; +mod insert; +mod linked_vec; +mod maintain; +mod pipe; +mod prewarm; +mod rerank; +mod search; +mod select_heap; +mod tape; +mod tuples; +mod vectors; + pub mod operator; -pub mod prewarm; -pub mod scan; -pub mod tape; -pub mod tuples; -pub mod vacuum; -pub mod vectors; +pub mod types; + +pub use build::build; +pub use bulkdelete::bulkdelete; +pub use cache::cache; +pub use insert::insert; +pub use maintain::maintain; +pub use prewarm::prewarm; +pub use rerank::{rerank_heap, rerank_index}; +pub use search::search; use std::ops::{Deref, DerefMut}; +use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout}; #[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] pub struct Opaque { pub next: u32, pub skip: u32, @@ -34,7 +57,7 @@ pub trait PageGuard { fn id(&self) -> u32; } -pub trait RelationRead { +pub trait RelationRead: Clone { type Page: Page; type ReadGuard<'a>: PageGuard + Deref where @@ -50,3 +73,9 @@ pub trait RelationWrite: RelationRead { fn extend(&self, tracking_freespace: bool) -> Self::WriteGuard<'_>; fn search(&self, freespace: usize) -> Option>; } + +#[derive(Debug, Clone, Copy)] +pub enum RerankMethod { + Index, + Heap, +} diff --git a/crates/algorithm/src/linked_vec.rs b/crates/algorithm/src/linked_vec.rs new file mode 100644 index 00000000..179a0489 --- /dev/null +++ b/crates/algorithm/src/linked_vec.rs @@ -0,0 +1,36 @@ +pub struct LinkedVec { + inner: Vec>, + last: Vec, +} + +impl LinkedVec { + pub fn new() -> Self { + Self { + inner: Vec::new(), + last: Vec::with_capacity(4096), + } + } + pub fn push(&mut self, value: T) { + if self.last.len() == self.last.capacity() { + self.reserve(); + } + #[allow(unsafe_code)] + unsafe { + std::hint::assert_unchecked(self.last.len() != self.last.capacity()); + } + self.last.push(value); + } + #[cold] + fn reserve(&mut self) { + let fresh = Vec::with_capacity(self.last.capacity() * 4); + self.inner.push(core::mem::replace(&mut self.last, fresh)); + } + pub fn into_vec(self) -> Vec { + let mut last = self.last; + last.reserve(self.inner.iter().map(Vec::len).sum::()); + for x in self.inner { + last.extend(x); + } + last + } +} diff --git a/crates/algorithm/src/maintain.rs b/crates/algorithm/src/maintain.rs new file mode 100644 index 00000000..fec057b6 --- /dev/null +++ b/crates/algorithm/src/maintain.rs @@ -0,0 +1,147 @@ +use crate::operator::Operator; +use crate::pipe::Pipe; +use crate::tape::*; +use crate::tuples::*; +use crate::{Page, RelationWrite, freepages}; +use simd::fast_scan::unpack; + +pub fn maintain(index: impl RelationWrite, check: impl Fn()) { + let meta_guard = index.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let dims = meta_tuple.dims(); + let height_of_root = meta_tuple.height_of_root(); + let root_first = meta_tuple.root_first(); + let freepage_first = meta_tuple.freepage_first(); + drop(meta_guard); + + let firsts = { + type State = Vec; + let mut state: State = vec![root_first]; + let step = |state: State| { + let mut results = Vec::new(); + for first in state { + let mut current = first; + while current != u32::MAX { + check(); + let h1_guard = index.read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + for first in h1_tuple.first().iter().copied() { + results.push(first); + } + } + H1TupleReader::_1(_) => (), + } + } + current = h1_guard.get_opaque().next; + } + } + results + }; + for _ in (1..height_of_root).rev() { + state = step(state); + } + state + }; + + for first in firsts { + let mut jump_guard = index.write(first, false); + let mut jump_tuple = jump_guard + .get_mut(1) + .expect("data corruption") + .pipe(write_tuple::); + + let mut tape = H0TapeWriter::<_, _>::create(|| { + if let Some(id) = freepages::fetch(index.clone(), freepage_first) { + let mut write = index.write(id, false); + write.clear(); + write + } else { + index.extend(false) + } + }); + + let mut trace = Vec::new(); + + let first = *jump_tuple.first(); + let mut current = first; + let mut computing = None; + while current != u32::MAX { + check(); + trace.push(current); + let h0_guard = index.read(current); + for i in 1..=h0_guard.len() { + let h0_tuple = h0_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h0_tuple { + H0TupleReader::_0(h0_tuple) => { + if let Some(payload) = h0_tuple.payload() { + tape.push(H0Branch { + mean: h0_tuple.mean(), + dis_u_2: h0_tuple.code().0, + factor_ppc: h0_tuple.code().1, + factor_ip: h0_tuple.code().2, + factor_err: h0_tuple.code().3, + signs: h0_tuple + .code() + .4 + .iter() + .flat_map(|x| { + std::array::from_fn::<_, 64, _>(|i| *x & (1 << i) != 0) + }) + .take(dims as _) + .collect::>(), + payload, + }); + } + } + H0TupleReader::_1(h0_tuple) => { + let computing = &mut computing.take().unwrap_or_else(Vec::new); + computing.extend_from_slice(h0_tuple.elements()); + let unpacked = unpack(computing); + for j in 0..32 { + if let Some(payload) = h0_tuple.payload()[j] { + tape.push(H0Branch { + mean: h0_tuple.mean()[j], + dis_u_2: h0_tuple.metadata().0[j], + factor_ppc: h0_tuple.metadata().1[j], + factor_ip: h0_tuple.metadata().2[j], + factor_err: h0_tuple.metadata().3[j], + signs: unpacked[j] + .iter() + .flat_map(|&x| { + [x & 1 != 0, x & 2 != 0, x & 4 != 0, x & 8 != 0] + }) + .collect(), + payload, + }); + } + } + } + H0TupleReader::_2(h0_tuple) => { + let computing = computing.get_or_insert_with(Vec::new); + computing.extend_from_slice(h0_tuple.elements()); + } + } + } + current = h0_guard.get_opaque().next; + drop(h0_guard); + } + + let tape = tape.into_inner(); + let new = tape.first(); + drop(tape); + + *jump_tuple.first() = new; + drop(jump_guard); + + freepages::mark(index.clone(), freepage_first, &trace); + } +} diff --git a/src/algorithm/operator.rs b/crates/algorithm/src/operator.rs similarity index 88% rename from src/algorithm/operator.rs rename to crates/algorithm/src/operator.rs index 9506d7a2..e8a57bd4 100644 --- a/src/algorithm/operator.rs +++ b/crates/algorithm/src/operator.rs @@ -1,4 +1,4 @@ -use crate::types::{DistanceKind, OwnedVector}; +use crate::types::*; use distance::Distance; use half::f16; use simd::Floating; @@ -174,15 +174,15 @@ impl Default for Block { impl Accessor2< - [u64; 2], - [u64; 2], + [u8; 16], + [u8; 16], (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32]), (f32, f32, f32, f32, f32), > for Block { type Output = [Distance; 32]; - fn push(&mut self, input: &[[u64; 2]], target: &[[u64; 2]]) { + fn push(&mut self, input: &[[u8; 16]], target: &[[u8; 16]]) { let t = simd::fast_scan::fast_scan(input, target); for i in 0..32 { self.0[i] += t[i]; @@ -212,15 +212,15 @@ impl impl Accessor2< - [u64; 2], - [u64; 2], + [u8; 16], + [u8; 16], (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32]), (f32, f32, f32, f32, f32), > for Block { type Output = [Distance; 32]; - fn push(&mut self, input: &[[u64; 2]], target: &[[u64; 2]]) { + fn push(&mut self, input: &[[u8; 16]], target: &[[u8; 16]]) { let t = simd::fast_scan::fast_scan(input, target); for i in 0..32 { self.0[i] += t[i]; @@ -324,7 +324,6 @@ pub struct RAccess<'a, E, M, A> { } impl<'a, E, M, A> RAccess<'a, E, M, A> { - #[allow(dead_code)] pub fn new((elements, metadata): (&'a [E], M), accessor: A) -> Self { Self { elements, @@ -356,22 +355,16 @@ pub trait Vector: VectorOwned { fn elements_and_metadata(vector: Self::Borrowed<'_>) -> (&[Self::Element], Self::Metadata); fn from_owned(vector: OwnedVector) -> Self; - fn random_projection(vector: Self::Borrowed<'_>) -> Self; - - fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u64; 2]>); + fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u8; 16]>); fn compute_lut( vector: Self::Borrowed<'_>, ) -> ( - (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, Vec<[u8; 16]>), (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), ); fn code(vector: Self::Borrowed<'_>) -> rabitq::Code; - - fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec; - - fn build_from_vecf32(x: &[f32]) -> Self; } impl Vector for VectOwned { @@ -381,11 +374,14 @@ impl Vector for VectOwned { fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f32]>) { let vector = vector.slice(); - ((), match vector.len() { - 0..=960 => vec![vector], - 961..=1280 => vec![&vector[..640], &vector[640..]], - 1281.. => vector.chunks(1920).collect(), - }) + ( + (), + match vector.len() { + 0..=960 => vec![vector], + 961..=1280 => vec![&vector[..640], &vector[640..]], + 1281.. => vector.chunks(1920).collect(), + }, + ) } fn elements_and_metadata(vector: Self::Borrowed<'_>) -> (&[Self::Element], Self::Metadata) { @@ -399,18 +395,14 @@ impl Vector for VectOwned { } } - fn random_projection(vector: Self::Borrowed<'_>) -> Self { - Self::new(crate::projection::project(vector.slice())) - } - - fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { + fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u8; 16]>) { rabitq::block::preprocess(vector.slice()) } fn compute_lut( vector: Self::Borrowed<'_>, ) -> ( - (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, Vec<[u8; 16]>), (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), ) { rabitq::compute_lut(vector.slice()) @@ -419,14 +411,6 @@ impl Vector for VectOwned { fn code(vector: Self::Borrowed<'_>) -> rabitq::Code { rabitq::code(vector.dims(), vector.slice()) } - - fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { - vector.slice().to_vec() - } - - fn build_from_vecf32(x: &[f32]) -> Self { - Self::new(x.to_vec()) - } } impl Vector for VectOwned { @@ -436,11 +420,14 @@ impl Vector for VectOwned { fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f16]>) { let vector = vector.slice(); - ((), match vector.len() { - 0..=1920 => vec![vector], - 1921..=2560 => vec![&vector[..1280], &vector[1280..]], - 2561.. => vector.chunks(3840).collect(), - }) + ( + (), + match vector.len() { + 0..=1920 => vec![vector], + 1921..=2560 => vec![&vector[..1280], &vector[1280..]], + 2561.. => vector.chunks(3840).collect(), + }, + ) } fn elements_and_metadata(vector: Self::Borrowed<'_>) -> (&[Self::Element], Self::Metadata) { @@ -454,20 +441,14 @@ impl Vector for VectOwned { } } - fn random_projection(vector: Self::Borrowed<'_>) -> Self { - Self::new(f16::vector_from_f32(&crate::projection::project( - &f16::vector_to_f32(vector.slice()), - ))) - } - - fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { + fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u8; 16]>) { rabitq::block::preprocess(&f16::vector_to_f32(vector.slice())) } fn compute_lut( vector: Self::Borrowed<'_>, ) -> ( - (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, Vec<[u8; 16]>), (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), ) { rabitq::compute_lut(&f16::vector_to_f32(vector.slice())) @@ -476,14 +457,6 @@ impl Vector for VectOwned { fn code(vector: Self::Borrowed<'_>) -> rabitq::Code { rabitq::code(vector.dims(), &f16::vector_to_f32(vector.slice())) } - - fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { - f16::vector_to_f32(vector.slice()) - } - - fn build_from_vecf32(x: &[f32]) -> Self { - Self::new(f16::vector_from_f32(x)) - } } pub trait OperatorDistance: 'static + Debug + Copy { @@ -496,8 +469,8 @@ pub trait OperatorDistance: 'static + Debug + Copy { ) -> Distance; type BlockAccessor: for<'a> Accessor2< - [u64; 2], - [u64; 2], + [u8; 16], + [u8; 16], (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]), (f32, f32, f32, f32, f32), Output = [Distance; 32], diff --git a/src/utils/pipe.rs b/crates/algorithm/src/pipe.rs similarity index 100% rename from src/utils/pipe.rs rename to crates/algorithm/src/pipe.rs diff --git a/src/algorithm/prewarm.rs b/crates/algorithm/src/prewarm.rs similarity index 84% rename from src/algorithm/prewarm.rs rename to crates/algorithm/src/prewarm.rs index ab2731fa..587f7528 100644 --- a/src/algorithm/prewarm.rs +++ b/crates/algorithm/src/prewarm.rs @@ -1,12 +1,11 @@ -use crate::algorithm::operator::Operator; -use crate::algorithm::tuples::*; -use crate::algorithm::vectors; -use crate::algorithm::{Page, RelationRead}; -use crate::utils::pipe::Pipe; +use crate::operator::Operator; +use crate::pipe::Pipe; +use crate::tuples::*; +use crate::{Page, RelationRead, vectors}; use std::fmt::Write; -pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> String { - let meta_guard = relation.read(0); +pub fn prewarm(index: impl RelationRead, height: i32, check: impl Fn()) -> String { + let meta_guard = index.read(0); let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); let height_of_root = meta_tuple.height_of_root(); let root_mean = meta_tuple.root_mean(); @@ -23,7 +22,7 @@ pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> let mut state: State = { let mut nodes = Vec::new(); { - vectors::vector_access_1::(relation.clone(), root_mean, ()); + vectors::access_1::(index.clone(), root_mean, ()); nodes.push(root_first); } writeln!(message, "------------------------").unwrap(); @@ -40,8 +39,8 @@ pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> let mut current = list; while current != u32::MAX { counter_pages += 1; - pgrx::check_for_interrupts!(); - let h1_guard = relation.read(current); + check(); + let h1_guard = index.read(current); for i in 1..=h1_guard.len() { counter_tuples += 1; let h1_tuple = h1_guard @@ -51,7 +50,7 @@ pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> match h1_tuple { H1TupleReader::_0(h1_tuple) => { for mean in h1_tuple.mean().iter().copied() { - vectors::vector_access_1::(relation.clone(), mean, ()); + vectors::access_1::(index.clone(), mean, ()); } for first in h1_tuple.first().iter().copied() { nodes.push(first); @@ -77,7 +76,7 @@ pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> let mut counter_tuples = 0_usize; let mut counter_nodes = 0_usize; for list in state { - let jump_guard = relation.read(list); + let jump_guard = index.read(list); let jump_tuple = jump_guard .get(1) .expect("data corruption") @@ -86,8 +85,8 @@ pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> let mut current = first; while current != u32::MAX { counter_pages += 1; - pgrx::check_for_interrupts!(); - let h0_guard = relation.read(current); + check(); + let h0_guard = index.read(current); for i in 1..=h0_guard.len() { counter_tuples += 1; let h0_tuple = h0_guard diff --git a/crates/algorithm/src/rerank.rs b/crates/algorithm/src/rerank.rs new file mode 100644 index 00000000..840fdd4c --- /dev/null +++ b/crates/algorithm/src/rerank.rs @@ -0,0 +1,71 @@ +use crate::operator::*; +use crate::tuples::*; +use crate::{RelationRead, vectors}; +use always_equal::AlwaysEqual; +use distance::Distance; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::num::NonZeroU64; +use vector::VectorOwned; + +pub fn rerank_index( + index: impl RelationRead, + vector: O::Vector, + results: Vec<( + Reverse, + AlwaysEqual, + AlwaysEqual, + )>, +) -> impl Iterator { + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _)>::new(); + std::iter::from_fn(move || { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); + if let Some(dis_u) = vectors::access_0::( + index.clone(), + mean, + pay_u, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + O::DistanceAccessor::default(), + ), + ) { + cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); + }; + } + let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; + Some((dis_u, pay_u)) + }) +} + +pub fn rerank_heap( + vector: O::Vector, + results: Vec<( + Reverse, + AlwaysEqual, + AlwaysEqual, + )>, + fetch: F, +) -> impl Iterator +where + F: Fn(NonZeroU64) -> Option, +{ + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _)>::new(); + std::iter::from_fn(move || { + let vector = O::Vector::elements_and_metadata(vector.as_borrowed()); + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(_), AlwaysEqual(pay_u)) = heap.pop().unwrap(); + if let Some(vec_u) = fetch(pay_u) { + let vec_u = O::Vector::elements_and_metadata(vec_u.as_borrowed()); + let mut accessor = O::DistanceAccessor::default(); + accessor.push(vector.0, vec_u.0); + let dis_u = accessor.finish(vector.1, vec_u.1); + cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); + } + } + let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; + Some((dis_u, pay_u)) + }) +} diff --git a/src/algorithm/scan.rs b/crates/algorithm/src/search.rs similarity index 68% rename from src/algorithm/scan.rs rename to crates/algorithm/src/search.rs index fee6df3b..9972e861 100644 --- a/src/algorithm/scan.rs +++ b/crates/algorithm/src/search.rs @@ -1,32 +1,43 @@ -use crate::algorithm::operator::*; -use crate::algorithm::tape::read_h0_tape; -use crate::algorithm::tape::read_h1_tape; -use crate::algorithm::tuples::*; -use crate::algorithm::vectors; -use crate::algorithm::{Page, RelationRead}; -use crate::utils::pipe::Pipe; +use crate::linked_vec::LinkedVec; +use crate::operator::*; +use crate::pipe::Pipe; +use crate::tape::{access_0, access_1}; +use crate::tuples::*; +use crate::{Page, RelationRead, RerankMethod, vectors}; use always_equal::AlwaysEqual; use distance::Distance; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::num::NonZeroU64; -use vector::VectorBorrowed; -use vector::VectorOwned; +use vector::{VectorBorrowed, VectorOwned}; -pub fn scan( - relation: impl RelationRead + Clone, +pub fn search( + index: impl RelationRead, vector: O::Vector, probes: Vec, epsilon: f32, -) -> impl Iterator { - let vector = O::Vector::random_projection(vector.as_borrowed()); - let meta_guard = relation.read(0); +) -> ( + RerankMethod, + Vec<( + Reverse, + AlwaysEqual, + AlwaysEqual, + )>, +) { + let meta_guard = index.read(0); let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); let dims = meta_tuple.dims(); let is_residual = meta_tuple.is_residual(); + let rerank_in_heap = meta_tuple.rerank_in_heap(); let height_of_root = meta_tuple.height_of_root(); assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); - assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); + if height_of_root as usize != 1 + probes.len() { + panic!( + "need {} probes, but {} probes provided", + height_of_root - 1, + probes.len() + ); + } let root_mean = meta_tuple.root_mean(); let root_first = meta_tuple.root_first(); drop(meta_guard); @@ -41,8 +52,8 @@ pub fn scan( let mut state: State = vec![{ let mean = root_mean; if is_residual { - let residual_u = vectors::vector_access_1::( - relation.clone(), + let residual_u = vectors::access_1::( + index.clone(), mean, LAccess::new( O::Vector::elements_and_metadata(vector.as_borrowed()), @@ -55,15 +66,15 @@ pub fn scan( } }]; let step = |state: State, probes| { - let mut results = Vec::new(); + let mut results = LinkedVec::new(); for (first, residual) in state { let lut = if let Some(residual) = residual { &O::Vector::compute_lut_block(residual.as_borrowed()) } else { default_lut.as_ref().map(|x| &x.0).unwrap() }; - read_h1_tape( - relation.clone(), + access_1( + index.clone(), first, || { RAccess::new( @@ -76,14 +87,14 @@ pub fn scan( }, ); } - let mut heap = BinaryHeap::from(results); + let mut heap = BinaryHeap::from(results.into_vec()); let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); std::iter::from_fn(|| { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); if is_residual { - let (dis_u, residual_u) = vectors::vector_access_1::( - relation.clone(), + let (dis_u, residual_u) = vectors::access_1::( + index.clone(), mean, LAccess::new( O::Vector::elements_and_metadata(vector.as_borrowed()), @@ -99,8 +110,8 @@ pub fn scan( AlwaysEqual(Some(residual_u)), )); } else { - let dis_u = vectors::vector_access_1::( - relation.clone(), + let dis_u = vectors::access_1::( + index.clone(), mean, LAccess::new( O::Vector::elements_and_metadata(vector.as_borrowed()), @@ -120,21 +131,21 @@ pub fn scan( state = step(state, probes[i as usize - 1]); } - let mut results = Vec::new(); + let mut results = LinkedVec::new(); for (first, residual) in state { let lut = if let Some(residual) = residual.as_ref().map(|x| x.as_borrowed()) { &O::Vector::compute_lut(residual) } else { default_lut.as_ref().unwrap() }; - let jump_guard = relation.read(first); + let jump_guard = index.read(first); let jump_tuple = jump_guard .get(1) .expect("data corruption") .pipe(read_tuple::); let first = jump_tuple.first(); - read_h0_tape( - relation.clone(), + access_0( + index.clone(), first, || { RAccess::new( @@ -148,24 +159,12 @@ pub fn scan( }, ); } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _)>::new(); - std::iter::from_fn(move || { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); - if let Some(dis_u) = vectors::vector_access_0::( - relation.clone(), - mean, - pay_u, - LAccess::new( - O::Vector::elements_and_metadata(vector.as_borrowed()), - O::DistanceAccessor::default(), - ), - ) { - cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); - }; - } - let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; - Some((dis_u, pay_u)) - }) + ( + if rerank_in_heap { + RerankMethod::Heap + } else { + RerankMethod::Index + }, + results.into_vec(), + ) } diff --git a/crates/algorithm/src/select_heap.rs b/crates/algorithm/src/select_heap.rs new file mode 100644 index 00000000..86723881 --- /dev/null +++ b/crates/algorithm/src/select_heap.rs @@ -0,0 +1,63 @@ +pub struct SelectHeap { + threshold: usize, + inner: Vec, +} + +impl SelectHeap { + pub fn from_vec(mut vec: Vec) -> Self { + let n = vec.len(); + if n != 0 { + let threshold = n.saturating_sub(n.div_ceil(384)); + turboselect::select_nth_unstable(&mut vec, threshold); + vec[threshold..].sort_unstable(); + Self { + threshold, + inner: vec, + } + } else { + Self { + threshold: 0, + inner: Vec::new(), + } + } + } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn pop(&mut self) -> Option { + if self.inner.len() <= self.threshold { + heapify::pop_heap(&mut self.inner); + } + let t = self.inner.pop(); + if self.inner.len() == self.threshold { + heapify::make_heap(&mut self.inner); + } + t + } + pub fn peek(&self) -> Option<&T> { + if self.inner.len() <= self.threshold { + self.inner.first() + } else { + self.inner.last() + } + } +} + +#[test] +fn test_select_heap() { + for _ in 0..1000 { + let sequence = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let answer = { + let mut x = sequence.clone(); + x.sort_by_key(|x| std::cmp::Reverse(*x)); + x + }; + let result = { + let mut x = SelectHeap::from_vec(sequence.clone()); + std::iter::from_fn(|| x.pop()).collect::>() + }; + assert_eq!(answer, result); + } +} diff --git a/src/algorithm/tape.rs b/crates/algorithm/src/tape.rs similarity index 82% rename from src/algorithm/tape.rs rename to crates/algorithm/src/tape.rs index 4ee722a5..edc1c4ee 100644 --- a/src/algorithm/tape.rs +++ b/crates/algorithm/src/tape.rs @@ -1,12 +1,9 @@ -use super::RelationRead; -use super::operator::Accessor1; -use crate::algorithm::Page; -use crate::algorithm::PageGuard; -use crate::algorithm::tuples::*; -use crate::utils::pipe::Pipe; +use crate::operator::Accessor1; +use crate::pipe::Pipe; +use crate::tuples::*; +use crate::{Page, PageGuard, RelationRead, RelationWrite}; use distance::Distance; -use simd::fast_scan::any_pack; -use simd::fast_scan::padding_pack; +use simd::fast_scan::{any_pack, padding_pack}; use std::marker::PhantomData; use std::num::NonZeroU64; use std::ops::DerefMut; @@ -178,7 +175,7 @@ where } } -pub struct H0BranchWriter { +pub struct H0Branch { pub mean: IndexPointer, pub dis_u_2: f32, pub factor_ppc: f32, @@ -188,12 +185,12 @@ pub struct H0BranchWriter { pub payload: NonZeroU64, } -pub struct H0Tape { +pub struct H0TapeWriter { tape: TapeWriter, - branches: Vec, + branches: Vec, } -impl H0Tape +impl H0TapeWriter where G: PageGuard + DerefMut, G::Target: Page, @@ -205,7 +202,7 @@ where branches: Vec::new(), } } - pub fn push(&mut self, branch: H0BranchWriter) { + pub fn push(&mut self, branch: H0Branch) { self.branches.push(branch); if self.branches.len() == 32 { let chunk = std::array::from_fn::<_, 32, _>(|_| self.branches.pop().unwrap()); @@ -253,14 +250,14 @@ where } } -pub fn read_h1_tape( - relation: impl RelationRead, +pub fn access_1( + index: impl RelationRead, first: u32, - compute_block: impl Fn() -> A + Copy, + make_block_accessor: impl Fn() -> A + Copy, mut callback: impl FnMut(Distance, IndexPointer, u32), ) where A: for<'a> Accessor1< - [u64; 2], + [u8; 16], (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]), Output = [Distance; 32], >, @@ -269,7 +266,7 @@ pub fn read_h1_tape( let mut current = first; let mut computing = None; while current != u32::MAX { - let h1_guard = relation.read(current); + let h1_guard = index.read(current); for i in 1..=h1_guard.len() { let h1_tuple = h1_guard .get(i) @@ -277,7 +274,7 @@ pub fn read_h1_tape( .pipe(read_tuple::); match h1_tuple { H1TupleReader::_0(h1_tuple) => { - let mut compute = computing.take().unwrap_or_else(compute_block); + let mut compute = computing.take().unwrap_or_else(make_block_accessor); compute.push(h1_tuple.elements()); let lowerbounds = compute.finish(h1_tuple.metadata()); for i in 0..h1_tuple.len() { @@ -289,7 +286,7 @@ pub fn read_h1_tape( } } H1TupleReader::_1(h1_tuple) => { - let computing = computing.get_or_insert_with(compute_block); + let computing = computing.get_or_insert_with(make_block_accessor); computing.push(h1_tuple.elements()); } } @@ -298,15 +295,15 @@ pub fn read_h1_tape( } } -pub fn read_h0_tape( - relation: impl RelationRead, +pub fn access_0( + index: impl RelationRead, first: u32, - compute_block: impl Fn() -> A + Copy, + make_block_accessor: impl Fn() -> A + Copy, compute_binary: impl Fn((f32, f32, f32, f32, &[u64])) -> Distance, mut callback: impl FnMut(Distance, IndexPointer, NonZeroU64), ) where A: for<'a> Accessor1< - [u64; 2], + [u8; 16], (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]), Output = [Distance; 32], >, @@ -315,7 +312,7 @@ pub fn read_h0_tape( let mut current = first; let mut computing = None; while current != u32::MAX { - let h0_guard = relation.read(current); + let h0_guard = index.read(current); for i in 1..=h0_guard.len() { let h0_tuple = h0_guard .get(i) @@ -329,7 +326,7 @@ pub fn read_h0_tape( } } H0TupleReader::_1(h0_tuple) => { - let mut compute = computing.take().unwrap_or_else(compute_block); + let mut compute = computing.take().unwrap_or_else(make_block_accessor); compute.push(h0_tuple.elements()); let lowerbounds = compute.finish(h0_tuple.metadata()); for j in 0..32 { @@ -339,7 +336,7 @@ pub fn read_h0_tape( } } H0TupleReader::_2(h0_tuple) => { - let computing = computing.get_or_insert_with(compute_block); + let computing = computing.get_or_insert_with(make_block_accessor); computing.push(h0_tuple.elements()); } } @@ -347,3 +344,48 @@ pub fn read_h0_tape( current = h0_guard.get_opaque().next; } } + +pub fn append( + index: impl RelationWrite, + first: u32, + bytes: &[u8], + tracking_freespace: bool, +) -> IndexPointer { + assert!(first != u32::MAX); + let mut current = first; + loop { + let read = index.read(current); + if read.freespace() as usize >= bytes.len() || read.get_opaque().next == u32::MAX { + drop(read); + let mut write = index.write(current, tracking_freespace); + if write.get_opaque().next == u32::MAX { + if let Some(i) = write.alloc(bytes) { + return pair_to_pointer((current, i)); + } + let mut extend = index.extend(tracking_freespace); + write.get_opaque_mut().next = extend.id(); + drop(write); + let fresh = extend.id(); + if let Some(i) = extend.alloc(bytes) { + drop(extend); + let mut past = index.write(first, tracking_freespace); + past.get_opaque_mut().skip = fresh.max(past.get_opaque().skip); + return pair_to_pointer((fresh, i)); + } else { + panic!("a tuple cannot even be fit in a fresh page"); + } + } + if current == first && write.get_opaque().skip != first { + current = write.get_opaque().skip; + } else { + current = write.get_opaque().next; + } + } else { + if current == first && read.get_opaque().skip != first { + current = read.get_opaque().skip; + } else { + current = read.get_opaque().next; + } + } + } +} diff --git a/src/algorithm/tuples.rs b/crates/algorithm/src/tuples.rs similarity index 91% rename from src/algorithm/tuples.rs rename to crates/algorithm/src/tuples.rs index 06d7cedb..f60d5527 100644 --- a/src/algorithm/tuples.rs +++ b/crates/algorithm/src/tuples.rs @@ -1,4 +1,5 @@ -use crate::algorithm::operator::Vector; +use crate::operator::Vector; +use std::marker::PhantomData; use std::num::{NonZeroU8, NonZeroU64}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout}; @@ -6,7 +7,7 @@ use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout}; pub const ALIGN: usize = 8; pub type Tag = u64; const MAGIC: u64 = u64::from_ne_bytes(*b"vchordrq"); -const VERSION: u64 = 1; +const VERSION: u64 = 2; pub trait Tuple: 'static { type Reader<'a>: TupleReader<'a, Tuple = Self>; @@ -49,7 +50,8 @@ struct MetaTupleHeader { dims: u32, height_of_root: u32, is_residual: Bool, - _padding_0: [ZeroU8; 3], + rerank_in_heap: Bool, + _padding_0: [ZeroU8; 2], vectors_first: u32, // raw vector root_mean: IndexPointer, @@ -62,6 +64,7 @@ pub struct MetaTuple { pub dims: u32, pub height_of_root: u32, pub is_residual: bool, + pub rerank_in_heap: bool, pub vectors_first: u32, pub root_mean: IndexPointer, pub root_first: u32, @@ -78,6 +81,7 @@ impl Tuple for MetaTuple { dims: self.dims, height_of_root: self.height_of_root, is_residual: self.is_residual.into(), + rerank_in_heap: self.rerank_in_heap.into(), _padding_0: Default::default(), vectors_first: self.vectors_first, root_mean: self.root_mean, @@ -124,6 +128,9 @@ impl MetaTupleReader<'_> { pub fn is_residual(self) -> bool { self.header.is_residual.into() } + pub fn rerank_in_heap(self) -> bool { + self.header.rerank_in_heap.into() + } pub fn vectors_first(self) -> u32 { self.header.vectors_first } @@ -141,7 +148,7 @@ impl MetaTupleReader<'_> { // freepage tuple #[repr(C, align(8))] -#[derive(Debug, Clone, Copy, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] struct FreepageTupleHeader { a: [u32; 1], b: [u32; 32], @@ -233,16 +240,18 @@ impl FreepageTupleWriter<'_> { // vector tuple #[repr(C, align(8))] -#[derive(Debug, Clone, Copy, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] struct VectorTupleHeader0 { payload: Option, metadata_s: usize, elements_s: usize, elements_e: usize, + #[cfg(target_pointer_width = "32")] + _padding_0: [ZeroU8; 4], } #[repr(C, align(8))] -#[derive(Debug, Clone, Copy, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] struct VectorTupleHeader1 { payload: Option, pointer: IndexPointer, @@ -276,7 +285,7 @@ impl Tuple for VectorTuple { elements, } => { buffer.extend((0 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -294,6 +303,8 @@ impl Tuple for VectorTuple { metadata_s, elements_s, elements_e, + #[cfg(target_pointer_width = "32")] + _padding_0: Default::default(), } .as_bytes(), ); @@ -304,7 +315,7 @@ impl Tuple for VectorTuple { elements, } => { buffer.extend((1 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -435,10 +446,10 @@ pub enum H1Tuple { factor_err: [f32; 32], first: [u32; 32], len: u32, - elements: Vec<[u64; 2]>, + elements: Vec<[u8; 16]>, }, _1 { - elements: Vec<[u64; 2]>, + elements: Vec<[u8; 16]>, }, } @@ -448,7 +459,7 @@ impl H1Tuple { freespace -= size_of::() as isize; freespace -= size_of::() as isize; if freespace >= 0 { - Some(freespace as usize / size_of::<[u64; 2]>()) + Some(freespace as usize / size_of::<[u8; 16]>()) } else { None } @@ -458,7 +469,7 @@ impl H1Tuple { freespace -= size_of::() as isize; freespace -= size_of::() as isize; if freespace >= 0 { - Some(freespace as usize / size_of::<[u64; 2]>()) + Some(freespace as usize / size_of::<[u8; 16]>()) } else { None } @@ -482,7 +493,7 @@ impl Tuple for H1Tuple { elements, } => { buffer.extend((0 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -507,7 +518,7 @@ impl Tuple for H1Tuple { } Self::_1 { elements } => { buffer.extend((1 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -536,13 +547,13 @@ pub enum H1TupleReader<'a> { #[derive(Debug, Clone, Copy, PartialEq)] pub struct H1TupleReader0<'a> { header: &'a H1TupleHeader0, - elements: &'a [[u64; 2]], + elements: &'a [[u8; 16]], } #[derive(Debug, Clone, Copy, PartialEq)] pub struct H1TupleReader1<'a> { header: &'a H1TupleHeader1, - elements: &'a [[u64; 2]], + elements: &'a [[u8; 16]], } impl<'a> TupleReader<'a> for H1TupleReader<'a> { @@ -586,13 +597,13 @@ impl<'a> H1TupleReader0<'a> { pub fn first(self) -> &'a [u32] { &self.header.first[..self.header.len as usize] } - pub fn elements(&self) -> &'a [[u64; 2]] { + pub fn elements(&self) -> &'a [[u8; 16]] { self.elements } } impl<'a> H1TupleReader1<'a> { - pub fn elements(&self) -> &'a [[u64; 2]] { + pub fn elements(&self) -> &'a [[u8; 16]] { self.elements } } @@ -720,10 +731,10 @@ pub enum H0Tuple { factor_ip: [f32; 32], factor_err: [f32; 32], payload: [Option; 32], - elements: Vec<[u64; 2]>, + elements: Vec<[u8; 16]>, }, _2 { - elements: Vec<[u64; 2]>, + elements: Vec<[u8; 16]>, }, } @@ -733,7 +744,7 @@ impl H0Tuple { freespace -= size_of::() as isize; freespace -= size_of::() as isize; if freespace >= 0 { - Some(freespace as usize / size_of::<[u64; 2]>()) + Some(freespace as usize / size_of::<[u8; 16]>()) } else { None } @@ -743,7 +754,7 @@ impl H0Tuple { freespace -= size_of::() as isize; freespace -= size_of::() as isize; if freespace >= 0 { - Some(freespace as usize / size_of::<[u64; 2]>()) + Some(freespace as usize / size_of::<[u8; 16]>()) } else { None } @@ -766,7 +777,7 @@ impl Tuple for H0Tuple { elements, } => { buffer.extend((0 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -797,7 +808,7 @@ impl Tuple for H0Tuple { elements, } => { buffer.extend((1 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -820,7 +831,7 @@ impl Tuple for H0Tuple { } Self::_2 { elements } => { buffer.extend((2 as Tag).to_ne_bytes()); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } @@ -878,7 +889,7 @@ impl<'a> H0TupleReader0<'a> { #[derive(Debug, Clone, Copy, PartialEq)] pub struct H0TupleReader1<'a> { header: &'a H0TupleHeader1, - elements: &'a [[u64; 2]], + elements: &'a [[u8; 16]], } impl<'a> H0TupleReader1<'a> { @@ -893,7 +904,7 @@ impl<'a> H0TupleReader1<'a> { &self.header.factor_err, ) } - pub fn elements(self) -> &'a [[u64; 2]] { + pub fn elements(self) -> &'a [[u8; 16]] { self.elements } pub fn payload(self) -> &'a [Option; 32] { @@ -904,11 +915,11 @@ impl<'a> H0TupleReader1<'a> { #[derive(Debug, Clone, Copy, PartialEq)] pub struct H0TupleReader2<'a> { header: &'a H0TupleHeader2, - elements: &'a [[u64; 2]], + elements: &'a [[u8; 16]], } impl<'a> H0TupleReader2<'a> { - pub fn elements(self) -> &'a [[u64; 2]] { + pub fn elements(self) -> &'a [[u8; 16]] { self.elements } } @@ -961,7 +972,7 @@ pub struct H0TupleWriter0<'a> { pub struct H0TupleWriter1<'a> { header: &'a mut H0TupleHeader1, #[allow(dead_code)] - elements: &'a mut [[u64; 2]], + elements: &'a mut [[u8; 16]], } #[derive(Debug)] @@ -969,7 +980,7 @@ pub struct H0TupleWriter2<'a> { #[allow(dead_code)] header: &'a mut H0TupleHeader2, #[allow(dead_code)] - elements: &'a mut [[u64; 2]], + elements: &'a mut [[u8; 16]], } impl<'a> TupleWriter<'a> for H0TupleWriter<'a> { @@ -1025,16 +1036,15 @@ pub const fn pair_to_pointer(pair: (u32, u16)) -> IndexPointer { IndexPointer(value) } -#[allow(dead_code)] -const fn soundness_check(a: (u32, u16)) { +#[test] +const fn soundness_check() { + let a = (111, 222); let b = pair_to_pointer(a); let c = pointer_to_pair(b); assert!(a.0 == c.0); assert!(a.1 == c.1); } -const _: () = soundness_check((111, 222)); - #[repr(transparent)] #[derive( Debug, @@ -1123,15 +1133,17 @@ impl<'a> RefChecker<'a> { } pub struct MutChecker<'a> { - flag: Vec, - bytes: &'a mut [u8], + flag: usize, + bytes: *mut [u8], + phantom: PhantomData<&'a mut [u8]>, } impl<'a> MutChecker<'a> { pub fn new(bytes: &'a mut [u8]) -> Self { Self { - flag: vec![0u64; bytes.len().div_ceil(64)], + flag: 0, bytes, + phantom: PhantomData, } } pub fn prefix( @@ -1143,15 +1155,14 @@ impl<'a> MutChecker<'a> { if !(start <= end && end <= self.bytes.len()) { panic!("bad bytes"); } - for i in start..end { - if (self.flag[i / 64] & (1 << (i % 64))) != 0 { - panic!("bad bytes"); - } else { - self.flag[i / 64] |= 1 << (i % 64); - } + if !(self.flag <= start) { + panic!("bad bytes"); + } else { + self.flag = end; } + #[allow(unsafe_code)] let bytes = unsafe { - std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr().add(start), end - start) + std::slice::from_raw_parts_mut((self.bytes as *mut u8).add(start), end - start) }; FromBytes::mut_from_bytes(bytes).expect("bad bytes") } @@ -1165,21 +1176,19 @@ impl<'a> MutChecker<'a> { if !(start <= end && end <= self.bytes.len()) { panic!("bad bytes"); } - for i in start..end { - if (self.flag[i / 64] & (1 << (i % 64))) != 0 { - panic!("bad bytes"); - } else { - self.flag[i / 64] |= 1 << (i % 64); - } + if !(self.flag <= start) { + panic!("bad bytes"); + } else { + self.flag = end; } + #[allow(unsafe_code)] let bytes = unsafe { - std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr().add(start), end - start) + std::slice::from_raw_parts_mut((self.bytes as *mut u8).add(start), end - start) }; FromBytes::mut_from_bytes(bytes).expect("bad bytes") } } -// this test only passes if `MIRIFLAGS="-Zmiri-tree-borrows"` is set #[test] fn aliasing_test() { #[repr(C, align(8))] @@ -1191,7 +1200,7 @@ fn aliasing_test() { let serialized = { let elements = (0u32..1111).collect::>(); let mut buffer = Vec::::new(); - buffer.extend(std::iter::repeat(0).take(size_of::())); + buffer.extend(std::iter::repeat_n(0, size_of::())); while buffer.len() % ALIGN != 0 { buffer.push(0); } diff --git a/crates/algorithm/src/types.rs b/crates/algorithm/src/types.rs new file mode 100644 index 00000000..ea9784af --- /dev/null +++ b/crates/algorithm/src/types.rs @@ -0,0 +1,84 @@ +use half::f16; +use serde::{Deserialize, Serialize}; +use validator::{Validate, ValidationError}; +use vector::vect::{VectBorrowed, VectOwned}; + +#[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct VchordrqIndexOptions { + #[serde(default = "VchordrqIndexOptions::default_residual_quantization")] + pub residual_quantization: bool, + #[serde(default = "VchordrqIndexOptions::default_rerank_in_table")] + pub rerank_in_table: bool, +} + +impl VchordrqIndexOptions { + fn default_residual_quantization() -> bool { + false + } + fn default_rerank_in_table() -> bool { + false + } +} + +#[derive(Debug, Clone)] +pub enum OwnedVector { + Vecf32(VectOwned), + Vecf16(VectOwned), +} + +#[derive(Debug, Clone, Copy)] +pub enum BorrowedVector<'a> { + Vecf32(VectBorrowed<'a, f32>), + Vecf16(VectBorrowed<'a, f16>), +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum DistanceKind { + L2, + Dot, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum VectorKind { + Vecf32, + Vecf16, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[validate(schema(function = "Self::validate_self"))] +pub struct VectorOptions { + #[validate(range(min = 1, max = 1_048_575))] + #[serde(rename = "dimensions")] + pub dims: u32, + #[serde(rename = "vector")] + pub v: VectorKind, + #[serde(rename = "distance")] + pub d: DistanceKind, +} + +impl VectorOptions { + pub fn validate_self(&self) -> Result<(), ValidationError> { + match (self.v, self.d, self.dims) { + (VectorKind::Vecf32, DistanceKind::L2, 1..65536) => Ok(()), + (VectorKind::Vecf32, DistanceKind::Dot, 1..65536) => Ok(()), + (VectorKind::Vecf16, DistanceKind::L2, 1..65536) => Ok(()), + (VectorKind::Vecf16, DistanceKind::Dot, 1..65536) => Ok(()), + _ => Err(ValidationError::new("not valid vector options")), + } + } +} + +pub struct Structure { + pub means: Vec, + pub children: Vec>, +} + +impl Structure { + pub fn len(&self) -> usize { + self.children.len() + } +} diff --git a/crates/algorithm/src/vectors.rs b/crates/algorithm/src/vectors.rs new file mode 100644 index 00000000..dcc11af0 --- /dev/null +++ b/crates/algorithm/src/vectors.rs @@ -0,0 +1,92 @@ +use crate::operator::*; +use crate::pipe::Pipe; +use crate::tuples::*; +use crate::{Page, PageGuard, RelationRead, RelationWrite, tape}; +use std::num::NonZeroU64; +use vector::VectorOwned; + +pub fn access_1< + O: Operator, + A: Accessor1<::Element, ::Metadata>, +>( + index: impl RelationRead, + mean: IndexPointer, + accessor: A, +) -> A::Output { + let mut cursor = Err(mean); + let mut result = accessor; + while let Err(mean) = cursor.map_err(pointer_to_pair) { + let vector_guard = index.read(mean.0); + let vector_tuple = vector_guard + .get(mean.1) + .expect("data corruption") + .pipe(read_tuple::>); + if vector_tuple.payload().is_some() { + panic!("data corruption"); + } + result.push(vector_tuple.elements()); + cursor = vector_tuple.metadata_or_pointer(); + } + result.finish(cursor.expect("data corruption")) +} + +pub fn access_0< + O: Operator, + A: Accessor1<::Element, ::Metadata>, +>( + index: impl RelationRead, + mean: IndexPointer, + payload: NonZeroU64, + accessor: A, +) -> Option { + let mut cursor = Err(mean); + let mut result = accessor; + while let Err(mean) = cursor.map_err(pointer_to_pair) { + let vector_guard = index.read(mean.0); + let vector_tuple = vector_guard + .get(mean.1)? + .pipe(read_tuple::>); + if vector_tuple.payload().is_none() { + panic!("data corruption"); + } + if vector_tuple.payload() != Some(payload) { + return None; + } + result.push(vector_tuple.elements()); + cursor = vector_tuple.metadata_or_pointer(); + } + Some(result.finish(cursor.ok()?)) +} + +pub fn append( + index: impl RelationWrite, + vectors_first: u32, + vector: ::Borrowed<'_>, + payload: NonZeroU64, +) -> IndexPointer { + fn append(index: impl RelationWrite, first: u32, bytes: &[u8]) -> IndexPointer { + if let Some(mut write) = index.search(bytes.len()) { + let i = write.alloc(bytes).unwrap(); + return pair_to_pointer((write.id(), i)); + } + tape::append(index, first, bytes, true) + } + let (metadata, slices) = O::Vector::vector_split(vector); + let mut chain = Ok(metadata); + for i in (0..slices.len()).rev() { + let bytes = serialize::>(&match chain { + Ok(metadata) => VectorTuple::_0 { + elements: slices[i].to_vec(), + payload: Some(payload), + metadata, + }, + Err(pointer) => VectorTuple::_1 { + elements: slices[i].to_vec(), + payload: Some(payload), + pointer, + }, + }); + chain = Err(append(index.clone(), vectors_first, &bytes)); + } + chain.err().unwrap() +} diff --git a/crates/k_means/Cargo.toml b/crates/k_means/Cargo.toml new file mode 100644 index 00000000..6ba2ccf7 --- /dev/null +++ b/crates/k_means/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "k_means" +version.workspace = true +edition.workspace = true + +[dependencies] +rabitq = { path = "../rabitq" } +simd = { path = "../simd" } + +half.workspace = true +rand.workspace = true +rayon = "1.10.0" + +[lints] +workspace = true diff --git a/src/utils/k_means.rs b/crates/k_means/src/lib.rs similarity index 74% rename from src/utils/k_means.rs rename to crates/k_means/src/lib.rs index b1808c90..ecbea29b 100644 --- a/src/utils/k_means.rs +++ b/crates/k_means/src/lib.rs @@ -1,12 +1,15 @@ -use super::parallelism::{ParallelIterator, Parallelism}; +#![allow(clippy::type_complexity)] + use half::f16; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use simd::Floating; use simd::fast_scan::{any_pack, padding_pack}; -pub fn k_means( - parallelism: &P, +pub fn k_means( + num_threads: usize, + check: impl Fn(), c: usize, dims: usize, samples: &[Vec], @@ -19,22 +22,30 @@ pub fn k_means( if n <= c { quick_centers(c, dims, samples.to_vec(), is_spherical) } else { - let compute = |parallelism: &P, centroids: &[Vec]| { - if n >= 1000 && c >= 1000 { - rabitq_index(parallelism, dims, n, c, samples, centroids) - } else { - flat_index(parallelism, dims, n, c, samples, centroids) - } - }; - let mut lloyd_k_means = - LloydKMeans::new(parallelism, c, dims, samples, is_spherical, compute); - for _ in 0..iterations { - parallelism.check(); - if lloyd_k_means.iterate() { - break; - } - } - lloyd_k_means.finish() + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .build_scoped( + |thread| thread.run(), + move |_| { + let compute = |centroids: &[Vec]| { + if n >= 1000 && c >= 1000 { + rabitq_index(dims, n, c, samples, centroids) + } else { + flat_index(dims, n, c, samples, centroids) + } + }; + let mut lloyd_k_means = + LloydKMeans::new(c, dims, samples, is_spherical, compute); + for _ in 0..iterations { + check(); + if lloyd_k_means.iterate() { + break; + } + } + lloyd_k_means.finish() + }, + ) + .expect("failed to build thread pool") } } @@ -58,10 +69,15 @@ fn quick_centers( ) -> Vec> { let n = samples.len(); assert!(c >= n); - let mut rng = rand::thread_rng(); + if c == 1 && n == 0 { + return vec![vec![0.0; dims]]; + } + let mut rng = rand::rng(); let mut centroids = samples; for _ in n..c { - let r = (0..dims).map(|_| rng.gen_range(-1.0f32..1.0f32)).collect(); + let r = (0..dims) + .map(|_| rng.random_range(-1.0f32..1.0f32)) + .collect(); centroids.push(r); } if is_spherical { @@ -74,8 +90,7 @@ fn quick_centers( centroids } -fn rabitq_index( - parallelism: &P, +fn rabitq_index( dims: usize, n: usize, c: usize, @@ -108,10 +123,10 @@ fn rabitq_index( factor_ppc: [f32; 32], factor_ip: [f32; 32], factor_err: [f32; 32], - elements: Vec<[u64; 2]>, + elements: Vec<[u8; 16]>, } impl Block { - fn code(&self) -> (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[[u64; 2]]) { + fn code(&self) -> (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[[u8; 16]]) { ( &self.dis_u_2, &self.factor_ppc, @@ -131,19 +146,17 @@ fn rabitq_index( elements: padding_pack(chunk.iter().map(|x| rabitq::pack_to_u4(&x.signs))), }); } - parallelism - .rayon_into_par_iter(0..n) + (0..n) + .into_par_iter() .map(|i| { - use distance::Distance; let lut = rabitq::block::preprocess(&samples[i]); - let mut result = (Distance::INFINITY, 0); + let mut result = (f32::INFINITY, 0); for block in 0..c.div_ceil(32) { let lowerbound = rabitq::block::process_lowerbound_l2(&lut, blocks[block].code(), 1.9); for j in block * 32..std::cmp::min(block * 32 + 32, c) { - if lowerbound[j - block * 32] < result.0 { - let dis = - Distance::from_f32(f32::reduce_sum_of_d2(&samples[i], ¢roids[j])); + if lowerbound[j - block * 32].to_f32() < result.0 { + let dis = f32::reduce_sum_of_d2(&samples[i], ¢roids[j]); if dis <= result.0 { result = (dis, j); } @@ -155,16 +168,15 @@ fn rabitq_index( .collect::>() } -fn flat_index( - parallelism: &P, +fn flat_index( _dims: usize, n: usize, c: usize, samples: &[Vec], centroids: &[Vec], ) -> Vec { - parallelism - .rayon_into_par_iter(0..n) + (0..n) + .into_par_iter() .map(|i| { let mut result = (f32::INFINITY, 0); for j in 0..c { @@ -178,8 +190,7 @@ fn flat_index( .collect::>() } -struct LloydKMeans<'a, P, F> { - parallelism: &'a P, +struct LloydKMeans<'a, F> { dims: usize, c: usize, is_spherical: bool, @@ -192,26 +203,19 @@ struct LloydKMeans<'a, P, F> { const DELTA: f32 = f16::EPSILON.to_f32_const(); -impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P, F> { - fn new( - parallelism: &'a P, - c: usize, - dims: usize, - samples: &'a [Vec], - is_spherical: bool, - compute: F, - ) -> Self { +impl<'a, F: Fn(&[Vec]) -> Vec> LloydKMeans<'a, F> { + fn new(c: usize, dims: usize, samples: &'a [Vec], is_spherical: bool, compute: F) -> Self { let n = samples.len(); - let mut rng = StdRng::from_entropy(); + let mut rng = StdRng::from_seed([7; 32]); let mut centroids = Vec::with_capacity(c); for index in rand::seq::index::sample(&mut rng, n, c).into_iter() { centroids.push(samples[index].clone()); } - let assign = parallelism - .rayon_into_par_iter(0..n) + let assign = (0..n) + .into_par_iter() .map(|i| { let mut result = (f32::INFINITY, 0); for j in 0..c { @@ -225,7 +229,6 @@ impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P .collect::>(); Self { - parallelism, dims, c, is_spherical, @@ -244,9 +247,8 @@ impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P let samples = self.samples; let n = samples.len(); - let (sum, mut count) = self - .parallelism - .rayon_into_par_iter(0..n) + let (sum, mut count) = (0..n) + .into_par_iter() .fold( || (vec![vec![f32::zero(); dims]; c], vec![0.0f32; c]), |(mut sum, mut count), i| { @@ -266,9 +268,8 @@ impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P }, ); - let mut centroids = self - .parallelism - .rayon_into_par_iter(0..c) + let mut centroids = (0..c) + .into_par_iter() .map(|i| f32::vector_mul_scalar(&sum[i], 1.0 / count[i])) .collect::>(); @@ -278,7 +279,7 @@ impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P } let mut o = 0; loop { - let alpha = rand.gen_range(0.0..1.0f32); + let alpha = rand.random_range(0.0..1.0f32); let beta = (count[o] - 1.0) / (n - c) as f32; if alpha < beta { break; @@ -293,15 +294,13 @@ impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P } if self.is_spherical { - self.parallelism - .rayon_into_par_iter(&mut centroids) - .for_each(|centroid| { - let l = f32::reduce_sum_of_x2(centroid).sqrt(); - f32::vector_mul_scalar_inplace(centroid, 1.0 / l); - }); + (&mut centroids).into_par_iter().for_each(|centroid| { + let l = f32::reduce_sum_of_x2(centroid).sqrt(); + f32::vector_mul_scalar_inplace(centroid, 1.0 / l); + }); } - let assign = (self.compute)(self.parallelism, ¢roids); + let assign = (self.compute)(¢roids); let result = (0..n).all(|i| assign[i] == self.assign[i]); diff --git a/crates/rabitq/src/block.rs b/crates/rabitq/src/block.rs index 9f26fce1..52d4a82d 100644 --- a/crates/rabitq/src/block.rs +++ b/crates/rabitq/src/block.rs @@ -1,7 +1,7 @@ use distance::Distance; use simd::Floating; -pub fn preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { +pub fn preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<[u8; 16]>) { let dis_v_2 = f32::reduce_sum_of_x2(vector); let (k, b, qvector) = simd::quantize::quantize(vector, 15.0); let qvector_sum = if vector.len() <= 4369 { @@ -13,13 +13,13 @@ pub fn preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { } pub fn process_lowerbound_l2( - lut: &(f32, f32, f32, f32, Vec<[u64; 2]>), + lut: &(f32, f32, f32, f32, Vec<[u8; 16]>), (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( &[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], - &[[u64; 2]], + &[[u8; 16]], ), epsilon: f32, ) -> [Distance; 32] { @@ -36,13 +36,13 @@ pub fn process_lowerbound_l2( } pub fn process_lowerbound_dot( - lut: &(f32, f32, f32, f32, Vec<[u64; 2]>), + lut: &(f32, f32, f32, f32, Vec<[u8; 16]>), (_, factor_ppc, factor_ip, factor_err, t): ( &[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], - &[[u64; 2]], + &[[u8; 16]], ), epsilon: f32, ) -> [Distance; 32] { @@ -56,11 +56,12 @@ pub fn process_lowerbound_dot( }) } -pub fn compress(mut vector: Vec) -> Vec<[u64; 2]> { - let width = vector.len().div_ceil(4); - vector.resize(width * 4, 0); - let mut result = vec![[0u64, 0u64]; width]; - for i in 0..width { +pub fn compress(mut vector: Vec) -> Vec<[u8; 16]> { + let n = vector.len().div_ceil(4); + vector.resize(n * 4, 0); + let mut result = vec![[0u8; 16]; n]; + for i in 0..n { + #[allow(unsafe_code)] unsafe { // this hint is used to skip bound checks std::hint::assert_unchecked(4 * i + 3 < vector.len()); @@ -70,26 +71,22 @@ pub fn compress(mut vector: Vec) -> Vec<[u64; 2]> { let t_2 = vector[4 * i + 2]; let t_3 = vector[4 * i + 3]; result[i] = [ - u64::from_le_bytes([ - 0, - t_0, - t_1, - t_1 + t_0, - t_2, - t_2 + t_0, - t_2 + t_1, - t_2 + t_1 + t_0, - ]), - u64::from_le_bytes([ - t_3, - t_3 + t_0, - t_3 + t_1, - t_3 + t_1 + t_0, - t_3 + t_2, - t_3 + t_2 + t_0, - t_3 + t_2 + t_1, - t_3 + t_2 + t_1 + t_0, - ]), + 0, + t_0, + t_1, + t_1 + t_0, + t_2, + t_2 + t_0, + t_2 + t_1, + t_2 + t_1 + t_0, + t_3, + t_3 + t_0, + t_3 + t_1, + t_3 + t_1 + t_0, + t_3 + t_2, + t_3 + t_2 + t_0, + t_3 + t_2 + t_1, + t_3 + t_2 + t_1 + t_0, ]; } result diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index 1e379a30..a90f4453 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -77,7 +77,7 @@ pub fn code(dims: u32, vector: &[f32]) -> Code { pub fn compute_lut( vector: &[f32], ) -> ( - (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, Vec<[u8; 16]>), (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), ) { use simd::Floating; diff --git a/crates/random_orthogonal_matrix/Cargo.toml b/crates/random_orthogonal_matrix/Cargo.toml index 16c3a39a..4e8f733c 100644 --- a/crates/random_orthogonal_matrix/Cargo.toml +++ b/crates/random_orthogonal_matrix/Cargo.toml @@ -8,8 +8,8 @@ edition.workspace = true nalgebra = "=0.33.0" rand.workspace = true -rand_chacha = "0.3.1" -rand_distr = "0.4.3" +rand_chacha = "0.9.0" +rand_distr = "0.5.0" [lints] workspace = true diff --git a/crates/random_orthogonal_matrix/src/lib.rs b/crates/random_orthogonal_matrix/src/lib.rs index e2711504..a9cbb5c8 100644 --- a/crates/random_orthogonal_matrix/src/lib.rs +++ b/crates/random_orthogonal_matrix/src/lib.rs @@ -30,15 +30,18 @@ fn random_full_rank_matrix(n: usize) -> DMatrix { #[test] fn check_random_orthogonal_matrix() { - assert_eq!(random_orthogonal_matrix(2), vec![ - vec![-0.5424608, -0.8400813], - vec![0.8400813, -0.54246056] - ]); - assert_eq!(random_orthogonal_matrix(3), vec![ - vec![-0.5309615, -0.69094884, -0.49058124], - vec![0.8222731, -0.56002235, -0.10120347], - vec![0.20481002, 0.45712686, -0.86549866] - ]); + assert_eq!( + random_orthogonal_matrix(2), + vec![vec![-0.5424608, -0.8400813], vec![0.8400813, -0.54246056]] + ); + assert_eq!( + random_orthogonal_matrix(3), + vec![ + vec![-0.5309615, -0.69094884, -0.49058124], + vec![0.8222731, -0.56002235, -0.10120347], + vec![0.20481002, 0.45712686, -0.86549866] + ] + ); } pub fn random_orthogonal_matrix(n: usize) -> Vec> { diff --git a/crates/simd/Cargo.toml b/crates/simd/Cargo.toml index 0905fb14..bed14a33 100644 --- a/crates/simd/Cargo.toml +++ b/crates/simd/Cargo.toml @@ -4,15 +4,16 @@ version.workspace = true edition.workspace = true [dependencies] -half.workspace = true -serde.workspace = true simd_macros = { path = "../simd_macros" } +half.workspace = true +zerocopy.workspace = true + [dev-dependencies] rand.workspace = true [build-dependencies] -cc = "1.2.6" +cc = "1.2.13" [lints] workspace = true diff --git a/crates/simd/build.rs b/crates/simd/build.rs index 22ebf935..c7a49d65 100644 --- a/crates/simd/build.rs +++ b/crates/simd/build.rs @@ -8,5 +8,5 @@ fn main() { .flag("-ffp-contract=fast") .flag("-freciprocal-math") .flag("-fno-signed-zeros") - .compile("base_cshim"); + .compile("simd_cshim"); } diff --git a/crates/simd/cshim.c b/crates/simd/cshim.c index 1374bbfe..7c0f044a 100644 --- a/crates/simd/cshim.c +++ b/crates/simd/cshim.c @@ -2,17 +2,18 @@ #error "clang version must be >= 16" #endif -#include -#include - #ifdef __aarch64__ #include #include +#include +#include + +typedef __fp16 f16; +typedef float f32; -__attribute__((target("v8.3a,fp16"))) float -fp16_reduce_sum_of_xy_v8_3a_fp16_unroll(__fp16 *__restrict a, - __fp16 *__restrict b, size_t n) { +__attribute__((target("fp16"))) float +fp16_reduce_sum_of_xy_a2_fp16(f16 *restrict a, f16 *restrict b, size_t n) { float16x8_t xy_0 = vdupq_n_f16(0.0); float16x8_t xy_1 = vdupq_n_f16(0.0); float16x8_t xy_2 = vdupq_n_f16(0.0); @@ -35,8 +36,8 @@ fp16_reduce_sum_of_xy_v8_3a_fp16_unroll(__fp16 *__restrict a, xy_3 = vfmaq_f16(xy_3, x_3, y_3); } if (n > 0) { - __fp16 A[32] = {}; - __fp16 B[32] = {}; + f16 A[32] = {}; + f16 B[32] = {}; for (size_t i = 0; i < n; i += 1) { A[i] = a[i]; B[i] = b[i]; @@ -55,14 +56,13 @@ fp16_reduce_sum_of_xy_v8_3a_fp16_unroll(__fp16 *__restrict a, xy_3 = vfmaq_f16(xy_3, x_3, y_3); } float16x8_t xy = vaddq_f16(vaddq_f16(xy_0, xy_1), vaddq_f16(xy_2, xy_3)); - return vgetq_lane_f16(xy, 0) + vgetq_lane_f16(xy, 1) + vgetq_lane_f16(xy, 2) + - vgetq_lane_f16(xy, 3) + vgetq_lane_f16(xy, 4) + vgetq_lane_f16(xy, 5) + - vgetq_lane_f16(xy, 6) + vgetq_lane_f16(xy, 7); + float32x4_t lo = vcvt_f32_f16(vget_low_f16(xy)); + float32x4_t hi = vcvt_f32_f16(vget_high_f16(xy)); + return vaddvq_f32(lo) + vaddvq_f32(hi); } -__attribute__((target("v8.3a,sve"))) float -fp16_reduce_sum_of_xy_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, - size_t n) { +__attribute__((target("sve"))) float +fp16_reduce_sum_of_xy_a3_512(f16 *restrict a, f16 *restrict b, size_t n) { svfloat16_t xy = svdup_f16(0.0); for (size_t i = 0; i < n; i += svcnth()) { svbool_t mask = svwhilelt_b16(i, n); @@ -73,9 +73,8 @@ fp16_reduce_sum_of_xy_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, return svaddv_f16(svptrue_b16(), xy); } -__attribute__((target("v8.3a,fp16"))) float -fp16_reduce_sum_of_d2_v8_3a_fp16_unroll(__fp16 *__restrict a, - __fp16 *__restrict b, size_t n) { +__attribute__((target("fp16"))) float +fp16_reduce_sum_of_d2_a2_fp16(f16 *restrict a, f16 *restrict b, size_t n) { float16x8_t d2_0 = vdupq_n_f16(0.0); float16x8_t d2_1 = vdupq_n_f16(0.0); float16x8_t d2_2 = vdupq_n_f16(0.0); @@ -102,8 +101,8 @@ fp16_reduce_sum_of_d2_v8_3a_fp16_unroll(__fp16 *__restrict a, d2_3 = vfmaq_f16(d2_3, d_3, d_3); } if (n > 0) { - __fp16 A[32] = {}; - __fp16 B[32] = {}; + f16 A[32] = {}; + f16 B[32] = {}; for (size_t i = 0; i < n; i += 1) { A[i] = a[i]; B[i] = b[i]; @@ -126,14 +125,13 @@ fp16_reduce_sum_of_d2_v8_3a_fp16_unroll(__fp16 *__restrict a, d2_3 = vfmaq_f16(d2_3, d_3, d_3); } float16x8_t d2 = vaddq_f16(vaddq_f16(d2_0, d2_1), vaddq_f16(d2_2, d2_3)); - return vgetq_lane_f16(d2, 0) + vgetq_lane_f16(d2, 1) + vgetq_lane_f16(d2, 2) + - vgetq_lane_f16(d2, 3) + vgetq_lane_f16(d2, 4) + vgetq_lane_f16(d2, 5) + - vgetq_lane_f16(d2, 6) + vgetq_lane_f16(d2, 7); + float32x4_t lo = vcvt_f32_f16(vget_low_f16(d2)); + float32x4_t hi = vcvt_f32_f16(vget_high_f16(d2)); + return vaddvq_f32(lo) + vaddvq_f32(hi); } -__attribute__((target("v8.3a,sve"))) float -fp16_reduce_sum_of_d2_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, - size_t n) { +__attribute__((target("sve"))) float +fp16_reduce_sum_of_d2_a3_512(f16 *restrict a, f16 *restrict b, size_t n) { svfloat16_t d2 = svdup_f16(0.0); for (size_t i = 0; i < n; i += svcnth()) { svbool_t mask = svwhilelt_b16(i, n); @@ -145,8 +143,8 @@ fp16_reduce_sum_of_d2_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, return svaddv_f16(svptrue_b16(), d2); } -__attribute__((target("v8.3a,sve"))) float -fp32_reduce_sum_of_x_v8_3a_sve(float *__restrict this, size_t n) { +__attribute__((target("sve"))) float +fp32_reduce_sum_of_x_a3_256(float *restrict this, size_t n) { svfloat32_t sum = svdup_f32(0.0); for (size_t i = 0; i < n; i += svcntw()) { svbool_t mask = svwhilelt_b32(i, n); @@ -156,8 +154,8 @@ fp32_reduce_sum_of_x_v8_3a_sve(float *__restrict this, size_t n) { return svaddv_f32(svptrue_b32(), sum); } -__attribute__((target("v8.3a,sve"))) float -fp32_reduce_sum_of_abs_x_v8_3a_sve(float *__restrict this, size_t n) { +__attribute__((target("sve"))) float +fp32_reduce_sum_of_abs_x_a3_256(float *restrict this, size_t n) { svfloat32_t sum = svdup_f32(0.0); for (size_t i = 0; i < n; i += svcntw()) { svbool_t mask = svwhilelt_b32(i, n); @@ -167,8 +165,8 @@ fp32_reduce_sum_of_abs_x_v8_3a_sve(float *__restrict this, size_t n) { return svaddv_f32(svptrue_b32(), sum); } -__attribute__((target("v8.3a,sve"))) float -fp32_reduce_sum_of_x2_v8_3a_sve(float *__restrict this, size_t n) { +__attribute__((target("sve"))) float +fp32_reduce_sum_of_x2_a3_256(float *restrict this, size_t n) { svfloat32_t sum = svdup_f32(0.0); for (size_t i = 0; i < n; i += svcntw()) { svbool_t mask = svwhilelt_b32(i, n); @@ -178,9 +176,9 @@ fp32_reduce_sum_of_x2_v8_3a_sve(float *__restrict this, size_t n) { return svaddv_f32(svptrue_b32(), sum); } -__attribute__((target("v8.3a,sve"))) void -fp32_reduce_min_max_of_x_v8_3a_sve(float *__restrict this, size_t n, - float *out_min, float *out_max) { +__attribute__((target("sve"))) void +fp32_reduce_min_max_of_x_a3_256(float *restrict this, size_t n, float *out_min, + float *out_max) { svfloat32_t min = svdup_f32(1.0 / 0.0); svfloat32_t max = svdup_f32(-1.0 / 0.0); for (size_t i = 0; i < n; i += svcntw()) { @@ -193,9 +191,9 @@ fp32_reduce_min_max_of_x_v8_3a_sve(float *__restrict this, size_t n, *out_max = svmaxv_f32(svptrue_b32(), max); } -__attribute__((target("v8.3a,sve"))) float -fp32_reduce_sum_of_xy_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, - size_t n) { +__attribute__((target("sve"))) float +fp32_reduce_sum_of_xy_a3_256(float *restrict lhs, float *restrict rhs, + size_t n) { svfloat32_t sum = svdup_f32(0.0); for (size_t i = 0; i < n; i += svcntw()) { svbool_t mask = svwhilelt_b32(i, n); @@ -206,9 +204,9 @@ fp32_reduce_sum_of_xy_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, return svaddv_f32(svptrue_b32(), sum); } -__attribute__((target("v8.3a,sve"))) float -fp32_reduce_sum_of_d2_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, - size_t n) { +__attribute__((target("sve"))) float +fp32_reduce_sum_of_d2_a3_256(float *restrict lhs, float *restrict rhs, + size_t n) { svfloat32_t sum = svdup_f32(0.0); for (size_t i = 0; i < n; i += svcntw()) { svbool_t mask = svwhilelt_b32(i, n); diff --git a/crates/simd/src/bit.rs b/crates/simd/src/bit.rs index dc417bda..d8321dc4 100644 --- a/crates/simd/src/bit.rs +++ b/crates/simd/src/bit.rs @@ -8,9 +8,9 @@ mod sum_of_and { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512vpopcntdq")] - fn sum_of_and_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + fn sum_of_and_v4_512_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -38,21 +38,24 @@ mod sum_of_and { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn sum_of_and_v4_avx512vpopcntdq_test() { - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { - println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + fn sum_of_and_v4_512_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!( + "test {} ... skipped (v4.512:avx512vpopcntdq)", + module_path!() + ); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); - let specialized = unsafe { sum_of_and_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = sum_of_and_fallback(&lhs, &rhs); + let specialized = unsafe { sum_of_and_v4_512_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512:avx512vpopcntdq", "v4.512", "v3", "v2", "a2")] pub fn sum_of_and(lhs: &[u64], rhs: &[u64]) -> u32 { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -74,9 +77,9 @@ mod sum_of_or { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512vpopcntdq")] - fn sum_of_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + fn sum_of_or_v4_512_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -104,21 +107,24 @@ mod sum_of_or { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn sum_of_or_v4_avx512vpopcntdq_test() { - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { - println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + fn sum_of_or_v4_512_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!( + "test {} ... skipped (v4.512:avx512vpopcntdq)", + module_path!() + ); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); - let specialized = unsafe { sum_of_or_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = sum_of_or_fallback(&lhs, &rhs); + let specialized = unsafe { sum_of_or_v4_512_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512:avx512vpopcntdq", "v4.512", "v3", "v2", "a2")] pub fn sum_of_or(lhs: &[u64], rhs: &[u64]) -> u32 { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -140,9 +146,9 @@ mod sum_of_xor { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512vpopcntdq")] - fn sum_of_xor_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + fn sum_of_xor_v4_512_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -170,21 +176,24 @@ mod sum_of_xor { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn sum_of_xor_v4_avx512vpopcntdq_test() { - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { - println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + fn sum_of_xor_v4_512_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!( + "test {} ... skipped (v4.512:avx512vpopcntdq)", + module_path!() + ); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); - let specialized = unsafe { sum_of_xor_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = sum_of_xor_fallback(&lhs, &rhs); + let specialized = unsafe { sum_of_xor_v4_512_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512:avx512vpopcntdq", "v4.512", "v3", "v2", "a2")] pub fn sum_of_xor(lhs: &[u64], rhs: &[u64]) -> u32 { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -206,9 +215,9 @@ mod sum_of_and_or { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512vpopcntdq")] - fn sum_of_and_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { + fn sum_of_and_or_v4_512_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -242,21 +251,24 @@ mod sum_of_and_or { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn sum_of_xor_v4_avx512vpopcntdq_test() { - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { - println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + fn sum_of_xor_v4_512_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!( + "test {} ... skipped (v4.512:avx512vpopcntdq)", + module_path!() + ); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); - let specialized = unsafe { sum_of_and_or_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = sum_of_and_or_fallback(&lhs, &rhs); + let specialized = unsafe { sum_of_and_or_v4_512_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512:avx512vpopcntdq", "v4.512", "v3", "v2", "a2")] pub fn sum_of_and_or(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -280,9 +292,9 @@ mod sum_of_x { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512vpopcntdq")] - fn sum_of_x_v4_avx512vpopcntdq(this: &[u64]) -> u32 { + fn sum_of_x_v4_512_avx512vpopcntdq(this: &[u64]) -> u32 { unsafe { use std::arch::x86_64::*; let mut and = _mm512_setzero_si512(); @@ -305,20 +317,23 @@ mod sum_of_x { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn sum_of_x_v4_avx512vpopcntdq_test() { - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { - println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + fn sum_of_x_v4_512_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!( + "test {} ... skipped (v4.512:avx512vpopcntdq)", + module_path!() + ); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let this = (0..126).map(|_| rand::random::()).collect::>(); - let specialized = unsafe { sum_of_x_v4_avx512vpopcntdq(&this) }; - let fallback = sum_of_x_fallback(&this); + let specialized = unsafe { sum_of_x_v4_512_avx512vpopcntdq(&this) }; + let fallback = fallback(&this); assert_eq!(specialized, fallback); } } - #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512:avx512vpopcntdq", "v4.512", "v3", "v2", "a2")] pub fn sum_of_x(this: &[u64]) -> u32 { let n = this.len(); let mut and = 0; @@ -335,7 +350,7 @@ pub fn vector_and(lhs: &[u64], rhs: &[u64]) -> Vec { } mod vector_and { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_and(lhs: &[u64], rhs: &[u64]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -358,7 +373,7 @@ pub fn vector_or(lhs: &[u64], rhs: &[u64]) -> Vec { } mod vector_or { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_or(lhs: &[u64], rhs: &[u64]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -381,7 +396,7 @@ pub fn vector_xor(lhs: &[u64], rhs: &[u64]) -> Vec { } mod vector_xor { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_xor(lhs: &[u64], rhs: &[u64]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); diff --git a/crates/simd/src/emulate.rs b/crates/simd/src/emulate.rs index 520b6a1d..293aa746 100644 --- a/crates/simd/src/emulate.rs +++ b/crates/simd/src/emulate.rs @@ -3,7 +3,7 @@ // Instructions. arXiv preprint arXiv:2112.06342. #[inline] #[cfg(target_arch = "x86_64")] -#[crate::target_cpu(enable = "v4")] +#[crate::target_cpu(enable = "v4.512")] pub fn emulate_mm512_2intersect_epi32( a: std::arch::x86_64::__m512i, b: std::arch::x86_64::__m512i, @@ -85,7 +85,7 @@ pub fn emulate_mm_reduce_add_ps(mut x: std::arch::x86_64::__m128) -> f32 { #[inline] #[cfg(target_arch = "x86_64")] -#[crate::target_cpu(enable = "v4")] +#[crate::target_cpu(enable = "v4.512")] pub fn emulate_mm512_reduce_add_epi16(x: std::arch::x86_64::__m512i) -> i16 { unsafe { use std::arch::x86_64::*; @@ -134,7 +134,6 @@ pub fn emulate_mm256_reduce_add_epi32(mut x: std::arch::x86_64::__m256i) -> i32 } } -#[expect(dead_code)] #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v2")] diff --git a/crates/simd/src/f16.rs b/crates/simd/src/f16.rs index ce906187..2fb176be 100644 --- a/crates/simd/src/f16.rs +++ b/crates/simd/src/f16.rs @@ -129,11 +129,9 @@ impl Floating for f16 { } mod reduce_or_of_is_zero_x { - // FIXME: add manually-implemented SIMD version - use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_or_of_is_zero_x(this: &[f16]) -> bool { for &x in this { if x == f16::ZERO { @@ -149,7 +147,7 @@ mod reduce_sum_of_x { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_x(this: &[f16]) -> f32 { let n = this.len(); let mut x = 0.0f32; @@ -165,7 +163,7 @@ mod reduce_sum_of_abs_x { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_abs_x(this: &[f16]) -> f32 { let n = this.len(); let mut x = 0.0f32; @@ -181,7 +179,7 @@ mod reduce_sum_of_x2 { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_x2(this: &[f16]) -> f32 { let n = this.len(); let mut x2 = 0.0f32; @@ -197,7 +195,7 @@ mod reduce_min_max_of_x { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) { let mut min = f32::INFINITY; let mut max = f32::NEG_INFINITY; @@ -215,9 +213,9 @@ mod reduce_sum_of_xy { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512fp16")] - pub fn reduce_sum_of_xy_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + pub fn reduce_sum_of_xy_v4_512_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -245,27 +243,27 @@ mod reduce_sum_of_xy { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn reduce_sum_of_xy_v4_avx512fp16_test() { + fn reduce_sum_of_xy_v4_512_avx512fp16_test() { use rand::Rng; const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { - println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4_512:avx512fp16)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v4_avx512fp16(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_xy_v4_512_avx512fp16(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -276,8 +274,8 @@ mod reduce_sum_of_xy { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - pub fn reduce_sum_of_xy_v4(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + pub fn reduce_sum_of_xy_v4_512(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -308,21 +306,21 @@ mod reduce_sum_of_xy { fn reduce_sum_of_xy_v4_test() { use rand::Rng; const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); - let specialized = unsafe { reduce_sum_of_xy_v4(&lhs, &rhs) }; - let fallback = reduce_sum_of_xy_fallback(&lhs, &rhs); + let specialized = unsafe { reduce_sum_of_xy_v4_512(&lhs, &rhs) }; + let fallback = fallback(&lhs, &rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -372,89 +370,20 @@ mod reduce_sum_of_xy { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); - } - } - } - - #[inline] - #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v2")] - #[target_feature(enable = "f16c")] - #[target_feature(enable = "fma")] - pub fn reduce_sum_of_xy_v2_f16c_fma(lhs: &[f16], rhs: &[f16]) -> f32 { - use crate::emulate::emulate_mm_reduce_add_ps; - assert!(lhs.len() == rhs.len()); - unsafe { - use std::arch::x86_64::*; - let mut n = lhs.len(); - let mut a = lhs.as_ptr(); - let mut b = rhs.as_ptr(); - let mut xy = _mm_setzero_ps(); - while n >= 4 { - let x = _mm_cvtph_ps(_mm_loadu_si128(a.cast())); - let y = _mm_cvtph_ps(_mm_loadu_si128(b.cast())); - a = a.add(4); - b = b.add(4); - n -= 4; - xy = _mm_fmadd_ps(x, y, xy); - } - let mut xy = emulate_mm_reduce_add_ps(xy); - while n > 0 { - let x = a.read().to_f32(); - let y = b.read().to_f32(); - a = a.add(1); - b = b.add(1); - n -= 1; - xy += x * y; - } - xy - } - } - - #[cfg(all(target_arch = "x86_64", test, not(miri)))] - #[test] - fn reduce_sum_of_xy_v2_f16c_fma_test() { - use rand::Rng; - const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v2") - || !crate::is_feature_detected!("f16c") - || !crate::is_feature_detected!("fma") - { - println!("test {} ... skipped (v2:f16c:fma)", module_path!()); - return; - } - let mut rng = rand::thread_rng(); - for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { - let n = 4016; - let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) - .collect::>(); - let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) - .collect::>(); - for z in 3984..4016 { - let lhs = &lhs[..z]; - let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v2_f16c_fma(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -465,49 +394,41 @@ mod reduce_sum_of_xy { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "fp16")] - pub fn reduce_sum_of_xy_v8_3a_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + pub fn reduce_sum_of_xy_a2_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { extern "C" { - fn fp16_reduce_sum_of_xy_v8_3a_fp16_unroll( - a: *const (), - b: *const (), - n: usize, - ) -> f32; + fn fp16_reduce_sum_of_xy_a2_fp16(a: *const (), b: *const (), n: usize) -> f32; } - fp16_reduce_sum_of_xy_v8_3a_fp16_unroll( - lhs.as_ptr().cast(), - rhs.as_ptr().cast(), - lhs.len(), - ) + fp16_reduce_sum_of_xy_a2_fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_xy_v8_3a_fp16_test() { + fn reduce_sum_of_xy_a2_fp16_test() { use rand::Rng; const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("fp16") { - println!("test {} ... skipped (v8.3a:fp16)", module_path!()); + if !crate::is_cpu_detected!("a2") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (a2:fp16)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v8_3a_fp16(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_xy_a2_fp16(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -516,45 +437,43 @@ mod reduce_sum_of_xy { } } - // temporarily disables this for uncertain precision - #[cfg_attr(not(test), expect(dead_code))] #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - pub fn reduce_sum_of_xy_v8_3a_sve(lhs: &[f16], rhs: &[f16]) -> f32 { + pub fn reduce_sum_of_xy_a3_512(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { extern "C" { - fn fp16_reduce_sum_of_xy_v8_3a_sve(a: *const (), b: *const (), n: usize) -> f32; + fn fp16_reduce_sum_of_xy_a3_512(a: *const (), b: *const (), n: usize) -> f32; } - fp16_reduce_sum_of_xy_v8_3a_sve(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) + fp16_reduce_sum_of_xy_a3_512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_xy_v8_3a_sve_test() { + fn reduce_sum_of_xy_a3_512_test() { use rand::Rng; const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.512") { + println!("test {} ... skipped (a3.512)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v8_3a_sve(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_xy_a3_512(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -563,7 +482,7 @@ mod reduce_sum_of_xy { } } - #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", @"v2:f16c:fma", @"v8.3a:fp16")] + #[crate::multiversion(@"v4.512:avx512fp16", @"v4.512", @"v3", @"a3.512", @"a2:fp16")] pub fn reduce_sum_of_xy(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -580,9 +499,9 @@ mod reduce_sum_of_d2 { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] #[target_feature(enable = "avx512fp16")] - pub fn reduce_sum_of_d2_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + pub fn reduce_sum_of_d2_v4_512_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -612,27 +531,27 @@ mod reduce_sum_of_d2 { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] - fn reduce_sum_of_d2_v4_avx512fp16_test() { + fn reduce_sum_of_d2_v4_512_avx512fp16_test() { use rand::Rng; const EPSILON: f32 = 6.0; - if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { - println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + if !crate::is_cpu_detected!("v4.512") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4_512:avx512fp16)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v4_avx512fp16(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_v4_512_avx512fp16(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -643,8 +562,8 @@ mod reduce_sum_of_d2 { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - pub fn reduce_sum_of_d2_v4(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + pub fn reduce_sum_of_d2_v4_512(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -677,24 +596,24 @@ mod reduce_sum_of_d2 { fn reduce_sum_of_d2_v4_test() { use rand::Rng; const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_v4_512(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -747,91 +666,20 @@ mod reduce_sum_of_d2 { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); - } - } - } - - #[inline] - #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v2")] - #[target_feature(enable = "f16c")] - #[target_feature(enable = "fma")] - pub fn reduce_sum_of_d2_v2_f16c_fma(lhs: &[f16], rhs: &[f16]) -> f32 { - use crate::emulate::emulate_mm_reduce_add_ps; - assert!(lhs.len() == rhs.len()); - unsafe { - use std::arch::x86_64::*; - let mut n = lhs.len() as u32; - let mut a = lhs.as_ptr(); - let mut b = rhs.as_ptr(); - let mut d2 = _mm_setzero_ps(); - while n >= 4 { - let x = _mm_cvtph_ps(_mm_loadu_si128(a.cast())); - let y = _mm_cvtph_ps(_mm_loadu_si128(b.cast())); - a = a.add(4); - b = b.add(4); - n -= 4; - let d = _mm_sub_ps(x, y); - d2 = _mm_fmadd_ps(d, d, d2); - } - let mut d2 = emulate_mm_reduce_add_ps(d2); - while n > 0 { - let x = a.read().to_f32(); - let y = b.read().to_f32(); - a = a.add(1); - b = b.add(1); - n -= 1; - let d = x - y; - d2 += d * d; - } - d2 - } - } - - #[cfg(all(target_arch = "x86_64", test, not(miri)))] - #[test] - fn reduce_sum_of_d2_v2_f16c_fma_test() { - use rand::Rng; - const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v2") - || !crate::is_feature_detected!("f16c") - || !crate::is_feature_detected!("fma") - { - println!("test {} ... skipped (v2:f16c:fma)", module_path!()); - return; - } - let mut rng = rand::thread_rng(); - for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { - let n = 4016; - let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) - .collect::>(); - let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) - .collect::>(); - for z in 3984..4016 { - let lhs = &lhs[..z]; - let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v2_f16c_fma(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -842,49 +690,41 @@ mod reduce_sum_of_d2 { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "fp16")] - pub fn reduce_sum_of_d2_v8_3a_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + pub fn reduce_sum_of_d2_a2_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { extern "C" { - fn fp16_reduce_sum_of_d2_v8_3a_fp16_unroll( - a: *const (), - b: *const (), - n: usize, - ) -> f32; + fn fp16_reduce_sum_of_d2_a2_fp16(a: *const (), b: *const (), n: usize) -> f32; } - fp16_reduce_sum_of_d2_v8_3a_fp16_unroll( - lhs.as_ptr().cast(), - rhs.as_ptr().cast(), - lhs.len(), - ) + fp16_reduce_sum_of_d2_a2_fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_d2_v8_3a_fp16_test() { + fn reduce_sum_of_d2_a2_fp16_test() { use rand::Rng; const EPSILON: f32 = 6.0; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("fp16") { - println!("test {} ... skipped (v8.3a:fp16)", module_path!()); + if !crate::is_cpu_detected!("a2") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (a2:fp16)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v8_3a_fp16(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_a2_fp16(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -893,45 +733,43 @@ mod reduce_sum_of_d2 { } } - // temporarily disables this for uncertain precision - #[cfg_attr(not(test), expect(dead_code))] #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - pub fn reduce_sum_of_d2_v8_3a_sve(lhs: &[f16], rhs: &[f16]) -> f32 { + pub fn reduce_sum_of_d2_a3_512(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { extern "C" { - fn fp16_reduce_sum_of_d2_v8_3a_sve(a: *const (), b: *const (), n: usize) -> f32; + fn fp16_reduce_sum_of_d2_a3_512(a: *const (), b: *const (), n: usize) -> f32; } - fp16_reduce_sum_of_d2_v8_3a_sve(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) + fp16_reduce_sum_of_d2_a3_512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_d2_v8_3a_sve_test() { + fn reduce_sum_of_d2_a3_512_test() { use rand::Rng; const EPSILON: f32 = 6.0; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.512") { + println!("test {} ... skipped (a3.512)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); let rhs = (0..n) - .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .map(|_| f16::from_f32(rng.random_range(-1.0..=1.0))) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v8_3a_sve(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_a3_512(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -940,7 +778,7 @@ mod reduce_sum_of_d2 { } } - #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", @"v2:f16c:fma", @"v8.3a:fp16")] + #[crate::multiversion(@"v4.512:avx512fp16", @"v4.512", @"v3", @"a3.512", @"a2:fp16")] pub fn reduce_sum_of_d2(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -959,7 +797,7 @@ mod reduce_sum_of_xy_sparse { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -992,7 +830,7 @@ mod reduce_sum_of_d2_sparse { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -1031,7 +869,7 @@ mod reduce_sum_of_d2_sparse { mod vector_add { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_add(lhs: &[f16], rhs: &[f16]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -1051,7 +889,7 @@ mod vector_add { mod vector_add_inplace { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_add_inplace(lhs: &mut [f16], rhs: &[f16]) { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -1064,7 +902,7 @@ mod vector_add_inplace { mod vector_sub { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_sub(lhs: &[f16], rhs: &[f16]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -1084,7 +922,7 @@ mod vector_sub { mod vector_mul { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_mul(lhs: &[f16], rhs: &[f16]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -1104,7 +942,7 @@ mod vector_mul { mod vector_mul_scalar { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_mul_scalar(lhs: &[f16], rhs: f32) -> Vec { let rhs = f16::from_f32(rhs); let n = lhs.len(); @@ -1124,7 +962,7 @@ mod vector_mul_scalar { mod vector_mul_scalar_inplace { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_mul_scalar_inplace(lhs: &mut [f16], rhs: f32) { let rhs = f16::from_f32(rhs); let n = lhs.len(); @@ -1137,7 +975,7 @@ mod vector_mul_scalar_inplace { mod vector_abs_inplace { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_abs_inplace(this: &mut [f16]) { let n = this.len(); for i in 0..n { @@ -1149,7 +987,7 @@ mod vector_abs_inplace { mod vector_from_f32 { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_from_f32(this: &[f32]) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); @@ -1168,7 +1006,7 @@ mod vector_from_f32 { mod vector_to_f32 { use half::f16; - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_to_f32(this: &[f16]) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); diff --git a/crates/simd/src/f32.rs b/crates/simd/src/f32.rs index 555da5c7..54670deb 100644 --- a/crates/simd/src/f32.rs +++ b/crates/simd/src/f32.rs @@ -119,9 +119,7 @@ impl Floating for f32 { } mod reduce_or_of_is_zero_x { - // FIXME: add manually-implemented SIMD version - - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_or_of_is_zero_x(this: &[f32]) -> bool { for &x in this { if x == 0.0f32 { @@ -135,8 +133,8 @@ mod reduce_or_of_is_zero_x { mod reduce_sum_of_x { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_x_v4(this: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_x_v4_512(this: &[f32]) -> f32 { unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -162,20 +160,20 @@ mod reduce_sum_of_x { fn reduce_sum_of_x_v4_test() { use rand::Rng; const EPSILON: f32 = 0.008; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_v4(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let specialized = unsafe { reduce_sum_of_x_v4_512(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -227,16 +225,16 @@ mod reduce_sum_of_x { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_v3(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -282,16 +280,16 @@ mod reduce_sum_of_x { println!("test {} ... skipped (v2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_v2(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -302,8 +300,8 @@ mod reduce_sum_of_x { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_sum_of_x_v8_3a(this: &[f32]) -> f32 { + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_x_a2(this: &[f32]) -> f32 { unsafe { use std::arch::aarch64::*; let mut n = this.len(); @@ -329,23 +327,23 @@ mod reduce_sum_of_x { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_x_v8_3a_test() { + fn reduce_sum_of_x_a2_test() { use rand::Rng; const EPSILON: f32 = 0.008; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8_3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_v8_3a(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let specialized = unsafe { reduce_sum_of_x_a2(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -356,36 +354,36 @@ mod reduce_sum_of_x { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - fn reduce_sum_of_x_v8_3a_sve(this: &[f32]) -> f32 { + fn reduce_sum_of_x_a3_256(this: &[f32]) -> f32 { unsafe { extern "C" { - fn fp32_reduce_sum_of_x_v8_3a_sve(this: *const f32, n: usize) -> f32; + fn fp32_reduce_sum_of_x_a3_256(this: *const f32, n: usize) -> f32; } - fp32_reduce_sum_of_x_v8_3a_sve(this.as_ptr(), this.len()) + fp32_reduce_sum_of_x_a3_256(this.as_ptr(), this.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_x_v8_3a_sve_test() { + fn reduce_sum_of_x_a3_256_test() { use rand::Rng; const EPSILON: f32 = 0.008; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8_3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.256") { + println!("test {} ... skipped (a3.256)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_v8_3a_sve(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let specialized = unsafe { reduce_sum_of_x_a3_256(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -394,7 +392,7 @@ mod reduce_sum_of_x { } } - #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2", @"a3.256", @"a2")] pub fn reduce_sum_of_x(this: &[f32]) -> f32 { let n = this.len(); let mut sum = 0.0f32; @@ -408,8 +406,8 @@ mod reduce_sum_of_x { mod reduce_sum_of_abs_x { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_abs_x_v4(this: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_abs_x_v4_512(this: &[f32]) -> f32 { unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -437,20 +435,20 @@ mod reduce_sum_of_abs_x { fn reduce_sum_of_abs_x_v4_test() { use rand::Rng; const EPSILON: f32 = 0.008; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_abs_x_v4(this) }; - let fallback = reduce_sum_of_abs_x_fallback(this); + let specialized = unsafe { reduce_sum_of_abs_x_v4_512(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -506,16 +504,16 @@ mod reduce_sum_of_abs_x { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_abs_x_v3(this) }; - let fallback = reduce_sum_of_abs_x_fallback(this); + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -564,16 +562,16 @@ mod reduce_sum_of_abs_x { println!("test {} ... skipped (v2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_abs_x_v2(this) }; - let fallback = reduce_sum_of_abs_x_fallback(this); + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -584,8 +582,8 @@ mod reduce_sum_of_abs_x { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_sum_of_abs_x_v8_3a(this: &[f32]) -> f32 { + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_abs_x_a2(this: &[f32]) -> f32 { unsafe { use std::arch::aarch64::*; let mut n = this.len(); @@ -613,23 +611,23 @@ mod reduce_sum_of_abs_x { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_abs_x_v8_3a_test() { + fn reduce_sum_of_abs_x_a2_test() { use rand::Rng; const EPSILON: f32 = 0.008; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_abs_x_v8_3a(this) }; - let fallback = reduce_sum_of_abs_x_fallback(this); + let specialized = unsafe { reduce_sum_of_abs_x_a2(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -640,36 +638,36 @@ mod reduce_sum_of_abs_x { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - fn reduce_sum_of_abs_x_v8_3a_sve(this: &[f32]) -> f32 { + fn reduce_sum_of_abs_x_a3_256(this: &[f32]) -> f32 { unsafe { extern "C" { - fn fp32_reduce_sum_of_abs_x_v8_3a_sve(this: *const f32, n: usize) -> f32; + fn fp32_reduce_sum_of_abs_x_a3_256(this: *const f32, n: usize) -> f32; } - fp32_reduce_sum_of_abs_x_v8_3a_sve(this.as_ptr(), this.len()) + fp32_reduce_sum_of_abs_x_a3_256(this.as_ptr(), this.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_abs_x_v8_3a_sve_test() { + fn reduce_sum_of_abs_x_a3_256_test() { use rand::Rng; const EPSILON: f32 = 0.008; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.256") { + println!("test {} ... skipped (a3.256)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_abs_x_v8_3a_sve(this) }; - let fallback = reduce_sum_of_abs_x_fallback(this); + let specialized = unsafe { reduce_sum_of_abs_x_a3_256(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -678,7 +676,7 @@ mod reduce_sum_of_abs_x { } } - #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2", @"a3.256", @"a2")] pub fn reduce_sum_of_abs_x(this: &[f32]) -> f32 { let n = this.len(); let mut sum = 0.0f32; @@ -692,8 +690,8 @@ mod reduce_sum_of_abs_x { mod reduce_sum_of_x2 { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_x2_v4(this: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_x2_v4_512(this: &[f32]) -> f32 { unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -719,20 +717,20 @@ mod reduce_sum_of_x2 { fn reduce_sum_of_x2_v4_test() { use rand::Rng; const EPSILON: f32 = 0.006; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x2_v4(this) }; - let fallback = reduce_sum_of_x2_fallback(this); + let specialized = unsafe { reduce_sum_of_x2_v4_512(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -784,16 +782,16 @@ mod reduce_sum_of_x2 { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x2_v3(this) }; - let fallback = reduce_sum_of_x2_fallback(this); + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -840,16 +838,16 @@ mod reduce_sum_of_x2 { println!("test {} ... skipped (v2:fma)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x2_v2_fma(this) }; - let fallback = reduce_sum_of_x2_fallback(this); + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -860,8 +858,8 @@ mod reduce_sum_of_x2 { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_sum_of_x2_v8_3a(this: &[f32]) -> f32 { + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_x2_a2(this: &[f32]) -> f32 { unsafe { use std::arch::aarch64::*; let mut n = this.len(); @@ -887,23 +885,23 @@ mod reduce_sum_of_x2 { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_x2_v8_3a_test() { + fn reduce_sum_of_x2_a2_test() { use rand::Rng; const EPSILON: f32 = 0.006; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x2_v8_3a(this) }; - let fallback = reduce_sum_of_x2_fallback(this); + let specialized = unsafe { reduce_sum_of_x2_a2(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -914,36 +912,36 @@ mod reduce_sum_of_x2 { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - fn reduce_sum_of_x2_v8_3a_sve(this: &[f32]) -> f32 { + fn reduce_sum_of_x2_a3_256(this: &[f32]) -> f32 { unsafe { extern "C" { - fn fp32_reduce_sum_of_x2_v8_3a_sve(this: *const f32, n: usize) -> f32; + fn fp32_reduce_sum_of_x2_a3_256(this: *const f32, n: usize) -> f32; } - fp32_reduce_sum_of_x2_v8_3a_sve(this.as_ptr(), this.len()) + fp32_reduce_sum_of_x2_a3_256(this.as_ptr(), this.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_x2_v8_3a_sve_test() { + fn reduce_sum_of_x2_a3_256_test() { use rand::Rng; const EPSILON: f32 = 0.006; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.256") { + println!("test {} ... skipped (a3.256)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let this = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x2_v8_3a_sve(this) }; - let fallback = reduce_sum_of_x2_fallback(this); + let specialized = unsafe { reduce_sum_of_x2_a3_256(this) }; + let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -952,7 +950,7 @@ mod reduce_sum_of_x2 { } } - #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2:fma", @"a3.256", @"a2")] pub fn reduce_sum_of_x2(this: &[f32]) -> f32 { let n = this.len(); let mut x2 = 0.0f32; @@ -969,8 +967,8 @@ mod reduce_min_max_of_x { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_min_max_of_x_v4(this: &[f32]) -> (f32, f32) { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_min_max_of_x_v4_512(this: &[f32]) -> (f32, f32) { unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -1000,20 +998,20 @@ mod reduce_min_max_of_x { #[test] fn reduce_min_max_of_x_v4_test() { use rand::Rng; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 200; let x = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 50..200 { let x = &x[..z]; - let specialized = unsafe { reduce_min_max_of_x_v4(x) }; - let fallback = reduce_min_max_of_x_fallback(x); + let specialized = unsafe { reduce_min_max_of_x_v4_512(x) }; + let fallback = fallback(x); assert_eq!(specialized.0, fallback.0); assert_eq!(specialized.1, fallback.1); } @@ -1024,8 +1022,7 @@ mod reduce_min_max_of_x { #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v3")] fn reduce_min_max_of_x_v3(this: &[f32]) -> (f32, f32) { - use crate::emulate::emulate_mm256_reduce_max_ps; - use crate::emulate::emulate_mm256_reduce_min_ps; + use crate::emulate::{emulate_mm256_reduce_max_ps, emulate_mm256_reduce_min_ps}; unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -1061,16 +1058,16 @@ mod reduce_min_max_of_x { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 200; let x = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 50..200 { let x = &x[..z]; let specialized = unsafe { reduce_min_max_of_x_v3(x) }; - let fallback = reduce_min_max_of_x_fallback(x); + let fallback = fallback(x); assert_eq!(specialized.0, fallback.0,); assert_eq!(specialized.1, fallback.1,); } @@ -1081,8 +1078,7 @@ mod reduce_min_max_of_x { #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v2")] fn reduce_min_max_of_x_v2(this: &[f32]) -> (f32, f32) { - use crate::emulate::emulate_mm_reduce_max_ps; - use crate::emulate::emulate_mm_reduce_min_ps; + use crate::emulate::{emulate_mm_reduce_max_ps, emulate_mm_reduce_min_ps}; unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -1118,16 +1114,16 @@ mod reduce_min_max_of_x { println!("test {} ... skipped (v2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 200; let x = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 50..200 { let x = &x[..z]; let specialized = unsafe { reduce_min_max_of_x_v2(x) }; - let fallback = reduce_min_max_of_x_fallback(x); + let fallback = fallback(x); assert_eq!(specialized.0, fallback.0,); assert_eq!(specialized.1, fallback.1,); } @@ -1136,8 +1132,8 @@ mod reduce_min_max_of_x { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_min_max_of_x_v8_3a(this: &[f32]) -> (f32, f32) { + #[crate::target_cpu(enable = "a2")] + fn reduce_min_max_of_x_a2(this: &[f32]) -> (f32, f32) { unsafe { use std::arch::aarch64::*; let mut n = this.len(); @@ -1167,22 +1163,22 @@ mod reduce_min_max_of_x { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_min_max_of_x_v8_3a_test() { + fn reduce_min_max_of_x_a2_test() { use rand::Rng; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 200; let x = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 50..200 { let x = &x[..z]; - let specialized = unsafe { reduce_min_max_of_x_v8_3a(x) }; - let fallback = reduce_min_max_of_x_fallback(x); + let specialized = unsafe { reduce_min_max_of_x_a2(x) }; + let fallback = fallback(x); assert_eq!(specialized.0, fallback.0,); assert_eq!(specialized.1, fallback.1,); } @@ -1191,50 +1187,50 @@ mod reduce_min_max_of_x { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - fn reduce_min_max_of_x_v8_3a_sve(this: &[f32]) -> (f32, f32) { + fn reduce_min_max_of_x_a3_256(this: &[f32]) -> (f32, f32) { let mut min = f32::INFINITY; let mut max = -f32::INFINITY; unsafe { extern "C" { - fn fp32_reduce_min_max_of_x_v8_3a_sve( + fn fp32_reduce_min_max_of_x_a3_256( this: *const f32, n: usize, out_min: &mut f32, out_max: &mut f32, ); } - fp32_reduce_min_max_of_x_v8_3a_sve(this.as_ptr(), this.len(), &mut min, &mut max); + fp32_reduce_min_max_of_x_a3_256(this.as_ptr(), this.len(), &mut min, &mut max); } (min, max) } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_min_max_of_x_v8_3a_sve_test() { + fn reduce_min_max_of_x_a3_256_test() { use rand::Rng; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.256") { + println!("test {} ... skipped (a3.256)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 200; let x = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 50..200 { let x = &x[..z]; - let specialized = unsafe { reduce_min_max_of_x_v8_3a_sve(x) }; - let fallback = reduce_min_max_of_x_fallback(x); + let specialized = unsafe { reduce_min_max_of_x_a3_256(x) }; + let fallback = fallback(x); assert_eq!(specialized.0, fallback.0,); assert_eq!(specialized.1, fallback.1,); } } } - #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2", @"a3.256", @"a2")] pub fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) { let mut min = f32::INFINITY; let mut max = f32::NEG_INFINITY; @@ -1250,8 +1246,8 @@ mod reduce_min_max_of_x { mod reduce_sum_of_xy { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_xy_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_xy_v4_512(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -1282,24 +1278,24 @@ mod reduce_sum_of_xy { fn reduce_sum_of_xy_v4_test() { use rand::Rng; const EPSILON: f32 = 0.004; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v4(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_xy_v4_512(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1359,20 +1355,20 @@ mod reduce_sum_of_xy { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1425,20 +1421,20 @@ mod reduce_sum_of_xy { println!("test {} ... skipped (v2:fma)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_xy_v2_fma(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1449,8 +1445,8 @@ mod reduce_sum_of_xy { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_sum_of_xy_v8_3a(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_xy_a2(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::aarch64::*; @@ -1482,27 +1478,27 @@ mod reduce_sum_of_xy { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_xy_v8_3a_test() { + fn reduce_sum_of_xy_a2_test() { use rand::Rng; const EPSILON: f32 = 0.004; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v8_3a(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_xy_a2(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1513,41 +1509,41 @@ mod reduce_sum_of_xy { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - fn reduce_sum_of_xy_v8_3a_sve(lhs: &[f32], rhs: &[f32]) -> f32 { + fn reduce_sum_of_xy_a3_256(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { extern "C" { - fn fp32_reduce_sum_of_xy_v8_3a_sve(a: *const f32, b: *const f32, n: usize) -> f32; + fn fp32_reduce_sum_of_xy_a3_256(a: *const f32, b: *const f32, n: usize) -> f32; } - fp32_reduce_sum_of_xy_v8_3a_sve(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) + fp32_reduce_sum_of_xy_a3_256(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_xy_v8_3a_sve_test() { + fn reduce_sum_of_xy_a3_256_test() { use rand::Rng; const EPSILON: f32 = 0.004; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.256") { + println!("test {} ... skipped (a3.256)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v8_3a_sve(lhs, rhs) }; - let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_xy_a3_256(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1556,7 +1552,7 @@ mod reduce_sum_of_xy { } } - #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2:fma", @"a3.256", @"a2")] pub fn reduce_sum_of_xy(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -1571,8 +1567,8 @@ mod reduce_sum_of_xy { mod reduce_sum_of_d2 { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_d2_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_d2_v4_512(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -1605,24 +1601,24 @@ mod reduce_sum_of_d2 { fn reduce_sum_of_d2_v4_test() { use rand::Rng; const EPSILON: f32 = 0.02; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_v4_512(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1685,20 +1681,20 @@ mod reduce_sum_of_d2 { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1753,20 +1749,20 @@ mod reduce_sum_of_d2 { println!("test {} ... skipped (v2:fma)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_d2_v2_fma(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1777,8 +1773,8 @@ mod reduce_sum_of_d2 { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_sum_of_d2_v8_3a(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_d2_a2(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::aarch64::*; @@ -1812,27 +1808,27 @@ mod reduce_sum_of_d2 { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_d2_v8_3a_test() { + fn reduce_sum_of_d2_a2_test() { use rand::Rng; const EPSILON: f32 = 0.02; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v8_3a(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_a2(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1843,41 +1839,41 @@ mod reduce_sum_of_d2 { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] + #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "sve")] - fn reduce_sum_of_d2_v8_3a_sve(lhs: &[f32], rhs: &[f32]) -> f32 { + fn reduce_sum_of_d2_a3_256(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { extern "C" { - fn fp32_reduce_sum_of_d2_v8_3a_sve(a: *const f32, b: *const f32, n: usize) -> f32; + fn fp32_reduce_sum_of_d2_a3_256(a: *const f32, b: *const f32, n: usize) -> f32; } - fp32_reduce_sum_of_d2_v8_3a_sve(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) + fp32_reduce_sum_of_d2_a3_256(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) } } #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_d2_v8_3a_sve_test() { + fn reduce_sum_of_d2_a3_256_test() { use rand::Rng; const EPSILON: f32 = 0.02; - if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { - println!("test {} ... skipped (v8.3a:sve)", module_path!()); + if !crate::is_cpu_detected!("a3.256") { + println!("test {} ... skipped (a3.256)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; let lhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rhs = (0..n) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v8_3a_sve(lhs, rhs) }; - let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + let specialized = unsafe { reduce_sum_of_d2_a3_256(lhs, rhs) }; + let fallback = fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1886,7 +1882,7 @@ mod reduce_sum_of_d2 { } } - #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2:fma", @"a3.256", @"a2")] pub fn reduce_sum_of_d2(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -1902,8 +1898,8 @@ mod reduce_sum_of_d2 { mod reduce_sum_of_xy_sparse { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_xy_sparse_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_xy_sparse_v4_512(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { use crate::emulate::emulate_mm512_2intersect_epi32; assert_eq!(li.len(), lv.len()); assert_eq!(ri.len(), rv.len()); @@ -1951,11 +1947,11 @@ mod reduce_sum_of_xy_sparse { fn reduce_sum_of_xy_sparse_v4_test() { use rand::Rng; const EPSILON: f32 = 0.000001; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); pub fn sample_u32_sorted( rng: &mut (impl Rng + ?Sized), length: u32, @@ -1972,15 +1968,15 @@ mod reduce_sum_of_xy_sparse { let lm = 300; let lidx = sample_u32_sorted(&mut rng, 10000, lm); let lval = (0..lm) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rm = 350; let ridx = sample_u32_sorted(&mut rng, 10000, rm); let rval = (0..rm) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); - let specialized = unsafe { reduce_sum_of_xy_sparse_v4(&lidx, &lval, &ridx, &rval) }; - let fallback = reduce_sum_of_xy_sparse_fallback(&lidx, &lval, &ridx, &rval); + let specialized = unsafe { reduce_sum_of_xy_sparse_v4_512(&lidx, &lval, &ridx, &rval) }; + let fallback = fallback(&lidx, &lval, &ridx, &rval); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1988,7 +1984,7 @@ mod reduce_sum_of_xy_sparse { } } - #[crate::multiversion(@"v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -2018,8 +2014,8 @@ mod reduce_sum_of_xy_sparse { mod reduce_sum_of_d2_sparse { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_d2_sparse_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_d2_sparse_v4_512(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { use crate::emulate::emulate_mm512_2intersect_epi32; assert_eq!(li.len(), lv.len()); assert_eq!(ri.len(), rv.len()); @@ -2101,11 +2097,11 @@ mod reduce_sum_of_d2_sparse { fn reduce_sum_of_d2_sparse_v4_test() { use rand::Rng; const EPSILON: f32 = 0.0004; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); pub fn sample_u32_sorted( rng: &mut (impl Rng + ?Sized), length: u32, @@ -2122,15 +2118,15 @@ mod reduce_sum_of_d2_sparse { let lm = 300; let lidx = sample_u32_sorted(&mut rng, 10000, lm); let lval = (0..lm) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); let rm = 350; let ridx = sample_u32_sorted(&mut rng, 10000, rm); let rval = (0..rm) - .map(|_| rng.gen_range(-1.0..=1.0)) + .map(|_| rng.random_range(-1.0..=1.0)) .collect::>(); - let specialized = unsafe { reduce_sum_of_d2_sparse_v4(&lidx, &lval, &ridx, &rval) }; - let fallback = reduce_sum_of_d2_sparse_fallback(&lidx, &lval, &ridx, &rval); + let specialized = unsafe { reduce_sum_of_d2_sparse_v4_512(&lidx, &lval, &ridx, &rval) }; + let fallback = fallback(&lidx, &lval, &ridx, &rval); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -2138,7 +2134,7 @@ mod reduce_sum_of_d2_sparse { } } - #[crate::multiversion(@"v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -2175,7 +2171,7 @@ mod reduce_sum_of_d2_sparse { } mod vector_add { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_add(lhs: &[f32], rhs: &[f32]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -2193,7 +2189,7 @@ mod vector_add { } mod vector_add_inplace { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_add_inplace(lhs: &mut [f32], rhs: &[f32]) { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -2204,7 +2200,7 @@ mod vector_add_inplace { } mod vector_sub { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_sub(lhs: &[f32], rhs: &[f32]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -2222,7 +2218,7 @@ mod vector_sub { } mod vector_mul { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_mul(lhs: &[f32], rhs: &[f32]) -> Vec { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -2240,7 +2236,7 @@ mod vector_mul { } mod vector_mul_scalar { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_mul_scalar(lhs: &[f32], rhs: f32) -> Vec { let n = lhs.len(); let mut r = Vec::::with_capacity(n); @@ -2257,7 +2253,7 @@ mod vector_mul_scalar { } mod vector_mul_scalar_inplace { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_mul_scalar_inplace(lhs: &mut [f32], rhs: f32) { let n = lhs.len(); for i in 0..n { @@ -2267,7 +2263,7 @@ mod vector_mul_scalar_inplace { } mod vector_abs_inplace { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn vector_abs_inplace(this: &mut [f32]) { let n = this.len(); for i in 0..n { diff --git a/crates/simd/src/fast_scan/mod.rs b/crates/simd/src/fast_scan/mod.rs index 84ad9901..a99c227f 100644 --- a/crates/simd/src/fast_scan/mod.rs +++ b/crates/simd/src/fast_scan/mod.rs @@ -32,7 +32,7 @@ bits 4..7 | code (n-1),vector 16 | code (n-1),vector 24 | ... | code (n-1),vecto */ -pub fn pack(x: [&[u8]; 32]) -> Vec<[u64; 2]> { +pub fn pack(x: [&[u8]; 32]) -> Vec<[u8; 16]> { let n = { let l = x.each_ref().map(|i| i.len()); for i in 1..32 { @@ -43,74 +43,68 @@ pub fn pack(x: [&[u8]; 32]) -> Vec<[u64; 2]> { let mut result = Vec::with_capacity(n); for i in 0..n { result.push([ - u64::from_le_bytes([ - x[0][i] | (x[16][i] << 4), - x[8][i] | (x[24][i] << 4), - x[1][i] | (x[17][i] << 4), - x[9][i] | (x[25][i] << 4), - x[2][i] | (x[18][i] << 4), - x[10][i] | (x[26][i] << 4), - x[3][i] | (x[19][i] << 4), - x[11][i] | (x[27][i] << 4), - ]), - u64::from_le_bytes([ - x[4][i] | (x[20][i] << 4), - x[12][i] | (x[28][i] << 4), - x[5][i] | (x[21][i] << 4), - x[13][i] | (x[29][i] << 4), - x[6][i] | (x[22][i] << 4), - x[14][i] | (x[30][i] << 4), - x[7][i] | (x[23][i] << 4), - x[15][i] | (x[31][i] << 4), - ]), + x[0][i] | (x[16][i] << 4), + x[8][i] | (x[24][i] << 4), + x[1][i] | (x[17][i] << 4), + x[9][i] | (x[25][i] << 4), + x[2][i] | (x[18][i] << 4), + x[10][i] | (x[26][i] << 4), + x[3][i] | (x[19][i] << 4), + x[11][i] | (x[27][i] << 4), + x[4][i] | (x[20][i] << 4), + x[12][i] | (x[28][i] << 4), + x[5][i] | (x[21][i] << 4), + x[13][i] | (x[29][i] << 4), + x[6][i] | (x[22][i] << 4), + x[14][i] | (x[30][i] << 4), + x[7][i] | (x[23][i] << 4), + x[15][i] | (x[31][i] << 4), ]); } result } -pub fn unpack(x: &[[u64; 2]]) -> [Vec; 32] { +pub fn unpack(x: &[[u8; 16]]) -> [Vec; 32] { let n = x.len(); let mut result = std::array::from_fn(|_| Vec::with_capacity(n)); for i in 0..n { - let a = x[i][0].to_le_bytes(); - let b = x[i][1].to_le_bytes(); - result[0].push(a[0] & 0xf); - result[1].push(a[2] & 0xf); - result[2].push(a[4] & 0xf); - result[3].push(a[6] & 0xf); - result[4].push(b[0] & 0xf); - result[5].push(b[2] & 0xf); - result[6].push(b[4] & 0xf); - result[7].push(b[6] & 0xf); - result[8].push(a[1] & 0xf); - result[9].push(a[3] & 0xf); - result[10].push(a[5] & 0xf); - result[11].push(a[7] & 0xf); - result[12].push(b[1] & 0xf); - result[13].push(b[3] & 0xf); - result[14].push(b[5] & 0xf); - result[15].push(b[7] & 0xf); - result[16].push(a[0] >> 4); - result[17].push(a[2] >> 4); - result[18].push(a[4] >> 4); - result[19].push(a[6] >> 4); - result[20].push(b[0] >> 4); - result[21].push(b[2] >> 4); - result[22].push(b[4] >> 4); - result[23].push(b[6] >> 4); - result[24].push(a[1] >> 4); - result[25].push(a[3] >> 4); - result[26].push(a[5] >> 4); - result[27].push(a[7] >> 4); - result[28].push(b[1] >> 4); - result[29].push(b[3] >> 4); - result[30].push(b[5] >> 4); - result[31].push(b[7] >> 4); + result[0].push(x[i][0] & 0xf); + result[1].push(x[i][2] & 0xf); + result[2].push(x[i][4] & 0xf); + result[3].push(x[i][6] & 0xf); + result[4].push(x[i][8] & 0xf); + result[5].push(x[i][10] & 0xf); + result[6].push(x[i][12] & 0xf); + result[7].push(x[i][14] & 0xf); + result[8].push(x[i][1] & 0xf); + result[9].push(x[i][3] & 0xf); + result[10].push(x[i][5] & 0xf); + result[11].push(x[i][7] & 0xf); + result[12].push(x[i][9] & 0xf); + result[13].push(x[i][11] & 0xf); + result[14].push(x[i][13] & 0xf); + result[15].push(x[i][15] & 0xf); + result[16].push(x[i][0] >> 4); + result[17].push(x[i][2] >> 4); + result[18].push(x[i][4] >> 4); + result[19].push(x[i][6] >> 4); + result[20].push(x[i][8] >> 4); + result[21].push(x[i][10] >> 4); + result[22].push(x[i][12] >> 4); + result[23].push(x[i][14] >> 4); + result[24].push(x[i][1] >> 4); + result[25].push(x[i][3] >> 4); + result[26].push(x[i][5] >> 4); + result[27].push(x[i][7] >> 4); + result[28].push(x[i][9] >> 4); + result[29].push(x[i][11] >> 4); + result[30].push(x[i][13] >> 4); + result[31].push(x[i][15] >> 4); } result } -pub fn padding_pack(x: impl IntoIterator>) -> Vec<[u64; 2]> { +pub fn padding_pack(x: impl IntoIterator>) -> Vec<[u8; 16]> { let x = x.into_iter().collect::>(); let x = x.iter().map(|x| x.as_ref()).collect::>(); if x.is_empty() || x.len() > 32 { @@ -131,8 +125,8 @@ pub fn any_pack(mut x: impl Iterator) -> [T; 32] { mod fast_scan { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn fast_scan_v4(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { + #[crate::target_cpu(enable = "v4.512")] + fn fast_scan_v4_512(code: &[[u8; 16]], lut: &[[u8; 16]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(code.len(), lut.len()); let n = code.len(); @@ -141,7 +135,7 @@ mod fast_scan { use std::arch::x86_64::*; #[inline] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { unsafe { let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); @@ -151,7 +145,7 @@ mod fast_scan { } #[inline] - #[crate::target_cpu(enable = "v4")] + #[crate::target_cpu(enable = "v4.512")] fn combine4x2(x0x1x2x3: __m512i, y0y1y2y3: __m512i) -> __m256i { unsafe { let x0x1 = _mm512_castsi512_si256(x0x1x2x3); @@ -171,11 +165,11 @@ mod fast_scan { let mut i = 0_usize; while i + 4 <= n { - let c = _mm512_loadu_si512(code.as_ptr().add(i).cast()); + let code = _mm512_loadu_si512(code.as_ptr().add(i).cast()); let mask = _mm512_set1_epi8(0xf); - let clo = _mm512_and_si512(c, mask); - let chi = _mm512_and_si512(_mm512_srli_epi16(c, 4), mask); + let clo = _mm512_and_si512(code, mask); + let chi = _mm512_and_si512(_mm512_srli_epi16(code, 4), mask); let lut = _mm512_loadu_si512(lut.as_ptr().add(i).cast()); let res_lo = _mm512_shuffle_epi8(lut, clo); @@ -188,11 +182,11 @@ mod fast_scan { i += 4; } if i + 2 <= n { - let c = _mm256_loadu_si256(code.as_ptr().add(i).cast()); + let code = _mm256_loadu_si256(code.as_ptr().add(i).cast()); let mask = _mm256_set1_epi8(0xf); - let clo = _mm256_and_si256(c, mask); - let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); + let clo = _mm256_and_si256(code, mask); + let chi = _mm256_and_si256(_mm256_srli_epi16(code, 4), mask); let lut = _mm256_loadu_si256(lut.as_ptr().add(i).cast()); let res_lo = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, clo)); @@ -205,11 +199,11 @@ mod fast_scan { i += 2; } if i < n { - let c = _mm_loadu_si128(code.as_ptr().add(i).cast()); + let code = _mm_loadu_si128(code.as_ptr().add(i).cast()); let mask = _mm_set1_epi8(0xf); - let clo = _mm_and_si128(c, mask); - let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + let clo = _mm_and_si128(code, mask); + let chi = _mm_and_si128(_mm_srli_epi16(code, 4), mask); let lut = _mm_loadu_si128(lut.as_ptr().add(i).cast()); let res_lo = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, clo)); @@ -244,20 +238,20 @@ mod fast_scan { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] fn fast_scan_v4_test() { - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { for n in 90..110 { let code = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); let lut = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); unsafe { - assert_eq!(fast_scan_v4(&code, &lut), fast_scan_fallback(&code, &lut)); + assert_eq!(fast_scan_v4_512(&code, &lut), fallback(&code, &lut)); } } } @@ -266,7 +260,7 @@ mod fast_scan { #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v3")] - fn fast_scan_v3(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { + fn fast_scan_v3(code: &[[u8; 16]], lut: &[[u8; 16]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(code.len(), lut.len()); let n = code.len(); @@ -291,11 +285,11 @@ mod fast_scan { let mut i = 0_usize; while i + 2 <= n { - let c = _mm256_loadu_si256(code.as_ptr().add(i).cast()); + let code = _mm256_loadu_si256(code.as_ptr().add(i).cast()); let mask = _mm256_set1_epi8(0xf); - let clo = _mm256_and_si256(c, mask); - let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); + let clo = _mm256_and_si256(code, mask); + let chi = _mm256_and_si256(_mm256_srli_epi16(code, 4), mask); let lut = _mm256_loadu_si256(lut.as_ptr().add(i).cast()); let res_lo = _mm256_shuffle_epi8(lut, clo); @@ -308,11 +302,11 @@ mod fast_scan { i += 2; } if i < n { - let c = _mm_loadu_si128(code.as_ptr().add(i).cast()); + let code = _mm_loadu_si128(code.as_ptr().add(i).cast()); let mask = _mm_set1_epi8(0xf); - let clo = _mm_and_si128(c, mask); - let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + let clo = _mm_and_si128(code, mask); + let chi = _mm_and_si128(_mm_srli_epi16(code, 4), mask); let lut = _mm_loadu_si128(lut.as_ptr().add(i).cast()); let res_lo = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, clo)); @@ -354,13 +348,13 @@ mod fast_scan { for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { for n in 90..110 { let code = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); let lut = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); unsafe { - assert_eq!(fast_scan_v3(&code, &lut), fast_scan_fallback(&code, &lut)); + assert_eq!(fast_scan_v3(&code, &lut), fallback(&code, &lut)); } } } @@ -368,7 +362,7 @@ mod fast_scan { #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v2")] - fn fast_scan_v2(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { + fn fast_scan_v2(code: &[[u8; 16]], lut: &[[u8; 16]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(code.len(), lut.len()); let n = code.len(); @@ -383,11 +377,11 @@ mod fast_scan { let mut i = 0_usize; while i < n { - let c = _mm_loadu_si128(code.as_ptr().add(i).cast()); + let code = _mm_loadu_si128(code.as_ptr().add(i).cast()); let mask = _mm_set1_epi8(0xf); - let clo = _mm_and_si128(c, mask); - let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + let clo = _mm_and_si128(code, mask); + let chi = _mm_and_si128(_mm_srli_epi16(code, 4), mask); let lut = _mm_loadu_si128(lut.as_ptr().add(i).cast()); let res_lo = _mm_shuffle_epi8(lut, clo); @@ -425,21 +419,21 @@ mod fast_scan { for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { for n in 90..110 { let code = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); let lut = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); unsafe { - assert_eq!(fast_scan_v2(&code, &lut), fast_scan_fallback(&code, &lut)); + assert_eq!(fast_scan_v2(&code, &lut), fallback(&code, &lut)); } } } } #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn fast_scan_v8_3a(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { + #[crate::target_cpu(enable = "a2")] + fn fast_scan_a2(code: &[[u8; 16]], lut: &[[u8; 16]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(code.len(), lut.len()); let n = code.len(); @@ -454,11 +448,10 @@ mod fast_scan { let mut i = 0_usize; while i < n { - let c = vld1q_u8(code.as_ptr().add(i).cast()); + let code = vld1q_u8(code.as_ptr().add(i).cast()); - let mask = vdupq_n_u8(0xf); - let clo = vandq_u8(c, mask); - let chi = vandq_u8(vshrq_n_u8(c, 4), mask); + let clo = vandq_u8(code, vdupq_n_u8(0xf)); + let chi = vshrq_n_u8(code, 4); let lut = vld1q_u8(lut.as_ptr().add(i).cast()); let res_lo = vreinterpretq_u16_u8(vqtbl1q_u8(lut, clo)); @@ -488,49 +481,37 @@ mod fast_scan { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn fast_scan_v8_3a_test() { - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + fn fast_scan_a2_test() { + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { for n in 90..110 { let code = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); let lut = (0..n) - .map(|_| [rand::random(), rand::random()]) - .collect::>(); + .map(|_| std::array::from_fn(|_| rand::random())) + .collect::>(); unsafe { - assert_eq!( - fast_scan_v8_3a(&code, &lut), - fast_scan_fallback(&code, &lut) - ); + assert_eq!(fast_scan_a2(&code, &lut), fallback(&code, &lut)); } } } } - #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] - pub fn fast_scan(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { - assert_eq!(code.len(), lut.len()); - let n = code.len(); - - fn unary(op: impl Fn(T) -> U, a: [T; N]) -> [U; N] { - std::array::from_fn(|i| op(a[i])) - } - fn binary(op: impl Fn(T, T) -> T, a: [T; N], b: [T; N]) -> [T; N] { + #[crate::multiversion(@"v4.512", @"v3", @"v2", @"a2")] + pub fn fast_scan(code: &[[u8; 16]], lut: &[[u8; 16]]) -> [u16; 32] { + fn binary(op: impl Fn(u16, u16) -> u16, a: [u16; 8], b: [u16; 8]) -> [u16; 8] { std::array::from_fn(|i| op(a[i], b[i])) } - fn shuffle(a: [T; N], b: [u8; N]) -> [T; N] { + fn shuffle(a: [u8; 16], b: [u8; 16]) -> [u8; 16] { std::array::from_fn(|i| a[b[i] as usize]) } - fn cast(x: [u8; 16]) -> [u16; 8] { - std::array::from_fn(|i| u16::from_le_bytes([x[i << 1 | 0], x[i << 1 | 1]])) - } - fn setr(x: [[T; 8]; 4]) -> [T; 32] { - std::array::from_fn(|i| x[i >> 3][i & 7]) - } + + assert_eq!(code.len(), lut.len()); + let n = code.len(); let mut a_0 = [0u16; 8]; let mut a_1 = [0u16; 8]; @@ -538,29 +519,28 @@ mod fast_scan { let mut a_3 = [0u16; 8]; for i in 0..n { - let c = unsafe { std::mem::transmute::<[u64; 2], [u8; 16]>(code[i]) }; + let code = code[i]; - let mask = [0xfu8; 16]; - let clo = binary(std::ops::BitAnd::bitand, c, mask); - let chi = binary(std::ops::BitAnd::bitand, unary(|x| x >> 4, c), mask); + let clo = code.map(|x| x & 0xf); + let chi = code.map(|x| x >> 4); - let lut = unsafe { std::mem::transmute::<[u64; 2], [u8; 16]>(lut[i]) }; - let res_lo = cast(shuffle(lut, clo)); + let lut = lut[i]; + let res_lo = zerocopy::transmute!(shuffle(lut, clo)); a_0 = binary(u16::wrapping_add, a_0, res_lo); - a_1 = binary(u16::wrapping_add, a_1, unary(|x| x >> 8, res_lo)); - let res_hi = cast(shuffle(lut, chi)); + a_1 = binary(u16::wrapping_add, a_1, res_lo.map(|x| x >> 8)); + let res_hi = zerocopy::transmute!(shuffle(lut, chi)); a_2 = binary(u16::wrapping_add, a_2, res_hi); - a_3 = binary(u16::wrapping_add, a_3, unary(|x| x >> 8, res_hi)); + a_3 = binary(u16::wrapping_add, a_3, res_hi.map(|x| x >> 8)); } - a_0 = binary(u16::wrapping_sub, a_0, unary(|x| x.wrapping_shl(8), a_1)); - a_2 = binary(u16::wrapping_sub, a_2, unary(|x| x.wrapping_shl(8), a_3)); + a_0 = binary(u16::wrapping_sub, a_0, a_1.map(|x| x << 8)); + a_2 = binary(u16::wrapping_sub, a_2, a_3.map(|x| x << 8)); - setr([a_0, a_1, a_2, a_3]) + zerocopy::transmute!([a_0, a_1, a_2, a_3]) } } #[inline(always)] -pub fn fast_scan(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { +pub fn fast_scan(code: &[[u8; 16]], lut: &[[u8; 16]]) -> [u16; 32] { fast_scan::fast_scan(code, lut) } diff --git a/crates/simd/src/lib.rs b/crates/simd/src/lib.rs index aec5c523..e2ce8304 100644 --- a/crates/simd/src/lib.rs +++ b/crates/simd/src/lib.rs @@ -1,7 +1,7 @@ -#![feature(target_feature_11)] #![feature(avx512_target_feature)] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))] +#![allow(unsafe_code)] mod aligned; mod emulate; @@ -15,16 +15,7 @@ pub mod quantize; pub mod u8; pub trait Floating: - Copy - + Send - + Sync - + std::fmt::Debug - + serde::Serialize - + for<'a> serde::Deserialize<'a> - + Default - + 'static - + PartialEq - + PartialOrd + Copy + Send + Sync + std::fmt::Debug + Default + 'static + PartialEq + PartialOrd { fn zero() -> Self; fn infinity() -> Self; @@ -77,10 +68,77 @@ mod internal { #[cfg(target_arch = "riscv64")] #[allow(unused_imports)] pub use is_riscv64_cpu_detected; + + #[cfg(target_arch = "x86_64")] + pub fn is_v4_512_detected() -> bool { + std::arch::is_x86_feature_detected!("avx512bw") + && std::arch::is_x86_feature_detected!("avx512cd") + && std::arch::is_x86_feature_detected!("avx512dq") + && std::arch::is_x86_feature_detected!("avx512vl") + && std::arch::is_x86_feature_detected!("bmi1") + && std::arch::is_x86_feature_detected!("bmi2") + && std::arch::is_x86_feature_detected!("lzcnt") + && std::arch::is_x86_feature_detected!("movbe") + && std::arch::is_x86_feature_detected!("popcnt") + } + + #[cfg(target_arch = "x86_64")] + pub fn is_v3_detected() -> bool { + std::arch::is_x86_feature_detected!("avx2") + && std::arch::is_x86_feature_detected!("f16c") + && std::arch::is_x86_feature_detected!("fma") + && std::arch::is_x86_feature_detected!("bmi1") + && std::arch::is_x86_feature_detected!("bmi2") + && std::arch::is_x86_feature_detected!("lzcnt") + && std::arch::is_x86_feature_detected!("movbe") + && std::arch::is_x86_feature_detected!("popcnt") + } + + #[cfg(target_arch = "x86_64")] + pub fn is_v2_detected() -> bool { + std::arch::is_x86_feature_detected!("sse4.2") + && std::arch::is_x86_feature_detected!("popcnt") + } + + #[cfg(target_arch = "aarch64")] + pub fn is_a3_512_detected() -> bool { + #[target_feature(enable = "sve")] + fn is_512_detected() -> bool { + let vl: u64; + unsafe { + core::arch::asm!( + "rdvl {0}, #8", + out(reg) vl + ); + } + vl >= 512 + } + std::arch::is_aarch64_feature_detected!("sve") && unsafe { is_512_detected() } + } + + #[cfg(target_arch = "aarch64")] + pub fn is_a3_256_detected() -> bool { + #[target_feature(enable = "sve")] + fn is_256_detected() -> bool { + let vl: u64; + unsafe { + core::arch::asm!( + "rdvl {0}, #8", + out(reg) vl + ); + } + vl >= 256 + } + std::arch::is_aarch64_feature_detected!("sve") && unsafe { is_256_detected() } + } + + #[cfg(target_arch = "aarch64")] + pub fn is_a2_detected() -> bool { + std::arch::is_aarch64_feature_detected!("neon") + } } -pub use simd_macros::multiversion; -pub use simd_macros::target_cpu; +pub use simd_macros::{multiversion, target_cpu}; #[cfg(target_arch = "x86_64")] #[allow(unused_imports)] diff --git a/crates/simd/src/packed_u4.rs b/crates/simd/src/packed_u4.rs index d68333c7..ae32aebb 100644 --- a/crates/simd/src/packed_u4.rs +++ b/crates/simd/src/packed_u4.rs @@ -3,7 +3,7 @@ pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { } mod reduce_sum_of_xy { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { assert_eq!(s.len(), t.len()); let n = s.len(); diff --git a/crates/simd/src/quantize.rs b/crates/simd/src/quantize.rs index baaf594a..880f64bc 100644 --- a/crates/simd/src/quantize.rs +++ b/crates/simd/src/quantize.rs @@ -1,8 +1,8 @@ mod mul_add_round { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn mul_add_round_v4(this: &[f32], k: f32, b: f32) -> Vec { + #[crate::target_cpu(enable = "v4.512")] + fn mul_add_round_v4_512(this: &[f32], k: f32, b: f32) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); unsafe { @@ -42,7 +42,7 @@ mod mul_add_round { #[cfg(all(target_arch = "x86_64", test, not(miri)))] #[test] fn mul_add_round_v4_test() { - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -53,8 +53,8 @@ mod mul_add_round { let x = &x[..z]; let k = 20.0; let b = 20.0; - let specialized = unsafe { mul_add_round_v4(x, k, b) }; - let fallback = mul_add_round_fallback(x, k, b); + let specialized = unsafe { mul_add_round_v4_512(x, k, b) }; + let fallback = fallback(x, k, b); assert_eq!(specialized, fallback); } } @@ -123,7 +123,7 @@ mod mul_add_round { let k = 20.0; let b = 20.0; let specialized = unsafe { mul_add_round_v3(x, k, b) }; - let fallback = mul_add_round_fallback(x, k, b); + let fallback = fallback(x, k, b); assert_eq!(specialized, fallback); } } @@ -189,7 +189,7 @@ mod mul_add_round { let k = 20.0; let b = 20.0; let specialized = unsafe { mul_add_round_v2_fma(x, k, b) }; - let fallback = mul_add_round_fallback(x, k, b); + let fallback = fallback(x, k, b); assert_eq!(specialized, fallback); } } @@ -197,8 +197,8 @@ mod mul_add_round { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn mul_add_round_v8_3a(this: &[f32], k: f32, b: f32) -> Vec { + #[crate::target_cpu(enable = "a2")] + fn mul_add_round_a2(this: &[f32], k: f32, b: f32) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); unsafe { @@ -244,9 +244,9 @@ mod mul_add_round { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn mul_add_round_v8_3a_test() { - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + fn mul_add_round_a2_test() { + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { @@ -256,14 +256,14 @@ mod mul_add_round { let x = &x[..z]; let k = 20.0; let b = 20.0; - let specialized = unsafe { mul_add_round_v8_3a(x, k, b) }; - let fallback = mul_add_round_fallback(x, k, b); + let specialized = unsafe { mul_add_round_a2(x, k, b) }; + let fallback = fallback(x, k, b); assert_eq!(specialized, fallback); } } } - #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2:fma", @"a2")] pub fn mul_add_round(this: &[f32], k: f32, b: f32) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); diff --git a/crates/simd/src/u8.rs b/crates/simd/src/u8.rs index bc845332..4c8e238e 100644 --- a/crates/simd/src/u8.rs +++ b/crates/simd/src/u8.rs @@ -1,5 +1,5 @@ mod reduce_sum_of_xy { - #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion("v4.512", "v3", "v2", "a2")] pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { assert_eq!(s.len(), t.len()); let n = s.len(); @@ -19,8 +19,8 @@ pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { mod reduce_sum_of_x_as_u16 { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_x_as_u16_v4(this: &[u8]) -> u16 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_x_as_u16_v4_512(this: &[u8]) -> u16 { use crate::emulate::emulate_mm512_reduce_add_epi16; unsafe { use std::arch::x86_64::*; @@ -47,18 +47,18 @@ mod reduce_sum_of_x_as_u16 { #[test] fn reduce_sum_of_x_as_u16_v4_test() { use rand::Rng; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; - let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_as_u16_v4(this) }; - let fallback = reduce_sum_of_x_as_u16_fallback(this); + let specialized = unsafe { reduce_sum_of_x_as_u16_v4_512(this) }; + let fallback = fallback(this); assert_eq!(specialized, fallback); } } @@ -101,14 +101,14 @@ mod reduce_sum_of_x_as_u16 { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; - let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_as_u16_v3(this) }; - let fallback = reduce_sum_of_x_as_u16_fallback(this); + let fallback = fallback(this); assert_eq!(specialized, fallback); } } @@ -151,14 +151,14 @@ mod reduce_sum_of_x_as_u16 { println!("test {} ... skipped (v2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; - let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_as_u16_v2(this) }; - let fallback = reduce_sum_of_x_as_u16_fallback(this); + let fallback = fallback(this); assert_eq!(specialized, fallback); } } @@ -166,8 +166,8 @@ mod reduce_sum_of_x_as_u16 { #[inline] #[cfg(target_arch = "aarch64")] - #[crate::target_cpu(enable = "v8.3a")] - fn reduce_sum_of_x_as_u16_v8_3a(this: &[u8]) -> u16 { + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_x_as_u16_a2(this: &[u8]) -> u16 { unsafe { use std::arch::aarch64::*; let us = vdupq_n_u16(255); @@ -194,26 +194,26 @@ mod reduce_sum_of_x_as_u16 { #[cfg(all(target_arch = "aarch64", test, not(miri)))] #[test] - fn reduce_sum_of_x_as_u16_v8_3a_test() { + fn reduce_sum_of_x_as_u16_a2_test() { use rand::Rng; - if !crate::is_cpu_detected!("v8.3a") { - println!("test {} ... skipped (v8.3a)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; - let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_as_u16_v8_3a(this) }; - let fallback = reduce_sum_of_x_as_u16_fallback(this); + let specialized = unsafe { reduce_sum_of_x_as_u16_a2(this) }; + let fallback = fallback(this); assert_eq!(specialized, fallback); } } } - #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2", @"a2")] pub fn reduce_sum_of_x_as_u16(this: &[u8]) -> u16 { let n = this.len(); let mut sum = 0; @@ -232,8 +232,8 @@ pub fn reduce_sum_of_x_as_u16(vector: &[u8]) -> u16 { mod reduce_sum_of_x { #[inline] #[cfg(target_arch = "x86_64")] - #[crate::target_cpu(enable = "v4")] - fn reduce_sum_of_x_v4(this: &[u8]) -> u32 { + #[crate::target_cpu(enable = "v4.512")] + fn reduce_sum_of_x_v4_512(this: &[u8]) -> u32 { unsafe { use std::arch::x86_64::*; let us = _mm512_set1_epi32(255); @@ -259,18 +259,18 @@ mod reduce_sum_of_x { #[test] fn reduce_sum_of_x_v4_test() { use rand::Rng; - if !crate::is_cpu_detected!("v4") { + if !crate::is_cpu_detected!("v4.512") { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; - let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_v4(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let specialized = unsafe { reduce_sum_of_x_v4_512(this) }; + let fallback = fallback(this); assert_eq!(specialized, fallback); } } @@ -307,26 +307,126 @@ mod reduce_sum_of_x { #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_x_as_u16_v3_test() { + fn reduce_sum_of_x_v3_test() { use rand::Rng; if !crate::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { let n = 4016; - let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_v3(this) }; - let fallback = reduce_sum_of_x_fallback(this); + let fallback = fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + fn reduce_sum_of_x_v2(this: &[u8]) -> u32 { + use crate::emulate::emulate_mm_reduce_add_epi32; + unsafe { + use std::arch::x86_64::*; + let us = _mm_set1_epi32(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm_setzero_si128(); + while n >= 4 { + let x = _mm_cvtsi32_si128(a.cast::().read_unaligned()); + a = a.add(4); + n -= 4; + sum = _mm_add_epi32(_mm_and_si128(us, _mm_cvtepi8_epi32(x)), sum); + } + let mut sum = emulate_mm_reduce_add_epi32(sum) as u32; + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u32; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_v2_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v2(this) }; + let fallback = fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_x_a2(this: &[u8]) -> u32 { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum_0 = vdupq_n_u32(0); + let mut sum_1 = vdupq_n_u32(0); + while n >= 8 { + let x = vmovl_u8(vld1_u8(a.cast())); + a = a.add(8); + n -= 8; + sum_0 = vaddq_u32(vmovl_u16(vget_low_u16(x)), sum_0); + sum_1 = vaddq_u32(vmovl_u16(vget_high_u16(x)), sum_1); + } + let mut sum = vaddvq_u32(vaddq_u32(sum_0, sum_1)); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u32; + } + sum + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_x_a2_test() { + use rand::Rng; + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.random_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_a2(this) }; + let fallback = fallback(this); assert_eq!(specialized, fallback); } } } - #[crate::multiversion(@"v4", @"v3", "v2", "v8.3a:sve", "v8.3a")] + #[crate::multiversion(@"v4.512", @"v3", @"v2", @"a2")] pub fn reduce_sum_of_x(this: &[u8]) -> u32 { let n = this.len(); let mut sum = 0; diff --git a/crates/simd_macros/Cargo.toml b/crates/simd_macros/Cargo.toml index 799db570..c608b3e4 100644 --- a/crates/simd_macros/Cargo.toml +++ b/crates/simd_macros/Cargo.toml @@ -7,9 +7,9 @@ edition.workspace = true proc-macro = true [dependencies] -proc-macro2 = { version = "1.0.79", features = ["proc-macro"] } -quote = "1.0.35" -syn = { version = "2.0.53", default-features = false, features = [ +proc-macro2 = { version = "1.0.93", features = ["proc-macro"] } +quote = "1.0.38" +syn = { version = "2.0.98", default-features = false, features = [ "clone-impls", "full", "parsing", diff --git a/crates/simd_macros/src/lib.rs b/crates/simd_macros/src/lib.rs index fe455646..68a5b8c6 100644 --- a/crates/simd_macros/src/lib.rs +++ b/crates/simd_macros/src/lib.rs @@ -118,11 +118,9 @@ pub fn multiversion( } }); } - let fallback_name = - syn::Ident::new(&format!("{name}_fallback"), proc_macro2::Span::mixed_site()); quote::quote! { #versions - fn #fallback_name < #generics_params > (#inputs) #output #generics_where { #block } + fn fallback < #generics_params > (#inputs) #output #generics_where { #block } #[inline(always)] #(#attrs)* #vis #sig { static CACHE: core::sync::atomic::AtomicPtr<()> = core::sync::atomic::AtomicPtr::new(core::ptr::null_mut()); @@ -132,7 +130,7 @@ pub fn multiversion( return unsafe { f(#(#arguments,)*) }; } #branches - let _multiversion_internal: unsafe fn(#inputs) #output = #fallback_name; + let _multiversion_internal: unsafe fn(#inputs) #output = fallback; CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); unsafe { _multiversion_internal(#(#arguments,)*) } } @@ -184,21 +182,22 @@ pub fn define_is_cpu_detected(input: proc_macro::TokenStream) -> proc_macro::Tok if target_cpu.target_arch != target_arch { continue; } - let target_features = target_cpu.target_features; let target_cpu = target_cpu.target_cpu; + let ident = syn::Ident::new( + &format!("is_{}_detected", target_cpu.replace('.', "_")), + proc_macro2::Span::mixed_site(), + ); arms.extend(quote::quote! { - (#target_cpu) => { - true #(&& $crate::is_feature_detected!(#target_features))* - }; + (#target_cpu) => { $crate::internal::#ident() }; }); } - let name = syn::Ident::new( + let ident = syn::Ident::new( &format!("is_{target_arch}_cpu_detected"), proc_macro2::Span::mixed_site(), ); quote::quote! { #[macro_export] - macro_rules! #name { + macro_rules! #ident { #arms } } diff --git a/crates/simd_macros/src/target.rs b/crates/simd_macros/src/target.rs index c2584013..e2be8534 100644 --- a/crates/simd_macros/src/target.rs +++ b/crates/simd_macros/src/target.rs @@ -6,78 +6,48 @@ pub struct TargetCpu { pub const TARGET_CPUS: &[TargetCpu] = &[ TargetCpu { - target_cpu: "v4", + target_cpu: "v4.512", target_arch: "x86_64", target_features: &[ - "avx", - "avx2", - "avx512bw", - "avx512cd", - "avx512dq", - "avx512f", - "avx512vl", - "bmi1", - "bmi2", - "cmpxchg16b", - "f16c", - "fma", - "fxsr", - "lzcnt", - "movbe", - "popcnt", - "sse", - "sse2", - "sse3", - "sse4.1", - "sse4.2", - "ssse3", - "xsave", + "avx512bw", "avx512cd", "avx512dq", "avx512vl", // simd + "bmi1", "bmi2", "lzcnt", "movbe", "popcnt", // bit-operations ], }, TargetCpu { target_cpu: "v3", target_arch: "x86_64", target_features: &[ - "avx", - "avx2", - "bmi1", - "bmi2", - "cmpxchg16b", - "f16c", - "fma", - "fxsr", - "lzcnt", - "movbe", - "popcnt", - "sse", - "sse2", - "sse3", - "sse4.1", - "sse4.2", - "ssse3", - "xsave", + "avx2", "f16c", "fma", // simd + "bmi1", "bmi2", "lzcnt", "movbe", "popcnt", // bit-operations ], }, TargetCpu { target_cpu: "v2", target_arch: "x86_64", target_features: &[ - "cmpxchg16b", - "fxsr", - "popcnt", - "sse", - "sse2", - "sse3", - "sse4.1", - "sse4.2", - "ssse3", + "sse4.2", // simd + "popcnt", // bit-operations ], }, TargetCpu { - target_cpu: "v8.3a", + target_cpu: "a3.512", target_arch: "aarch64", target_features: &[ - "crc", "dpb", "fcma", "jsconv", "lse", "neon", "paca", "pacg", "rcpc", "rdm", + "sve", // simd + ], + }, + TargetCpu { + target_cpu: "a3.256", + target_arch: "aarch64", + target_features: &[ + "sve", // simd + ], + }, + TargetCpu { + target_cpu: "a2", + target_arch: "aarch64", + target_features: &[ + "neon", // simd ], }, ]; diff --git a/crates/vector/Cargo.toml b/crates/vector/Cargo.toml index b910d361..692bf642 100644 --- a/crates/vector/Cargo.toml +++ b/crates/vector/Cargo.toml @@ -5,9 +5,9 @@ edition.workspace = true [dependencies] distance = { path = "../distance" } -half.workspace = true -serde.workspace = true simd = { path = "../simd" } +half.workspace = true + [lints] workspace = true diff --git a/crates/vector/src/bvect.rs b/crates/vector/src/bvect.rs index fe80164c..7877addb 100644 --- a/crates/vector/src/bvect.rs +++ b/crates/vector/src/bvect.rs @@ -1,12 +1,11 @@ use crate::{VectorBorrowed, VectorOwned}; use distance::Distance; -use serde::{Deserialize, Serialize}; use std::ops::{Bound, RangeBounds}; pub const BVECTOR_WIDTH: u32 = u64::BITS; // When using binary vector, please ensure that the padding bits are always zero. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct BVectOwned { dims: u32, data: Vec, @@ -29,7 +28,10 @@ impl BVectOwned { if dims % BVECTOR_WIDTH != 0 && data[data.len() - 1] >> (dims % BVECTOR_WIDTH) != 0 { return None; } - unsafe { Some(Self::new_unchecked(dims, data)) } + #[allow(unsafe_code)] + unsafe { + Some(Self::new_unchecked(dims, data)) + } } /// # Safety @@ -37,6 +39,7 @@ impl BVectOwned { /// * `dims` must be in `1..=65535`. /// * `data` must be of the correct length. /// * The padding bits must be zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked(dims: u32, data: Vec) -> Self { Self { dims, data } @@ -83,7 +86,10 @@ impl<'a> BVectBorrowed<'a> { if dims % BVECTOR_WIDTH != 0 && data[data.len() - 1] >> (dims % BVECTOR_WIDTH) != 0 { return None; } - unsafe { Some(Self::new_unchecked(dims, data)) } + #[allow(unsafe_code)] + unsafe { + Some(Self::new_unchecked(dims, data)) + } } /// # Safety @@ -91,6 +97,7 @@ impl<'a> BVectBorrowed<'a> { /// * `dims` must be in `1..=65535`. /// * `data` must be of the correct length. /// * The padding bits must be zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked(dims: u32, data: &'a [u64]) -> Self { Self { dims, data } diff --git a/crates/vector/src/lib.rs b/crates/vector/src/lib.rs index e82e4a6c..64128c7a 100644 --- a/crates/vector/src/lib.rs +++ b/crates/vector/src/lib.rs @@ -3,7 +3,7 @@ pub mod scalar8; pub mod svect; pub mod vect; -pub trait VectorOwned: Clone + serde::Serialize + for<'a> serde::Deserialize<'a> + 'static { +pub trait VectorOwned: Clone + 'static { type Borrowed<'a>: VectorBorrowed; fn as_borrowed(&self) -> Self::Borrowed<'_>; diff --git a/crates/vector/src/scalar8.rs b/crates/vector/src/scalar8.rs index ff9095aa..3792e270 100644 --- a/crates/vector/src/scalar8.rs +++ b/crates/vector/src/scalar8.rs @@ -1,9 +1,8 @@ use crate::{VectorBorrowed, VectorOwned}; use distance::Distance; -use serde::{Deserialize, Serialize}; use std::ops::RangeBounds; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct Scalar8Owned { sum_of_x2: f32, k: f32, @@ -29,12 +28,14 @@ impl Scalar8Owned { if !(1..=65535).contains(&code.len()) { return None; } + #[allow(unsafe_code)] Some(unsafe { Self::new_unchecked(sum_of_x2, k, b, sum_of_code, code) }) } /// # Safety /// /// * `code.len()` must not be zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked( sum_of_x2: f32, @@ -105,12 +106,14 @@ impl<'a> Scalar8Borrowed<'a> { if !(1..=65535).contains(&code.len()) { return None; } + #[allow(unsafe_code)] Some(unsafe { Self::new_unchecked(sum_of_x2, k, b, sum_of_code, code) }) } /// # Safety /// /// * `code.len()` must not be zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked( sum_of_x2: f32, diff --git a/crates/vector/src/svect.rs b/crates/vector/src/svect.rs index 08f678d3..26cbf197 100644 --- a/crates/vector/src/svect.rs +++ b/crates/vector/src/svect.rs @@ -1,10 +1,9 @@ use crate::{VectorBorrowed, VectorOwned}; use distance::Distance; -use serde::{Deserialize, Serialize}; use simd::Floating; use std::ops::{Bound, RangeBounds}; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct SVectOwned { dims: u32, indexes: Vec, @@ -37,7 +36,10 @@ impl SVectOwned { if S::reduce_or_of_is_zero_x(&values) { return None; } - unsafe { Some(Self::new_unchecked(dims, indexes, values)) } + #[allow(unsafe_code)] + unsafe { + Some(Self::new_unchecked(dims, indexes, values)) + } } /// # Safety @@ -46,6 +48,7 @@ impl SVectOwned { /// * `indexes.len()` must be equal to `values.len()`. /// * `indexes` must be a strictly increasing sequence and the last in the sequence must be less than `dims`. /// * A floating number in `values` must not be positive zero or negative zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked(dims: u32, indexes: Vec, values: Vec) -> Self { Self { @@ -119,7 +122,10 @@ impl<'a, S: Floating> SVectBorrowed<'a, S> { return None; } } - unsafe { Some(Self::new_unchecked(dims, indexes, values)) } + #[allow(unsafe_code)] + unsafe { + Some(Self::new_unchecked(dims, indexes, values)) + } } /// # Safety @@ -129,6 +135,7 @@ impl<'a, S: Floating> SVectBorrowed<'a, S> { /// * `indexes` must be a strictly increasing sequence and the last in the sequence must be less than `dims`. /// * A floating number in `values` must not be positive zero or negative zero. #[inline(always)] + #[allow(unsafe_code)] pub unsafe fn new_unchecked(dims: u32, indexes: &'a [u32], values: &'a [S]) -> Self { Self { dims, diff --git a/crates/vector/src/vect.rs b/crates/vector/src/vect.rs index 34b186f7..527fc6ad 100644 --- a/crates/vector/src/vect.rs +++ b/crates/vector/src/vect.rs @@ -1,11 +1,10 @@ use super::{VectorBorrowed, VectorOwned}; use distance::Distance; -use serde::{Deserialize, Serialize}; use simd::Floating; use std::cmp::Ordering; use std::ops::RangeBounds; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] #[repr(transparent)] pub struct VectOwned(Vec); @@ -20,12 +19,14 @@ impl VectOwned { if !(1..=65535).contains(&slice.len()) { return None; } + #[allow(unsafe_code)] Some(unsafe { Self::new_unchecked(slice) }) } /// # Safety /// /// * `slice.len()` must not be zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked(slice: Vec) -> Self { Self(slice) @@ -76,12 +77,14 @@ impl<'a, S: Floating> VectBorrowed<'a, S> { if !(1..=65535).contains(&slice.len()) { return None; } + #[allow(unsafe_code)] Some(unsafe { Self::new_unchecked(slice) }) } /// # Safety /// /// * `slice.len()` must not be zero. + #[allow(unsafe_code)] #[inline(always)] pub unsafe fn new_unchecked(slice: &'a [S]) -> Self { Self(slice) diff --git a/docker/pgrx.Dockerfile b/docker/pgrx.Dockerfile deleted file mode 100644 index 93ec4aff..00000000 --- a/docker/pgrx.Dockerfile +++ /dev/null @@ -1,64 +0,0 @@ -# CNPG only support Debian 12 (Bookworm) -FROM ubuntu:22.04 - -ARG PGRX_VERSION -ARG RUST_TOOLCHAIN -ARG TARGETARCH - -ENV DEBIAN_FRONTEND=noninteractive \ - LANG=en_US.UTF-8 \ - LC_ALL=en_US.UTF-8 \ - RUSTFLAGS="-Dwarnings" \ - RUST_BACKTRACE=1 \ - CARGO_TERM_COLOR=always \ - SCCACHE_VERSION=0.9.0 - -RUN set -eux; \ - apt update; \ - apt install -y --no-install-recommends \ - curl \ - ca-certificates \ - build-essential \ - postgresql-common gnupg \ - libreadline-dev zlib1g-dev flex bison libxml2-dev libxslt-dev libssl-dev libxml2-utils xsltproc ccache pkg-config \ - zip - -RUN set -eux; \ - apt -y install lsb-release wget software-properties-common gnupg; \ - curl --proto '=https' --tlsv1.2 -sSf https://apt.llvm.org/llvm.sh | bash -s -- 18; \ - update-alternatives --install /usr/bin/clang clang $(which clang-18) 255 - -# set up sccache -RUN set -ex; \ - curl -fsSL -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v${SCCACHE_VERSION}/sccache-v${SCCACHE_VERSION}-$(uname -m)-unknown-linux-musl.tar.gz; \ - tar -xzf sccache.tar.gz --strip-components=1; \ - rm sccache.tar.gz; \ - mv sccache /usr/local/bin/ - -RUN /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y -# install all the PostgresQL -RUN set -ex; \ - for v in $(seq 14 17); do \ - apt install -y --no-install-recommends postgresql-$v postgresql-server-dev-$v postgresql-$v-pgvector; \ - done; \ - rm -rf /var/lib/apt/lists/*; - -# create a non-root user (make it compatible with Ubuntu 24.04) -RUN useradd -u 1000 -U -m ubuntu -RUN chown -R ubuntu:ubuntu /usr/share/postgresql/ /usr/lib/postgresql/ -USER ubuntu -ENV PATH="$PATH:/home/ubuntu/.cargo/bin" -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - -WORKDIR /workspace -RUN rustup toolchain install ${RUST_TOOLCHAIN} -RUN rustup target add $(uname -m)-unknown-linux-gnu - -RUN cargo install cargo-pgrx --locked --version=${PGRX_VERSION} - -RUN set -ex; \ - for v in $(seq 14 17); do \ - cargo pgrx init --pg$v=/usr/lib/postgresql/$v/bin/pg_config; \ - done; - -CMD [ "/bin/bash" ] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index eb254600..d3e25b13 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-12-25" +channel = "nightly-2025-02-14" diff --git a/rustfmt.toml b/rustfmt.toml index 35011368..c32b643a 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,2 @@ style_edition = "2024" +imports_granularity = "Module" diff --git a/sql/install/vchord--0.2.1.sql b/sql/install/vchord--0.2.1.sql new file mode 100644 index 00000000..e5cf547b --- /dev/null +++ b/sql/install/vchord--0.2.1.sql @@ -0,0 +1,500 @@ +/* */ +/* +This file is auto generated by pgrx. + +The ordering of items is not stable, it is driven by a dependency graph. +*/ +/* */ + +/* */ +-- src/lib.rs:10 +-- bootstrap +-- List of shell types + +CREATE TYPE scalar8; +CREATE TYPE sphere_vector; +CREATE TYPE sphere_halfvec; +CREATE TYPE sphere_scalar8; +/* */ + +/* */ +-- src/datatype/functions_scalar8.rs:18 +-- vchord::datatype::functions_scalar8::_vchord_halfvec_quantize_to_scalar8 +/* */ + +/* */ +-- src/datatype/operators_halfvec.rs:54 +-- vchord::datatype::operators_halfvec::_vchord_halfvec_sphere_cosine_in +CREATE FUNCTION "_vchord_halfvec_sphere_cosine_in"( + "lhs" halfvec, /* vchord::datatype::memory_halfvec::HalfvecInput */ + "rhs" sphere_halfvec /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_halfvec_sphere_cosine_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_halfvec.rs:30 +-- vchord::datatype::operators_halfvec::_vchord_halfvec_sphere_ip_in +CREATE FUNCTION "_vchord_halfvec_sphere_ip_in"( + "lhs" halfvec, /* vchord::datatype::memory_halfvec::HalfvecInput */ + "rhs" sphere_halfvec /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_halfvec_sphere_ip_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_halfvec.rs:6 +-- vchord::datatype::operators_halfvec::_vchord_halfvec_sphere_l2_in +CREATE FUNCTION "_vchord_halfvec_sphere_l2_in"( + "lhs" halfvec, /* vchord::datatype::memory_halfvec::HalfvecInput */ + "rhs" sphere_halfvec /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_halfvec_sphere_l2_in_wrapper'; +/* */ + +/* */ +-- src/datatype/text_scalar8.rs:7 +-- vchord::datatype::text_scalar8::_vchord_scalar8_in +CREATE FUNCTION "_vchord_scalar8_in"( + "input" cstring, /* &core::ffi::c_str::CStr */ + "oid" oid, /* pgrx_pg_sys::submodules::oids::Oid */ + "typmod" INT /* i32 */ +) RETURNS scalar8 /* vchord::datatype::memory_scalar8::Scalar8Output */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_scalar8.rs:26 +-- vchord::datatype::operators_scalar8::_vchord_scalar8_operator_cosine +CREATE FUNCTION "_vchord_scalar8_operator_cosine"( + "lhs" scalar8, /* vchord::datatype::memory_scalar8::Scalar8Input */ + "rhs" scalar8 /* vchord::datatype::memory_scalar8::Scalar8Input */ +) RETURNS real /* f32 */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_operator_cosine_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_scalar8.rs:6 +-- vchord::datatype::operators_scalar8::_vchord_scalar8_operator_ip +CREATE FUNCTION "_vchord_scalar8_operator_ip"( + "lhs" scalar8, /* vchord::datatype::memory_scalar8::Scalar8Input */ + "rhs" scalar8 /* vchord::datatype::memory_scalar8::Scalar8Input */ +) RETURNS real /* f32 */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_operator_ip_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_scalar8.rs:16 +-- vchord::datatype::operators_scalar8::_vchord_scalar8_operator_l2 +CREATE FUNCTION "_vchord_scalar8_operator_l2"( + "lhs" scalar8, /* vchord::datatype::memory_scalar8::Scalar8Input */ + "rhs" scalar8 /* vchord::datatype::memory_scalar8::Scalar8Input */ +) RETURNS real /* f32 */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_operator_l2_wrapper'; +/* */ + +/* */ +-- src/datatype/text_scalar8.rs:119 +-- vchord::datatype::text_scalar8::_vchord_scalar8_out +CREATE FUNCTION "_vchord_scalar8_out"( + "vector" scalar8 /* vchord::datatype::memory_scalar8::Scalar8Input */ +) RETURNS cstring /* alloc::ffi::c_str::CString */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_out_wrapper'; +/* */ + +/* */ +-- src/datatype/binary_scalar8.rs:22 +-- vchord::datatype::binary_scalar8::_vchord_scalar8_recv +CREATE FUNCTION "_vchord_scalar8_recv"( + "internal" internal, /* pgrx::datum::internal::Internal */ + "oid" oid, /* pgrx_pg_sys::submodules::oids::Oid */ + "typmod" INT /* i32 */ +) RETURNS scalar8 /* vchord::datatype::memory_scalar8::Scalar8Output */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_recv_wrapper'; +/* */ + +/* */ +-- src/datatype/binary_scalar8.rs:7 +-- vchord::datatype::binary_scalar8::_vchord_scalar8_send +CREATE FUNCTION "_vchord_scalar8_send"( + "vector" scalar8 /* vchord::datatype::memory_scalar8::Scalar8Input */ +) RETURNS bytea /* alloc::vec::Vec */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_send_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_scalar8.rs:84 +-- vchord::datatype::operators_scalar8::_vchord_scalar8_sphere_cosine_in +CREATE FUNCTION "_vchord_scalar8_sphere_cosine_in"( + "lhs" scalar8, /* vchord::datatype::memory_scalar8::Scalar8Input */ + "rhs" sphere_scalar8 /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_sphere_cosine_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_scalar8.rs:36 +-- vchord::datatype::operators_scalar8::_vchord_scalar8_sphere_ip_in +CREATE FUNCTION "_vchord_scalar8_sphere_ip_in"( + "lhs" scalar8, /* vchord::datatype::memory_scalar8::Scalar8Input */ + "rhs" sphere_scalar8 /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_sphere_ip_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_scalar8.rs:60 +-- vchord::datatype::operators_scalar8::_vchord_scalar8_sphere_l2_in +CREATE FUNCTION "_vchord_scalar8_sphere_l2_in"( + "lhs" scalar8, /* vchord::datatype::memory_scalar8::Scalar8Input */ + "rhs" sphere_scalar8 /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_scalar8_sphere_l2_in_wrapper'; +/* */ + +/* */ +-- src/datatype/typmod.rs:45 +-- vchord::datatype::typmod::_vchord_typmod_in_65535 +CREATE FUNCTION "_vchord_typmod_in_65535"( + "list" cstring[] /* pgrx::datum::array::Array<&core::ffi::c_str::CStr> */ +) RETURNS INT /* i32 */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_typmod_in_65535_wrapper'; +/* */ + +/* */ +-- src/datatype/typmod.rs:63 +-- vchord::datatype::typmod::_vchord_typmod_out +CREATE FUNCTION "_vchord_typmod_out"( + "typmod" INT /* i32 */ +) RETURNS cstring /* alloc::ffi::c_str::CString */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_typmod_out_wrapper'; +/* */ + +/* */ +-- src/datatype/functions_scalar8.rs:8 +-- vchord::datatype::functions_scalar8::_vchord_vector_quantize_to_scalar8 +/* */ + +/* */ +-- src/datatype/operators_vector.rs:54 +-- vchord::datatype::operators_vector::_vchord_vector_sphere_cosine_in +CREATE FUNCTION "_vchord_vector_sphere_cosine_in"( + "lhs" vector, /* vchord::datatype::memory_vector::VectorInput */ + "rhs" sphere_vector /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_vector_sphere_cosine_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_vector.rs:30 +-- vchord::datatype::operators_vector::_vchord_vector_sphere_ip_in +CREATE FUNCTION "_vchord_vector_sphere_ip_in"( + "lhs" vector, /* vchord::datatype::memory_vector::VectorInput */ + "rhs" sphere_vector /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_vector_sphere_ip_in_wrapper'; +/* */ + +/* */ +-- src/datatype/operators_vector.rs:6 +-- vchord::datatype::operators_vector::_vchord_vector_sphere_l2_in +CREATE FUNCTION "_vchord_vector_sphere_l2_in"( + "lhs" vector, /* vchord::datatype::memory_vector::VectorInput */ + "rhs" sphere_vector /* pgrx::heap_tuple::PgHeapTuple */ +) RETURNS bool /* bool */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchord_vector_sphere_l2_in_wrapper'; +/* */ + +/* */ +-- src/index/am/mod.rs:60 +-- vchord::index::am::_vchordrq_amhandler +/* */ + +/* */ +-- src/index/functions.rs:9 +-- vchord::index::functions::_vchordrq_prewarm +/* */ + +/* */ +-- src/index/opclass.rs:36 +-- vchord::index::opclass::_vchordrq_support_halfvec_cosine_ops +CREATE FUNCTION "_vchordrq_support_halfvec_cosine_ops"() RETURNS TEXT /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchordrq_support_halfvec_cosine_ops_wrapper'; +/* */ + +/* */ +-- src/index/opclass.rs:31 +-- vchord::index::opclass::_vchordrq_support_halfvec_ip_ops +CREATE FUNCTION "_vchordrq_support_halfvec_ip_ops"() RETURNS TEXT /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchordrq_support_halfvec_ip_ops_wrapper'; +/* */ + +/* */ +-- src/index/opclass.rs:26 +-- vchord::index::opclass::_vchordrq_support_halfvec_l2_ops +CREATE FUNCTION "_vchordrq_support_halfvec_l2_ops"() RETURNS TEXT /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchordrq_support_halfvec_l2_ops_wrapper'; +/* */ + +/* */ +-- src/index/opclass.rs:21 +-- vchord::index::opclass::_vchordrq_support_vector_cosine_ops +CREATE FUNCTION "_vchordrq_support_vector_cosine_ops"() RETURNS TEXT /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchordrq_support_vector_cosine_ops_wrapper'; +/* */ + +/* */ +-- src/index/opclass.rs:16 +-- vchord::index::opclass::_vchordrq_support_vector_ip_ops +CREATE FUNCTION "_vchordrq_support_vector_ip_ops"() RETURNS TEXT /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchordrq_support_vector_ip_ops_wrapper'; +/* */ + +/* */ +-- src/index/opclass.rs:11 +-- vchord::index::opclass::_vchordrq_support_vector_l2_ops +CREATE FUNCTION "_vchordrq_support_vector_l2_ops"() RETURNS TEXT /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', '_vchordrq_support_vector_l2_ops_wrapper'; +/* */ + +/* */ +-- src/lib.rs:11 +-- finalize +-- List of data types + +CREATE TYPE scalar8 ( + INPUT = _vchord_scalar8_in, + OUTPUT = _vchord_scalar8_out, + RECEIVE = _vchord_scalar8_recv, + SEND = _vchord_scalar8_send, + TYPMOD_IN = _vchord_typmod_in_65535, + TYPMOD_OUT = _vchord_typmod_out, + STORAGE = EXTERNAL, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); + +CREATE TYPE sphere_vector AS ( + center vector, + radius REAL +); + +CREATE TYPE sphere_halfvec AS ( + center halfvec, + radius REAL +); + +CREATE TYPE sphere_scalar8 AS ( + center scalar8, + radius REAL +); + +-- List of operators + +CREATE OPERATOR <-> ( + PROCEDURE = _vchord_scalar8_operator_l2, + LEFTARG = scalar8, + RIGHTARG = scalar8, + COMMUTATOR = <-> +); + +CREATE OPERATOR <#> ( + PROCEDURE = _vchord_scalar8_operator_ip, + LEFTARG = scalar8, + RIGHTARG = scalar8, + COMMUTATOR = <#> +); + +CREATE OPERATOR <=> ( + PROCEDURE = _vchord_scalar8_operator_cosine, + LEFTARG = scalar8, + RIGHTARG = scalar8, + COMMUTATOR = <=> +); + +CREATE OPERATOR <<->> ( + PROCEDURE = _vchord_vector_sphere_l2_in, + LEFTARG = vector, + RIGHTARG = sphere_vector, + COMMUTATOR = <<->> +); + +CREATE OPERATOR <<->> ( + PROCEDURE = _vchord_halfvec_sphere_l2_in, + LEFTARG = halfvec, + RIGHTARG = sphere_halfvec, + COMMUTATOR = <<->> +); + +CREATE OPERATOR <<->> ( + PROCEDURE = _vchord_scalar8_sphere_l2_in, + LEFTARG = scalar8, + RIGHTARG = sphere_scalar8, + COMMUTATOR = <<->> +); + +CREATE OPERATOR <<#>> ( + PROCEDURE = _vchord_vector_sphere_ip_in, + LEFTARG = vector, + RIGHTARG = sphere_vector, + COMMUTATOR = <<#>> +); + +CREATE OPERATOR <<#>> ( + PROCEDURE = _vchord_halfvec_sphere_ip_in, + LEFTARG = halfvec, + RIGHTARG = sphere_halfvec, + COMMUTATOR = <<#>> +); + +CREATE OPERATOR <<#>> ( + PROCEDURE = _vchord_scalar8_sphere_ip_in, + LEFTARG = scalar8, + RIGHTARG = sphere_scalar8, + COMMUTATOR = <<#>> +); + +CREATE OPERATOR <<=>> ( + PROCEDURE = _vchord_vector_sphere_cosine_in, + LEFTARG = vector, + RIGHTARG = sphere_vector, + COMMUTATOR = <<=>> +); + +CREATE OPERATOR <<=>> ( + PROCEDURE = _vchord_halfvec_sphere_cosine_in, + LEFTARG = halfvec, + RIGHTARG = sphere_halfvec, + COMMUTATOR = <<=>> +); + +CREATE OPERATOR <<=>> ( + PROCEDURE = _vchord_scalar8_sphere_cosine_in, + LEFTARG = scalar8, + RIGHTARG = sphere_scalar8, + COMMUTATOR = <<=>> +); + +-- List of functions + +CREATE FUNCTION sphere(vector, real) RETURNS sphere_vector +IMMUTABLE PARALLEL SAFE LANGUAGE sql AS 'SELECT ROW($1, $2)'; + +CREATE FUNCTION sphere(halfvec, real) RETURNS sphere_halfvec +IMMUTABLE PARALLEL SAFE LANGUAGE sql AS 'SELECT ROW($1, $2)'; + +CREATE FUNCTION sphere(scalar8, real) RETURNS sphere_scalar8 +IMMUTABLE PARALLEL SAFE LANGUAGE sql AS 'SELECT ROW($1, $2)'; + +CREATE FUNCTION quantize_to_scalar8(vector) RETURNS scalar8 +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchord_vector_quantize_to_scalar8_wrapper'; + +CREATE FUNCTION quantize_to_scalar8(halfvec) RETURNS scalar8 +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchord_halfvec_quantize_to_scalar8_wrapper'; + +CREATE FUNCTION vchordrq_amhandler(internal) RETURNS index_am_handler +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_amhandler_wrapper'; + +CREATE FUNCTION vchordrq_prewarm(regclass, integer default 0) RETURNS TEXT +STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_prewarm_wrapper'; + +-- List of access methods + +CREATE ACCESS METHOD vchordrq TYPE INDEX HANDLER vchordrq_amhandler; + +-- List of operator families + +CREATE OPERATOR FAMILY vector_l2_ops USING vchordrq; +CREATE OPERATOR FAMILY vector_ip_ops USING vchordrq; +CREATE OPERATOR FAMILY vector_cosine_ops USING vchordrq; +CREATE OPERATOR FAMILY halfvec_l2_ops USING vchordrq; +CREATE OPERATOR FAMILY halfvec_ip_ops USING vchordrq; +CREATE OPERATOR FAMILY halfvec_cosine_ops USING vchordrq; + +-- List of operator classes + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING vchordrq FAMILY vector_l2_ops AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + OPERATOR 2 <<->> (vector, sphere_vector) FOR SEARCH, + FUNCTION 1 _vchordrq_support_vector_l2_ops(); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING vchordrq FAMILY vector_ip_ops AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + OPERATOR 2 <<#>> (vector, sphere_vector) FOR SEARCH, + FUNCTION 1 _vchordrq_support_vector_ip_ops(); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING vchordrq FAMILY vector_cosine_ops AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH, + FUNCTION 1 _vchordrq_support_vector_cosine_ops(); + +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING vchordrq FAMILY halfvec_l2_ops AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + OPERATOR 2 <<->> (halfvec, sphere_halfvec) FOR SEARCH, + FUNCTION 1 _vchordrq_support_halfvec_l2_ops(); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING vchordrq FAMILY halfvec_ip_ops AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + OPERATOR 2 <<#>> (halfvec, sphere_halfvec) FOR SEARCH, + FUNCTION 1 _vchordrq_support_halfvec_ip_ops(); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING vchordrq FAMILY halfvec_cosine_ops AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + OPERATOR 2 <<=>> (halfvec, sphere_halfvec) FOR SEARCH, + FUNCTION 1 _vchordrq_support_halfvec_cosine_ops(); +/* */ + diff --git a/sql/upgrade/vchord--0.2.0--0.2.1.sql b/sql/upgrade/vchord--0.2.0--0.2.1.sql new file mode 100644 index 00000000..be321fe8 --- /dev/null +++ b/sql/upgrade/vchord--0.2.0--0.2.1.sql @@ -0,0 +1 @@ +/* nothing to do */ diff --git a/src/algorithm/build.rs b/src/algorithm/build.rs deleted file mode 100644 index 4c893b7f..00000000 --- a/src/algorithm/build.rs +++ /dev/null @@ -1,373 +0,0 @@ -use crate::algorithm::RelationWrite; -use crate::algorithm::operator::{Operator, Vector}; -use crate::algorithm::tape::*; -use crate::algorithm::tuples::*; -use crate::index::am_options::Opfamily; -use crate::types::VchordrqBuildOptions; -use crate::types::VchordrqExternalBuildOptions; -use crate::types::VchordrqIndexingOptions; -use crate::types::VchordrqInternalBuildOptions; -use crate::types::VectorOptions; -use rand::Rng; -use simd::Floating; -use std::num::NonZeroU64; -use std::sync::Arc; -use vector::VectorBorrowed; -use vector::VectorOwned; - -pub trait HeapRelation { - fn traverse(&self, progress: bool, callback: F) - where - F: FnMut((NonZeroU64, O::Vector)); - fn opfamily(&self) -> Opfamily; -} - -pub trait Reporter { - fn tuples_total(&mut self, tuples_total: u64); -} - -pub fn build, R: Reporter>( - vector_options: VectorOptions, - vchordrq_options: VchordrqIndexingOptions, - heap_relation: T, - relation: impl RelationWrite, - mut reporter: R, -) { - let dims = vector_options.dims; - let is_residual = vchordrq_options.residual_quantization && O::SUPPORTS_RESIDUAL; - let structures = match vchordrq_options.build { - VchordrqBuildOptions::External(external_build) => Structure::extern_build( - vector_options.clone(), - heap_relation.opfamily(), - external_build.clone(), - ), - VchordrqBuildOptions::Internal(internal_build) => { - let mut tuples_total = 0_u64; - let samples = { - let mut rand = rand::thread_rng(); - let max_number_of_samples = internal_build - .lists - .last() - .unwrap() - .saturating_mul(internal_build.sampling_factor); - let mut samples = Vec::new(); - let mut number_of_samples = 0_u32; - heap_relation.traverse(false, |(_, vector)| { - let vector = vector.as_borrowed(); - assert_eq!(dims, vector.dims(), "invalid vector dimensions"); - if number_of_samples < max_number_of_samples { - samples.push(O::Vector::build_to_vecf32(vector)); - number_of_samples += 1; - } else { - let index = rand.gen_range(0..max_number_of_samples) as usize; - samples[index] = O::Vector::build_to_vecf32(vector); - } - tuples_total += 1; - }); - samples - }; - reporter.tuples_total(tuples_total); - Structure::internal_build(vector_options.clone(), internal_build.clone(), samples) - } - }; - let mut meta = TapeWriter::<_, _, MetaTuple>::create(|| relation.extend(false)); - assert_eq!(meta.first(), 0); - let freepage = TapeWriter::<_, _, FreepageTuple>::create(|| relation.extend(false)); - let mut vectors = TapeWriter::<_, _, VectorTuple>::create(|| relation.extend(true)); - let mut pointer_of_means = Vec::>::new(); - for i in 0..structures.len() { - let mut level = Vec::new(); - for j in 0..structures[i].len() { - let vector = O::Vector::build_from_vecf32(&structures[i].means[j]); - let (metadata, slices) = O::Vector::vector_split(vector.as_borrowed()); - let mut chain = Ok(metadata); - for i in (0..slices.len()).rev() { - chain = Err(vectors.push(match chain { - Ok(metadata) => VectorTuple::_0 { - payload: None, - elements: slices[i].to_vec(), - metadata, - }, - Err(pointer) => VectorTuple::_1 { - payload: None, - elements: slices[i].to_vec(), - pointer, - }, - })); - } - level.push(chain.err().unwrap()); - } - pointer_of_means.push(level); - } - let mut pointer_of_firsts = Vec::>::new(); - for i in 0..structures.len() { - let mut level = Vec::new(); - for j in 0..structures[i].len() { - if i == 0 { - let tape = TapeWriter::<_, _, H0Tuple>::create(|| relation.extend(false)); - let mut jump = TapeWriter::<_, _, JumpTuple>::create(|| relation.extend(false)); - jump.push(JumpTuple { - first: tape.first(), - }); - level.push(jump.first()); - } else { - let mut tape = H1TapeWriter::<_, _>::create(|| relation.extend(false)); - let h2_mean = &structures[i].means[j]; - let h2_children = &structures[i].children[j]; - for child in h2_children.iter().copied() { - let h1_mean = &structures[i - 1].means[child as usize]; - let code = if is_residual { - rabitq::code(dims, &f32::vector_sub(h1_mean, h2_mean)) - } else { - rabitq::code(dims, h1_mean) - }; - tape.push(H1Branch { - mean: pointer_of_means[i - 1][child as usize], - dis_u_2: code.dis_u_2, - factor_ppc: code.factor_ppc, - factor_ip: code.factor_ip, - factor_err: code.factor_err, - signs: code.signs, - first: pointer_of_firsts[i - 1][child as usize], - }); - } - let tape = tape.into_inner(); - level.push(tape.first()); - } - } - pointer_of_firsts.push(level); - } - meta.push(MetaTuple { - dims, - height_of_root: structures.len() as u32, - is_residual, - vectors_first: vectors.first(), - root_mean: pointer_of_means.last().unwrap()[0], - root_first: pointer_of_firsts.last().unwrap()[0], - freepage_first: freepage.first(), - }); -} - -struct Structure { - means: Vec>, - children: Vec>, -} - -impl Structure { - fn len(&self) -> usize { - self.children.len() - } - fn internal_build( - vector_options: VectorOptions, - internal_build: VchordrqInternalBuildOptions, - mut samples: Vec>, - ) -> Vec { - use std::iter::once; - for sample in samples.iter_mut() { - *sample = crate::projection::project(sample); - } - let mut result = Vec::::new(); - for w in internal_build.lists.iter().rev().copied().chain(once(1)) { - let means = crate::utils::parallelism::RayonParallelism::scoped( - internal_build.build_threads as _, - Arc::new(|| { - pgrx::check_for_interrupts!(); - }), - |parallelism| { - crate::utils::k_means::k_means( - parallelism, - w as usize, - vector_options.dims as usize, - if let Some(structure) = result.last() { - &structure.means - } else { - &samples - }, - internal_build.spherical_centroids, - 10, - ) - }, - ) - .expect("failed to create thread pool"); - if let Some(structure) = result.last() { - let mut children = vec![Vec::new(); means.len()]; - for i in 0..structure.len() as u32 { - let target = - crate::utils::k_means::k_means_lookup(&structure.means[i as usize], &means); - children[target].push(i); - } - let (means, children) = std::iter::zip(means, children) - .filter(|(_, x)| !x.is_empty()) - .unzip::<_, _, Vec<_>, Vec<_>>(); - result.push(Structure { means, children }); - } else { - let children = vec![Vec::new(); means.len()]; - result.push(Structure { means, children }); - } - } - result - } - fn extern_build( - vector_options: VectorOptions, - _opfamily: Opfamily, - external_build: VchordrqExternalBuildOptions, - ) -> Vec { - use std::collections::BTreeMap; - let VchordrqExternalBuildOptions { table } = external_build; - let mut parents = BTreeMap::new(); - let mut vectors = BTreeMap::new(); - pgrx::spi::Spi::connect(|client| { - use crate::datatype::memory_vector::VectorOutput; - use pgrx::pg_sys::panic::ErrorReportable; - use vector::VectorBorrowed; - let schema_query = "SELECT n.nspname::TEXT - FROM pg_catalog.pg_extension e - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace - WHERE e.extname = 'vector';"; - let pgvector_schema: String = client - .select(schema_query, None, None) - .unwrap_or_report() - .first() - .get_by_name("nspname") - .expect("external build: cannot get schema of pgvector") - .expect("external build: cannot get schema of pgvector"); - let dump_query = - format!("SELECT id, parent, vector::{pgvector_schema}.vector FROM {table};"); - let centroids = client.select(&dump_query, None, None).unwrap_or_report(); - for row in centroids { - let id: Option = row.get_by_name("id").unwrap(); - let parent: Option = row.get_by_name("parent").unwrap(); - let vector: Option = row.get_by_name("vector").unwrap(); - let id = id.expect("external build: id could not be NULL"); - let vector = vector.expect("external build: vector could not be NULL"); - let pop = parents.insert(id, parent); - if pop.is_some() { - pgrx::error!( - "external build: there are at least two lines have same id, id = {id}" - ); - } - if vector_options.dims != vector.as_borrowed().dims() { - pgrx::error!("external build: incorrect dimension, id = {id}"); - } - vectors.insert(id, crate::projection::project(vector.as_borrowed().slice())); - } - }); - if parents.len() >= 2 && parents.values().all(|x| x.is_none()) { - // if there are more than one vertexs and no edges, - // assume there is an implicit root - let n = parents.len(); - let mut result = Vec::new(); - result.push(Structure { - means: vectors.values().cloned().collect::>(), - children: vec![Vec::new(); n], - }); - result.push(Structure { - means: vec![{ - // compute the vector on root, without normalizing it - let mut sum = vec![0.0f32; vector_options.dims as _]; - for vector in vectors.values() { - f32::vector_add_inplace(&mut sum, vector); - } - f32::vector_mul_scalar_inplace(&mut sum, 1.0 / n as f32); - sum - }], - children: vec![(0..n as u32).collect()], - }); - return result; - } - let mut children = parents - .keys() - .map(|x| (*x, Vec::new())) - .collect::>(); - let mut root = None; - for (&id, &parent) in parents.iter() { - if let Some(parent) = parent { - if let Some(parent) = children.get_mut(&parent) { - parent.push(id); - } else { - pgrx::error!( - "external build: parent does not exist, id = {id}, parent = {parent}" - ); - } - } else { - if let Some(root) = root { - pgrx::error!("external build: two root, id = {root}, id = {id}"); - } else { - root = Some(id); - } - } - } - let Some(root) = root else { - pgrx::error!("external build: there are no root"); - }; - let mut heights = BTreeMap::<_, _>::new(); - fn dfs_for_heights( - heights: &mut BTreeMap>, - children: &BTreeMap>, - u: i32, - ) { - if heights.contains_key(&u) { - pgrx::error!("external build: detect a cycle, id = {u}"); - } - heights.insert(u, None); - let mut height = None; - for &v in children[&u].iter() { - dfs_for_heights(heights, children, v); - let new = heights[&v].unwrap() + 1; - if let Some(height) = height { - if height != new { - pgrx::error!("external build: two heights, id = {u}"); - } - } else { - height = Some(new); - } - } - if height.is_none() { - height = Some(1); - } - heights.insert(u, height); - } - dfs_for_heights(&mut heights, &children, root); - let heights = heights - .into_iter() - .map(|(k, v)| (k, v.expect("not a connected graph"))) - .collect::>(); - if !(1..=8).contains(&(heights[&root] - 1)) { - pgrx::error!( - "external build: unexpected tree height, height = {}", - heights[&root] - ); - } - let mut cursors = vec![0_u32; 1 + heights[&root] as usize]; - let mut labels = BTreeMap::new(); - for id in parents.keys().copied() { - let height = heights[&id]; - let cursor = cursors[height as usize]; - labels.insert(id, (height, cursor)); - cursors[height as usize] += 1; - } - fn extract( - height: u32, - labels: &BTreeMap, - vectors: &BTreeMap>, - children: &BTreeMap>, - ) -> (Vec>, Vec>) { - labels - .iter() - .filter(|(_, (h, _))| *h == height) - .map(|(id, _)| { - ( - vectors[id].clone(), - children[id].iter().map(|id| labels[id].1).collect(), - ) - }) - .unzip() - } - let mut result = Vec::new(); - for height in 1..=heights[&root] { - let (means, children) = extract(height, &labels, &vectors, &children); - result.push(Structure { means, children }); - } - result - } -} diff --git a/src/algorithm/vacuum.rs b/src/algorithm/vacuum.rs deleted file mode 100644 index 17366256..00000000 --- a/src/algorithm/vacuum.rs +++ /dev/null @@ -1,311 +0,0 @@ -use crate::algorithm::freepages; -use crate::algorithm::operator::Operator; -use crate::algorithm::tape::*; -use crate::algorithm::tuples::*; -use crate::algorithm::{Page, RelationWrite}; -use crate::utils::pipe::Pipe; -use simd::fast_scan::unpack; -use std::num::NonZeroU64; - -pub fn bulkdelete( - relation: impl RelationWrite, - delay: impl Fn(), - callback: impl Fn(NonZeroU64) -> bool, -) { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); - let height_of_root = meta_tuple.height_of_root(); - let root_first = meta_tuple.root_first(); - let vectors_first = meta_tuple.vectors_first(); - drop(meta_guard); - { - type State = Vec; - let mut state: State = vec![root_first]; - let step = |state: State| { - let mut results = Vec::new(); - for first in state { - let mut current = first; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .expect("data corruption") - .pipe(read_tuple::); - match h1_tuple { - H1TupleReader::_0(h1_tuple) => { - for first in h1_tuple.first().iter().copied() { - results.push(first); - } - } - H1TupleReader::_1(_) => (), - } - } - current = h1_guard.get_opaque().next; - } - } - results - }; - for _ in (1..height_of_root).rev() { - state = step(state); - } - for first in state { - let jump_guard = relation.read(first); - let jump_tuple = jump_guard - .get(1) - .expect("data corruption") - .pipe(read_tuple::); - let first = jump_tuple.first(); - let mut current = first; - while current != u32::MAX { - delay(); - let read = relation.read(current); - let flag = 'flag: { - for i in 1..=read.len() { - let h0_tuple = read - .get(i) - .expect("data corruption") - .pipe(read_tuple::); - match h0_tuple { - H0TupleReader::_0(h0_tuple) => { - let p = h0_tuple.payload(); - if let Some(payload) = p { - if callback(payload) { - break 'flag true; - } - } - } - H0TupleReader::_1(h0_tuple) => { - let p = h0_tuple.payload(); - for j in 0..32 { - if let Some(payload) = p[j] { - if callback(payload) { - break 'flag true; - } - } - } - } - H0TupleReader::_2(_) => (), - } - } - false - }; - if flag { - drop(read); - let mut write = relation.write(current, false); - for i in 1..=write.len() { - let h0_tuple = write - .get_mut(i) - .expect("data corruption") - .pipe(write_tuple::); - match h0_tuple { - H0TupleWriter::_0(mut h0_tuple) => { - let p = h0_tuple.payload(); - if let Some(payload) = *p { - if callback(payload) { - *p = None; - } - } - } - H0TupleWriter::_1(mut h0_tuple) => { - let p = h0_tuple.payload(); - for j in 0..32 { - if let Some(payload) = p[j] { - if callback(payload) { - p[j] = None; - } - } - } - } - H0TupleWriter::_2(_) => (), - } - } - current = write.get_opaque().next; - } else { - current = read.get_opaque().next; - } - } - } - } - { - let first = vectors_first; - let mut current = first; - while current != u32::MAX { - delay(); - let read = relation.read(current); - let flag = 'flag: { - for i in 1..=read.len() { - if let Some(vector_bytes) = read.get(i) { - let vector_tuple = vector_bytes.pipe(read_tuple::>); - let p = vector_tuple.payload(); - if let Some(payload) = p { - if callback(payload) { - break 'flag true; - } - } - } - } - false - }; - if flag { - drop(read); - let mut write = relation.write(current, true); - for i in 1..=write.len() { - if let Some(vector_bytes) = write.get(i) { - let vector_tuple = vector_bytes.pipe(read_tuple::>); - let p = vector_tuple.payload(); - if let Some(payload) = p { - if callback(payload) { - write.free(i); - } - } - }; - } - current = write.get_opaque().next; - } else { - current = read.get_opaque().next; - } - } - } -} - -pub fn maintain(relation: impl RelationWrite + Clone, delay: impl Fn()) { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); - let dims = meta_tuple.dims(); - let height_of_root = meta_tuple.height_of_root(); - let root_first = meta_tuple.root_first(); - let freepage_first = meta_tuple.freepage_first(); - drop(meta_guard); - - let firsts = { - type State = Vec; - let mut state: State = vec![root_first]; - let step = |state: State| { - let mut results = Vec::new(); - for first in state { - let mut current = first; - while current != u32::MAX { - delay(); - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .expect("data corruption") - .pipe(read_tuple::); - match h1_tuple { - H1TupleReader::_0(h1_tuple) => { - for first in h1_tuple.first().iter().copied() { - results.push(first); - } - } - H1TupleReader::_1(_) => (), - } - } - current = h1_guard.get_opaque().next; - } - } - results - }; - for _ in (1..height_of_root).rev() { - state = step(state); - } - state - }; - - for first in firsts { - let mut jump_guard = relation.write(first, false); - let mut jump_tuple = jump_guard - .get_mut(1) - .expect("data corruption") - .pipe(write_tuple::); - - let mut tape = H0Tape::<_, _>::create(|| { - if let Some(id) = freepages::fetch(relation.clone(), freepage_first) { - let mut write = relation.write(id, false); - write.clear(); - write - } else { - relation.extend(false) - } - }); - - let mut trace = Vec::new(); - - let first = *jump_tuple.first(); - let mut current = first; - let mut computing = None; - while current != u32::MAX { - delay(); - trace.push(current); - let h0_guard = relation.read(current); - for i in 1..=h0_guard.len() { - let h0_tuple = h0_guard - .get(i) - .expect("data corruption") - .pipe(read_tuple::); - match h0_tuple { - H0TupleReader::_0(h0_tuple) => { - if let Some(payload) = h0_tuple.payload() { - tape.push(H0BranchWriter { - mean: h0_tuple.mean(), - dis_u_2: h0_tuple.code().0, - factor_ppc: h0_tuple.code().1, - factor_ip: h0_tuple.code().2, - factor_err: h0_tuple.code().3, - signs: h0_tuple - .code() - .4 - .iter() - .flat_map(|x| { - std::array::from_fn::<_, 64, _>(|i| *x & (1 << i) != 0) - }) - .take(dims as _) - .collect::>(), - payload, - }); - } - } - H0TupleReader::_1(h0_tuple) => { - let computing = &mut computing.take().unwrap_or_else(Vec::new); - computing.extend_from_slice(h0_tuple.elements()); - let unpacked = unpack(computing); - for j in 0..32 { - if let Some(payload) = h0_tuple.payload()[j] { - tape.push(H0BranchWriter { - mean: h0_tuple.mean()[j], - dis_u_2: h0_tuple.metadata().0[j], - factor_ppc: h0_tuple.metadata().1[j], - factor_ip: h0_tuple.metadata().2[j], - factor_err: h0_tuple.metadata().3[j], - signs: unpacked[j] - .iter() - .flat_map(|&x| { - [x & 1 != 0, x & 2 != 0, x & 4 != 0, x & 8 != 0] - }) - .collect(), - payload, - }); - } - } - } - H0TupleReader::_2(h0_tuple) => { - let computing = computing.get_or_insert_with(Vec::new); - computing.extend_from_slice(h0_tuple.elements()); - } - } - } - current = h0_guard.get_opaque().next; - drop(h0_guard); - } - - let tape = tape.into_inner(); - let new = tape.first(); - drop(tape); - - *jump_tuple.first() = new; - drop(jump_guard); - - freepages::mark(relation.clone(), freepage_first, &trace); - } -} diff --git a/src/algorithm/vectors.rs b/src/algorithm/vectors.rs deleted file mode 100644 index d71499bb..00000000 --- a/src/algorithm/vectors.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::algorithm::operator::*; -use crate::algorithm::tuples::*; -use crate::algorithm::{Page, PageGuard, RelationRead, RelationWrite}; -use crate::utils::pipe::Pipe; -use std::num::NonZeroU64; -use vector::VectorOwned; - -pub fn vector_access_1< - O: Operator, - A: Accessor1<::Element, ::Metadata>, ->( - relation: impl RelationRead, - mean: IndexPointer, - accessor: A, -) -> A::Output { - let mut cursor = Err(mean); - let mut result = accessor; - while let Err(mean) = cursor.map_err(pointer_to_pair) { - let vector_guard = relation.read(mean.0); - let vector_tuple = vector_guard - .get(mean.1) - .expect("data corruption") - .pipe(read_tuple::>); - if vector_tuple.payload().is_some() { - panic!("data corruption"); - } - result.push(vector_tuple.elements()); - cursor = vector_tuple.metadata_or_pointer(); - } - result.finish(cursor.expect("data corruption")) -} - -pub fn vector_access_0< - O: Operator, - A: Accessor1<::Element, ::Metadata>, ->( - relation: impl RelationRead, - mean: IndexPointer, - payload: NonZeroU64, - accessor: A, -) -> Option { - let mut cursor = Err(mean); - let mut result = accessor; - while let Err(mean) = cursor.map_err(pointer_to_pair) { - let vector_guard = relation.read(mean.0); - let vector_tuple = vector_guard - .get(mean.1)? - .pipe(read_tuple::>); - if vector_tuple.payload().is_none() { - panic!("data corruption"); - } - if vector_tuple.payload() != Some(payload) { - return None; - } - result.push(vector_tuple.elements()); - cursor = vector_tuple.metadata_or_pointer(); - } - Some(result.finish(cursor.ok()?)) -} - -pub fn vector_append( - relation: impl RelationWrite + Clone, - vectors_first: u32, - vector: ::Borrowed<'_>, - payload: NonZeroU64, -) -> IndexPointer { - fn append(relation: impl RelationWrite, first: u32, bytes: &[u8]) -> IndexPointer { - if let Some(mut write) = relation.search(bytes.len()) { - let i = write.alloc(bytes).unwrap(); - return pair_to_pointer((write.id(), i)); - } - assert!(first != u32::MAX); - let mut current = first; - loop { - let read = relation.read(current); - if read.freespace() as usize >= bytes.len() || read.get_opaque().next == u32::MAX { - drop(read); - let mut write = relation.write(current, true); - if let Some(i) = write.alloc(bytes) { - return pair_to_pointer((current, i)); - } - if write.get_opaque().next == u32::MAX { - let mut extend = relation.extend(true); - write.get_opaque_mut().next = extend.id(); - drop(write); - if let Some(i) = extend.alloc(bytes) { - let result = (extend.id(), i); - drop(extend); - let mut past = relation.write(first, true); - let skip = &mut past.get_opaque_mut().skip; - assert!(*skip != u32::MAX); - *skip = std::cmp::max(*skip, result.0); - return pair_to_pointer(result); - } else { - panic!("a tuple cannot even be fit in a fresh page"); - } - } - if current == first && write.get_opaque().skip != first { - current = write.get_opaque().skip; - } else { - current = write.get_opaque().next; - } - } else { - if current == first && read.get_opaque().skip != first { - current = read.get_opaque().skip; - } else { - current = read.get_opaque().next; - } - } - } - } - let (metadata, slices) = O::Vector::vector_split(vector); - let mut chain = Ok(metadata); - for i in (0..slices.len()).rev() { - chain = Err(append( - relation.clone(), - vectors_first, - &serialize::>(&match chain { - Ok(metadata) => VectorTuple::_0 { - elements: slices[i].to_vec(), - payload: Some(payload), - metadata, - }, - Err(pointer) => VectorTuple::_1 { - elements: slices[i].to_vec(), - payload: Some(payload), - pointer, - }, - }), - )); - } - chain.err().unwrap() -} diff --git a/src/bin/pgrx_embed.rs b/src/bin/pgrx_embed.rs index 5f5c4d85..afd0164c 100644 --- a/src/bin/pgrx_embed.rs +++ b/src/bin/pgrx_embed.rs @@ -1 +1,2 @@ +#![allow(unsafe_code)] ::pgrx::pgrx_embed!(); diff --git a/src/datatype/memory_halfvec.rs b/src/datatype/memory_halfvec.rs index b60f6c5f..3e9fe09f 100644 --- a/src/datatype/memory_halfvec.rs +++ b/src/datatype/memory_halfvec.rs @@ -1,20 +1,14 @@ use half::f16; -use pgrx::datum::FromDatum; -use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::datum::{FromDatum, IntoDatum}; +use pgrx::pg_sys::{Datum, Oid}; +use pgrx::pgrx_sql_entity_graph::metadata::*; use std::marker::PhantomData; use std::ptr::NonNull; use vector::VectorBorrowed; use vector::vect::VectBorrowed; #[repr(C, align(8))] -pub struct HalfvecHeader { +struct HalfvecHeader { varlena: u32, dims: u16, unused: u16, @@ -28,10 +22,10 @@ impl HalfvecHeader { } (size_of::() + size_of::() * len).next_multiple_of(8) } - pub unsafe fn as_borrowed<'a>(this: NonNull) -> VectBorrowed<'a, f16> { + unsafe fn as_borrowed<'a>(this: NonNull) -> VectBorrowed<'a, f16> { unsafe { let this = this.as_ptr(); - VectBorrowed::new_unchecked(std::slice::from_raw_parts( + VectBorrowed::new(std::slice::from_raw_parts( (&raw const (*this).elements).cast(), (&raw const (*this).dims).read() as usize, )) @@ -93,7 +87,7 @@ impl HalfvecOutput { pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> { unsafe { HalfvecHeader::as_borrowed(self.0) } } - pub fn into_raw(self) -> *mut HalfvecHeader { + fn into_raw(self) -> *mut HalfvecHeader { let result = self.0.as_ptr(); std::mem::forget(self); result diff --git a/src/datatype/memory_scalar8.rs b/src/datatype/memory_scalar8.rs index 4f306548..19e4ff47 100644 --- a/src/datatype/memory_scalar8.rs +++ b/src/datatype/memory_scalar8.rs @@ -1,19 +1,13 @@ -use pgrx::datum::FromDatum; -use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::datum::{FromDatum, IntoDatum}; +use pgrx::pg_sys::{Datum, Oid}; +use pgrx::pgrx_sql_entity_graph::metadata::*; use std::marker::PhantomData; use std::ptr::NonNull; use vector::VectorBorrowed; use vector::scalar8::Scalar8Borrowed; #[repr(C, align(8))] -pub struct Scalar8Header { +struct Scalar8Header { varlena: u32, dims: u16, unused: u16, @@ -31,10 +25,10 @@ impl Scalar8Header { } (size_of::() + size_of::() * len).next_multiple_of(8) } - pub unsafe fn as_borrowed<'a>(this: NonNull) -> Scalar8Borrowed<'a> { + unsafe fn as_borrowed<'a>(this: NonNull) -> Scalar8Borrowed<'a> { unsafe { let this = this.as_ptr(); - Scalar8Borrowed::new_unchecked( + Scalar8Borrowed::new( (&raw const (*this).sum_of_x2).read(), (&raw const (*this).k).read(), (&raw const (*this).b).read(), @@ -105,7 +99,7 @@ impl Scalar8Output { pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> { unsafe { Scalar8Header::as_borrowed(self.0) } } - pub fn into_raw(self) -> *mut Scalar8Header { + fn into_raw(self) -> *mut Scalar8Header { let result = self.0.as_ptr(); std::mem::forget(self); result diff --git a/src/datatype/memory_vector.rs b/src/datatype/memory_vector.rs index de70ba16..4d9f9f21 100644 --- a/src/datatype/memory_vector.rs +++ b/src/datatype/memory_vector.rs @@ -1,19 +1,13 @@ -use pgrx::datum::FromDatum; -use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::datum::{FromDatum, IntoDatum}; +use pgrx::pg_sys::{Datum, Oid}; +use pgrx::pgrx_sql_entity_graph::metadata::*; use std::marker::PhantomData; use std::ptr::NonNull; use vector::VectorBorrowed; use vector::vect::VectBorrowed; #[repr(C, align(8))] -pub struct VectorHeader { +struct VectorHeader { varlena: u32, dims: u16, unused: u16, @@ -27,10 +21,10 @@ impl VectorHeader { } (size_of::() + size_of::() * len).next_multiple_of(8) } - pub unsafe fn as_borrowed<'a>(this: NonNull) -> VectBorrowed<'a, f32> { + unsafe fn as_borrowed<'a>(this: NonNull) -> VectBorrowed<'a, f32> { unsafe { let this = this.as_ptr(); - VectBorrowed::new_unchecked(std::slice::from_raw_parts( + VectBorrowed::new(std::slice::from_raw_parts( (&raw const (*this).elements).cast(), (&raw const (*this).dims).read() as usize, )) @@ -92,7 +86,7 @@ impl VectorOutput { pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> { unsafe { VectorHeader::as_borrowed(self.0) } } - pub fn into_raw(self) -> *mut VectorHeader { + fn into_raw(self) -> *mut VectorHeader { let result = self.0.as_ptr(); std::mem::forget(self); result diff --git a/src/gucs/mod.rs b/src/gucs/mod.rs deleted file mode 100644 index 2fb489e1..00000000 --- a/src/gucs/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub mod executing; -pub mod prewarm; - -pub unsafe fn init() { - unsafe { - executing::init(); - prewarm::init(); - prewarm::prewarm(); - #[cfg(any(feature = "pg13", feature = "pg14"))] - pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchordrq".as_ptr()); - #[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17"))] - pgrx::pg_sys::MarkGUCPrefixReserved(c"vchordrq".as_ptr()); - } -} diff --git a/src/gucs/prewarm.rs b/src/gucs/prewarm.rs deleted file mode 100644 index bc484367..00000000 --- a/src/gucs/prewarm.rs +++ /dev/null @@ -1,32 +0,0 @@ -use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting}; -use std::ffi::CStr; - -static PREWARM_DIM: GucSetting> = - GucSetting::>::new(Some(c"64,128,256,384,512,768,1024,1536")); - -pub unsafe fn init() { - GucRegistry::define_string_guc( - "vchordrq.prewarm_dim", - "prewarm_dim when the extension is loading.", - "prewarm_dim when the extension is loading.", - &PREWARM_DIM, - GucContext::Userset, - GucFlags::default(), - ); -} - -pub fn prewarm() { - if let Some(prewarm_dim) = PREWARM_DIM.get() { - if let Ok(prewarm_dim) = prewarm_dim.to_str() { - for dim in prewarm_dim.split(',') { - if let Ok(dim) = dim.trim().parse::() { - crate::projection::prewarm(dim as _); - } else { - pgrx::warning!("{dim:?} is not a valid integer"); - } - } - } else { - pgrx::warning!("vchordrq.prewarm_dim is not a valid UTF-8 string"); - } - } -} diff --git a/src/index/am.rs b/src/index/am.rs deleted file mode 100644 index 5db07d24..00000000 --- a/src/index/am.rs +++ /dev/null @@ -1,1103 +0,0 @@ -use crate::algorithm; -use crate::algorithm::build::{HeapRelation, Reporter}; -use crate::algorithm::operator::{Dot, L2, Op}; -use crate::algorithm::operator::{Operator, Vector}; -use crate::index::am_options::{Opfamily, Reloption}; -use crate::index::am_scan::Scanner; -use crate::index::utils::{ctid_to_pointer, pointer_to_ctid}; -use crate::index::{am_options, am_scan}; -use crate::postgres::PostgresRelation; -use crate::types::{DistanceKind, VectorKind}; -use half::f16; -use pgrx::datum::Internal; -use pgrx::pg_sys::Datum; -use std::num::NonZeroU64; -use vector::vect::VectOwned; - -static mut RELOPT_KIND_VCHORDRQ: pgrx::pg_sys::relopt_kind::Type = 0; - -pub unsafe fn init() { - unsafe { - (&raw mut RELOPT_KIND_VCHORDRQ).write(pgrx::pg_sys::add_reloption_kind()); - pgrx::pg_sys::add_string_reloption( - (&raw const RELOPT_KIND_VCHORDRQ).read(), - c"options".as_ptr(), - c"Vector index options, represented as a TOML string.".as_ptr(), - c"".as_ptr(), - None, - pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, - ); - } -} - -#[pgrx::pg_extern(sql = "")] -fn _vchordrq_amhandler(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { - type T = pgrx::pg_sys::IndexAmRoutine; - unsafe { - let index_am_routine = pgrx::pg_sys::palloc0(size_of::()) as *mut T; - index_am_routine.write(AM_HANDLER); - Internal::from(Some(Datum::from(index_am_routine))) - } -} - -const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = { - let mut am_routine = - unsafe { std::mem::MaybeUninit::::zeroed().assume_init() }; - - am_routine.type_ = pgrx::pg_sys::NodeTag::T_IndexAmRoutine; - - am_routine.amsupport = 1; - am_routine.amcanorderbyop = true; - - #[cfg(feature = "pg17")] - { - am_routine.amcanbuildparallel = true; - } - - // Index access methods that set `amoptionalkey` to `false` - // must index all tuples, even if the first column is `NULL`. - // However, PostgreSQL does not generate a path if there is no - // index clauses, even if there is a `ORDER BY` clause. - // So we have to set it to `true` and set costs of every path - // for vector index scans without `ORDER BY` clauses a large number - // and throw errors if someone really wants such a path. - am_routine.amoptionalkey = true; - - am_routine.amvalidate = Some(amvalidate); - am_routine.amoptions = Some(amoptions); - am_routine.amcostestimate = Some(amcostestimate); - - am_routine.ambuild = Some(ambuild); - am_routine.ambuildempty = Some(ambuildempty); - am_routine.aminsert = Some(aminsert); - am_routine.ambulkdelete = Some(ambulkdelete); - am_routine.amvacuumcleanup = Some(amvacuumcleanup); - - am_routine.ambeginscan = Some(ambeginscan); - am_routine.amrescan = Some(amrescan); - am_routine.amgettuple = Some(amgettuple); - am_routine.amendscan = Some(amendscan); - - am_routine -}; - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amvalidate(_opclass_oid: pgrx::pg_sys::Oid) -> bool { - true -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amoptions(reloptions: Datum, validate: bool) -> *mut pgrx::pg_sys::bytea { - let rdopts = unsafe { - pgrx::pg_sys::build_reloptions( - reloptions, - validate, - (&raw const RELOPT_KIND_VCHORDRQ).read(), - size_of::(), - Reloption::TAB.as_ptr(), - Reloption::TAB.len() as _, - ) - }; - rdopts as *mut pgrx::pg_sys::bytea -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amcostestimate( - _root: *mut pgrx::pg_sys::PlannerInfo, - path: *mut pgrx::pg_sys::IndexPath, - _loop_count: f64, - index_startup_cost: *mut pgrx::pg_sys::Cost, - index_total_cost: *mut pgrx::pg_sys::Cost, - index_selectivity: *mut pgrx::pg_sys::Selectivity, - index_correlation: *mut f64, - index_pages: *mut f64, -) { - unsafe { - if (*path).indexorderbys.is_null() && (*path).indexclauses.is_null() { - *index_startup_cost = f64::MAX; - *index_total_cost = f64::MAX; - *index_selectivity = 0.0; - *index_correlation = 0.0; - *index_pages = 0.0; - return; - } - *index_startup_cost = 0.0; - *index_total_cost = 0.0; - *index_selectivity = 1.0; - *index_correlation = 1.0; - *index_pages = 0.0; - } -} - -#[derive(Debug, Clone)] -struct PgReporter {} - -impl Reporter for PgReporter { - fn tuples_total(&mut self, tuples_total: u64) { - unsafe { - pgrx::pg_sys::pgstat_progress_update_param( - pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_TOTAL as _, - tuples_total as _, - ); - } - } -} - -impl PgReporter { - fn tuples_done(&mut self, tuples_done: u64) { - unsafe { - pgrx::pg_sys::pgstat_progress_update_param( - pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_DONE as _, - tuples_done as _, - ); - } - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambuild( - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, -) -> *mut pgrx::pg_sys::IndexBuildResult { - use validator::Validate; - #[derive(Debug, Clone)] - pub struct Heap { - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, - opfamily: Opfamily, - } - impl HeapRelation for Heap { - fn traverse(&self, progress: bool, callback: F) - where - F: FnMut((NonZeroU64, O::Vector)), - { - pub struct State<'a, F> { - pub this: &'a Heap, - pub callback: F, - } - #[pgrx::pg_guard] - unsafe extern "C" fn call( - _index: pgrx::pg_sys::Relation, - ctid: pgrx::pg_sys::ItemPointer, - values: *mut Datum, - is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut core::ffi::c_void, - ) where - F: FnMut((NonZeroU64, O::Vector)), - { - let state = unsafe { &mut *state.cast::>() }; - let opfamily = state.this.opfamily; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - let pointer = unsafe { ctid_to_pointer(ctid.read()) }; - if let Some(vector) = vector { - (state.callback)((pointer, O::Vector::from_owned(vector))); - } - } - let table_am = unsafe { &*(*self.heap).rd_tableam }; - let mut state = State { - this: self, - callback, - }; - unsafe { - table_am.index_build_range_scan.unwrap()( - self.heap, - self.index, - self.index_info, - true, - false, - progress, - 0, - pgrx::pg_sys::InvalidBlockNumber, - Some(call::), - (&mut state) as *mut State as *mut _, - std::ptr::null_mut(), - ); - } - } - - fn opfamily(&self) -> Opfamily { - self.opfamily - } - } - let (vector_options, vchordrq_options) = unsafe { am_options::options(index) }; - if let Err(errors) = Validate::validate(&vector_options) { - pgrx::error!("error while validating options: {}", errors); - } - if vector_options.dims == 0 { - pgrx::error!("error while validating options: dimension cannot be 0"); - } - if vector_options.dims > 60000 { - pgrx::error!("error while validating options: dimension is too large"); - } - if let Err(errors) = Validate::validate(&vchordrq_options) { - pgrx::error!("error while validating options: {}", errors); - } - let opfamily = unsafe { am_options::opfamily(index) }; - let heap_relation = Heap { - heap, - index, - index_info, - opfamily, - }; - let mut reporter = PgReporter {}; - let index_relation = unsafe { PostgresRelation::new(index) }; - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - algorithm::build::build::, L2>, Heap, _>( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ) - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - algorithm::build::build::, Dot>, Heap, _>( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ) - } - (VectorKind::Vecf16, DistanceKind::L2) => { - algorithm::build::build::, L2>, Heap, _>( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ) - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - algorithm::build::build::, Dot>, Heap, _>( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ) - } - } - if let Some(leader) = unsafe { VchordrqLeader::enter(heap, index, (*index_info).ii_Concurrent) } - { - unsafe { - parallel_build( - index, - heap, - index_info, - leader.tablescandesc, - leader.vchordrqshared, - Some(reporter), - ); - leader.wait(); - let nparticipants = leader.nparticipants; - loop { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*leader.vchordrqshared).mutex); - if (*leader.vchordrqshared).nparticipantsdone == nparticipants { - pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqshared).mutex); - break; - } - pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqshared).mutex); - pgrx::pg_sys::ConditionVariableSleep( - &raw mut (*leader.vchordrqshared).workersdonecv, - pgrx::pg_sys::WaitEventIPC::WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN, - ); - } - pgrx::pg_sys::ConditionVariableCancelSleep(); - } - } else { - let mut indtuples = 0; - reporter.tuples_done(indtuples); - let relation = unsafe { PostgresRelation::new(index) }; - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - HeapRelation::, L2>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, L2>>( - relation.clone(), - pointer, - vector, - ); - indtuples += 1; - reporter.tuples_done(indtuples); - }, - ); - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - HeapRelation::, Dot>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, Dot>>( - relation.clone(), - pointer, - vector, - ); - indtuples += 1; - reporter.tuples_done(indtuples); - }, - ); - } - (VectorKind::Vecf16, DistanceKind::L2) => { - HeapRelation::, L2>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, L2>>( - relation.clone(), - pointer, - vector, - ); - indtuples += 1; - reporter.tuples_done(indtuples); - }, - ); - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - HeapRelation::, Dot>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, Dot>>( - relation.clone(), - pointer, - vector, - ); - indtuples += 1; - reporter.tuples_done(indtuples); - }, - ); - } - } - } - let relation = unsafe { PostgresRelation::new(index) }; - let delay = || { - pgrx::check_for_interrupts!(); - }; - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - type O = Op, L2>; - algorithm::vacuum::maintain::(relation, delay); - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - type O = Op, Dot>; - algorithm::vacuum::maintain::(relation, delay); - } - (VectorKind::Vecf16, DistanceKind::L2) => { - type O = Op, L2>; - algorithm::vacuum::maintain::(relation, delay); - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - type O = Op, Dot>; - algorithm::vacuum::maintain::(relation, delay); - } - } - unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } -} - -struct VchordrqShared { - /* Immutable state */ - heaprelid: pgrx::pg_sys::Oid, - indexrelid: pgrx::pg_sys::Oid, - isconcurrent: bool, - - /* Worker progress */ - workersdonecv: pgrx::pg_sys::ConditionVariable, - - /* Mutex for mutable state */ - mutex: pgrx::pg_sys::slock_t, - - /* Mutable state */ - nparticipantsdone: i32, - indtuples: u64, -} - -fn is_mvcc_snapshot(snapshot: *mut pgrx::pg_sys::SnapshotData) -> bool { - matches!( - unsafe { (*snapshot).snapshot_type }, - pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC - | pgrx::pg_sys::SnapshotType::SNAPSHOT_HISTORIC_MVCC - ) -} - -struct VchordrqLeader { - pcxt: *mut pgrx::pg_sys::ParallelContext, - nparticipants: i32, - vchordrqshared: *mut VchordrqShared, - tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, - snapshot: pgrx::pg_sys::Snapshot, -} - -impl VchordrqLeader { - pub unsafe fn enter( - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - isconcurrent: bool, - ) -> Option { - unsafe fn compute_parallel_workers( - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - ) -> i32 { - unsafe { - if pgrx::pg_sys::plan_create_index_workers((*heap).rd_id, (*index).rd_id) == 0 { - return 0; - } - if !(*heap).rd_options.is_null() { - let std_options = (*heap).rd_options.cast::(); - std::cmp::min( - (*std_options).parallel_workers, - pgrx::pg_sys::max_parallel_maintenance_workers, - ) - } else { - pgrx::pg_sys::max_parallel_maintenance_workers - } - } - } - - let request = unsafe { compute_parallel_workers(heap, index) }; - if request <= 0 { - return None; - } - - unsafe { - pgrx::pg_sys::EnterParallelMode(); - } - let pcxt = unsafe { - pgrx::pg_sys::CreateParallelContext( - c"vchord".as_ptr(), - c"vchordrq_parallel_build_main".as_ptr(), - request, - ) - }; - - let snapshot = if isconcurrent { - unsafe { pgrx::pg_sys::RegisterSnapshot(pgrx::pg_sys::GetTransactionSnapshot()) } - } else { - &raw mut pgrx::pg_sys::SnapshotAnyData - }; - - fn estimate_chunk(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { - e.space_for_chunks += x.next_multiple_of(pgrx::pg_sys::ALIGNOF_BUFFER as _); - } - fn estimate_keys(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { - e.number_of_keys += x; - } - let est_tablescandesc = - unsafe { pgrx::pg_sys::table_parallelscan_estimate(heap, snapshot) }; - unsafe { - estimate_chunk(&mut (*pcxt).estimator, size_of::()); - estimate_keys(&mut (*pcxt).estimator, 1); - estimate_chunk(&mut (*pcxt).estimator, est_tablescandesc); - estimate_keys(&mut (*pcxt).estimator, 1); - } - - unsafe { - pgrx::pg_sys::InitializeParallelDSM(pcxt); - if (*pcxt).seg.is_null() { - if is_mvcc_snapshot(snapshot) { - pgrx::pg_sys::UnregisterSnapshot(snapshot); - } - pgrx::pg_sys::DestroyParallelContext(pcxt); - pgrx::pg_sys::ExitParallelMode(); - return None; - } - } - - let vchordrqshared = unsafe { - let vchordrqshared = - pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, size_of::()) - .cast::(); - vchordrqshared.write(VchordrqShared { - heaprelid: (*heap).rd_id, - indexrelid: (*index).rd_id, - isconcurrent, - workersdonecv: std::mem::zeroed(), - mutex: std::mem::zeroed(), - nparticipantsdone: 0, - indtuples: 0, - }); - pgrx::pg_sys::ConditionVariableInit(&raw mut (*vchordrqshared).workersdonecv); - pgrx::pg_sys::SpinLockInit(&raw mut (*vchordrqshared).mutex); - vchordrqshared - }; - - let tablescandesc = unsafe { - let tablescandesc = pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, est_tablescandesc) - .cast::(); - pgrx::pg_sys::table_parallelscan_initialize(heap, tablescandesc, snapshot); - tablescandesc - }; - - unsafe { - pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000001, vchordrqshared.cast()); - pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000002, tablescandesc.cast()); - } - - unsafe { - pgrx::pg_sys::LaunchParallelWorkers(pcxt); - } - - let nworkers_launched = unsafe { (*pcxt).nworkers_launched }; - - unsafe { - if nworkers_launched == 0 { - pgrx::pg_sys::WaitForParallelWorkersToFinish(pcxt); - if is_mvcc_snapshot(snapshot) { - pgrx::pg_sys::UnregisterSnapshot(snapshot); - } - pgrx::pg_sys::DestroyParallelContext(pcxt); - pgrx::pg_sys::ExitParallelMode(); - return None; - } - } - - Some(Self { - pcxt, - nparticipants: nworkers_launched + 1, - vchordrqshared, - tablescandesc, - snapshot, - }) - } - - pub fn wait(&self) { - unsafe { - pgrx::pg_sys::WaitForParallelWorkersToAttach(self.pcxt); - } - } -} - -impl Drop for VchordrqLeader { - fn drop(&mut self) { - if !std::thread::panicking() { - unsafe { - pgrx::pg_sys::WaitForParallelWorkersToFinish(self.pcxt); - if is_mvcc_snapshot(self.snapshot) { - pgrx::pg_sys::UnregisterSnapshot(self.snapshot); - } - pgrx::pg_sys::DestroyParallelContext(self.pcxt); - pgrx::pg_sys::ExitParallelMode(); - } - } - } -} - -#[pgrx::pg_guard] -#[unsafe(no_mangle)] -pub unsafe extern "C" fn vchordrq_parallel_build_main( - _seg: *mut pgrx::pg_sys::dsm_segment, - toc: *mut pgrx::pg_sys::shm_toc, -) { - let vchordrqshared = unsafe { - pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000001, false).cast::() - }; - let tablescandesc = unsafe { - pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000002, false) - .cast::() - }; - let heap_lockmode; - let index_lockmode; - if unsafe { !(*vchordrqshared).isconcurrent } { - heap_lockmode = pgrx::pg_sys::ShareLock as pgrx::pg_sys::LOCKMODE; - index_lockmode = pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; - } else { - heap_lockmode = pgrx::pg_sys::ShareUpdateExclusiveLock as pgrx::pg_sys::LOCKMODE; - index_lockmode = pgrx::pg_sys::RowExclusiveLock as pgrx::pg_sys::LOCKMODE; - } - let heap = unsafe { pgrx::pg_sys::table_open((*vchordrqshared).heaprelid, heap_lockmode) }; - let index = unsafe { pgrx::pg_sys::index_open((*vchordrqshared).indexrelid, index_lockmode) }; - let index_info = unsafe { pgrx::pg_sys::BuildIndexInfo(index) }; - unsafe { - (*index_info).ii_Concurrent = (*vchordrqshared).isconcurrent; - } - - unsafe { - parallel_build(index, heap, index_info, tablescandesc, vchordrqshared, None); - } - - unsafe { - pgrx::pg_sys::index_close(index, index_lockmode); - pgrx::pg_sys::table_close(heap, heap_lockmode); - } -} - -unsafe fn parallel_build( - index: *mut pgrx::pg_sys::RelationData, - heap: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, - tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, - vchordrqshared: *mut VchordrqShared, - mut reporter: Option, -) { - #[derive(Debug, Clone)] - pub struct Heap { - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, - opfamily: Opfamily, - scan: *mut pgrx::pg_sys::TableScanDescData, - } - impl HeapRelation for Heap { - fn traverse(&self, progress: bool, callback: F) - where - F: FnMut((NonZeroU64, O::Vector)), - { - pub struct State<'a, F> { - pub this: &'a Heap, - pub callback: F, - } - #[pgrx::pg_guard] - unsafe extern "C" fn call( - _index: pgrx::pg_sys::Relation, - ctid: pgrx::pg_sys::ItemPointer, - values: *mut Datum, - is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut core::ffi::c_void, - ) where - F: FnMut((NonZeroU64, O::Vector)), - { - let state = unsafe { &mut *state.cast::>() }; - let opfamily = state.this.opfamily; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - let pointer = unsafe { ctid_to_pointer(ctid.read()) }; - if let Some(vector) = vector { - (state.callback)((pointer, O::Vector::from_owned(vector))); - } - } - let table_am = unsafe { &*(*self.heap).rd_tableam }; - let mut state = State { - this: self, - callback, - }; - unsafe { - table_am.index_build_range_scan.unwrap()( - self.heap, - self.index, - self.index_info, - true, - false, - progress, - 0, - pgrx::pg_sys::InvalidBlockNumber, - Some(call::), - (&mut state) as *mut State as *mut _, - self.scan, - ); - } - } - - fn opfamily(&self) -> Opfamily { - self.opfamily - } - } - - let index_relation = unsafe { PostgresRelation::new(index) }; - - let scan = unsafe { pgrx::pg_sys::table_beginscan_parallel(heap, tablescandesc) }; - let opfamily = unsafe { am_options::opfamily(index) }; - let heap_relation = Heap { - heap, - index, - index_info, - opfamily, - scan, - }; - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - HeapRelation::, L2>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, L2>>( - index_relation.clone(), - pointer, - vector, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); - } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); - } - } - }, - ); - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - HeapRelation::, Dot>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, Dot>>( - index_relation.clone(), - pointer, - vector, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); - } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); - } - } - }, - ); - } - (VectorKind::Vecf16, DistanceKind::L2) => { - HeapRelation::, L2>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, L2>>( - index_relation.clone(), - pointer, - vector, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); - } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); - } - } - }, - ); - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - HeapRelation::, Dot>>::traverse( - &heap_relation, - true, - |(pointer, vector)| { - algorithm::insert::insert::, Dot>>( - index_relation.clone(), - pointer, - vector, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); - } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); - } - } - }, - ); - } - } - unsafe { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).nparticipantsdone += 1; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); - pgrx::pg_sys::ConditionVariableSignal(&raw mut (*vchordrqshared).workersdonecv); - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambuildempty(_index: pgrx::pg_sys::Relation) { - pgrx::error!("Unlogged indexes are not supported."); -} - -#[cfg(feature = "pg13")] -#[pgrx::pg_guard] -pub unsafe extern "C" fn aminsert( - index: pgrx::pg_sys::Relation, - values: *mut Datum, - is_null: *mut bool, - heap_tid: pgrx::pg_sys::ItemPointer, - _heap: pgrx::pg_sys::Relation, - _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, - _index_info: *mut pgrx::pg_sys::IndexInfo, -) -> bool { - let opfamily = unsafe { am_options::opfamily(index) }; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - if let Some(vector) = vector { - let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - algorithm::insert::insert::, L2>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - algorithm::insert::insert::, Dot>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - (VectorKind::Vecf16, DistanceKind::L2) => { - algorithm::insert::insert::, L2>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - algorithm::insert::insert::, Dot>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - } - } - false -} - -#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16", feature = "pg17"))] -#[pgrx::pg_guard] -pub unsafe extern "C" fn aminsert( - index: pgrx::pg_sys::Relation, - values: *mut Datum, - is_null: *mut bool, - heap_tid: pgrx::pg_sys::ItemPointer, - _heap: pgrx::pg_sys::Relation, - _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, - _index_unchanged: bool, - _index_info: *mut pgrx::pg_sys::IndexInfo, -) -> bool { - let opfamily = unsafe { am_options::opfamily(index) }; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - if let Some(vector) = vector { - let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - algorithm::insert::insert::, L2>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - algorithm::insert::insert::, Dot>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - (VectorKind::Vecf16, DistanceKind::L2) => { - algorithm::insert::insert::, L2>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - algorithm::insert::insert::, Dot>>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - ) - } - } - } - false -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambeginscan( - index: pgrx::pg_sys::Relation, - n_keys: std::os::raw::c_int, - n_orderbys: std::os::raw::c_int, -) -> pgrx::pg_sys::IndexScanDesc { - use pgrx::memcxt::PgMemoryContexts::CurrentMemoryContext; - - let scan = unsafe { pgrx::pg_sys::RelationGetIndexScan(index, n_keys, n_orderbys) }; - unsafe { - let scanner = am_scan::scan_make(None, None, false); - (*scan).opaque = CurrentMemoryContext.leak_and_drop_on_delete(scanner).cast(); - } - scan -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amrescan( - scan: pgrx::pg_sys::IndexScanDesc, - keys: pgrx::pg_sys::ScanKey, - _n_keys: std::os::raw::c_int, - orderbys: pgrx::pg_sys::ScanKey, - _n_orderbys: std::os::raw::c_int, -) { - unsafe { - if !keys.is_null() && (*scan).numberOfKeys > 0 { - std::ptr::copy(keys, (*scan).keyData, (*scan).numberOfKeys as _); - } - if !orderbys.is_null() && (*scan).numberOfOrderBys > 0 { - std::ptr::copy(orderbys, (*scan).orderByData, (*scan).numberOfOrderBys as _); - } - let opfamily = am_options::opfamily((*scan).indexRelation); - let (orderbys, spheres) = { - let mut orderbys = Vec::new(); - let mut spheres = Vec::new(); - if (*scan).numberOfOrderBys == 0 && (*scan).numberOfKeys == 0 { - pgrx::error!( - "vector search with no WHERE clause and no ORDER BY clause is not supported" - ); - } - for i in 0..(*scan).numberOfOrderBys { - let data = (*scan).orderByData.add(i as usize); - let value = (*data).sk_argument; - let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; - match (*data).sk_strategy { - 1 => orderbys.push(opfamily.datum_to_vector(value, is_null)), - _ => unreachable!(), - } - } - for i in 0..(*scan).numberOfKeys { - let data = (*scan).keyData.add(i as usize); - let value = (*data).sk_argument; - let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; - match (*data).sk_strategy { - 2 => spheres.push(opfamily.datum_to_sphere(value, is_null)), - _ => unreachable!(), - } - } - (orderbys, spheres) - }; - let (vector, threshold, recheck) = am_scan::scan_build(orderbys, spheres, opfamily); - let scanner = (*scan).opaque.cast::().as_mut().unwrap_unchecked(); - let scanner = std::mem::replace(scanner, am_scan::scan_make(vector, threshold, recheck)); - am_scan::scan_release(scanner); - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amgettuple( - scan: pgrx::pg_sys::IndexScanDesc, - direction: pgrx::pg_sys::ScanDirection::Type, -) -> bool { - if direction != pgrx::pg_sys::ScanDirection::ForwardScanDirection { - pgrx::error!("vector search without a forward scan direction is not supported"); - } - // https://www.postgresql.org/docs/current/index-locking.html - // If heap entries referenced physical pointers are deleted before - // they are consumed by PostgreSQL, PostgreSQL will received wrong - // physical pointers: no rows or irreverent rows are referenced. - if unsafe { (*(*scan).xs_snapshot).snapshot_type } != pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC - { - pgrx::error!("scanning with a non-MVCC-compliant snapshot is not supported"); - } - let scanner = unsafe { (*scan).opaque.cast::().as_mut().unwrap_unchecked() }; - let relation = unsafe { PostgresRelation::new((*scan).indexRelation) }; - if let Some((pointer, recheck)) = am_scan::scan_next(scanner, relation) { - let ctid = pointer_to_ctid(pointer); - unsafe { - (*scan).xs_heaptid = ctid; - (*scan).xs_recheckorderby = false; - (*scan).xs_recheck = recheck; - } - true - } else { - false - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) { - unsafe { - let scanner = (*scan).opaque.cast::().as_mut().unwrap_unchecked(); - let scanner = std::mem::replace(scanner, am_scan::scan_make(None, None, false)); - am_scan::scan_release(scanner); - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambulkdelete( - info: *mut pgrx::pg_sys::IndexVacuumInfo, - stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, - callback: pgrx::pg_sys::IndexBulkDeleteCallback, - callback_state: *mut std::os::raw::c_void, -) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { - let mut stats = stats; - if stats.is_null() { - stats = unsafe { - pgrx::pg_sys::palloc0(size_of::()).cast() - }; - } - let opfamily = unsafe { am_options::opfamily((*info).index) }; - let callback = callback.unwrap(); - let callback = |p: NonZeroU64| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; - let index = unsafe { PostgresRelation::new((*info).index) }; - let delay = || unsafe { - pgrx::pg_sys::vacuum_delay_point(); - }; - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - type O = Op, L2>; - algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - type O = Op, Dot>; - algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); - } - (VectorKind::Vecf16, DistanceKind::L2) => { - type O = Op, L2>; - algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - type O = Op, Dot>; - algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); - } - } - stats -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amvacuumcleanup( - info: *mut pgrx::pg_sys::IndexVacuumInfo, - _stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, -) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { - let opfamily = unsafe { am_options::opfamily((*info).index) }; - let index = unsafe { PostgresRelation::new((*info).index) }; - let delay = || unsafe { - pgrx::pg_sys::vacuum_delay_point(); - }; - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - type O = Op, L2>; - algorithm::vacuum::maintain::(index, delay); - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - type O = Op, Dot>; - algorithm::vacuum::maintain::(index, delay); - } - (VectorKind::Vecf16, DistanceKind::L2) => { - type O = Op, L2>; - algorithm::vacuum::maintain::(index, delay); - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - type O = Op, Dot>; - algorithm::vacuum::maintain::(index, delay); - } - } - std::ptr::null_mut() -} diff --git a/src/index/am/am_build.rs b/src/index/am/am_build.rs new file mode 100644 index 00000000..e8dab2d1 --- /dev/null +++ b/src/index/am/am_build.rs @@ -0,0 +1,1351 @@ +use crate::datatype::typmod::Typmod; +use crate::index::am::{Reloption, ctid_to_pointer}; +use crate::index::opclass::{Opfamily, opfamily}; +use crate::index::projection::RandomProject; +use crate::index::storage::{PostgresPage, PostgresRelation}; +use crate::index::types::*; +use algorithm::operator::{Dot, L2, Op, Vector}; +use algorithm::types::*; +use algorithm::{PageGuard, RelationRead, RelationWrite}; +use half::f16; +use pgrx::pg_sys::Datum; +use rand::Rng; +use simd::Floating; +use std::num::NonZeroU64; +use std::ops::Deref; +use vector::vect::VectOwned; +use vector::{VectorBorrowed, VectorOwned}; + +#[derive(Debug, Clone)] +struct Heap { + heap_relation: pgrx::pg_sys::Relation, + index_relation: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, + opfamily: Opfamily, + scan: *mut pgrx::pg_sys::TableScanDescData, +} + +impl Heap { + fn traverse(&self, progress: bool, callback: F) { + pub struct State<'a, F> { + pub this: &'a Heap, + pub callback: F, + } + #[pgrx::pg_guard] + unsafe extern "C" fn call( + _index_relation: pgrx::pg_sys::Relation, + ctid: pgrx::pg_sys::ItemPointer, + values: *mut Datum, + is_null: *mut bool, + _tuple_is_alive: bool, + state: *mut core::ffi::c_void, + ) where + F: FnMut((NonZeroU64, V)), + { + let state = unsafe { &mut *state.cast::>() }; + let opfamily = state.this.opfamily; + let vector = unsafe { opfamily.input_vector(*values.add(0), *is_null.add(0)) }; + let pointer = unsafe { ctid_to_pointer(ctid.read()) }; + if let Some(vector) = vector { + (state.callback)((pointer, V::from_owned(vector))); + } + } + let table_am = unsafe { &*(*self.heap_relation).rd_tableam }; + let mut state = State { + this: self, + callback, + }; + unsafe { + table_am.index_build_range_scan.unwrap()( + self.heap_relation, + self.index_relation, + self.index_info, + true, + false, + progress, + 0, + pgrx::pg_sys::InvalidBlockNumber, + Some(call::), + (&mut state) as *mut State as *mut _, + self.scan, + ); + } + } +} + +#[derive(Debug, Clone)] +struct PostgresReporter {} + +impl PostgresReporter { + fn tuples_total(&mut self, tuples_total: u64) { + unsafe { + pgrx::pg_sys::pgstat_progress_update_param( + pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_TOTAL as _, + tuples_total as _, + ); + } + } + fn tuples_done(&mut self, tuples_done: u64) { + unsafe { + pgrx::pg_sys::pgstat_progress_update_param( + pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_DONE as _, + tuples_done as _, + ); + } + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambuild( + heap_relation: pgrx::pg_sys::Relation, + index_relation: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, +) -> *mut pgrx::pg_sys::IndexBuildResult { + use validator::Validate; + let (vector_options, vchordrq_options) = unsafe { options(index_relation) }; + if let Err(errors) = Validate::validate(&vector_options) { + pgrx::error!("error while validating options: {}", errors); + } + if vector_options.dims == 0 { + pgrx::error!("error while validating options: dimension cannot be 0"); + } + if vector_options.dims > 60000 { + pgrx::error!("error while validating options: dimension is too large"); + } + if let Err(errors) = Validate::validate(&vchordrq_options) { + pgrx::error!("error while validating options: {}", errors); + } + let opfamily = unsafe { opfamily(index_relation) }; + let heap = Heap { + heap_relation, + index_relation, + index_info, + opfamily, + scan: std::ptr::null_mut(), + }; + let index = unsafe { PostgresRelation::new(index_relation) }; + let mut reporter = PostgresReporter {}; + let structures = match vchordrq_options.build.source.clone() { + VchordrqBuildSourceOptions::External(external_build) => { + make_external_build(vector_options.clone(), opfamily, external_build.clone()) + } + VchordrqBuildSourceOptions::Internal(internal_build) => { + let mut tuples_total = 0_u64; + let samples = 'a: { + let mut rand = rand::rng(); + let Some(max_number_of_samples) = internal_build + .lists + .last() + .map(|x| x.saturating_mul(internal_build.sampling_factor)) + else { + break 'a Vec::new(); + }; + let mut samples = Vec::new(); + let mut number_of_samples = 0_u32; + match opfamily.vector_kind() { + VectorKind::Vecf32 => { + heap.traverse(false, |(_, vector): (_, VectOwned)| { + let vector = vector.as_borrowed(); + assert_eq!( + vector_options.dims, + vector.dims(), + "invalid vector dimensions" + ); + if number_of_samples < max_number_of_samples { + samples.push(VectOwned::::build_to_vecf32(vector)); + number_of_samples += 1; + } else { + let index = rand.random_range(0..max_number_of_samples) as usize; + samples[index] = VectOwned::::build_to_vecf32(vector); + } + tuples_total += 1; + }); + } + VectorKind::Vecf16 => { + heap.traverse(false, |(_, vector): (_, VectOwned)| { + let vector = vector.as_borrowed(); + assert_eq!( + vector_options.dims, + vector.dims(), + "invalid vector dimensions" + ); + if number_of_samples < max_number_of_samples { + samples.push(VectOwned::::build_to_vecf32(vector)); + number_of_samples += 1; + } else { + let index = rand.random_range(0..max_number_of_samples) as usize; + samples[index] = VectOwned::::build_to_vecf32(vector); + } + tuples_total += 1; + }); + } + } + samples + }; + reporter.tuples_total(tuples_total); + make_internal_build(vector_options.clone(), internal_build.clone(), samples) + } + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => algorithm::build::, L2>>( + vector_options, + vchordrq_options.index, + index.clone(), + map_structures(structures, |x| InternalBuild::build_from_vecf32(&x)), + ), + (VectorKind::Vecf32, DistanceKind::Dot) => algorithm::build::, Dot>>( + vector_options, + vchordrq_options.index, + index.clone(), + map_structures(structures, |x| InternalBuild::build_from_vecf32(&x)), + ), + (VectorKind::Vecf16, DistanceKind::L2) => algorithm::build::, L2>>( + vector_options, + vchordrq_options.index, + index.clone(), + map_structures(structures, |x| InternalBuild::build_from_vecf32(&x)), + ), + (VectorKind::Vecf16, DistanceKind::Dot) => algorithm::build::, Dot>>( + vector_options, + vchordrq_options.index, + index.clone(), + map_structures(structures, |x| InternalBuild::build_from_vecf32(&x)), + ), + } + let cache = if vchordrq_options.build.pin { + let mut trace = algorithm::cache(index.clone()); + trace.sort(); + trace.dedup(); + if let Some(max) = trace.last().copied() { + let mut mapping = vec![u32::MAX; 1 + max as usize]; + let mut pages = Vec::>::with_capacity(trace.len()); + for id in trace { + mapping[id as usize] = pages.len() as u32; + pages.push(index.read(id).clone_into_boxed()); + } + vchordrq_cached::VchordrqCached::_1 { mapping, pages } + } else { + vchordrq_cached::VchordrqCached::_0 {} + } + } else { + vchordrq_cached::VchordrqCached::_0 {} + }; + if let Some(leader) = unsafe { + VchordrqLeader::enter( + heap_relation, + index_relation, + (*index_info).ii_Concurrent, + cache, + ) + } { + unsafe { + parallel_build( + index_relation, + heap_relation, + index_info, + leader.tablescandesc, + leader.vchordrqshared, + leader.vchordrqcached, + Some(reporter), + ); + leader.wait(); + let nparticipants = leader.nparticipants; + loop { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*leader.vchordrqshared).mutex); + if (*leader.vchordrqshared).nparticipantsdone == nparticipants { + pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqshared).mutex); + break; + } + pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqshared).mutex); + pgrx::pg_sys::ConditionVariableSleep( + &raw mut (*leader.vchordrqshared).workersdonecv, + pgrx::pg_sys::WaitEventIPC::WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN, + ); + } + pgrx::pg_sys::ConditionVariableCancelSleep(); + } + } else { + let mut indtuples = 0; + reporter.tuples_done(indtuples); + let relation = unsafe { PostgresRelation::new(index_relation) }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, L2>>( + relation.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, Dot>>( + relation.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, L2>>( + relation.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, Dot>>( + relation.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }); + } + } + } + let check = || { + pgrx::check_for_interrupts!(); + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + algorithm::maintain::, L2>>(index, check); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + algorithm::maintain::, Dot>>(index, check); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + algorithm::maintain::, L2>>(index, check); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + algorithm::maintain::, Dot>>(index, check); + } + } + unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } +} + +struct VchordrqShared { + /* Immutable state */ + heaprelid: pgrx::pg_sys::Oid, + indexrelid: pgrx::pg_sys::Oid, + isconcurrent: bool, + + /* Worker progress */ + workersdonecv: pgrx::pg_sys::ConditionVariable, + + /* Mutex for mutable state */ + mutex: pgrx::pg_sys::slock_t, + + /* Mutable state */ + nparticipantsdone: i32, + indtuples: u64, +} + +mod vchordrq_cached { + pub const ALIGN: usize = 8; + pub type Tag = u64; + + use crate::index::storage::PostgresPage; + use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; + use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout}; + + #[repr(C, align(8))] + #[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] + struct VchordrqCachedHeader0 {} + + #[repr(C, align(8))] + #[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] + struct VchordrqCachedHeader1 { + mapping_s: usize, + mapping_e: usize, + pages_s: usize, + pages_e: usize, + } + + pub enum VchordrqCached { + _0 {}, + _1 { + mapping: Vec, + pages: Vec>, + }, + } + + impl VchordrqCached { + pub fn serialize(&self) -> Vec { + let mut buffer = Vec::new(); + match self { + VchordrqCached::_0 {} => { + buffer.extend((0 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat_n(0, size_of::())); + buffer[size_of::()..][..size_of::()] + .copy_from_slice(VchordrqCachedHeader0 {}.as_bytes()); + } + VchordrqCached::_1 { mapping, pages } => { + buffer.extend((1 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat_n(0, size_of::())); + let mapping_s = buffer.len(); + buffer.extend(mapping.as_bytes()); + let mapping_e = buffer.len(); + while buffer.len() % ALIGN != 0 { + buffer.push(0u8); + } + let pages_s = buffer.len(); + buffer.extend(pages.iter().flat_map(|x| unsafe { + std::mem::transmute::<&PostgresPage, &[u8; 8192]>(x.as_ref()) + })); + let pages_e = buffer.len(); + while buffer.len() % ALIGN != 0 { + buffer.push(0u8); + } + buffer[size_of::()..][..size_of::()] + .copy_from_slice( + VchordrqCachedHeader1 { + mapping_s, + mapping_e, + pages_s, + pages_e, + } + .as_bytes(), + ); + } + } + buffer + } + } + + #[derive(Debug, Clone, Copy)] + pub enum VchordrqCachedReader<'a> { + _0(#[allow(dead_code)] VchordrqCachedReader0<'a>), + _1(VchordrqCachedReader1<'a>), + } + + #[derive(Debug, Clone, Copy)] + pub struct VchordrqCachedReader0<'a> { + #[allow(dead_code)] + header: &'a VchordrqCachedHeader0, + } + + #[derive(Debug, Clone, Copy)] + pub struct VchordrqCachedReader1<'a> { + #[allow(dead_code)] + header: &'a VchordrqCachedHeader1, + mapping: &'a [u32], + pages: &'a [PostgresPage], + } + + impl<'a> VchordrqCachedReader1<'a> { + pub fn get(&self, id: u32) -> Option<&'a PostgresPage> { + let index = *self.mapping.get(id as usize)?; + if index == u32::MAX { + return None; + } + Some(&self.pages[index as usize]) + } + } + + impl<'a> VchordrqCachedReader<'a> { + pub fn deserialize_ref(source: &'a [u8]) -> Self { + let tag = u64::from_ne_bytes(std::array::from_fn(|i| source[i])); + match tag { + 0 => { + let checker = RefChecker::new(source); + let header: &VchordrqCachedHeader0 = checker.prefix(size_of::()); + Self::_0(VchordrqCachedReader0 { header }) + } + 1 => { + let checker = RefChecker::new(source); + let header: &VchordrqCachedHeader1 = checker.prefix(size_of::()); + let mapping = checker.bytes(header.mapping_s, header.mapping_e); + let pages = + unsafe { checker.bytes_slice_unchecked(header.pages_s, header.pages_e) }; + Self::_1(VchordrqCachedReader1 { + header, + mapping, + pages, + }) + } + _ => panic!("bad bytes"), + } + } + } + + pub struct RefChecker<'a> { + bytes: &'a [u8], + } + + impl<'a> RefChecker<'a> { + pub fn new(bytes: &'a [u8]) -> Self { + Self { bytes } + } + pub fn prefix( + &self, + s: usize, + ) -> &'a T { + let start = s; + let end = s + size_of::(); + let bytes = &self.bytes[start..end]; + FromBytes::ref_from_bytes(bytes).expect("bad bytes") + } + pub fn bytes( + &self, + s: usize, + e: usize, + ) -> &'a T { + let start = s; + let end = e; + let bytes = &self.bytes[start..end]; + FromBytes::ref_from_bytes(bytes).expect("bad bytes") + } + pub unsafe fn bytes_slice_unchecked(&self, s: usize, e: usize) -> &'a [T] { + let start = s; + let end = e; + let bytes = &self.bytes[start..end]; + if size_of::() == 0 || bytes.len() % size_of::() == 0 { + let ptr = bytes as *const [u8] as *const T; + if ptr.is_aligned() { + unsafe { std::slice::from_raw_parts(ptr, bytes.len() / size_of::()) } + } else { + panic!("bad bytes") + } + } else { + panic!("bad bytes") + } + } + } +} + +fn is_mvcc_snapshot(snapshot: *mut pgrx::pg_sys::SnapshotData) -> bool { + matches!( + unsafe { (*snapshot).snapshot_type }, + pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC + | pgrx::pg_sys::SnapshotType::SNAPSHOT_HISTORIC_MVCC + ) +} + +struct VchordrqLeader { + pcxt: *mut pgrx::pg_sys::ParallelContext, + nparticipants: i32, + snapshot: pgrx::pg_sys::Snapshot, + vchordrqshared: *mut VchordrqShared, + tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, + vchordrqcached: *const u8, +} + +impl VchordrqLeader { + pub unsafe fn enter( + heap_relation: pgrx::pg_sys::Relation, + index_relation: pgrx::pg_sys::Relation, + isconcurrent: bool, + cache: vchordrq_cached::VchordrqCached, + ) -> Option { + let _cache = cache.serialize(); + drop(cache); + let cache = _cache; + + unsafe fn compute_parallel_workers( + heap_relation: pgrx::pg_sys::Relation, + index_relation: pgrx::pg_sys::Relation, + ) -> i32 { + unsafe { + if pgrx::pg_sys::plan_create_index_workers( + (*heap_relation).rd_id, + (*index_relation).rd_id, + ) == 0 + { + return 0; + } + if !(*heap_relation).rd_options.is_null() { + let std_options = (*heap_relation) + .rd_options + .cast::(); + std::cmp::min( + (*std_options).parallel_workers, + pgrx::pg_sys::max_parallel_maintenance_workers, + ) + } else { + pgrx::pg_sys::max_parallel_maintenance_workers + } + } + } + + let request = unsafe { compute_parallel_workers(heap_relation, index_relation) }; + if request <= 0 { + return None; + } + + unsafe { + pgrx::pg_sys::EnterParallelMode(); + } + let pcxt = unsafe { + pgrx::pg_sys::CreateParallelContext( + c"vchord".as_ptr(), + c"vchordrq_parallel_build_main".as_ptr(), + request, + ) + }; + + let snapshot = if isconcurrent { + unsafe { pgrx::pg_sys::RegisterSnapshot(pgrx::pg_sys::GetTransactionSnapshot()) } + } else { + &raw mut pgrx::pg_sys::SnapshotAnyData + }; + + fn estimate_chunk(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { + e.space_for_chunks += x.next_multiple_of(pgrx::pg_sys::ALIGNOF_BUFFER as _); + } + fn estimate_keys(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { + e.number_of_keys += x; + } + let est_tablescandesc = + unsafe { pgrx::pg_sys::table_parallelscan_estimate(heap_relation, snapshot) }; + unsafe { + estimate_chunk(&mut (*pcxt).estimator, size_of::()); + estimate_keys(&mut (*pcxt).estimator, 1); + estimate_chunk(&mut (*pcxt).estimator, est_tablescandesc); + estimate_keys(&mut (*pcxt).estimator, 1); + estimate_chunk(&mut (*pcxt).estimator, 8 + cache.len()); + estimate_keys(&mut (*pcxt).estimator, 1); + } + + unsafe { + pgrx::pg_sys::InitializeParallelDSM(pcxt); + if (*pcxt).seg.is_null() { + if is_mvcc_snapshot(snapshot) { + pgrx::pg_sys::UnregisterSnapshot(snapshot); + } + pgrx::pg_sys::DestroyParallelContext(pcxt); + pgrx::pg_sys::ExitParallelMode(); + return None; + } + } + + let vchordrqshared = unsafe { + let vchordrqshared = + pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, size_of::()) + .cast::(); + vchordrqshared.write(VchordrqShared { + heaprelid: (*heap_relation).rd_id, + indexrelid: (*index_relation).rd_id, + isconcurrent, + workersdonecv: std::mem::zeroed(), + mutex: std::mem::zeroed(), + nparticipantsdone: 0, + indtuples: 0, + }); + pgrx::pg_sys::ConditionVariableInit(&raw mut (*vchordrqshared).workersdonecv); + pgrx::pg_sys::SpinLockInit(&raw mut (*vchordrqshared).mutex); + vchordrqshared + }; + + let tablescandesc = unsafe { + let tablescandesc = pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, est_tablescandesc) + .cast::(); + pgrx::pg_sys::table_parallelscan_initialize(heap_relation, tablescandesc, snapshot); + tablescandesc + }; + + let vchordrqcached = unsafe { + let x = pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, 8 + cache.len()).cast::(); + (x as *mut u64).write_unaligned(cache.len() as _); + std::ptr::copy(cache.as_ptr(), x.add(8), cache.len()); + x + }; + + unsafe { + pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000001, vchordrqshared.cast()); + pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000002, tablescandesc.cast()); + pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000003, vchordrqcached.cast()); + } + + unsafe { + pgrx::pg_sys::LaunchParallelWorkers(pcxt); + } + + let nworkers_launched = unsafe { (*pcxt).nworkers_launched }; + + unsafe { + if nworkers_launched == 0 { + pgrx::pg_sys::WaitForParallelWorkersToFinish(pcxt); + if is_mvcc_snapshot(snapshot) { + pgrx::pg_sys::UnregisterSnapshot(snapshot); + } + pgrx::pg_sys::DestroyParallelContext(pcxt); + pgrx::pg_sys::ExitParallelMode(); + return None; + } + } + + Some(Self { + pcxt, + nparticipants: nworkers_launched + 1, + snapshot, + vchordrqshared, + tablescandesc, + vchordrqcached, + }) + } + + pub fn wait(&self) { + unsafe { + pgrx::pg_sys::WaitForParallelWorkersToAttach(self.pcxt); + } + } +} + +impl Drop for VchordrqLeader { + fn drop(&mut self) { + if !std::thread::panicking() { + unsafe { + pgrx::pg_sys::WaitForParallelWorkersToFinish(self.pcxt); + if is_mvcc_snapshot(self.snapshot) { + pgrx::pg_sys::UnregisterSnapshot(self.snapshot); + } + pgrx::pg_sys::DestroyParallelContext(self.pcxt); + pgrx::pg_sys::ExitParallelMode(); + } + } + } +} + +#[pgrx::pg_guard] +#[unsafe(no_mangle)] +pub unsafe extern "C" fn vchordrq_parallel_build_main( + _seg: *mut pgrx::pg_sys::dsm_segment, + toc: *mut pgrx::pg_sys::shm_toc, +) { + let vchordrqshared = unsafe { + pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000001, false).cast::() + }; + let tablescandesc = unsafe { + pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000002, false) + .cast::() + }; + let vchordrqcached = unsafe { + pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000003, false) + .cast::() + .cast_const() + }; + let heap_lockmode; + let index_lockmode; + if unsafe { !(*vchordrqshared).isconcurrent } { + heap_lockmode = pgrx::pg_sys::ShareLock as pgrx::pg_sys::LOCKMODE; + index_lockmode = pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; + } else { + heap_lockmode = pgrx::pg_sys::ShareUpdateExclusiveLock as pgrx::pg_sys::LOCKMODE; + index_lockmode = pgrx::pg_sys::RowExclusiveLock as pgrx::pg_sys::LOCKMODE; + } + let heap = unsafe { pgrx::pg_sys::table_open((*vchordrqshared).heaprelid, heap_lockmode) }; + let index = unsafe { pgrx::pg_sys::index_open((*vchordrqshared).indexrelid, index_lockmode) }; + let index_info = unsafe { pgrx::pg_sys::BuildIndexInfo(index) }; + unsafe { + (*index_info).ii_Concurrent = (*vchordrqshared).isconcurrent; + } + + unsafe { + parallel_build( + index, + heap, + index_info, + tablescandesc, + vchordrqshared, + vchordrqcached, + None, + ); + } + + unsafe { + pgrx::pg_sys::index_close(index, index_lockmode); + pgrx::pg_sys::table_close(heap, heap_lockmode); + } +} + +unsafe fn parallel_build( + index_relation: pgrx::pg_sys::Relation, + heap_relation: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, + tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, + vchordrqshared: *mut VchordrqShared, + vchordrqcached: *const u8, + mut reporter: Option, +) { + use vchordrq_cached::VchordrqCachedReader; + let cached = VchordrqCachedReader::deserialize_ref(unsafe { + let bytes = (vchordrqcached as *const u64).read_unaligned(); + std::slice::from_raw_parts(vchordrqcached.add(8), bytes as _) + }); + + let index = unsafe { PostgresRelation::new(index_relation) }; + + let scan = unsafe { pgrx::pg_sys::table_beginscan_parallel(heap_relation, tablescandesc) }; + let opfamily = unsafe { opfamily(index_relation) }; + let heap = Heap { + heap_relation, + index_relation, + index_info, + opfamily, + scan, + }; + match cached { + VchordrqCachedReader::_0(_) => match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, L2>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, Dot>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, L2>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, Dot>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + }, + VchordrqCachedReader::_1(cached) => { + let index = CachingRelation { + cache: cached, + relation: index, + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, L2>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, Dot>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, L2>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + heap.traverse(true, |(pointer, vector): (_, VectOwned)| { + algorithm::insert::, Dot>>( + index.clone(), + pointer, + RandomProject::project(vector.as_borrowed()), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + } + } + } + unsafe { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).nparticipantsdone += 1; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + pgrx::pg_sys::ConditionVariableSignal(&raw mut (*vchordrqshared).workersdonecv); + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambuildempty(_index_relation: pgrx::pg_sys::Relation) { + pgrx::error!("Unlogged indexes are not supported."); +} + +unsafe fn options( + index_relation: pgrx::pg_sys::Relation, +) -> (VectorOptions, VchordrqIndexingOptions) { + let att = unsafe { &mut *(*index_relation).rd_att }; + let atts = unsafe { att.attrs.as_slice(att.natts as _) }; + if atts.is_empty() { + pgrx::error!("indexing on no columns is not supported"); + } + if atts.len() != 1 { + pgrx::error!("multicolumn index is not supported"); + } + // get dims + let typmod = Typmod::parse_from_i32(atts[0].type_mod()).unwrap(); + let dims = if let Some(dims) = typmod.dims() { + dims.get() + } else { + pgrx::error!( + "Dimensions type modifier of a vector column is needed for building the index." + ); + }; + // get v, d + let opfamily = unsafe { opfamily(index_relation) }; + let vector = VectorOptions { + dims, + v: opfamily.vector_kind(), + d: opfamily.distance_kind(), + }; + // get indexing, segment, optimizing + let rabitq = 'rabitq: { + let reloption = unsafe { (*index_relation).rd_options as *const Reloption }; + if reloption.is_null() || unsafe { (*reloption).options == 0 } { + break 'rabitq Default::default(); + } + let s = unsafe { Reloption::options(reloption) }.to_string_lossy(); + match toml::from_str::(&s) { + Ok(p) => p, + Err(e) => pgrx::error!("failed to parse options: {}", e), + } + }; + (vector, rabitq) +} + +pub fn make_internal_build( + vector_options: VectorOptions, + internal_build: VchordrqInternalBuildOptions, + mut samples: Vec>, +) -> Vec>> { + use std::iter::once; + for sample in samples.iter_mut() { + *sample = crate::index::projection::project(sample); + } + let mut result = Vec::>>::new(); + for w in internal_build.lists.iter().rev().copied().chain(once(1)) { + let means = k_means::k_means( + internal_build.build_threads as _, + || { + pgrx::check_for_interrupts!(); + }, + w as usize, + vector_options.dims as usize, + if let Some(structure) = result.last() { + &structure.means + } else { + &samples + }, + internal_build.spherical_centroids, + 10, + ); + if let Some(structure) = result.last() { + let mut children = vec![Vec::new(); means.len()]; + for i in 0..structure.len() as u32 { + let target = k_means::k_means_lookup(&structure.means[i as usize], &means); + children[target].push(i); + } + let (means, children) = std::iter::zip(means, children) + .filter(|(_, x)| !x.is_empty()) + .unzip::<_, _, Vec<_>, Vec<_>>(); + result.push(Structure { means, children }); + } else { + let children = vec![Vec::new(); means.len()]; + result.push(Structure { means, children }); + } + } + result +} + +pub fn make_external_build( + vector_options: VectorOptions, + _opfamily: Opfamily, + external_build: VchordrqExternalBuildOptions, +) -> Vec>> { + use std::collections::BTreeMap; + let VchordrqExternalBuildOptions { table } = external_build; + let mut parents = BTreeMap::new(); + let mut vectors = BTreeMap::new(); + pgrx::spi::Spi::connect(|client| { + use crate::datatype::memory_vector::VectorOutput; + use pgrx::pg_sys::panic::ErrorReportable; + use vector::VectorBorrowed; + let schema_query = "SELECT n.nspname::TEXT + FROM pg_catalog.pg_extension e + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace + WHERE e.extname = 'vector';"; + let pgvector_schema: String = client + .select(schema_query, None, None) + .unwrap_or_report() + .first() + .get_by_name("nspname") + .expect("external build: cannot get schema of pgvector") + .expect("external build: cannot get schema of pgvector"); + let dump_query = + format!("SELECT id, parent, vector::{pgvector_schema}.vector FROM {table};"); + let centroids = client.select(&dump_query, None, None).unwrap_or_report(); + for row in centroids { + let id: Option = row.get_by_name("id").unwrap(); + let parent: Option = row.get_by_name("parent").unwrap(); + let vector: Option = row.get_by_name("vector").unwrap(); + let id = id.expect("external build: id could not be NULL"); + let vector = vector.expect("external build: vector could not be NULL"); + let pop = parents.insert(id, parent); + if pop.is_some() { + pgrx::error!( + "external build: there are at least two lines have same id, id = {id}" + ); + } + if vector_options.dims != vector.as_borrowed().dims() { + pgrx::error!("external build: incorrect dimension, id = {id}"); + } + vectors.insert( + id, + crate::index::projection::project(vector.as_borrowed().slice()), + ); + } + }); + if parents.len() >= 2 && parents.values().all(|x| x.is_none()) { + // if there are more than one vertexs and no edges, + // assume there is an implicit root + let n = parents.len(); + let mut result = Vec::new(); + result.push(Structure { + means: vectors.values().cloned().collect::>(), + children: vec![Vec::new(); n], + }); + result.push(Structure { + means: vec![{ + // compute the vector on root, without normalizing it + let mut sum = vec![0.0f32; vector_options.dims as _]; + for vector in vectors.values() { + f32::vector_add_inplace(&mut sum, vector); + } + f32::vector_mul_scalar_inplace(&mut sum, 1.0 / n as f32); + sum + }], + children: vec![(0..n as u32).collect()], + }); + return result; + } + let mut children = parents + .keys() + .map(|x| (*x, Vec::new())) + .collect::>(); + let mut root = None; + for (&id, &parent) in parents.iter() { + if let Some(parent) = parent { + if let Some(parent) = children.get_mut(&parent) { + parent.push(id); + } else { + pgrx::error!("external build: parent does not exist, id = {id}, parent = {parent}"); + } + } else { + if let Some(root) = root { + pgrx::error!("external build: two root, id = {root}, id = {id}"); + } else { + root = Some(id); + } + } + } + let Some(root) = root else { + pgrx::error!("external build: there are no root"); + }; + let mut heights = BTreeMap::<_, _>::new(); + fn dfs_for_heights( + heights: &mut BTreeMap>, + children: &BTreeMap>, + u: i32, + ) { + if heights.contains_key(&u) { + pgrx::error!("external build: detect a cycle, id = {u}"); + } + heights.insert(u, None); + let mut height = None; + for &v in children[&u].iter() { + dfs_for_heights(heights, children, v); + let new = heights[&v].unwrap() + 1; + if let Some(height) = height { + if height != new { + pgrx::error!("external build: two heights, id = {u}"); + } + } else { + height = Some(new); + } + } + if height.is_none() { + height = Some(1); + } + heights.insert(u, height); + } + dfs_for_heights(&mut heights, &children, root); + let heights = heights + .into_iter() + .map(|(k, v)| (k, v.expect("not a connected graph"))) + .collect::>(); + if !(1..=8).contains(&(heights[&root] - 1)) { + pgrx::error!( + "external build: unexpected tree height, height = {}", + heights[&root] + ); + } + let mut cursors = vec![0_u32; 1 + heights[&root] as usize]; + let mut labels = BTreeMap::new(); + for id in parents.keys().copied() { + let height = heights[&id]; + let cursor = cursors[height as usize]; + labels.insert(id, (height, cursor)); + cursors[height as usize] += 1; + } + fn extract( + height: u32, + labels: &BTreeMap, + vectors: &BTreeMap>, + children: &BTreeMap>, + ) -> (Vec>, Vec>) { + labels + .iter() + .filter(|(_, (h, _))| *h == height) + .map(|(id, _)| { + ( + vectors[id].clone(), + children[id].iter().map(|id| labels[id].1).collect(), + ) + }) + .unzip() + } + let mut result = Vec::new(); + for height in 1..=heights[&root] { + let (means, children) = extract(height, &labels, &vectors, &children); + result.push(Structure { means, children }); + } + result +} + +pub fn map_structures(x: Vec>, f: impl Fn(T) -> U + Copy) -> Vec> { + x.into_iter() + .map(|Structure { means, children }| Structure { + means: means.into_iter().map(f).collect(), + children, + }) + .collect() +} + +pub trait InternalBuild: VectorOwned { + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec; + + fn build_from_vecf32(x: &[f32]) -> Self; +} + +impl InternalBuild for VectOwned { + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + vector.slice().to_vec() + } + + fn build_from_vecf32(x: &[f32]) -> Self { + Self::new(x.to_vec()) + } +} + +impl InternalBuild for VectOwned { + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + f16::vector_to_f32(vector.slice()) + } + + fn build_from_vecf32(x: &[f32]) -> Self { + Self::new(f16::vector_from_f32(x)) + } +} + +struct CachingRelation<'a, R> { + cache: vchordrq_cached::VchordrqCachedReader1<'a>, + relation: R, +} + +impl Clone for CachingRelation<'_, R> { + fn clone(&self) -> Self { + Self { + cache: self.cache, + relation: self.relation.clone(), + } + } +} + +enum CachingRelationReadGuard<'a, G: Deref> { + Wrapping(G), + Cached(u32, &'a G::Target), +} + +impl PageGuard for CachingRelationReadGuard<'_, G> { + fn id(&self) -> u32 { + match self { + CachingRelationReadGuard::Wrapping(x) => x.id(), + CachingRelationReadGuard::Cached(id, _) => *id, + } + } +} + +impl Deref for CachingRelationReadGuard<'_, G> { + type Target = G::Target; + + fn deref(&self) -> &Self::Target { + match self { + CachingRelationReadGuard::Wrapping(x) => x, + CachingRelationReadGuard::Cached(_, page) => page, + } + } +} + +impl> RelationRead for CachingRelation<'_, R> { + type Page = R::Page; + + type ReadGuard<'a> + = CachingRelationReadGuard<'a, R::ReadGuard<'a>> + where + Self: 'a; + + fn read(&self, id: u32) -> Self::ReadGuard<'_> { + if let Some(x) = self.cache.get(id) { + CachingRelationReadGuard::Cached(id, x) + } else { + CachingRelationReadGuard::Wrapping(self.relation.read(id)) + } + } +} + +impl> RelationWrite for CachingRelation<'_, R> { + type WriteGuard<'a> + = R::WriteGuard<'a> + where + Self: 'a; + + fn write(&self, id: u32, tracking_freespace: bool) -> Self::WriteGuard<'_> { + self.relation.write(id, tracking_freespace) + } + + fn extend(&self, tracking_freespace: bool) -> Self::WriteGuard<'_> { + self.relation.extend(tracking_freespace) + } + + fn search(&self, freespace: usize) -> Option> { + self.relation.search(freespace) + } +} diff --git a/src/index/am/am_scan.rs b/src/index/am/am_scan.rs new file mode 100644 index 00000000..02b13b6f --- /dev/null +++ b/src/index/am/am_scan.rs @@ -0,0 +1,692 @@ +use crate::index::am::pointer_to_ctid; +use crate::index::gucs::{epsilon, max_scan_tuples, probes}; +use crate::index::opclass::{Opfamily, Sphere, opfamily}; +use crate::index::projection::RandomProject; +use crate::index::storage::PostgresRelation; +use algorithm::RerankMethod; +use algorithm::operator::{Dot, L2, Op, Vector}; +use algorithm::types::*; +use half::f16; +use pgrx::pg_sys::Datum; +use std::cell::LazyCell; +use std::num::NonZeroU64; +use vector::VectorOwned; +use vector::vect::VectOwned; + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambeginscan( + index_relation: pgrx::pg_sys::Relation, + n_keys: std::os::raw::c_int, + n_orderbys: std::os::raw::c_int, +) -> pgrx::pg_sys::IndexScanDesc { + use pgrx::memcxt::PgMemoryContexts::CurrentMemoryContext; + + let scan = unsafe { pgrx::pg_sys::RelationGetIndexScan(index_relation, n_keys, n_orderbys) }; + unsafe { + let scanner = Scanner { + opfamily: opfamily(index_relation), + scanning: Scanning::Empty {}, + }; + (*scan).opaque = CurrentMemoryContext.leak_and_drop_on_delete(scanner).cast(); + } + scan +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amrescan( + scan: pgrx::pg_sys::IndexScanDesc, + keys: pgrx::pg_sys::ScanKey, + _n_keys: std::os::raw::c_int, + orderbys: pgrx::pg_sys::ScanKey, + _n_orderbys: std::os::raw::c_int, +) { + unsafe { + if !keys.is_null() && (*scan).numberOfKeys > 0 { + std::ptr::copy(keys, (*scan).keyData, (*scan).numberOfKeys as _); + } + if !orderbys.is_null() && (*scan).numberOfOrderBys > 0 { + std::ptr::copy(orderbys, (*scan).orderByData, (*scan).numberOfOrderBys as _); + } + let opfamily = opfamily((*scan).indexRelation); + let (orderbys, spheres) = { + let mut orderbys = Vec::new(); + let mut spheres = Vec::new(); + if (*scan).numberOfOrderBys == 0 && (*scan).numberOfKeys == 0 { + pgrx::error!( + "vector search with no WHERE clause and no ORDER BY clause is not supported" + ); + } + for i in 0..(*scan).numberOfOrderBys { + let data = (*scan).orderByData.add(i as usize); + let value = (*data).sk_argument; + let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; + match (*data).sk_strategy { + 1 => orderbys.push(opfamily.input_vector(value, is_null)), + _ => unreachable!(), + } + } + for i in 0..(*scan).numberOfKeys { + let data = (*scan).keyData.add(i as usize); + let value = (*data).sk_argument; + let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; + match (*data).sk_strategy { + 2 => spheres.push(opfamily.input_sphere(value, is_null)), + _ => unreachable!(), + } + } + (orderbys, spheres) + }; + let scanning; + if let Some((vector, threshold, recheck)) = scanner_build(orderbys, spheres) { + scanning = Scanning::Initial { + vector, + threshold, + recheck, + }; + } else { + scanning = Scanning::Empty {}; + }; + let scanner = &mut *(*scan).opaque.cast::(); + scanner.scanning = scanning; + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amgettuple( + scan: pgrx::pg_sys::IndexScanDesc, + direction: pgrx::pg_sys::ScanDirection::Type, +) -> bool { + if direction != pgrx::pg_sys::ScanDirection::ForwardScanDirection { + pgrx::error!("vector search without a forward scan direction is not supported"); + } + // https://www.postgresql.org/docs/current/index-locking.html + // If heap entries referenced physical pointers are deleted before + // they are consumed by PostgreSQL, PostgreSQL will received wrong + // physical pointers: no rows or irreverent rows are referenced. + if unsafe { (*(*scan).xs_snapshot).snapshot_type } != pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC + { + pgrx::error!("scanning with a non-MVCC-compliant snapshot is not supported"); + } + let scanner = unsafe { (*scan).opaque.cast::().as_mut().unwrap_unchecked() }; + let relation = unsafe { PostgresRelation::new((*scan).indexRelation) }; + if let Some((pointer, recheck)) = scanner_next( + scanner, + relation, + LazyCell::new(move || unsafe { + let index_info = pgrx::pg_sys::BuildIndexInfo((*scan).indexRelation); + let heap_relation = (*scan).heapRelation; + let estate = Scopeguard::new(pgrx::pg_sys::CreateExecutorState(), |estate| { + pgrx::pg_sys::FreeExecutorState(estate); + }); + let econtext = pgrx::pg_sys::MakePerTupleExprContext(*estate.get()); + move |opfamily: Opfamily, payload| { + let slot = Scopeguard::new( + pgrx::pg_sys::table_slot_create(heap_relation, std::ptr::null_mut()), + |slot| pgrx::pg_sys::ExecDropSingleTupleTableSlot(slot), + ); + (*econtext).ecxt_scantuple = *slot.get(); + let table_am = (*heap_relation).rd_tableam; + let fetch_row_version = (*table_am).tuple_fetch_row_version.unwrap(); + let mut ctid = pointer_to_ctid(payload); + if !fetch_row_version(heap_relation, &mut ctid, (*scan).xs_snapshot, *slot.get()) { + return None; + } + let mut values = [Datum::from(0); 32]; + let mut is_null = [true; 32]; + pgrx::pg_sys::FormIndexDatum( + index_info, + *slot.get(), + *estate.get(), + values.as_mut_ptr(), + is_null.as_mut_ptr(), + ); + opfamily.input_vector(values[0], is_null[0]) + } + }), + ) { + let ctid = pointer_to_ctid(pointer); + unsafe { + (*scan).xs_heaptid = ctid; + (*scan).xs_recheckorderby = false; + (*scan).xs_recheck = recheck; + } + true + } else { + false + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) { + let scanner = unsafe { &mut *(*scan).opaque.cast::() }; + scanner.scanning = Scanning::Empty {}; +} + +struct Scanner { + opfamily: Opfamily, + scanning: Scanning, +} + +enum Scanning { + Initial { + vector: OwnedVector, + threshold: Option, + recheck: bool, + }, + Vbase { + vbase: Box>, + recheck: bool, + }, + Empty {}, +} + +fn scanner_build( + orderbys: Vec>, + spheres: Vec>>, +) -> Option<(OwnedVector, Option, bool)> { + let mut vector = None; + let mut threshold = None; + let mut recheck = false; + for orderby_vector in orderbys.into_iter().flatten() { + if vector.is_none() { + vector = Some(orderby_vector); + } else { + pgrx::error!("vector search with multiple vectors is not supported"); + } + } + for Sphere { center, radius } in spheres.into_iter().flatten() { + if vector.is_none() { + (vector, threshold) = (Some(center), Some(radius)); + } else { + recheck = true; + } + } + Some((vector?, threshold, recheck)) +} + +fn scanner_next( + scanner: &mut Scanner, + relation: PostgresRelation, + fetch: LazyCell, +) -> Option<(NonZeroU64, bool)> +where + F: Fn(Opfamily, NonZeroU64) -> Option + 'static, + I: FnOnce() -> F + 'static, +{ + if let Scanning::Initial { + vector, + threshold, + recheck, + } = &scanner.scanning + { + let opfamily = scanner.opfamily; + let vector = vector.clone(); + let threshold = *threshold; + let recheck = *recheck; + let max_scan_tuples = max_scan_tuples(); + let probes = probes(); + let epsilon = epsilon(); + scanner.scanning = Scanning::Vbase { + vbase: match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + let vector = RandomProject::project( + VectOwned::::from_owned(vector.clone()).as_borrowed(), + ); + let (method, results) = algorithm::search::, L2>>( + relation.clone(), + vector.clone(), + probes, + epsilon, + ); + match (method, max_scan_tuples, threshold) { + (RerankMethod::Index, None, None) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Index, None, Some(threshold)) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Index, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Index, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + (RerankMethod::Heap, None, None) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Heap, None, Some(threshold)) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Heap, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Heap, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + } + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + let vector = RandomProject::project( + VectOwned::::from_owned(vector.clone()).as_borrowed(), + ); + let (method, results) = algorithm::search::, Dot>>( + relation.clone(), + vector.clone(), + probes, + epsilon, + ); + match (method, max_scan_tuples, threshold) { + (RerankMethod::Index, None, None) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Index, None, Some(threshold)) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Index, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Index, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + (RerankMethod::Heap, None, None) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Heap, None, Some(threshold)) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Heap, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Heap, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + } + } + (VectorKind::Vecf16, DistanceKind::L2) => { + let vector = RandomProject::project( + VectOwned::::from_owned(vector.clone()).as_borrowed(), + ); + let (method, results) = algorithm::search::, L2>>( + relation.clone(), + vector.clone(), + probes, + epsilon, + ); + match (method, max_scan_tuples, threshold) { + (RerankMethod::Index, None, None) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Index, None, Some(threshold)) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Index, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Index, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_index::, L2>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + (RerankMethod::Heap, None, None) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Heap, None, Some(threshold)) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Heap, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Heap, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_heap::, L2>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + } + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + let vector = RandomProject::project( + VectOwned::::from_owned(vector.clone()).as_borrowed(), + ); + let (method, results) = algorithm::search::, Dot>>( + relation.clone(), + vector.clone(), + probes, + epsilon, + ); + match (method, max_scan_tuples, threshold) { + (RerankMethod::Index, None, None) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Index, None, Some(threshold)) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Index, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Index, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_index::, Dot>>( + relation, vector, results, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + (RerankMethod::Heap, None, None) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.fuse()) as Box> + } + (RerankMethod::Heap, None, Some(threshold)) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take_while(move |(x, _)| *x < threshold)) + } + (RerankMethod::Heap, Some(max_scan_tuples), None) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new(vbase.take(max_scan_tuples as _)) + } + (RerankMethod::Heap, Some(max_scan_tuples), Some(threshold)) => { + let vbase = algorithm::rerank_heap::, Dot>, _>( + vector, + results, + move |payload| { + Some(RandomProject::project( + VectOwned::::from_owned(fetch(opfamily, payload)?) + .as_borrowed(), + )) + }, + ) + .map(move |(distance, payload)| (opfamily.output(distance), payload)); + Box::new( + vbase + .take_while(move |(x, _)| *x < threshold) + .take(max_scan_tuples as _), + ) + } + } + } + }, + recheck, + }; + } + match &mut scanner.scanning { + Scanning::Initial { .. } => unreachable!(), + Scanning::Vbase { vbase, recheck } => vbase.next().map(|(_, x)| (x, *recheck)), + Scanning::Empty {} => None, + } +} + +struct Scopeguard +where + T: Copy, + F: FnMut(T), +{ + t: T, + f: F, +} + +impl Scopeguard +where + T: Copy, + F: FnMut(T), +{ + fn new(t: T, f: F) -> Self { + Scopeguard { t, f } + } + fn get(&self) -> &T { + &self.t + } +} + +impl Drop for Scopeguard +where + T: Copy, + F: FnMut(T), +{ + fn drop(&mut self) { + (self.f)(self.t); + } +} diff --git a/src/index/am/mod.rs b/src/index/am/mod.rs new file mode 100644 index 00000000..dcac2314 --- /dev/null +++ b/src/index/am/mod.rs @@ -0,0 +1,324 @@ +pub mod am_build; +pub mod am_scan; + +use crate::index::projection::RandomProject; +use crate::index::storage::PostgresRelation; +use algorithm::operator::{Dot, L2, Op, Vector}; +use algorithm::types::*; +use half::f16; +use pgrx::datum::Internal; +use pgrx::pg_sys::Datum; +use std::ffi::CStr; +use std::num::NonZeroU64; +use std::sync::OnceLock; +use vector::VectorOwned; +use vector::vect::VectOwned; + +#[repr(C)] +struct Reloption { + vl_len_: i32, + pub options: i32, +} + +impl Reloption { + unsafe fn options<'a>(this: *const Self) -> &'a CStr { + unsafe { + let ptr = this + .cast::() + .add((&raw const (*this).options).read() as _); + CStr::from_ptr(ptr.cast()) + } + } +} + +const TABLE: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { + optname: c"options".as_ptr(), + opttype: pgrx::pg_sys::relopt_type::RELOPT_TYPE_STRING, + offset: std::mem::offset_of!(Reloption, options) as i32, +}]; + +static RELOPT_KIND: OnceLock = OnceLock::new(); + +pub fn init() { + RELOPT_KIND.get_or_init(|| { + let kind; + unsafe { + kind = pgrx::pg_sys::add_reloption_kind(); + pgrx::pg_sys::add_string_reloption( + kind, + c"options".as_ptr(), + c"Vector index options, represented as a TOML string.".as_ptr(), + c"".as_ptr(), + None, + pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, + ); + } + kind + }); +} + +#[pgrx::pg_extern(sql = "")] +fn _vchordrq_amhandler(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { + type T = pgrx::pg_sys::IndexAmRoutine; + unsafe { + let index_am_routine = pgrx::pg_sys::palloc0(size_of::()) as *mut T; + index_am_routine.write(AM_HANDLER); + Internal::from(Some(Datum::from(index_am_routine))) + } +} + +const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = const { + let mut am_routine = unsafe { std::mem::zeroed::() }; + + am_routine.type_ = pgrx::pg_sys::NodeTag::T_IndexAmRoutine; + + am_routine.amsupport = 1; + am_routine.amcanorderbyop = true; + + #[cfg(feature = "pg17")] + { + am_routine.amcanbuildparallel = true; + } + + // Index access methods that set `amoptionalkey` to `false` + // must index all tuples, even if the first column is `NULL`. + // However, PostgreSQL does not generate a path if there is no + // index clauses, even if there is a `ORDER BY` clause. + // So we have to set it to `true` and set costs of every path + // for vector index scans without `ORDER BY` clauses a large number + // and throw errors if someone really wants such a path. + am_routine.amoptionalkey = true; + + am_routine.amvalidate = Some(amvalidate); + am_routine.amoptions = Some(amoptions); + am_routine.amcostestimate = Some(amcostestimate); + + am_routine.ambuild = Some(am_build::ambuild); + am_routine.ambuildempty = Some(am_build::ambuildempty); + am_routine.aminsert = Some(aminsert); + am_routine.ambulkdelete = Some(ambulkdelete); + am_routine.amvacuumcleanup = Some(amvacuumcleanup); + + am_routine.ambeginscan = Some(am_scan::ambeginscan); + am_routine.amrescan = Some(am_scan::amrescan); + am_routine.amgettuple = Some(am_scan::amgettuple); + am_routine.amendscan = Some(am_scan::amendscan); + + am_routine +}; + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amvalidate(_opclass_oid: pgrx::pg_sys::Oid) -> bool { + true +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amoptions(reloptions: Datum, validate: bool) -> *mut pgrx::pg_sys::bytea { + let relopt_kind = RELOPT_KIND.get().copied().expect("init is not called"); + let rdopts = unsafe { + pgrx::pg_sys::build_reloptions( + reloptions, + validate, + relopt_kind, + size_of::(), + TABLE.as_ptr(), + TABLE.len() as _, + ) + }; + rdopts as *mut pgrx::pg_sys::bytea +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amcostestimate( + _root: *mut pgrx::pg_sys::PlannerInfo, + path: *mut pgrx::pg_sys::IndexPath, + _loop_count: f64, + index_startup_cost: *mut pgrx::pg_sys::Cost, + index_total_cost: *mut pgrx::pg_sys::Cost, + index_selectivity: *mut pgrx::pg_sys::Selectivity, + index_correlation: *mut f64, + index_pages: *mut f64, +) { + unsafe { + if (*path).indexorderbys.is_null() && (*path).indexclauses.is_null() { + *index_startup_cost = f64::MAX; + *index_total_cost = f64::MAX; + *index_selectivity = 0.0; + *index_correlation = 0.0; + *index_pages = 0.0; + return; + } + *index_startup_cost = 0.0; + *index_total_cost = 0.0; + *index_selectivity = 1.0; + *index_correlation = 1.0; + *index_pages = 0.0; + } +} + +#[cfg(feature = "pg13")] +#[pgrx::pg_guard] +pub unsafe extern "C" fn aminsert( + index_relation: pgrx::pg_sys::Relation, + values: *mut Datum, + is_null: *mut bool, + heap_tid: pgrx::pg_sys::ItemPointer, + _heap_relation: pgrx::pg_sys::Relation, + _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, + _index_info: *mut pgrx::pg_sys::IndexInfo, +) -> bool { + unsafe { aminsertinner(index_relation, values, is_null, heap_tid) } +} + +#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16", feature = "pg17"))] +#[pgrx::pg_guard] +pub unsafe extern "C" fn aminsert( + index_relation: pgrx::pg_sys::Relation, + values: *mut Datum, + is_null: *mut bool, + heap_tid: pgrx::pg_sys::ItemPointer, + _heap_relation: pgrx::pg_sys::Relation, + _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, + _index_unchanged: bool, + _index_info: *mut pgrx::pg_sys::IndexInfo, +) -> bool { + unsafe { aminsertinner(index_relation, values, is_null, heap_tid) } +} + +unsafe fn aminsertinner( + index_relation: pgrx::pg_sys::Relation, + values: *mut Datum, + is_null: *mut bool, + heap_tid: pgrx::pg_sys::ItemPointer, +) -> bool { + let opfamily = unsafe { crate::index::opclass::opfamily(index_relation) }; + let index = unsafe { PostgresRelation::new(index_relation) }; + let payload = ctid_to_pointer(unsafe { heap_tid.read() }); + let vector = unsafe { opfamily.input_vector(*values.add(0), *is_null.add(0)) }; + let Some(vector) = vector else { return false }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => algorithm::insert::, L2>>( + index, + payload, + RandomProject::project(VectOwned::::from_owned(vector).as_borrowed()), + ), + (VectorKind::Vecf32, DistanceKind::Dot) => algorithm::insert::, Dot>>( + index, + payload, + RandomProject::project(VectOwned::::from_owned(vector).as_borrowed()), + ), + (VectorKind::Vecf16, DistanceKind::L2) => algorithm::insert::, L2>>( + index, + payload, + RandomProject::project(VectOwned::::from_owned(vector).as_borrowed()), + ), + (VectorKind::Vecf16, DistanceKind::Dot) => algorithm::insert::, Dot>>( + index, + payload, + RandomProject::project(VectOwned::::from_owned(vector).as_borrowed()), + ), + } + false +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambulkdelete( + info: *mut pgrx::pg_sys::IndexVacuumInfo, + stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, + callback: pgrx::pg_sys::IndexBulkDeleteCallback, + callback_state: *mut std::os::raw::c_void, +) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { + let mut stats = stats; + if stats.is_null() { + stats = unsafe { + pgrx::pg_sys::palloc0(size_of::()).cast() + }; + } + let opfamily = unsafe { crate::index::opclass::opfamily((*info).index) }; + let index = unsafe { PostgresRelation::new((*info).index) }; + let check = || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }; + let callback = callback.expect("null function pointer"); + let callback = |p: NonZeroU64| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + algorithm::bulkdelete::, L2>>(index, check, callback); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + algorithm::bulkdelete::, Dot>>(index, check, callback); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + algorithm::bulkdelete::, L2>>(index, check, callback); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + algorithm::bulkdelete::, Dot>>(index, check, callback); + } + } + stats +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amvacuumcleanup( + info: *mut pgrx::pg_sys::IndexVacuumInfo, + stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, +) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { + let mut stats = stats; + if stats.is_null() { + stats = unsafe { + pgrx::pg_sys::palloc0(size_of::()).cast() + }; + } + let opfamily = unsafe { crate::index::opclass::opfamily((*info).index) }; + let index = unsafe { PostgresRelation::new((*info).index) }; + let check = || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + algorithm::maintain::, L2>>(index, check); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + algorithm::maintain::, Dot>>(index, check); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + algorithm::maintain::, L2>>(index, check); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + algorithm::maintain::, Dot>>(index, check); + } + } + stats +} + +const fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { + let value = pointer.get(); + pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { + bi_hi: ((value >> 32) & 0xffff) as u16, + bi_lo: ((value >> 16) & 0xffff) as u16, + }, + ip_posid: (value & 0xffff) as u16, + } +} + +const fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { + let mut value = 0; + value |= (ctid.ip_blkid.bi_hi as u64) << 32; + value |= (ctid.ip_blkid.bi_lo as u64) << 16; + value |= ctid.ip_posid as u64; + NonZeroU64::new(value).expect("invalid pointer") +} + +#[test] +const fn soundness_check() { + let a = pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { bi_hi: 1, bi_lo: 2 }, + ip_posid: 3, + }; + let b = ctid_to_pointer(a); + let c = pointer_to_ctid(b); + assert!(a.ip_blkid.bi_hi == c.ip_blkid.bi_hi); + assert!(a.ip_blkid.bi_lo == c.ip_blkid.bi_lo); + assert!(a.ip_posid == c.ip_posid); +} diff --git a/src/index/am_options.rs b/src/index/am_options.rs deleted file mode 100644 index 06ca8e50..00000000 --- a/src/index/am_options.rs +++ /dev/null @@ -1,235 +0,0 @@ -use crate::datatype::memory_halfvec::HalfvecInput; -use crate::datatype::memory_halfvec::HalfvecOutput; -use crate::datatype::memory_vector::VectorInput; -use crate::datatype::memory_vector::VectorOutput; -use crate::datatype::typmod::Typmod; -use crate::types::{BorrowedVector, OwnedVector}; -use crate::types::{DistanceKind, VectorKind}; -use crate::types::{VchordrqIndexingOptions, VectorOptions}; -use distance::Distance; -use pgrx::datum::FromDatum; -use pgrx::heap_tuple::PgHeapTuple; -use serde::Deserialize; -use std::ffi::CStr; -use std::num::NonZero; -use vector::VectorBorrowed; - -#[derive(Copy, Clone, Debug, Default)] -#[repr(C)] -pub struct Reloption { - vl_len_: i32, - pub options: i32, -} - -impl Reloption { - pub const TAB: &'static [pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { - optname: c"options".as_ptr(), - opttype: pgrx::pg_sys::relopt_type::RELOPT_TYPE_STRING, - offset: std::mem::offset_of!(Reloption, options) as i32, - }]; - unsafe fn options(&self) -> &CStr { - unsafe { - let ptr = (&raw const *self) - .cast::() - .offset(self.options as _); - CStr::from_ptr(ptr) - } - } -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum PgDistanceKind { - L2, - Dot, - Cos, -} - -impl PgDistanceKind { - pub fn to_distance(self) -> DistanceKind { - match self { - PgDistanceKind::L2 => DistanceKind::L2, - PgDistanceKind::Dot | PgDistanceKind::Cos => DistanceKind::Dot, - } - } -} - -fn convert_name_to_vd(name: &str) -> Option<(VectorKind, PgDistanceKind)> { - match name.strip_suffix("_ops") { - Some("vector_l2") => Some((VectorKind::Vecf32, PgDistanceKind::L2)), - Some("vector_ip") => Some((VectorKind::Vecf32, PgDistanceKind::Dot)), - Some("vector_cosine") => Some((VectorKind::Vecf32, PgDistanceKind::Cos)), - Some("halfvec_l2") => Some((VectorKind::Vecf16, PgDistanceKind::L2)), - Some("halfvec_ip") => Some((VectorKind::Vecf16, PgDistanceKind::Dot)), - Some("halfvec_cosine") => Some((VectorKind::Vecf16, PgDistanceKind::Cos)), - _ => None, - } -} - -unsafe fn convert_reloptions_to_options( - reloptions: *const pgrx::pg_sys::varlena, -) -> VchordrqIndexingOptions { - #[derive(Debug, Clone, Deserialize, Default)] - #[serde(deny_unknown_fields)] - struct Parsed { - #[serde(flatten)] - rabitq: VchordrqIndexingOptions, - } - let reloption = reloptions as *const Reloption; - if reloption.is_null() || unsafe { (*reloption).options == 0 } { - return Default::default(); - } - let s = unsafe { (*reloption).options() }.to_string_lossy(); - match toml::from_str::(&s) { - Ok(p) => p.rabitq, - Err(e) => pgrx::error!("failed to parse options: {}", e), - } -} - -pub unsafe fn options(index: pgrx::pg_sys::Relation) -> (VectorOptions, VchordrqIndexingOptions) { - let att = unsafe { &mut *(*index).rd_att }; - let atts = unsafe { att.attrs.as_slice(att.natts as _) }; - if atts.is_empty() { - pgrx::error!("indexing on no columns is not supported"); - } - if atts.len() != 1 { - pgrx::error!("multicolumn index is not supported"); - } - // get dims - let typmod = Typmod::parse_from_i32(atts[0].type_mod()).unwrap(); - let dims = if let Some(dims) = typmod.dims() { - dims.get() - } else { - pgrx::error!( - "Dimensions type modifier of a vector column is needed for building the index." - ); - }; - // get v, d - let opfamily = unsafe { opfamily(index) }; - let vector = VectorOptions { - dims, - v: opfamily.vector, - d: opfamily.distance_kind(), - }; - // get indexing, segment, optimizing - let rabitq = unsafe { convert_reloptions_to_options((*index).rd_options) }; - (vector, rabitq) -} - -#[derive(Debug, Clone, Copy)] -pub struct Opfamily { - vector: VectorKind, - pg_distance: PgDistanceKind, -} - -impl Opfamily { - pub unsafe fn datum_to_vector( - self, - datum: pgrx::pg_sys::Datum, - is_null: bool, - ) -> Option { - if is_null || datum.is_null() { - return None; - } - let vector = match self.vector { - VectorKind::Vecf32 => { - let vector = unsafe { VectorInput::from_datum(datum, false).unwrap() }; - self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed())) - } - VectorKind::Vecf16 => { - let vector = unsafe { HalfvecInput::from_datum(datum, false).unwrap() }; - self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed())) - } - }; - Some(vector) - } - pub unsafe fn datum_to_sphere( - self, - datum: pgrx::pg_sys::Datum, - is_null: bool, - ) -> (Option, Option) { - if is_null || datum.is_null() { - return (None, None); - } - let tuple = unsafe { PgHeapTuple::from_composite_datum(datum) }; - let center = match self.vector { - VectorKind::Vecf32 => tuple - .get_by_index::(NonZero::new(1).unwrap()) - .unwrap() - .map(|vector| self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed()))), - VectorKind::Vecf16 => tuple - .get_by_index::(NonZero::new(1).unwrap()) - .unwrap() - .map(|vector| self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed()))), - }; - let radius = tuple.get_by_index::(NonZero::new(2).unwrap()).unwrap(); - (center, radius) - } - pub fn preprocess(self, vector: BorrowedVector<'_>) -> OwnedVector { - use BorrowedVector as B; - use OwnedVector as O; - match (vector, self.pg_distance) { - (B::Vecf32(x), PgDistanceKind::L2) => O::Vecf32(x.own()), - (B::Vecf32(x), PgDistanceKind::Dot) => O::Vecf32(x.own()), - (B::Vecf32(x), PgDistanceKind::Cos) => O::Vecf32(x.function_normalize()), - (B::Vecf16(x), PgDistanceKind::L2) => O::Vecf16(x.own()), - (B::Vecf16(x), PgDistanceKind::Dot) => O::Vecf16(x.own()), - (B::Vecf16(x), PgDistanceKind::Cos) => O::Vecf16(x.function_normalize()), - } - } - pub fn process(self, x: Distance) -> f32 { - match self.pg_distance { - PgDistanceKind::Cos => f32::from(x) + 1.0f32, - PgDistanceKind::L2 => f32::from(x).sqrt(), - PgDistanceKind::Dot => x.into(), - } - } - pub fn distance_kind(self) -> DistanceKind { - self.pg_distance.to_distance() - } - pub fn vector_kind(self) -> VectorKind { - self.vector - } -} - -pub unsafe fn opfamily(index: pgrx::pg_sys::Relation) -> Opfamily { - use pgrx::pg_sys::Oid; - - let proc = unsafe { pgrx::pg_sys::index_getprocid(index, 1, 1) }; - - if proc == Oid::INVALID { - pgrx::error!("support function 1 is not found"); - } - - let mut flinfo = pgrx::pg_sys::FmgrInfo::default(); - unsafe { - pgrx::pg_sys::fmgr_info(proc, &mut flinfo); - } - - let fn_addr = flinfo.fn_addr.expect("null function pointer"); - - let mut fcinfo = unsafe { std::mem::zeroed::() }; - fcinfo.flinfo = &mut flinfo; - fcinfo.fncollation = pgrx::pg_sys::DEFAULT_COLLATION_OID; - fcinfo.context = std::ptr::null_mut(); - fcinfo.resultinfo = std::ptr::null_mut(); - fcinfo.isnull = true; - fcinfo.nargs = 0; - - let result_datum = unsafe { pgrx::pg_sys::ffi::pg_guard_ffi_boundary(|| fn_addr(&mut fcinfo)) }; - - let result_option = unsafe { String::from_datum(result_datum, fcinfo.isnull) }; - - let result_string = result_option.expect("null string"); - - let (vector, pg_distance) = convert_name_to_vd(&result_string).unwrap(); - - unsafe { - pgrx::pg_sys::pfree(result_datum.cast_mut_ptr()); - } - - Opfamily { - vector, - pg_distance, - } -} diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs deleted file mode 100644 index 83e62f33..00000000 --- a/src/index/am_scan.rs +++ /dev/null @@ -1,186 +0,0 @@ -use super::am_options::Opfamily; -use crate::algorithm::operator::Vector; -use crate::algorithm::operator::{Dot, L2, Op}; -use crate::algorithm::scan::scan; -use crate::gucs::executing::epsilon; -use crate::gucs::executing::max_scan_tuples; -use crate::gucs::executing::probes; -use crate::postgres::PostgresRelation; -use crate::types::DistanceKind; -use crate::types::OwnedVector; -use crate::types::VectorKind; -use distance::Distance; -use half::f16; -use std::num::NonZeroU64; -use vector::vect::VectOwned; - -pub enum Scanner { - Initial { - vector: Option<(OwnedVector, Opfamily)>, - threshold: Option, - recheck: bool, - }, - Vbase { - vbase: Box>, - threshold: Option, - recheck: bool, - opfamily: Opfamily, - }, - Empty {}, -} - -pub fn scan_build( - orderbys: Vec>, - spheres: Vec<(Option, Option)>, - opfamily: Opfamily, -) -> (Option<(OwnedVector, Opfamily)>, Option, bool) { - let mut pair = None; - let mut threshold = None; - let mut recheck = false; - for orderby_vector in orderbys { - if pair.is_none() { - pair = orderby_vector; - } else if orderby_vector.is_some() { - pgrx::error!("vector search with multiple vectors is not supported"); - } - } - for (sphere_vector, sphere_threshold) in spheres { - if pair.is_none() { - pair = sphere_vector; - threshold = sphere_threshold; - } else { - recheck = true; - break; - } - } - (pair.map(|x| (x, opfamily)), threshold, recheck) -} - -pub fn scan_make( - vector: Option<(OwnedVector, Opfamily)>, - threshold: Option, - recheck: bool, -) -> Scanner { - Scanner::Initial { - vector, - threshold, - recheck, - } -} - -pub fn scan_next(scanner: &mut Scanner, relation: PostgresRelation) -> Option<(NonZeroU64, bool)> { - if let Scanner::Initial { - vector, - threshold, - recheck, - } = scanner - { - if let Some((vector, opfamily)) = vector.as_ref() { - match (opfamily.vector_kind(), opfamily.distance_kind()) { - (VectorKind::Vecf32, DistanceKind::L2) => { - let vbase = scan::, L2>>( - relation, - VectOwned::::from_owned(vector.clone()), - probes(), - epsilon(), - ); - *scanner = Scanner::Vbase { - vbase: if let Some(max_scan_tuples) = max_scan_tuples() { - Box::new(vbase.take(max_scan_tuples as usize)) - } else { - Box::new(vbase) - }, - threshold: *threshold, - recheck: *recheck, - opfamily: *opfamily, - }; - } - (VectorKind::Vecf32, DistanceKind::Dot) => { - let vbase = scan::, Dot>>( - relation, - VectOwned::::from_owned(vector.clone()), - probes(), - epsilon(), - ); - *scanner = Scanner::Vbase { - vbase: if let Some(max_scan_tuples) = max_scan_tuples() { - Box::new(vbase.take(max_scan_tuples as usize)) - } else { - Box::new(vbase) - }, - threshold: *threshold, - recheck: *recheck, - opfamily: *opfamily, - }; - } - (VectorKind::Vecf16, DistanceKind::L2) => { - let vbase = scan::, L2>>( - relation, - VectOwned::::from_owned(vector.clone()), - probes(), - epsilon(), - ); - *scanner = Scanner::Vbase { - vbase: if let Some(max_scan_tuples) = max_scan_tuples() { - Box::new(vbase.take(max_scan_tuples as usize)) - } else { - Box::new(vbase) - }, - threshold: *threshold, - recheck: *recheck, - opfamily: *opfamily, - }; - } - (VectorKind::Vecf16, DistanceKind::Dot) => { - let vbase = scan::, Dot>>( - relation, - VectOwned::::from_owned(vector.clone()), - probes(), - epsilon(), - ); - *scanner = Scanner::Vbase { - vbase: if let Some(max_scan_tuples) = max_scan_tuples() { - Box::new(vbase.take(max_scan_tuples as usize)) - } else { - Box::new(vbase) - }, - threshold: *threshold, - recheck: *recheck, - opfamily: *opfamily, - }; - } - } - } else { - *scanner = Scanner::Empty {}; - } - } - match scanner { - Scanner::Initial { .. } => unreachable!(), - Scanner::Vbase { - vbase, - threshold, - recheck, - opfamily, - } => match ( - vbase.next().map(|(d, p)| (opfamily.process(d), p)), - threshold, - ) { - (Some((_, ptr)), None) => Some((ptr, *recheck)), - (Some((distance, ptr)), Some(t)) if distance < *t => Some((ptr, *recheck)), - _ => { - let scanner = std::mem::replace(scanner, Scanner::Empty {}); - scan_release(scanner); - None - } - }, - Scanner::Empty {} => None, - } -} - -pub fn scan_release(scanner: Scanner) { - match scanner { - Scanner::Initial { .. } => {} - Scanner::Vbase { .. } => {} - Scanner::Empty {} => {} - } -} diff --git a/src/index/functions.rs b/src/index/functions.rs index 1f3b4e28..64acde82 100644 --- a/src/index/functions.rs +++ b/src/index/functions.rs @@ -1,9 +1,6 @@ -use super::am_options; -use crate::algorithm::operator::{Dot, L2, Op}; -use crate::algorithm::prewarm::prewarm; -use crate::postgres::PostgresRelation; -use crate::types::DistanceKind; -use crate::types::VectorKind; +use crate::index::storage::PostgresRelation; +use algorithm::operator::{Dot, L2, Op}; +use algorithm::types::*; use half::f16; use pgrx::pg_sys::Oid; use pgrx_catalog::{PgAm, PgClass}; @@ -22,25 +19,33 @@ fn _vchordrq_prewarm(indexrelid: Oid, height: i32) -> String { if pg_class.relam() != pg_am.oid() { pgrx::error!("{:?} is not a vchordrq index", pg_class.relname()); } - let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; + let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::AccessShareLock as _) }; let relation = unsafe { PostgresRelation::new(index) }; - let opfamily = unsafe { am_options::opfamily(index) }; + let opfamily = unsafe { crate::index::opclass::opfamily(index) }; let message = match (opfamily.vector_kind(), opfamily.distance_kind()) { (VectorKind::Vecf32, DistanceKind::L2) => { - prewarm::, L2>>(relation, height) + algorithm::prewarm::, L2>>(relation, height, || { + pgrx::check_for_interrupts!(); + }) } (VectorKind::Vecf32, DistanceKind::Dot) => { - prewarm::, Dot>>(relation, height) + algorithm::prewarm::, Dot>>(relation, height, || { + pgrx::check_for_interrupts!(); + }) } (VectorKind::Vecf16, DistanceKind::L2) => { - prewarm::, L2>>(relation, height) + algorithm::prewarm::, L2>>(relation, height, || { + pgrx::check_for_interrupts!(); + }) } (VectorKind::Vecf16, DistanceKind::Dot) => { - prewarm::, Dot>>(relation, height) + algorithm::prewarm::, Dot>>(relation, height, || { + pgrx::check_for_interrupts!(); + }) } }; unsafe { - pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); + pgrx::pg_sys::index_close(index, pgrx::pg_sys::AccessShareLock as _); } message } diff --git a/src/gucs/executing.rs b/src/index/gucs.rs similarity index 59% rename from src/gucs/executing.rs rename to src/index/gucs.rs index af6cce7d..26d061ab 100644 --- a/src/gucs/executing.rs +++ b/src/index/gucs.rs @@ -4,8 +4,10 @@ use std::ffi::CStr; static PROBES: GucSetting> = GucSetting::>::new(Some(c"10")); static EPSILON: GucSetting = GucSetting::::new(1.9); static MAX_SCAN_TUPLES: GucSetting = GucSetting::::new(-1); +static PREWARM_DIM: GucSetting> = + GucSetting::>::new(Some(c"64,128,256,384,512,768,1024,1536")); -pub unsafe fn init() { +pub fn init() { GucRegistry::define_string_guc( "vchordrq.probes", "`probes` argument of vchordrq.", @@ -34,6 +36,20 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_string_guc( + "vchordrq.prewarm_dim", + "prewarm_dim when the extension is loading.", + "prewarm_dim when the extension is loading.", + &PREWARM_DIM, + GucContext::Userset, + GucFlags::default(), + ); + unsafe { + #[cfg(any(feature = "pg13", feature = "pg14"))] + pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchordrq".as_ptr()); + #[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17"))] + pgrx::pg_sys::MarkGUCPrefixReserved(c"vchordrq".as_ptr()); + } } pub fn probes() -> Vec { @@ -56,7 +72,9 @@ pub fn probes() -> Vec { c => pgrx::error!("unknown character in probes: ASCII = {c}"), } } - result.push(current.take().expect("empty probes")); + if let Some(current) = current { + result.push(current); + } result } } @@ -70,3 +88,24 @@ pub fn max_scan_tuples() -> Option { let x = MAX_SCAN_TUPLES.get(); if x < 0 { None } else { Some(x as u32) } } + +pub fn prewarm_dim() -> Vec { + if let Some(prewarm_dim) = PREWARM_DIM.get() { + if let Ok(prewarm_dim) = prewarm_dim.to_str() { + let mut result = Vec::new(); + for dim in prewarm_dim.split(',') { + if let Ok(dim) = dim.trim().parse::() { + result.push(dim); + } else { + pgrx::warning!("{dim:?} is not a valid integer"); + } + } + result + } else { + pgrx::warning!("vchordrq.prewarm_dim is not a valid UTF-8 string"); + Vec::new() + } + } else { + Vec::new() + } +} diff --git a/src/index/mod.rs b/src/index/mod.rs index 5203e4fb..e7a309ae 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -1,12 +1,15 @@ pub mod am; -pub mod am_options; -pub mod am_scan; pub mod functions; +pub mod gucs; pub mod opclass; -pub mod utils; +pub mod projection; +pub mod storage; +pub mod types; -pub unsafe fn init() { - unsafe { - am::init(); +pub fn init() { + am::init(); + gucs::init(); + for x in gucs::prewarm_dim() { + projection::prewarm(x as _); } } diff --git a/src/index/opclass.rs b/src/index/opclass.rs index a2dc8618..63a6be5d 100644 --- a/src/index/opclass.rs +++ b/src/index/opclass.rs @@ -1,3 +1,13 @@ +use crate::datatype::memory_halfvec::{HalfvecInput, HalfvecOutput}; +use crate::datatype::memory_vector::{VectorInput, VectorOutput}; +use algorithm::types::*; +use distance::Distance; +use pgrx::datum::FromDatum; +use pgrx::heap_tuple::PgHeapTuple; +use pgrx::pg_sys::Datum; +use std::num::NonZero; +use vector::VectorBorrowed; + #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchordrq_support_vector_l2_ops() -> String { "vector_l2_ops".to_string() @@ -27,3 +37,139 @@ fn _vchordrq_support_halfvec_ip_ops() -> String { fn _vchordrq_support_halfvec_cosine_ops() -> String { "halfvec_cosine_ops".to_string() } + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum PostgresDistanceKind { + L2, + Ip, + Cosine, +} + +pub struct Sphere { + pub center: T, + pub radius: f32, +} + +#[derive(Debug, Clone, Copy)] +pub struct Opfamily { + vector: VectorKind, + postgres_distance: PostgresDistanceKind, +} + +impl Opfamily { + fn input(self, vector: BorrowedVector<'_>) -> OwnedVector { + use {BorrowedVector as B, OwnedVector as O, PostgresDistanceKind as D}; + match (vector, self.postgres_distance) { + (B::Vecf32(x), D::L2) => O::Vecf32(x.own()), + (B::Vecf32(x), D::Ip) => O::Vecf32(x.own()), + (B::Vecf32(x), D::Cosine) => O::Vecf32(x.function_normalize()), + (B::Vecf16(x), D::L2) => O::Vecf16(x.own()), + (B::Vecf16(x), D::Ip) => O::Vecf16(x.own()), + (B::Vecf16(x), D::Cosine) => O::Vecf16(x.function_normalize()), + } + } + pub unsafe fn input_vector(self, datum: Datum, is_null: bool) -> Option { + if is_null || datum.is_null() { + return None; + } + let vector = match self.vector { + VectorKind::Vecf32 => { + let vector = unsafe { VectorInput::from_datum(datum, false).unwrap() }; + self.input(BorrowedVector::Vecf32(vector.as_borrowed())) + } + VectorKind::Vecf16 => { + let vector = unsafe { HalfvecInput::from_datum(datum, false).unwrap() }; + self.input(BorrowedVector::Vecf16(vector.as_borrowed())) + } + }; + Some(vector) + } + pub unsafe fn input_sphere(self, datum: Datum, is_null: bool) -> Option> { + if is_null || datum.is_null() { + return None; + } + let attno_1 = NonZero::new(1_usize).unwrap(); + let attno_2 = NonZero::new(2_usize).unwrap(); + let tuple = unsafe { PgHeapTuple::from_composite_datum(datum) }; + let center = match self.vector { + VectorKind::Vecf32 => { + let vector = tuple.get_by_index::(attno_1).unwrap()?; + self.input(BorrowedVector::Vecf32(vector.as_borrowed())) + } + VectorKind::Vecf16 => { + let vector = tuple.get_by_index::(attno_1).unwrap()?; + self.input(BorrowedVector::Vecf16(vector.as_borrowed())) + } + }; + let radius = tuple.get_by_index::(attno_2).unwrap()?; + Some(Sphere { center, radius }) + } + pub fn output(self, x: Distance) -> f32 { + match self.postgres_distance { + PostgresDistanceKind::Cosine => x.to_f32() + 1.0f32, + PostgresDistanceKind::L2 => x.to_f32().sqrt(), + PostgresDistanceKind::Ip => x.to_f32(), + } + } + pub const fn distance_kind(self) -> DistanceKind { + match self.postgres_distance { + PostgresDistanceKind::L2 => DistanceKind::L2, + PostgresDistanceKind::Ip | PostgresDistanceKind::Cosine => DistanceKind::Dot, + } + } + pub const fn vector_kind(self) -> VectorKind { + self.vector + } +} + +pub unsafe fn opfamily(index_relation: pgrx::pg_sys::Relation) -> Opfamily { + use pgrx::pg_sys::Oid; + + let proc = unsafe { pgrx::pg_sys::index_getprocid(index_relation, 1, 1) }; + + if proc == Oid::INVALID { + pgrx::error!("support function 1 is not found"); + } + + let mut flinfo = pgrx::pg_sys::FmgrInfo::default(); + + unsafe { + pgrx::pg_sys::fmgr_info(proc, &mut flinfo); + } + + let fn_addr = flinfo.fn_addr.expect("null function pointer"); + + let mut fcinfo = unsafe { std::mem::zeroed::() }; + fcinfo.flinfo = &mut flinfo; + fcinfo.fncollation = pgrx::pg_sys::DEFAULT_COLLATION_OID; + fcinfo.context = std::ptr::null_mut(); + fcinfo.resultinfo = std::ptr::null_mut(); + fcinfo.isnull = true; + fcinfo.nargs = 0; + + let result_datum = unsafe { pgrx::pg_sys::ffi::pg_guard_ffi_boundary(|| fn_addr(&mut fcinfo)) }; + + let result_option = unsafe { String::from_datum(result_datum, fcinfo.isnull) }; + + let result_string = result_option.expect("null return value"); + + let (vector, postgres_distance) = match result_string.as_str() { + "vector_l2_ops" => (VectorKind::Vecf32, PostgresDistanceKind::L2), + "vector_ip_ops" => (VectorKind::Vecf32, PostgresDistanceKind::Ip), + "vector_cosine_ops" => (VectorKind::Vecf32, PostgresDistanceKind::Cosine), + "halfvec_l2_ops" => (VectorKind::Vecf16, PostgresDistanceKind::L2), + "halfvec_ip_ops" => (VectorKind::Vecf16, PostgresDistanceKind::Ip), + "halfvec_cosine_ops" => (VectorKind::Vecf16, PostgresDistanceKind::Cosine), + _ => pgrx::error!("unknown operator class"), + }; + + unsafe { + pgrx::pg_sys::pfree(result_datum.cast_mut_ptr()); + } + + Opfamily { + vector, + postgres_distance, + } +} diff --git a/src/projection.rs b/src/index/projection.rs similarity index 51% rename from src/projection.rs rename to src/index/projection.rs index fbcaeffa..ca07e24f 100644 --- a/src/projection.rs +++ b/src/index/projection.rs @@ -1,5 +1,7 @@ +use half::f16; use random_orthogonal_matrix::random_orthogonal_matrix; use std::sync::OnceLock; +use vector::vect::{VectBorrowed, VectOwned}; fn matrix(n: usize) -> Option<&'static Vec>> { static MATRIXS: [OnceLock>>; 1 + 60000] = [const { OnceLock::new() }; 1 + 60000]; @@ -20,3 +22,25 @@ pub fn project(vector: &[f32]) -> Vec { .map(|i| f32::reduce_sum_of_xy(vector, &matrix[i])) .collect() } + +pub trait RandomProject { + type Output; + fn project(self) -> Self::Output; +} + +impl RandomProject for VectBorrowed<'_, f32> { + type Output = VectOwned; + fn project(self) -> VectOwned { + VectOwned::new(project(self.slice())) + } +} + +impl RandomProject for VectBorrowed<'_, f16> { + type Output = VectOwned; + fn project(self) -> VectOwned { + use simd::Floating; + VectOwned::new(f16::vector_from_f32(&project(&f16::vector_to_f32( + self.slice(), + )))) + } +} diff --git a/src/postgres.rs b/src/index/storage.rs similarity index 92% rename from src/postgres.rs rename to src/index/storage.rs index f68d0fa7..5d0cfc94 100644 --- a/src/postgres.rs +++ b/src/index/storage.rs @@ -1,4 +1,4 @@ -use crate::algorithm::{Opaque, Page, PageGuard, RelationRead, RelationWrite}; +use algorithm::{Opaque, Page, PageGuard, RelationRead, RelationWrite}; use std::mem::{MaybeUninit, offset_of}; use std::ops::{Deref, DerefMut}; use std::ptr::NonNull; @@ -16,6 +16,7 @@ const fn size_of_contents() -> usize { } #[repr(C, align(8))] +#[derive(Debug)] pub struct PostgresPage { header: pgrx::pg_sys::PageHeaderData, content: [u8; size_of_contents()], @@ -42,31 +43,13 @@ impl PostgresPage { assert_eq!(offset_of!(Self, opaque), this.header.pd_special as usize); this } - #[allow(dead_code)] - fn clone_into_boxed(&self) -> Box { + pub fn clone_into_boxed(&self) -> Box { let mut result = Box::new_uninit(); unsafe { std::ptr::copy(self as *const Self, result.as_mut_ptr(), 1); result.assume_init() } } - #[allow(dead_code)] - fn reconstruct(&mut self, removes: &[u16]) { - let mut removes = removes.to_vec(); - removes.sort(); - removes.dedup(); - let n = removes.len(); - if n > 0 { - assert!(removes[n - 1] <= self.len()); - unsafe { - pgrx::pg_sys::PageIndexMultiDelete( - (self as *mut Self).cast(), - removes.as_ptr().cast_mut(), - removes.len() as _, - ); - } - } - } } impl Page for PostgresPage { @@ -257,16 +240,6 @@ impl PostgresRelation { pub unsafe fn new(raw: pgrx::pg_sys::Relation) -> Self { Self { raw } } - - #[allow(dead_code)] - pub fn len(&self) -> u32 { - unsafe { - pgrx::pg_sys::RelationGetNumberOfBlocksInFork( - self.raw, - pgrx::pg_sys::ForkNumber::MAIN_FORKNUM, - ) - } - } } impl RelationRead for PostgresRelation { diff --git a/src/types.rs b/src/index/types.rs similarity index 58% rename from src/types.rs rename to src/index/types.rs index 4ef2171d..df238462 100644 --- a/src/types.rs +++ b/src/index/types.rs @@ -1,13 +1,12 @@ -use half::f16; +use algorithm::types::VchordrqIndexOptions; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError, ValidationErrors}; -use vector::vect::{VectBorrowed, VectOwned}; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] pub struct VchordrqInternalBuildOptions { #[serde(default = "VchordrqInternalBuildOptions::default_lists")] - #[validate(length(min = 1, max = 8), custom(function = VchordrqInternalBuildOptions::validate_lists))] + #[validate(length(min = 0, max = 8), custom(function = VchordrqInternalBuildOptions::validate_lists))] pub lists: Vec, #[serde(default = "VchordrqInternalBuildOptions::default_spherical_centroids")] pub spherical_centroids: bool, @@ -63,20 +62,20 @@ pub struct VchordrqExternalBuildOptions { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] #[serde(rename_all = "snake_case")] -pub enum VchordrqBuildOptions { +pub enum VchordrqBuildSourceOptions { Internal(VchordrqInternalBuildOptions), External(VchordrqExternalBuildOptions), } -impl Default for VchordrqBuildOptions { +impl Default for VchordrqBuildSourceOptions { fn default() -> Self { Self::Internal(Default::default()) } } -impl Validate for VchordrqBuildOptions { +impl Validate for VchordrqBuildSourceOptions { fn validate(&self) -> Result<(), ValidationErrors> { - use VchordrqBuildOptions::*; + use VchordrqBuildSourceOptions::*; match self { Internal(internal_build) => internal_build.validate(), External(external_build) => external_build.validate(), @@ -86,65 +85,24 @@ impl Validate for VchordrqBuildOptions { #[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] -pub struct VchordrqIndexingOptions { - #[serde(default = "VchordrqIndexingOptions::default_residual_quantization")] - pub residual_quantization: bool, - pub build: VchordrqBuildOptions, +#[serde(rename_all = "snake_case")] +pub struct VchordrqBuildOptions { + #[serde(flatten)] + pub source: VchordrqBuildSourceOptions, + #[serde(default = "VchordrqBuildOptions::default_pin")] + pub pin: bool, } -impl VchordrqIndexingOptions { - fn default_residual_quantization() -> bool { +impl VchordrqBuildOptions { + pub fn default_pin() -> bool { false } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum OwnedVector { - Vecf32(VectOwned), - Vecf16(VectOwned), -} - -#[derive(Debug, Clone, Copy)] -pub enum BorrowedVector<'a> { - Vecf32(VectBorrowed<'a, f32>), - Vecf16(VectBorrowed<'a, f16>), -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum DistanceKind { - L2, - Dot, -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum VectorKind { - Vecf32, - Vecf16, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] -#[validate(schema(function = "Self::validate_self"))] -pub struct VectorOptions { - #[validate(range(min = 1, max = 1_048_575))] - #[serde(rename = "dimensions")] - pub dims: u32, - #[serde(rename = "vector")] - pub v: VectorKind, - #[serde(rename = "distance")] - pub d: DistanceKind, -} - -impl VectorOptions { - pub fn validate_self(&self) -> Result<(), ValidationError> { - match (self.v, self.d, self.dims) { - (VectorKind::Vecf32, DistanceKind::L2, 1..65536) => Ok(()), - (VectorKind::Vecf32, DistanceKind::Dot, 1..65536) => Ok(()), - (VectorKind::Vecf16, DistanceKind::L2, 1..65536) => Ok(()), - (VectorKind::Vecf16, DistanceKind::Dot, 1..65536) => Ok(()), - _ => Err(ValidationError::new("not valid vector options")), - } - } +pub struct VchordrqIndexingOptions { + #[serde(flatten)] + pub index: VchordrqIndexOptions, + pub build: VchordrqBuildOptions, } diff --git a/src/index/utils.rs b/src/index/utils.rs deleted file mode 100644 index 18234ac0..00000000 --- a/src/index/utils.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::num::NonZeroU64; - -pub const fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { - let value = pointer.get(); - pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { - bi_hi: ((value >> 32) & 0xffff) as u16, - bi_lo: ((value >> 16) & 0xffff) as u16, - }, - ip_posid: (value & 0xffff) as u16, - } -} - -pub const fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { - let mut value = 0; - value |= (ctid.ip_blkid.bi_hi as u64) << 32; - value |= (ctid.ip_blkid.bi_lo as u64) << 16; - value |= ctid.ip_posid as u64; - NonZeroU64::new(value).expect("invalid pointer") -} - -#[allow(dead_code)] -const fn soundness_check(a: pgrx::pg_sys::ItemPointerData) { - let b = ctid_to_pointer(a); - let c = pointer_to_ctid(b); - assert!(a.ip_blkid.bi_hi == c.ip_blkid.bi_hi); - assert!(a.ip_blkid.bi_lo == c.ip_blkid.bi_lo); - assert!(a.ip_posid == c.ip_posid); -} - -const _: () = soundness_check(pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { bi_hi: 1, bi_lo: 2 }, - ip_posid: 3, -}); diff --git a/src/lib.rs b/src/lib.rs index 187e3cda..20ce1836 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,32 +1,22 @@ -#![feature(vec_pop_if)] #![allow(clippy::collapsible_else_if)] -#![allow(clippy::infallible_destructuring_match)] #![allow(clippy::too_many_arguments)] -#![allow(clippy::type_complexity)] +#![allow(unsafe_code)] -mod algorithm; mod datatype; -mod gucs; mod index; -mod postgres; -mod projection; -mod types; mod upgrade; -mod utils; pgrx::pg_module_magic!(); pgrx::extension_sql_file!("./sql/bootstrap.sql", bootstrap); pgrx::extension_sql_file!("./sql/finalize.sql", finalize); #[pgrx::pg_guard] -unsafe extern "C" fn _PG_init() { +extern "C" fn _PG_init() { if unsafe { pgrx::pg_sys::IsUnderPostmaster } { pgrx::error!("vchord must be loaded via shared_preload_libraries."); } + index::init(); unsafe { - index::init(); - gucs::init(); - #[cfg(any(feature = "pg13", feature = "pg14"))] pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchord".as_ptr()); #[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17"))] diff --git a/src/upgrade/symbols.rs b/src/upgrade.rs similarity index 100% rename from src/upgrade/symbols.rs rename to src/upgrade.rs diff --git a/src/upgrade/mod.rs b/src/upgrade/mod.rs deleted file mode 100644 index 6eb441db..00000000 --- a/src/upgrade/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod symbols; diff --git a/src/utils/mod.rs b/src/utils/mod.rs deleted file mode 100644 index 85a84e0e..00000000 --- a/src/utils/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod k_means; -pub mod parallelism; -pub mod pipe; diff --git a/src/utils/parallelism.rs b/src/utils/parallelism.rs deleted file mode 100644 index b960b568..00000000 --- a/src/utils/parallelism.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::any::Any; -use std::panic::AssertUnwindSafe; -use std::sync::Arc; - -pub use rayon::iter::ParallelIterator; - -pub trait Parallelism: Send + Sync { - fn check(&self); - - fn rayon_into_par_iter(&self, x: I) -> I::Iter; -} - -struct ParallelismCheckPanic(Box); - -pub struct RayonParallelism { - stop: Arc, -} - -impl RayonParallelism { - pub fn scoped( - num_threads: usize, - stop: Arc, - f: impl FnOnce(&Self) -> R, - ) -> Result { - match std::panic::catch_unwind(AssertUnwindSafe(|| { - rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) - .panic_handler(|e| { - if e.downcast_ref::().is_some() { - return; - } - log::error!("Asynchronous task panickied."); - }) - .build_scoped( - |thread| thread.run(), - |_| { - let pool = Self { stop: stop.clone() }; - f(&pool) - }, - ) - })) { - Ok(x) => x, - Err(e) => match e.downcast::() { - Ok(payload) => std::panic::resume_unwind((*payload).0), - Err(e) => std::panic::resume_unwind(e), - }, - } - } -} - -impl Parallelism for RayonParallelism { - fn check(&self) { - match std::panic::catch_unwind(AssertUnwindSafe(|| (self.stop)())) { - Ok(()) => (), - Err(payload) => std::panic::panic_any(ParallelismCheckPanic(payload)), - } - } - - fn rayon_into_par_iter(&self, x: I) -> I::Iter { - x.into_par_iter() - } -} diff --git a/taplo.toml b/taplo.toml new file mode 100644 index 00000000..41af6e14 --- /dev/null +++ b/taplo.toml @@ -0,0 +1,17 @@ +[formatting] +indent_string = " " + +[[rule]] +include = ["**/Cargo.toml"] +keys = [ + "dependencies", + "dev-dependencies", + "build-dependencies", + "target.*.dependencies", + "lints", + "patch.*", + "profile.*", + "workspace.dependencies", + "lints.dependencies", +] +formatting = { reorder_keys = true, reorder_arrays = true, align_comments = true } diff --git a/tests/logic/pin.slt b/tests/logic/pin.slt new file mode 100644 index 00000000..fbe923c3 --- /dev/null +++ b/tests/logic/pin.slt @@ -0,0 +1,18 @@ +statement ok +CREATE TABLE t (val vector(3)); + +statement ok +INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); + +statement ok +CREATE INDEX ON t USING vchordrq (val vector_ip_ops) +WITH (options = $$ +residual_quantization = false +build.pin = true +[build.internal] +lists = [32] +spherical_centroids = true +$$); + +statement ok +DROP TABLE t; diff --git a/tests/logic/pushdown_plan.slt b/tests/logic/pushdown_plan.slt index 78bfe440..f6aecc48 100644 --- a/tests/logic/pushdown_plan.slt +++ b/tests/logic/pushdown_plan.slt @@ -12,6 +12,9 @@ FROM generate_series(1, 10000); statement ok CREATE INDEX ind0 ON t USING vchordrq (val0 vector_l2_ops); +statement ok +SET enable_seqscan TO off; + # statement ok # CREATE INDEX ind1 ON t USING vchordrq (val1 halfvec_dot_ops); diff --git a/tests/logic/rerank_in_table.slt b/tests/logic/rerank_in_table.slt new file mode 100644 index 00000000..f6a138a1 --- /dev/null +++ b/tests/logic/rerank_in_table.slt @@ -0,0 +1,67 @@ +statement ok +CREATE TABLE t_column (id integer, val vector(3)); + +statement ok +INSERT INTO t_column (id, val) SELECT id, ARRAY[id, id, id]::real[] FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_column USING vchordrq (val vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = true +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +query I +SELECT id FROM t_column ORDER BY val <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +statement ok +DROP TABLE t_column; + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = true +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +statement ok +DROP TABLE t_expr; diff --git a/tools/package.sh b/tools/package.sh deleted file mode 100755 index e8c92bf1..00000000 --- a/tools/package.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash -set -eu - -printf "SEMVER = ${SEMVER}\n" -printf "VERSION = ${VERSION}\n" -printf "ARCH = ${ARCH}\n" -printf "PLATFORM = ${PLATFORM}\n" - -cargo build --lib --features pg$VERSION --release -cargo pgrx schema --features pg$VERSION --out ./target/schema.sql - -rm -rf ./build - -mkdir -p ./build/zip -cp -a ./sql/upgrade/. ./build/zip/ -cp ./target/schema.sql ./build/zip/vchord--$SEMVER.sql -sed -e "s/@CARGO_VERSION@/$SEMVER/g" < ./vchord.control > ./build/zip/vchord.control -cp ./target/release/libvchord.so ./build/zip/vchord.so -zip ./build/postgresql-${VERSION}-vchord_${SEMVER}_${ARCH}-linux-gnu.zip -j ./build/zip/* - -mkdir -p ./build/deb -mkdir -p ./build/deb/DEBIAN -mkdir -p ./build/deb/usr/share/postgresql/$VERSION/extension/ -mkdir -p ./build/deb/usr/lib/postgresql/$VERSION/lib/ -for file in $(ls ./build/zip/*.sql | xargs -n 1 basename); do - cp ./build/zip/$file ./build/deb/usr/share/postgresql/$VERSION/extension/$file -done -for file in $(ls ./build/zip/*.control | xargs -n 1 basename); do - cp ./build/zip/$file ./build/deb/usr/share/postgresql/$VERSION/extension/$file -done -for file in $(ls ./build/zip/*.so | xargs -n 1 basename); do - cp ./build/zip/$file ./build/deb/usr/lib/postgresql/$VERSION/lib/$file -done -echo "Package: postgresql-${VERSION}-vchord -Version: ${SEMVER}-1 -Section: database -Priority: optional -Architecture: ${PLATFORM} -Maintainer: Tensorchord -Description: Vector database plugin for Postgres, written in Rust, specifically designed for LLM -Homepage: https://vectorchord.ai/ -License: AGPL-3 or Elastic-2" \ -> ./build/deb/DEBIAN/control -(cd ./build/deb && md5sum usr/share/postgresql/$VERSION/extension/* usr/lib/postgresql/$VERSION/lib/*) > ./build/deb/DEBIAN/md5sums -dpkg-deb --root-owner-group -Zxz --build ./build/deb/ ./build/postgresql-${VERSION}-vchord_${SEMVER}-1_${PLATFORM}.deb