From 370d7f44838694f7546ef3932c2559aae179429c Mon Sep 17 00:00:00 2001 From: Stephen Cox Date: Sun, 12 Apr 2026 11:16:55 +1200 Subject: [PATCH] ggml-metal: add Metal kernel for ggml_roll (circular shift) Add native Metal GPU support for the ROLL operation, which performs circular shifts along tensor dimensions. Previously this op had no Metal kernel, causing CPU fallbacks and graph splits on Apple Silicon. The kernel uses the same wrap-around index logic as the CPU implementation: for each element, compute the source index as (dst_idx - shift) mod dim_size for each dimension. Files changed: - ggml-metal-impl.h: add ggml_metal_kargs_roll struct - ggml-metal-device.m: register GGML_OP_ROLL as supported - ggml-metal-device.cpp: add pipeline name mapping - ggml-metal-ops.h: declare ggml_metal_op_roll - ggml-metal-ops.cpp: dispatch function - ggml-metal.metal: kernel_roll shader Co-Authored-By: Claude Opus 4.6 (1M context) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 19 +++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 52 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 38 +++++++++++++++++ 6 files changed, 112 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e8548b053e8..aabfb7228f0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -67,6 +67,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml switch (op) { case GGML_OP_ADD_ID: op_str = "add_id"; break; case GGML_OP_CONCAT: op_str = "concat"; break; + case GGML_OP_ROLL: op_str = "roll"; break; default: GGML_ABORT("fatal error"); }; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 40cacb46520..5adf04ab21a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1065,6 +1065,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: case GGML_OP_CONCAT: + case GGML_OP_ROLL: return true; case GGML_OP_ADD: case GGML_OP_SUB: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 62b028f4a4a..1a846c38b6f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -166,6 +166,25 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t s3; +} ggml_metal_kargs_roll; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 846225d9077..e3d6792c5f3 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -267,6 +267,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_concat(ctx, idx); } break; + case GGML_OP_ROLL: + { + n_fuse = ggml_metal_op_roll(ctx, idx); + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -567,6 +571,54 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ggml_get_op_params_i32(op, 0); + const int32_t s1 = ggml_get_op_params_i32(op, 1); + const int32_t s2 = ggml_get_op_params_i32(op, 2); + const int32_t s3 = ggml_get_op_params_i32(op, 3); + + ggml_metal_kargs_roll args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.s2 =*/ s2, + /*.s3 =*/ s3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ROLL); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(1024, ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 50e3c5c77a1..8559603e8f1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -44,6 +44,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f67c5cd8a1d..06c05f492ff 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7399,6 +7399,44 @@ kernel void kernel_concat( } } +// circular shift (roll) along all dimensions +kernel void kernel_roll( + constant ggml_metal_kargs_roll & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + // wrap source indices for dims 1-3 + int i01 = i1 - args.s1; + i01 = i01 < 0 ? i01 + args.ne01 : (i01 >= args.ne01 ? i01 - args.ne01 : i01); + int i02 = i2 - args.s2; + i02 = i02 < 0 ? i02 + args.ne02 : (i02 >= args.ne02 ? i02 - args.ne02 : i02); + int i03 = i3 - args.s3; + i03 = i03 < 0 ? i03 + args.ne03 : (i03 >= args.ne03 ? i03 - args.ne03 : i03); + + device const char * src_row = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_row = dst + i3 *args.nb3 + i2 *args.nb2 + i1 *args.nb1; + + // wrap source index for dim 0 + int s = -args.s0; + s = s < 0 ? s + args.ne00 : (s >= args.ne00 ? s - args.ne00 : s); + + for (int i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + int i00 = i0 + s; + if (i00 >= args.ne00) { + i00 -= args.ne00; + } + + *((device float *)(dst_row + i0*args.nb0)) = *((device const float *)(src_row + i00*args.nb00)); + } +} + template void kernel_mul_mv_q2_K_f32_impl( args_t args,