Skip to content

Misc. bug: CUDA ggml_top_k() implementation crashes for large tensor shapes #21162

@fairydreaming

Description

@fairydreaming

Name and Version

./bin/llama-cli --version
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
version: 8580 (7c20367)
built with GNU 13.3.0 for Linux x86_64

Operating systems

Linux

Which llama.cpp modules do you know to be affected?

Other (Please specify in the next section)

Command line

./bin/test-backend-ops -o TOP_K

Problem description & steps to reproduce

How to reproduce the problem:

Step 1: Apply the following patch to test-backend-ops.cpp:

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 6a4f9b634..096ec58b1 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -8461,6 +8461,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     //    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
     //}
 
+    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {512, 512, 1, 1}, 512));
+    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 512, 1, 1}, 1024));
+    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1536, 512, 1, 1}, 1536));
+
     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {
         test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
         test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));

Step 2: Make a CUDA build of llama.cpp. CUDA version I used:

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0

Step 3: Run TOP_K op tests:

./bin/test-backend-ops -o TOP_K

Observed result: Test crashes

ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[512,512,1,1],k=512,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,512,1,1],k=1024,ties=0): OK
CUDA error: an illegal memory access was encountered
  current device: 0, in function ggml_backend_cuda_synchronize at /home/phm/projects/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2924

Full logs below.

Expected result: Test finishes successfully

Rationale: #21149 needs this working for tensor shapes up to {context size, ubatch size, 1, B}.

Note: Everything works correctly if you build llama.cpp with NVIDIA CUDA CCCL library (version >3.2) and enabled GGML_CUDA_USE_CUB. Tested on CUDA 13.2 and CCCL 13.2.27.

First Bad Commit

No idea if this ever worked correctly for tensors this big.

Relevant log output

Logs
$ ./bin/test-backend-ops -o TOP_K
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
Testing 2 devices

