diff --git a/src/psa_aead.c b/src/psa_aead.c index b94d85c..31bcac8 100644 --- a/src/psa_aead.c +++ b/src/psa_aead.c @@ -554,7 +554,7 @@ static psa_status_t wolfpsa_aead_encrypt_final(wolfpsa_aead_ctx_t *ctx, if (PSA_ALG_AEAD_EQUAL(ctx->alg, PSA_ALG_GCM)) { #ifdef HAVE_AESGCM Aes aes; - ret = wc_AesInit(&aes, NULL, INVALID_DEVID); + ret = wc_AesInit(&aes, NULL, wolfPSA_GetDefaultDevID()); if (ret == 0) { ret = wc_AesGcmSetKey(&aes, ctx->key, (word32)ctx->key_length); } @@ -580,7 +580,7 @@ static psa_status_t wolfpsa_aead_encrypt_final(wolfpsa_aead_ctx_t *ctx, if (wc_AesCcmCheckTagSize((int)ctx->tag_length) != 0) { return PSA_ERROR_NOT_SUPPORTED; } - ret = wc_AesInit(&aes, NULL, INVALID_DEVID); + ret = wc_AesInit(&aes, NULL, wolfPSA_GetDefaultDevID()); if (ret == 0) { ret = wc_AesCcmSetKey(&aes, ctx->key, (word32)ctx->key_length); } @@ -691,7 +691,7 @@ static psa_status_t wolfpsa_aead_decrypt_final(wolfpsa_aead_ctx_t *ctx, if (PSA_ALG_AEAD_EQUAL(ctx->alg, PSA_ALG_GCM)) { #ifdef HAVE_AESGCM Aes aes; - ret = wc_AesInit(&aes, NULL, INVALID_DEVID); + ret = wc_AesInit(&aes, NULL, wolfPSA_GetDefaultDevID()); if (ret == 0) { ret = wc_AesGcmSetKey(&aes, ctx->key, (word32)ctx->key_length); } @@ -720,7 +720,7 @@ static psa_status_t wolfpsa_aead_decrypt_final(wolfpsa_aead_ctx_t *ctx, if (wc_AesCcmCheckTagSize((int)tag_length) != 0) { return PSA_ERROR_INVALID_SIGNATURE; } - ret = wc_AesInit(&aes, NULL, INVALID_DEVID); + ret = wc_AesInit(&aes, NULL, wolfPSA_GetDefaultDevID()); if (ret == 0) { ret = wc_AesCcmSetKey(&aes, ctx->key, (word32)ctx->key_length); } diff --git a/src/psa_cipher.c b/src/psa_cipher.c index 7fd9c37..273714c 100644 --- a/src/psa_cipher.c +++ b/src/psa_cipher.c @@ -375,7 +375,7 @@ psa_status_t psa_cipher_encrypt_setup(psa_cipher_operation_t *operation, XMEMCPY(des_key, key_data, DES3_KEY_SIZE); - ret = wc_Des3Init(&ctx->des3, NULL, INVALID_DEVID); + ret = wc_Des3Init(&ctx->des3, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { wc_ForceZero(des_key, sizeof(des_key)); wolfpsa_forcezero_free_key_data(key_data, key_data_length); @@ -392,7 +392,7 @@ psa_status_t psa_cipher_encrypt_setup(psa_cipher_operation_t *operation, #endif } else { - ret = wc_AesInit(&ctx->aes, NULL, INVALID_DEVID); + ret = wc_AesInit(&ctx->aes, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { wolfpsa_forcezero_free_key_data(key_data, key_data_length); XFREE(ctx, NULL, DYNAMIC_TYPE_TMP_BUFFER); @@ -518,7 +518,7 @@ psa_status_t psa_cipher_decrypt_setup(psa_cipher_operation_t *operation, XMEMCPY(des_key, key_data, DES3_KEY_SIZE); - ret = wc_Des3Init(&ctx->des3, NULL, INVALID_DEVID); + ret = wc_Des3Init(&ctx->des3, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { wc_ForceZero(des_key, sizeof(des_key)); wolfpsa_forcezero_free_key_data(key_data, key_data_length); @@ -535,7 +535,7 @@ psa_status_t psa_cipher_decrypt_setup(psa_cipher_operation_t *operation, #endif } else { - ret = wc_AesInit(&ctx->aes, NULL, INVALID_DEVID); + ret = wc_AesInit(&ctx->aes, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { wolfpsa_forcezero_free_key_data(key_data, key_data_length); XFREE(ctx, NULL, DYNAMIC_TYPE_TMP_BUFFER); diff --git a/src/psa_engine.c b/src/psa_engine.c index 7f80d65..70c3546 100644 --- a/src/psa_engine.c +++ b/src/psa_engine.c @@ -31,6 +31,24 @@ #include #include #include +#include + +/* Runtime-settable devId threaded through every wolfPSA-internal + * wc_*Init()/wc_NewRsaKey() call. INVALID_DEVID (the default) keeps + * the original behaviour: wolfCrypt runs the operation locally. */ +static int wolfPSA_default_devid = INVALID_DEVID; + +int wolfPSA_SetDefaultDevID(int devId) +{ + wolfPSA_default_devid = devId; + return 0; +} + +int wolfPSA_GetDefaultDevID(void) +{ + return wolfPSA_default_devid; +} + /* wolfCrypt error code to PSA status code conversion */ psa_status_t wc_error_to_psa_status(int ret) { diff --git a/src/psa_hash_engine.c b/src/psa_hash_engine.c index 71f7b2a..09cfd96 100644 --- a/src/psa_hash_engine.c +++ b/src/psa_hash_engine.c @@ -402,16 +402,16 @@ psa_status_t psa_hash_setup(psa_hash_operation_t *operation, #endif #ifdef WOLFSSL_SHA3 case PSA_ALG_SHA3_224: - ret = wc_InitSha3_224(&ctx->ctx.sha3, NULL, INVALID_DEVID); + ret = wc_InitSha3_224(&ctx->ctx.sha3, NULL, wolfPSA_GetDefaultDevID()); break; case PSA_ALG_SHA3_256: - ret = wc_InitSha3_256(&ctx->ctx.sha3, NULL, INVALID_DEVID); + ret = wc_InitSha3_256(&ctx->ctx.sha3, NULL, wolfPSA_GetDefaultDevID()); break; case PSA_ALG_SHA3_384: - ret = wc_InitSha3_384(&ctx->ctx.sha3, NULL, INVALID_DEVID); + ret = wc_InitSha3_384(&ctx->ctx.sha3, NULL, wolfPSA_GetDefaultDevID()); break; case PSA_ALG_SHA3_512: - ret = wc_InitSha3_512(&ctx->ctx.sha3, NULL, INVALID_DEVID); + ret = wc_InitSha3_512(&ctx->ctx.sha3, NULL, wolfPSA_GetDefaultDevID()); break; #endif default: diff --git a/src/psa_key_derivation.c b/src/psa_key_derivation.c index d574997..1e0fef6 100644 --- a/src/psa_key_derivation.c +++ b/src/psa_key_derivation.c @@ -902,7 +902,7 @@ static psa_status_t wolfpsa_kdf_tls12_prf(wolfpsa_kdf_ctx_t *ctx, ctx->secret, (word32)ctx->secret_length, ctx->label, (word32)ctx->label_length, ctx->seed, (word32)ctx->seed_length, - 1, hash_type, NULL, INVALID_DEVID); + 1, hash_type, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -968,7 +968,7 @@ static psa_status_t wolfpsa_kdf_tls12_psk_to_ms(wolfpsa_kdf_ctx_t *ctx, premaster, (word32)premaster_len, (const byte *)"master secret", 13u, ctx->seed, (word32)ctx->seed_length, - 1, hash_type, NULL, INVALID_DEVID); + 1, hash_type, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { status = wc_error_to_psa_status(ret); } diff --git a/src/psa_key_storage.c b/src/psa_key_storage.c index 4de0b83..9f13ced 100644 --- a/src/psa_key_storage.c +++ b/src/psa_key_storage.c @@ -28,6 +28,7 @@ #if defined(WOLFSSL_PSA_ENGINE) #include +#include #include #include #include "psa_trace.h" @@ -1493,7 +1494,7 @@ psa_status_t psa_export_public_key( size_t total_len; uint8_t* out = data; - rsa = wc_NewRsaKey(NULL, INVALID_DEVID, &ret); + rsa = wc_NewRsaKey(NULL, wolfPSA_GetDefaultDevID(), &ret); if (rsa == NULL) { if (ret == 0) { ret = MEMORY_E; diff --git a/src/psa_lms_xmss.c b/src/psa_lms_xmss.c index b4454e1..ce24939 100644 --- a/src/psa_lms_xmss.c +++ b/src/psa_lms_xmss.c @@ -60,7 +60,7 @@ psa_status_t psa_lms_generate_key(uint8_t *private_key, } /* Initialize LMS key */ - ret = wc_LmsKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_LmsKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -134,7 +134,7 @@ psa_status_t psa_lms_sign(const uint8_t *private_key, } /* Initialize LMS key */ - ret = wc_LmsKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_LmsKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -181,7 +181,7 @@ psa_status_t psa_lms_verify(const uint8_t *public_key, } /* Initialize LMS key */ - ret = wc_LmsKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_LmsKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -234,7 +234,7 @@ psa_status_t psa_xmss_generate_key(uint8_t *private_key, } /* Initialize XMSS key */ - ret = wc_XmssKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_XmssKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -308,7 +308,7 @@ psa_status_t psa_xmss_sign(const uint8_t *private_key, } /* Initialize XMSS key */ - ret = wc_XmssKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_XmssKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -355,7 +355,7 @@ psa_status_t psa_xmss_verify(const uint8_t *public_key, } /* Initialize XMSS key */ - ret = wc_XmssKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_XmssKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } diff --git a/src/psa_mldsa.c b/src/psa_mldsa.c index f5ff891..7734d81 100644 --- a/src/psa_mldsa.c +++ b/src/psa_mldsa.c @@ -80,7 +80,7 @@ psa_status_t psa_ml_dsa_generate_key(psa_ml_dsa_parameter_t parameter, } /* Initialize ML-DSA key */ - ret = wc_MlDsaKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_MlDsaKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -158,7 +158,7 @@ psa_status_t psa_ml_dsa_sign(psa_ml_dsa_parameter_t parameter, } /* Initialize ML-DSA key */ - ret = wc_MlDsaKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_MlDsaKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -236,7 +236,7 @@ psa_status_t psa_ml_dsa_verify(psa_ml_dsa_parameter_t parameter, } /* Initialize ML-DSA key */ - ret = wc_MlDsaKey_Init(&key, NULL, INVALID_DEVID); + ret = wc_MlDsaKey_Init(&key, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } diff --git a/src/psa_mlkem.c b/src/psa_mlkem.c index e8e5d83..34508aa 100644 --- a/src/psa_mlkem.c +++ b/src/psa_mlkem.c @@ -77,7 +77,7 @@ psa_status_t psa_ml_kem_generate_key(psa_ml_kem_parameter_t parameter, } /* Initialize ML-KEM key */ - ret = wc_MlKemKey_Init(&key, type, NULL, INVALID_DEVID); + ret = wc_MlKemKey_Init(&key, type, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -170,7 +170,7 @@ psa_status_t psa_ml_kem_encapsulate(psa_ml_kem_parameter_t parameter, } /* Initialize ML-KEM key */ - ret = wc_MlKemKey_Init(&key, type, NULL, INVALID_DEVID); + ret = wc_MlKemKey_Init(&key, type, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } @@ -251,7 +251,7 @@ psa_status_t psa_ml_kem_decapsulate(psa_ml_kem_parameter_t parameter, } /* Initialize ML-KEM key */ - ret = wc_MlKemKey_Init(&key, type, NULL, INVALID_DEVID); + ret = wc_MlKemKey_Init(&key, type, NULL, wolfPSA_GetDefaultDevID()); if (ret != 0) { return wc_error_to_psa_status(ret); } diff --git a/wolfpsa.map b/wolfpsa.map index c33caa2..8609099 100644 --- a/wolfpsa.map +++ b/wolfpsa.map @@ -94,6 +94,8 @@ WOLFPSA_1.0 { wolfpsa_get_key_data; wolfpsa_test_get_next_key_id; wolfpsa_test_set_next_key_id; + wolfPSA_SetDefaultDevID; + wolfPSA_GetDefaultDevID; local: *; }; diff --git a/wolfpsa/psa_engine.h b/wolfpsa/psa_engine.h index d34a51d..28e33eb 100644 --- a/wolfpsa/psa_engine.h +++ b/wolfpsa/psa_engine.h @@ -47,8 +47,37 @@ #include #endif +#ifdef __cplusplus +extern "C" { +#endif + /* wolfCrypt error code to PSA status code conversion */ WOLFSSL_LOCAL psa_status_t wc_error_to_psa_status(int ret); +/* Default wolfCrypt devId threaded through wolfPSA's internal wc_*Init() + * calls. Defaults to INVALID_DEVID so that operations execute locally. + * Set to a registered crypto_cb devId (e.g. via wc_CryptoCb_RegisterDevice) + * to route every wolfPSA-issued wolfCrypt call through that callback — + * this is the integration hook for crypto offload backends such as + * wolfHSM or a hardware accelerator. Safe to call before psa_crypto_init(). + * + * Threading: the default devId is held in a process-global variable read + * by every wolfPSA-internal wc_*Init() invocation. Callers must set it + * during single-threaded initialisation (before any PSA operation is + * issued) or otherwise serialise the setter with external synchronisation; + * concurrent calls to wolfPSA_SetDefaultDevID() while PSA operations are + * in flight are not supported. + * + * Returns 0 on success. */ +WOLFSSL_API int wolfPSA_SetDefaultDevID(int devId); + +/* Returns the devId previously set with wolfPSA_SetDefaultDevID() or + * INVALID_DEVID if none has been set. */ +WOLFSSL_API int wolfPSA_GetDefaultDevID(void); + +#ifdef __cplusplus +} +#endif + #endif /* WOLFSSL_PSA_ENGINE */ #endif /* WOLFSSL_PSA_ENGINE_H */