Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml
switch (op) {
case GGML_OP_ADD_ID: op_str = "add_id"; break;
case GGML_OP_CONCAT: op_str = "concat"; break;
case GGML_OP_ROLL: op_str = "roll"; break;
default: GGML_ABORT("fatal error");
};

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
case GGML_OP_ROLL:
return true;
case GGML_OP_ADD:
case GGML_OP_SUB:
Expand Down
19 changes: 19 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ typedef struct {
int32_t dim;
} ggml_metal_kargs_concat;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t s0;
int32_t s1;
int32_t s2;
int32_t s3;
} ggml_metal_kargs_roll;

typedef struct {
int32_t ne00;
int32_t ne01;
Expand Down
52 changes: 52 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_concat(ctx, idx);
} break;
case GGML_OP_ROLL:
{
n_fuse = ggml_metal_op_roll(ctx, idx);
} break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
Expand Down Expand Up @@ -567,6 +571,54 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);

const int32_t s0 = ggml_get_op_params_i32(op, 0);
const int32_t s1 = ggml_get_op_params_i32(op, 1);
const int32_t s2 = ggml_get_op_params_i32(op, 2);
const int32_t s3 = ggml_get_op_params_i32(op, 3);

ggml_metal_kargs_roll args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.s0 =*/ s0,
/*.s1 =*/ s1,
/*.s2 =*/ s2,
/*.s3 =*/ s3,
};

auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ROLL);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);

const int nth = std::min(1024, ne00);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);

return 1;
}

int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);

int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
Expand Down
38 changes: 38 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -7399,6 +7399,44 @@ kernel void kernel_concat(
}
}

// circular shift (roll) along all dimensions
kernel void kernel_roll(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {

const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;

// wrap source indices for dims 1-3
int i01 = i1 - args.s1;
i01 = i01 < 0 ? i01 + args.ne01 : (i01 >= args.ne01 ? i01 - args.ne01 : i01);
int i02 = i2 - args.s2;
i02 = i02 < 0 ? i02 + args.ne02 : (i02 >= args.ne02 ? i02 - args.ne02 : i02);
int i03 = i3 - args.s3;
i03 = i03 < 0 ? i03 + args.ne03 : (i03 >= args.ne03 ? i03 - args.ne03 : i03);

device const char * src_row = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device char * dst_row = dst + i3 *args.nb3 + i2 *args.nb2 + i1 *args.nb1;

// wrap source index for dim 0
int s = -args.s0;
s = s < 0 ? s + args.ne00 : (s >= args.ne00 ? s - args.ne00 : s);

for (int i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
int i00 = i0 + s;
if (i00 >= args.ne00) {
i00 -= args.ne00;
}

*((device float *)(dst_row + i0*args.nb0)) = *((device const float *)(src_row + i00*args.nb00));
}
}

template<int nr0, typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
args_t args,
Expand Down