3636#include < cmath>
3737#include < cstdio>
3838#include < cstring>
39+ #include < memory>
3940#include < mutex>
4041#include < optional>
4142#include < queue>
43+ #include < unordered_map>
4244#include < unordered_set>
45+ #include < vector>
4346
4447#define GGML_COMMON_DECL_C
4548
@@ -770,6 +773,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(i
770773}
771774
772775// cann buffer
776+
777+ /* *
778+ * @brief Tracks multi-threaded write progress for a single tensor.
779+ *
780+ * When multiple threads call set_tensor on different chunks of the same tensor,
781+ * this tracker accumulates progress and defers post-processing (quantized format
782+ * transform or ND-to-NZ conversion) until all data has been written.
783+ */
784+ struct TensorSetTracker {
785+ std::mutex mtx; // /< Protects concurrent access to this tracker
786+ size_t bytes_written = 0 ; // /< Accumulated bytes written so far
787+ size_t total_bytes = 0 ; // /< Target size (full tensor)
788+ std::vector<uint8_t > host_buffer; // /< Host staging buffer for quantized tensors
789+ };
790+
773791/* *
774792 * @brief Context for managing a CANN buffer associated with a specific device.
775793 *
@@ -780,6 +798,9 @@ struct ggml_backend_cann_buffer_context {
780798 int32_t device; // /< The device ID associated with this buffer context.
781799 void * dev_ptr = nullptr ; // /< Pointer to the device memory allocated for the buffer.
782800
801+ std::mutex tracker_mutex; // /< Protects the trackers map
802+ std::unordered_map<void *, std::unique_ptr<TensorSetTracker>> trackers;
803+
783804 /* *
784805 * @brief Constructor to initialize the CANN buffer context.
785806 *
@@ -792,6 +813,31 @@ struct ggml_backend_cann_buffer_context {
792813 * @brief Destructor to free the device memory allocated for the buffer.
793814 */
794815 ~ggml_backend_cann_buffer_context () { ACL_CHECK (aclrtFree (dev_ptr)); }
816+
817+ /* *
818+ * @brief Get or create a tracker for the given tensor.
819+ */
820+ TensorSetTracker * get_or_create_tracker (ggml_tensor * tensor) {
821+ std::lock_guard<std::mutex> lock (tracker_mutex);
822+ auto key = tensor->data ;
823+ auto it = trackers.find (key);
824+ if (it == trackers.end ()) {
825+ auto tracker = std::make_unique<TensorSetTracker>();
826+ tracker->total_bytes = ggml_nbytes (tensor);
827+ auto * ptr = tracker.get ();
828+ trackers[key] = std::move (tracker);
829+ return ptr;
830+ }
831+ return it->second .get ();
832+ }
833+
834+ /* *
835+ * @brief Remove the tracker for the given tensor.
836+ */
837+ void remove_tracker (ggml_tensor * tensor) {
838+ std::lock_guard<std::mutex> lock (tracker_mutex);
839+ trackers.erase (tensor->data );
840+ }
795841};
796842
797843// cann buffer type
@@ -1124,6 +1170,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer
11241170 * designed to be used with a global array, one per device.
11251171 */
11261172struct ggml_cann_nz_workspace {
1173+ std::mutex mtx; // Protects ptr/allocated from concurrent access
11271174 void * ptr; // Pointer to allocated device buffer
11281175 size_t allocated; // Size of currently allocated buffer in bytes
11291176
@@ -1190,13 +1237,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
11901237 * @note The workspace buffer used in this function is managed globally and reused
11911238 * across calls. This reduces overhead from repeated memory allocation and deallocation.
11921239 */
1193- static void weight_format_to_nz (ggml_tensor * tensor, size_t offset, int device) {
1194- acl_tensor_ptr weightTransposed = ggml_cann_create_tensor (tensor, tensor->ne , tensor->nb , 2 , ACL_FORMAT_ND, offset );
1240+ static void weight_format_to_nz (ggml_tensor * tensor, int device) {
1241+ acl_tensor_ptr weightTransposed = ggml_cann_create_tensor (tensor, tensor->ne , tensor->nb , 2 , ACL_FORMAT_ND, 0 );
11951242 uint64_t workspaceSize = 0 ;
11961243 aclOpExecutor * executor;
11971244
11981245 // TransMatmulWeight
11991246 ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed.get (), &workspaceSize, &executor));
1247+
1248+ std::lock_guard<std::mutex> lock (g_nz_workspaces[device].mtx );
12001249 // Avoid frequent malloc/free of the workspace.
12011250 g_nz_workspaces[device].realloc (workspaceSize);
12021251
@@ -1210,7 +1259,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device)
12101259 * @brief Set tensor data in a CANN buffer.
12111260 *
12121261 * This function sets tensor data in a CANN buffer, handling transformations
1213- * if needed based on the tensor's type.
1262+ * if needed based on the tensor's type. It supports multi-threaded calls
1263+ * where different threads write different chunks of the same tensor.
1264+ *
1265+ * For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and
1266+ * the format transform is deferred until all chunks are written.
1267+ * For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ
1268+ * conversion is deferred until all chunks are written.
12141269 *
12151270 * @param buffer The CANN buffer where the tensor data will be set.
12161271 * @param tensor Pointer to the tensor whose data will be set.
@@ -1226,26 +1281,72 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
12261281 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context ;
12271282
12281283 ggml_cann_set_device (ctx->device );
1229- // TODO: refer to cann(#6017), it use thread's default stream.
1230- // For acl, synchronous functions use this default stream.
1231- // Why aclrtSynchronizeDevice?
12321284
12331285 // Only check env once.
12341286 static bool weight_to_nz = parse_bool (get_env_as_lowercase (" GGML_CANN_WEIGHT_NZ" ).value_or (" on" ));
1235- if (!need_transform (tensor->type )) {
1287+
1288+ bool is_quantized = need_transform (tensor->type );
1289+ bool is_nz = !is_quantized && tensor->type != GGML_TYPE_BF16 && weight_to_nz &&
1290+ is_matmul_weight ((const ggml_tensor *) tensor);
1291+
1292+ // Plain tensor (not quantized, not NZ): direct copy, no tracking needed
1293+ if (!is_quantized && !is_nz) {
12361294 ACL_CHECK (aclrtMemcpy ((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1237- if (weight_to_nz && tensor->type != GGML_TYPE_BF16
1238- && is_matmul_weight ((const ggml_tensor *) tensor)) {
1295+ return ;
1296+ }
1297+
1298+ // Single-shot write (full tensor at once): handle directly without tracking overhead
1299+ if (offset == 0 && size == ggml_nbytes (tensor)) {
1300+ if (is_quantized) {
1301+ void * transform_buffer = malloc (size);
1302+ ggml_backend_cann_transform (tensor, data, transform_buffer);
1303+ ACL_CHECK (aclrtMemcpy (tensor->data , size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1304+ free (transform_buffer);
1305+ } else {
1306+ // NZ weight
12391307 GGML_ASSERT (tensor->ne [2 ] == 1 );
12401308 GGML_ASSERT (tensor->ne [3 ] == 1 );
1241- weight_format_to_nz (tensor, offset, ctx->device );
1309+ ACL_CHECK (aclrtMemcpy (tensor->data , size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1310+ weight_format_to_nz (tensor, ctx->device );
12421311 }
1312+ return ;
1313+ }
1314+
1315+ // Chunked write: use tracker to accumulate progress and defer transform/conversion
1316+ TensorSetTracker * tracker = ctx->get_or_create_tracker (tensor);
1317+ std::unique_lock<std::mutex> lock (tracker->mtx );
1318+
1319+ if (is_quantized) {
1320+ // Stage data in host buffer; transform requires full tensor data
1321+ if (tracker->host_buffer .empty ()) {
1322+ tracker->host_buffer .resize (tracker->total_bytes );
1323+ }
1324+ memcpy (tracker->host_buffer .data () + offset, data, size);
12431325 } else {
1244- void * transform_buffer = malloc (size);
1245- ggml_backend_cann_transform (tensor, data, transform_buffer);
1326+ // NZ weight: upload chunk to device immediately, defer conversion
1327+ ACL_CHECK (aclrtMemcpy ((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1328+ }
12461329
1247- ACL_CHECK (aclrtMemcpy ((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1248- free (transform_buffer);
1330+ tracker->bytes_written += size;
1331+
1332+ // All chunks received: perform deferred transform/conversion
1333+ if (tracker->bytes_written >= tracker->total_bytes ) {
1334+ if (is_quantized) {
1335+ void * transform_buffer = malloc (tracker->total_bytes );
1336+ ggml_backend_cann_transform (tensor, tracker->host_buffer .data (), transform_buffer);
1337+ ACL_CHECK (aclrtMemcpy (tensor->data , tracker->total_bytes , transform_buffer, tracker->total_bytes , ACL_MEMCPY_HOST_TO_DEVICE));
1338+ free (transform_buffer);
1339+ }
1340+
1341+ if (is_nz) {
1342+ GGML_ASSERT (tensor->ne [2 ] == 1 );
1343+ GGML_ASSERT (tensor->ne [3 ] == 1 );
1344+ weight_format_to_nz (tensor, ctx->device );
1345+ }
1346+
1347+ // Unlock before removing tracker, as remove_tracker destroys the mutex
1348+ lock.unlock ();
1349+ ctx->remove_tracker (tensor);
12491350 }
12501351}
12511352
0 commit comments