Backend 1/2: CUDA0
  Device description: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition
  Device memory: 97247 MB (96640 MB free)

  TOP_K(type=f32,ne=[1,2,1,3],k=1,ties=1): OK
  TOP_K(type=f32,ne=[2,2,1,3],k=1,ties=1): OK
  TOP_K(type=f32,ne=[2,2,1,3],k=2,ties=1): OK
  TOP_K(type=f32,ne=[3,2,1,3],k=1,ties=1): OK
  TOP_K(type=f32,ne=[3,2,1,3],k=2,ties=1): OK
  TOP_K(type=f32,ne=[3,2,1,3],k=3,ties=1): OK
  TOP_K(type=f32,ne=[4,2,1,3],k=1,ties=1): OK
  TOP_K(type=f32,ne=[4,2,1,3],k=2,ties=1): OK
  TOP_K(type=f32,ne=[4,2,1,3],k=3,ties=1): OK
  TOP_K(type=f32,ne=[4,2,1,3],k=4,ties=1): OK
  TOP_K(type=f32,ne=[1,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[12,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[12,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[2,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[13,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[13,1,2,1],k=1,ties=1): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[2,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[13,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[13,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[4,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[15,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[15,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[4,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[15,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[15,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[4,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[15,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[15,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[8,1,1,1],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[19,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[8,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[8,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[8,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[19,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[16,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[16,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[16,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[16,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[16,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[27,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[32,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[32,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[32,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[32,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[32,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[43,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[64,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[64,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[64,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[64,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[64,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[75,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[128,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[128,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[128,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[128,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[128,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[128,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[139,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[256,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[256,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[256,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[256,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[256,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[256,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[267,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[512,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[523,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[1024,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[1035,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[2048,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[2059,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[4096,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[4107,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[8192,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[8203,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[16395,1,2,1],k=9999,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[32768,1,1,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[32779,1,2,1],k=9999,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[65536,1,1,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[65547,1,2,1],k=9999,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[131072,1,1,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[131083,1,2,1],k=9999,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[262155,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[262144,1,1,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[262155,1,2,1],k=9999,ties=1): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[524288,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=1,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=2,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=3,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=7,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=15,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=100,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=100,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=500,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=500,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=1023,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=1023,ties=1): OK
  TOP_K(type=f32,ne=[524288,1,1,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=9999,ties=0): OK
  TOP_K(type=f32,ne=[524299,1,2,1],k=9999,ties=1): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[16,10,10,10],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[60,10,10,10],k=1,ties=0): OK
  TOP_K(type=f32,ne=[1023,2,1,3],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,2,1,3],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[1025,2,1,3],k=1,ties=0): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=1,ties=0): OK
  TOP_K(type=f32,ne=[2047,2,1,3],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[2048,2,1,3],k=1,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[2049,2,1,3],k=1,ties=0): OK
  TOP_K(type=f32,ne=[16,10,10,10],k=2,ties=0): OK
  TOP_K(type=f32,ne=[60,10,10,10],k=2,ties=0): OK
  TOP_K(type=f32,ne=[1023,2,1,3],k=2,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,2,1,3],k=2,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[1025,2,1,3],k=2,ties=0): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=2,ties=0): OK
  TOP_K(type=f32,ne=[2047,2,1,3],k=2,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[2048,2,1,3],k=2,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[2049,2,1,3],k=2,ties=0): OK
  TOP_K(type=f32,ne=[16,10,10,10],k=3,ties=0): OK
  TOP_K(type=f32,ne=[60,10,10,10],k=3,ties=0): OK
  TOP_K(type=f32,ne=[1023,2,1,3],k=3,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,2,1,3],k=3,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[1025,2,1,3],k=3,ties=0): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=3,ties=0): OK
  TOP_K(type=f32,ne=[2047,2,1,3],k=3,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[2048,2,1,3],k=3,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[2049,2,1,3],k=3,ties=0): OK
  TOP_K(type=f32,ne=[16,10,10,10],k=7,ties=0): OK
  TOP_K(type=f32,ne=[60,10,10,10],k=7,ties=0): OK
  TOP_K(type=f32,ne=[1023,2,1,3],k=7,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,2,1,3],k=7,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[1025,2,1,3],k=7,ties=0): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=7,ties=0): OK
  TOP_K(type=f32,ne=[2047,2,1,3],k=7,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[2048,2,1,3],k=7,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[2049,2,1,3],k=7,ties=0): OK
  TOP_K(type=f32,ne=[16,10,10,10],k=15,ties=0): OK
  TOP_K(type=f32,ne=[60,10,10,10],k=15,ties=0): OK
  TOP_K(type=f32,ne=[1023,2,1,3],k=15,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,2,1,3],k=15,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[1025,2,1,3],k=15,ties=0): OK
  TOP_K(type=f32,ne=[16384,1,1,1],k=15,ties=0): OK
  TOP_K(type=f32,ne=[2047,2,1,3],k=15,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[2048,2,1,3],k=15,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[2049,2,1,3],k=15,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup reset
  TOP_K(type=f32,ne=[512,512,1,1],k=512,ties=0): OK
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  TOP_K(type=f32,ne=[1024,512,1,1],k=1024,ties=0): OK
CUDA error: an illegal memory access was encountered
  current device: 0, in function ggml_backend_cuda_synchronize at /home/phm/projects/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2924
  cudaStreamSynchronize(cuda_ctx->stream())
/home/phm/projects/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:98: CUDA error
[New LWP 32805]
[New LWP 32804]
[New LWP 32803]
[New LWP 32802]
[New LWP 32801]
[New LWP 32797]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007a1f50310813 in __GI___wait4 (pid=32806, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30	../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#0  0x00007a1f50310813 in __GI___wait4 (pid=32806, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30	in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x00007a1f509d97b3 in ggml_print_backtrace () from /home/phm/projects/llama.cpp/build-cuda/bin/libggml-base.so.0
#2  0x00007a1f509d995b in ggml_abort () from /home/phm/projects/llama.cpp/build-cuda/bin/libggml-base.so.0
#3  0x00007a1f4df8b627 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) () from /home/phm/projects/llama.cpp/build-cuda/bin/libggml-cuda.so.0
#4  0x00007a1f4df8cad8 in ggml_backend_cuda_synchronize(ggml_backend*) () from /home/phm/projects/llama.cpp/build-cuda/bin/libggml-cuda.so.0
#5  0x00007a1f509f0c7c in ggml_backend_graph_compute () from /home/phm/projects/llama.cpp/build-cuda/bin/libggml-base.so.0
#6  0x00007a1f509f5be1 in ggml_backend_compare_graph_backend () from /home/phm/projects/llama.cpp/build-cuda/bin/libggml-base.so.0
#7  0x000058fda0bda438 in test_case::eval(ggml_backend*, ggml_backend*, char const*, printer*) ()
#8  0x000058fda0ba110e in test_backend(ggml_backend*, test_mode, char const*, char const*, printer*, char const*) ()
#9  0x000058fda0b70ec0 in main ()
[Inferior 1 (process 32796) detached]
Aborted (core dumped)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions