@@ -272,6 +272,272 @@ namespace chatllm::deepseek::v2_light
272272 norm->load (path + " q_norm." , loader);
273273 }
274274 }
275+
276+ BaseMLAttention::BaseMLAttention (InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length,
277+ int q_lora_rank, int kv_lora_rank, int rope_dim, int qk_nope_head_dim, int v_head_dim,
278+ bool use_bias,
279+ int cache_length)
280+ : KVCacheAttention(ctx,
281+ num_attention_heads, num_kv_heads,
282+ BlockParams::Optimization::speed ? (qk_nope_head_dim + rope_dim) * num_kv_heads : rope_dim * 1,
283+ BlockParams::Optimization::speed ? v_head_dim * num_kv_heads : kv_lora_rank,
284+ max_length,
285+ cache_length),
286+ opt_speed(BlockParams::Optimization::speed),
287+ kv_lora_rank(kv_lora_rank),
288+ rope_dim(rope_dim),
289+ qk_nope_head_dim(qk_nope_head_dim),
290+ v_head_dim(v_head_dim),
291+ d_kv_proj(ctx, hidden_size, kv_lora_rank, nullptr , use_bias),
292+ k_pe_proj(ctx, hidden_size, rope_dim, nullptr , use_bias),
293+ u_k_nope_proj(ctx, kv_lora_rank, qk_nope_head_dim * num_kv_heads, nullptr , false ),
294+ u_v_proj(ctx, kv_lora_rank, v_head_dim * num_kv_heads, nullptr , false ),
295+ q_proj(ctx, hidden_size, num_attention_heads, q_lora_rank, rope_dim, qk_nope_head_dim, use_bias),
296+ o_proj(ctx, v_head_dim * num_attention_heads, hidden_size, use_bias),
297+ kv_norm(ctx, kv_lora_rank)
298+ {
299+ }
300+
301+ BaseMLAttention::BaseMLAttention (InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length,
302+ int q_lora_rank, int kv_lora_rank, int rope_dim, int qk_nope_head_dim, int v_head_dim,
303+ bool use_bias)
304+ : BaseMLAttention(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length,
305+ q_lora_rank, kv_lora_rank, rope_dim, qk_nope_head_dim, v_head_dim,
306+ use_bias,
307+ max_length)
308+ {}
309+
310+ int64_t BaseMLAttention::get_param_num (bool effective_only) const
311+ {
312+ int64_t r = 0 ;
313+ r += d_kv_proj.get_param_num (effective_only);
314+ r += k_pe_proj.get_param_num (effective_only);
315+ r += u_k_nope_proj.get_param_num (effective_only);
316+ r += u_v_proj.get_param_num (effective_only);
317+ r += q_proj.get_param_num (effective_only);
318+ r += kv_norm.get_param_num (effective_only);
319+ r += o_proj.get_param_num (effective_only);
320+ return r;
321+ }
322+
323+ ggml::tensor *BaseMLAttention::forward (ComputeContext *ctx, ggml::tensor *hidden_states, int n_past)
324+ {
325+ if (opt_speed)
326+ return forward_speed (ctx, hidden_states, n_past);
327+ else
328+ return forward_memory (ctx, hidden_states, n_past);
329+ }
330+
331+ void BaseMLAttention::load (const std::string &path, TensorLoader *loader)
332+ {
333+ KVCacheAttention::load (path, loader);
334+ q_proj.load (path + " " , loader);
335+
336+ d_kv_proj.load (path + " d_kv_proj." , loader);
337+ k_pe_proj.load (path + " k_pe_proj." , loader);
338+ u_k_nope_proj.load (path + " u_k_nope_proj." , loader);
339+ u_v_proj.load (path + " u_v_proj." , loader);
340+ o_proj.load (path + " o_proj." , loader);
341+ kv_norm.load (path + " kv_norm." , loader);
342+ }
343+
344+ ggml::tensor *BaseMLAttention::forward_speed (ComputeContext *ctx, ggml::tensor *hidden_states, int n_past)
345+ {
346+ const int hidden_size = o_proj.in_features ();
347+ const int qlen = (int )hidden_states->ne [1 ];
348+
349+ KVCacheAttention::before_forward (ctx, n_past, qlen);
350+
351+ ggml::tensor *kv_lora = d_kv_proj.forward (ctx, hidden_states);
352+ kv_lora = kv_norm.forward (ctx, kv_lora);
353+
354+ ggml::tensor *tmpv = u_v_proj.forward (ctx, kv_lora);
355+
356+ ggml::tensor *k_nope = u_k_nope_proj.forward (ctx, kv_lora);
357+ ggml::tensor *k_pe = k_pe_proj.forward (ctx, hidden_states);
358+
359+ ggml::tensor *tmpq = q_proj.forward (ctx, hidden_states);
360+
361+ ggml::tensor *scores = cross_attention_speed (ctx, hidden_size, n_past, qlen, tmpq, k_nope, k_pe, tmpv);
362+
363+ ggml::tensor *attn_output = o_proj.forward (ctx, scores);
364+
365+ return attn_output;
366+ }
367+
368+ ggml::tensor *BaseMLAttention::cross_attention_speed (ComputeContext *ctx, const int hidden_size, const int n_past, const int qlen,
369+ ggml::tensor *q, ggml::tensor *k_nope, ggml::tensor *k_pe, ggml::tensor *v)
370+ {
371+ // [qlen, heads, head_size]
372+ k_pe = ggml::reshape_3d (ctx, k_pe, rope_dim, 1 , qlen);
373+ k_pe = apply_pos_embedding_k (ctx, k_pe, rope_dim * 1 , qlen, pos);
374+
375+ k_nope = ggml::reshape_3d (ctx, k_nope, qk_nope_head_dim, num_kv_heads, qlen);
376+ k_pe = ggml::repeat (ctx, k_pe, rope_dim, num_kv_heads, qlen);
377+
378+ auto key_layer = ggml::concat (ctx, k_nope, k_pe, 0 );
379+
380+ // [qlen, heads, head_size]
381+ ggml::tensor * query_layer = ggml::reshape_3d (ctx, q, qk_nope_head_dim + rope_dim, num_attention_heads, qlen);
382+ ggml::tensor * q_pe = ggml::view_3d (ctx, query_layer,
383+ rope_dim, num_attention_heads, qlen,
384+ query_layer->nb [1 ], query_layer->nb [2 ],
385+ qk_nope_head_dim * ggml::element_size (query_layer));
386+
387+ if (ctx->is_using_gpu ())
388+ {
389+ // TODO: optimize (GPU rope requires continuous)
390+ ggml::tensor * q_pe_cont = ggml::cont (ctx, q_pe);
391+ q_pe_cont = apply_pos_embedding_q (ctx, q_pe_cont, rope_dim * num_attention_heads, qlen, pos);
392+ q_pe = ggml::cpy (ctx, q_pe_cont, q_pe);
393+ }
394+ else
395+ {
396+ q_pe = apply_pos_embedding_q (ctx, q_pe, rope_dim * num_attention_heads, qlen, pos);
397+ }
398+
399+ ggml::build_forward_expand (ctx, q_pe);
400+
401+ ggml::tensor *attn_scores = cross_attention_after_pe (ctx, hidden_size, n_past, qlen, query_layer, key_layer, v);
402+
403+ return attn_scores;
404+ }
405+
406+ ggml::tensor *BaseMLAttention::forward_memory (ComputeContext *ctx, ggml::tensor *hidden_states, int n_past)
407+ {
408+ const int hidden_size = o_proj.in_features ();
409+ const int qlen = (int )hidden_states->ne [1 ];
410+
411+ KVCacheAttention::before_forward (ctx, n_past, qlen);
412+
413+ ggml::tensor *kv_lora = d_kv_proj.forward (ctx, hidden_states);
414+ kv_lora = kv_norm.forward (ctx, kv_lora);
415+
416+ ggml::tensor *k_pe = k_pe_proj.forward (ctx, hidden_states);
417+
418+ ggml::tensor *tmpq = q_proj.forward (ctx, hidden_states);
419+
420+ ggml::tensor *scores = cross_attention_memory (ctx, hidden_size, n_past, qlen, tmpq, k_pe, kv_lora);
421+
422+ ggml::tensor *attn_output = o_proj.forward (ctx, scores);
423+
424+ return attn_output;
425+ }
426+
427+ ggml::tensor *BaseMLAttention::cross_attention_memory (ComputeContext *ctx, const int hidden_size, const int n_past, const int qlen,
428+ ggml::tensor *q, ggml::tensor *k_pe, ggml::tensor *kv_lora)
429+ {
430+ // [qlen, heads, head_size]
431+ k_pe = ggml::reshape_3d (ctx, k_pe, rope_dim, 1 , qlen);
432+ k_pe = apply_pos_embedding_k (ctx, k_pe, rope_dim * 1 , qlen, pos);
433+ k_pe = ggml::reshape_1d (ctx, k_pe, rope_dim * 1 * qlen);
434+
435+ // [qlen, heads, head_size]
436+ ggml::tensor * query_layer = ggml::reshape_3d (ctx, q, qk_nope_head_dim + rope_dim, num_attention_heads, qlen);
437+ ggml::tensor * q_pe = ggml::view_3d (ctx, query_layer,
438+ rope_dim, num_attention_heads, qlen,
439+ query_layer->nb [1 ], query_layer->nb [2 ],
440+ qk_nope_head_dim * ggml::element_size (query_layer));
441+
442+ if (ctx->is_using_gpu ())
443+ {
444+ // TODO: optimize (GPU rope requires continuous)
445+ ggml::tensor * q_pe_cont = ggml::cont (ctx, q_pe);
446+ q_pe_cont = apply_pos_embedding_q (ctx, q_pe_cont, rope_dim * num_attention_heads, qlen, pos);
447+ q_pe = ggml::cpy (ctx, q_pe_cont, q_pe);
448+ }
449+ else
450+ {
451+ q_pe = apply_pos_embedding_q (ctx, q_pe, rope_dim * num_attention_heads, qlen, pos);
452+ }
453+
454+ ggml::build_forward_expand (ctx, q_pe);
455+
456+ ggml::tensor *attn_scores = cross_attention_after_pe_memory (ctx, hidden_size, n_past, qlen, query_layer, k_pe, kv_lora);
457+
458+ return attn_scores;
459+ }
460+
461+ ggml::tensor *BaseMLAttention::get_k_pe_from_cache (ComputeContext *ctx, const int n_past, const int qlen)
462+ {
463+ ggml::tensor *k_pe = nullptr ;
464+
465+ k_pe = ggml::view_2d (ctx, k_cache, k_hidden_size, n_past + qlen,
466+ ggml::row_size (k_cache),
467+ 0 );
468+
469+ return k_pe;
470+ }
471+
472+ ggml::tensor *BaseMLAttention::get_kv_lora_from_cache (ComputeContext *ctx, const int n_past, const int qlen)
473+ {
474+ ggml::tensor *kv_lora = nullptr ;
475+
476+ kv_lora = ggml::view_2d (ctx, v_cache, v_hidden_size, n_past + qlen,
477+ v_hidden_size * ggml::element_size (v_cache),
478+ 0 );
479+
480+ return kv_lora;
481+ }
482+
483+ void BaseMLAttention::save_lora_to_cache (ComputeContext *ctx, const int n_past, const int qlen,
484+ ggml::tensor *k_pe, ggml::tensor *kv_lora)
485+ {
486+ ggml::tensor * pe_cache_view = ggml::view_1d (ctx, k_cache, qlen * k_hidden_size,
487+ ggml::row_size (k_cache) * n_past);
488+
489+ ggml::tensor * kv_cache_view = ggml::view_1d (ctx, v_cache, qlen * v_hidden_size,
490+ ggml::element_size (v_cache) * v_hidden_size * n_past);
491+
492+ ggml::tensor * pe_view = ggml::view_1d (ctx, k_pe, qlen * k_hidden_size, 0 );
493+ ggml::tensor * kv_view = ggml::view_1d (ctx, kv_lora, qlen * v_hidden_size, 0 );
494+
495+ // important: storing RoPE-ed version of K in the KV cache!
496+ ggml::build_forward_expand (ctx, ggml::cpy (ctx, pe_view, pe_cache_view));
497+ ggml::build_forward_expand (ctx, ggml::cpy (ctx, kv_view, kv_cache_view));
498+ }
499+
500+ ggml::tensor *BaseMLAttention::cross_attention_after_pe_memory (ComputeContext *ctx, const int hidden_size, const int n_past, const int qlen0,
501+ ggml::tensor *query_layer, ggml::tensor *k_pe, ggml::tensor *kv_lora)
502+ {
503+ const int head_size = qk_nope_head_dim + rope_dim;
504+
505+ if (!attn_scaling)
506+ query_layer = ggml::scale (ctx, query_layer, 1 .f / sqrtf ((float )head_size));
507+
508+ query_layer = ggml::permute (ctx, query_layer, 0 , 2 , 1 , 3 ); // [heads, qlen, head_size]
509+
510+ // store key and value to memory
511+ save_lora_to_cache (ctx, n_past, qlen0, k_pe, kv_lora);
512+
513+ ggml::tensor *k_pe_all = get_k_pe_from_cache (ctx, n_past, qlen0);
514+ ggml::tensor *kv_lora_all = get_kv_lora_from_cache (ctx, n_past, qlen0);
515+
516+ // make ggml ops happy
517+ kv_lora_all = ggml::cast (ctx, kv_lora_all, ggml::type::GGML_TYPE_F32);
518+
519+ const int qlen = n_past + qlen0;
520+
521+ ggml::tensor *k_nope = u_k_nope_proj.forward (ctx, kv_lora_all);
522+
523+ ggml::tensor *key_layer = nullptr ;
524+
525+ k_nope = ggml::reshape_3d (ctx, k_nope, qk_nope_head_dim, num_kv_heads, qlen);
526+ k_pe_all = ggml::reshape_3d (ctx, k_pe_all, rope_dim, 1 , qlen);
527+ k_pe_all = ggml::cast (ctx, k_pe_all, ggml::type::GGML_TYPE_F32);
528+ k_pe_all = ggml::repeat (ctx, k_pe_all, rope_dim, num_kv_heads, qlen);
529+
530+ key_layer = ggml::concat (ctx, k_nope, k_pe_all, 0 );
531+ key_layer = ggml::permute (ctx, key_layer, 0 , 2 , 1 , 3 ); // [qlen, heads, head_size] -> [heads, qlen, head_size]
532+
533+ ggml::tensor *value_layer = u_v_proj.forward (ctx, kv_lora_all);
534+ value_layer = ggml::reshape_3d (ctx, value_layer, v_head_dim, num_kv_heads, qlen); // [qlen, heads, head_size]
535+ value_layer = ggml::permute (ctx, value_layer, 1 , 2 , 0 , 3 ); // [qlen, heads, head_size] -> [heads, head_size, qlen]
536+ value_layer = ggml::cont (ctx, value_layer);
537+
538+ ggml::tensor *attn_scores = calc_attn_scores (ctx, hidden_size, n_past, qlen0, key_layer, query_layer, value_layer);
539+ return attn_scores;
540+ }
275541}
276542
277543namespace chatllm ::deepseek::v2
0 commit comments