diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e8548b053e8..aabfb7228f0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -67,6 +67,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml switch (op) { case GGML_OP_ADD_ID: op_str = "add_id"; break; case GGML_OP_CONCAT: op_str = "concat"; break; + case GGML_OP_ROLL: op_str = "roll"; break; default: GGML_ABORT("fatal error"); }; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 40cacb46520..5adf04ab21a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1065,6 +1065,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: case GGML_OP_CONCAT: + case GGML_OP_ROLL: return true; case GGML_OP_ADD: case GGML_OP_SUB: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 62b028f4a4a..1a846c38b6f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -166,6 +166,25 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t s3; +} ggml_metal_kargs_roll; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 846225d9077..e3d6792c5f3 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -267,6 +267,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_concat(ctx, idx); } break; + case GGML_OP_ROLL: + { + n_fuse = ggml_metal_op_roll(ctx, idx); + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -567,6 +571,54 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ggml_get_op_params_i32(op, 0); + const int32_t s1 = ggml_get_op_params_i32(op, 1); + const int32_t s2 = ggml_get_op_params_i32(op, 2); + const int32_t s3 = ggml_get_op_params_i32(op, 3); + + ggml_metal_kargs_roll args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.s2 =*/ s2, + /*.s3 =*/ s3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ROLL); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(1024, ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 50e3c5c77a1..8559603e8f1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -44,6 +44,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f67c5cd8a1d..06c05f492ff 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7399,6 +7399,44 @@ kernel void kernel_concat( } } +// circular shift (roll) along all dimensions +kernel void kernel_roll( + constant ggml_metal_kargs_roll & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + // wrap source indices for dims 1-3 + int i01 = i1 - args.s1; + i01 = i01 < 0 ? i01 + args.ne01 : (i01 >= args.ne01 ? i01 - args.ne01 : i01); + int i02 = i2 - args.s2; + i02 = i02 < 0 ? i02 + args.ne02 : (i02 >= args.ne02 ? i02 - args.ne02 : i02); + int i03 = i3 - args.s3; + i03 = i03 < 0 ? i03 + args.ne03 : (i03 >= args.ne03 ? i03 - args.ne03 : i03); + + device const char * src_row = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_row = dst + i3 *args.nb3 + i2 *args.nb2 + i1 *args.nb1; + + // wrap source index for dim 0 + int s = -args.s0; + s = s < 0 ? s + args.ne00 : (s >= args.ne00 ? s - args.ne00 : s); + + for (int i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + int i00 = i0 + s; + if (i00 >= args.ne00) { + i00 -= args.ne00; + } + + *((device float *)(dst_row + i0*args.nb0)) = *((device const float *)(src_row + i00*args.nb00)); + } +} + template void kernel_mul_mv_q2_K_f32_impl( args_t args,