Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@

#include <webgpu/webgpu_cpp.h>

#ifdef _WIN32
# ifndef WIN32_LEAN_AND_MEAN
# define WIN32_LEAN_AND_MEAN
# endif
# include <windows.h>
# include <dxgi1_6.h>
# pragma comment(lib, "dxgi.lib")
#endif

#include <atomic>
#include <cstdint>
#include <cstring>
Expand Down Expand Up @@ -3443,6 +3452,53 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
UINT64_MAX);
GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);

#ifdef _WIN32
// If the default adapter lacks ShaderF16 (e.g., a primary Pascal display GPU),
// gracefully fall back to the first adapter in the system that supports it.
if (!ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)) {
IDXGIFactory6* f6 = nullptr;
if (SUCCEEDED(CreateDXGIFactory1(__uuidof(IDXGIFactory6), (void**)&f6))) {
IDXGIAdapter1* dxgi_adapter = nullptr;
for (UINT i = 0; f6->EnumAdapterByGpuPreference(i, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, __uuidof(IDXGIAdapter1), (void**)&dxgi_adapter) != DXGI_ERROR_NOT_FOUND; i++) {
DXGI_ADAPTER_DESC1 desc;
if (SUCCEEDED(dxgi_adapter->GetDesc1(&desc))) {
struct LUIDOpts : wgpu::ChainedStruct { ::LUID adapterLUID; };
LUIDOpts lo{};
lo.sType = static_cast<wgpu::SType>(0x0005000C);
lo.nextInChain = nullptr;
lo.adapterLUID = desc.AdapterLuid;

wgpu::RequestAdapterOptions luid_opts;
luid_opts.backendType = wgpu::BackendType::D3D12;
luid_opts.nextInChain = &lo;

wgpu::Adapter candidate_adapter = nullptr;
ctx->webgpu_global_ctx->instance.WaitAny(
ctx->webgpu_global_ctx->instance.RequestAdapter(
&luid_opts, wgpu::CallbackMode::AllowSpontaneous,
[&candidate_adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter a, const char * message) {
if (status == wgpu::RequestAdapterStatus::Success) {
candidate_adapter = std::move(a);
}
}),
UINT64_MAX);

if (candidate_adapter != nullptr && candidate_adapter.HasFeature(wgpu::FeatureName::ShaderF16)) {
char s[256]{}; size_t n = 0;
wcstombs_s(&n, s, desc.Description, _TRUNCATE);
GGML_LOG_INFO("ggml_webgpu: default adapter lacks ShaderF16 - falling back to %s\n", s);
ctx->webgpu_global_ctx->adapter = std::move(candidate_adapter);
dxgi_adapter->Release();
break;
}
}
dxgi_adapter->Release();
}
f6->Release();
}
}
#endif

ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);

wgpu::AdapterInfo info{};
Expand Down Expand Up @@ -4085,6 +4141,23 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
instance_descriptor.nextInChain = &instanceTogglesDesc;
#endif

#ifdef _WIN32
// Load DXC by absolute path so it populates the process module list
// before Dawn's restricted LoadLibraryEx search path executes.
// dxil.dll must be loaded before dxcompiler.dll.
{
char exe_path[MAX_PATH] = {};
GetModuleFileNameA(nullptr, exe_path, MAX_PATH);
std::string dir(exe_path);
size_t last_slash = dir.find_last_of("\\/");
if (last_slash != std::string::npos) {
dir = dir.substr(0, last_slash + 1);
LoadLibraryA((dir + "dxil.dll").c_str());
LoadLibraryA((dir + "dxcompiler.dll").c_str());
}
}
#endif

wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
ctx.webgpu_global_ctx->instance = std::move(inst);
Expand Down Expand Up @@ -4121,3 +4194,4 @@ ggml_backend_t ggml_backend_webgpu_init(void) {
}

GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)