diff --git a/apps/example/android/gradle.properties b/apps/example/android/gradle.properties index 05d7f3374..cca14f651 100644 --- a/apps/example/android/gradle.properties +++ b/apps/example/android/gradle.properties @@ -44,6 +44,7 @@ hermesEnabled=true # Note that this is incompatible with web debugging. #newArchEnabled=true bridgelessEnabled=true +edgeToEdgeEnabled=true # Uncomment the line below to build React Native from source. #react.buildFromSource=true diff --git a/apps/example/ios/Podfile.lock b/apps/example/ios/Podfile.lock index 64bedb47a..2aee96e5f 100644 --- a/apps/example/ios/Podfile.lock +++ b/apps/example/ios/Podfile.lock @@ -1865,7 +1865,7 @@ PODS: - ReactCommon/turbomodule/core - SocketRocket - Yoga - - react-native-wgpu (0.5.8): + - react-native-wgpu (0.5.9): - boost - DoubleConversion - fast_float @@ -2897,11 +2897,11 @@ EXTERNAL SOURCES: SPEC CHECKSUMS: boost: 7e761d76ca2ce687f7cc98e698152abd03a18f90 - DoubleConversion: cb417026b2400c8f53ae97020b2be961b59470cb + DoubleConversion: 76ab83afb40bddeeee456813d9c04f67f78771b5 fast_float: b32c788ed9c6a8c584d114d0047beda9664e7cc6 FBLazyVector: 941bef1c8eeabd9fe1f501e30a5220beee913886 fmt: a40bb5bd0294ea969aaaba240a927bd33d878cdd - glog: 5683914934d5b6e4240e497e0f4a3b42d1854183 + glog: fdfdfe5479092de0c4bdbebedd9056951f092c4f hermes-engine: 35c763d57c9832d0eef764316ca1c4d043581394 RCT-Folly: 846fda9475e61ec7bcbf8a3fe81edfcaeb090669 RCTDeprecation: c0ed3249a97243002615517dff789bf4666cf585 @@ -2937,8 +2937,8 @@ SPEC CHECKSUMS: React-Mapbuffer: 9d2434a42701d6144ca18f0ca1c4507808ca7696 React-microtasksnativemodule: 75b6604b667d297292345302cc5bfb6b6aeccc1b react-native-safe-area-context: c00143b4823773bba23f2f19f85663ae89ceb460 - react-native-skia: 3dab14f6a3de3c479a73d87cb1f9aa9556dcec13 - react-native-wgpu: 32f373d5a9ee83fad1cdddd288bb0738cb97e0ba + react-native-skia: 888a5bf5ed5008ae9593990e7dc2ea022b9aea11 + react-native-wgpu: 994aa67c411536586859e2790405071120465f75 React-NativeModulesApple: 879fbdc5dcff7136abceb7880fe8a2022a1bd7c3 React-oscompat: 93b5535ea7f7dff46aaee4f78309a70979bdde9d React-perflogger: 5536d2df3d18fe0920263466f7b46a56351c0510 diff --git a/apps/example/src/App.tsx b/apps/example/src/App.tsx index bd4120a78..046606b88 100644 --- a/apps/example/src/App.tsx +++ b/apps/example/src/App.tsx @@ -3,6 +3,7 @@ import "./resolveAssetSourcePolyfill"; import { NavigationContainer } from "@react-navigation/native"; import { createStackNavigator } from "@react-navigation/stack"; import { GestureHandlerRootView } from "react-native-gesture-handler"; +import { SafeAreaProvider } from "react-native-safe-area-context"; import type { Routes } from "./Route"; import { Home } from "./Home"; @@ -36,6 +37,7 @@ import { Reanimated } from "./Reanimated"; import { AsyncStarvation } from "./Diagnostics/AsyncStarvation"; import { DeviceLostHang } from "./Diagnostics/DeviceLostHang"; import { StorageBufferVertices } from "./StorageBufferVertices"; +import { MultiContext } from "./MultiContext"; // The two lines below are needed by three.js import "fast-text-encoding"; @@ -49,56 +51,65 @@ function App() { return null; } return ( - - - - - - - - - - - - - - - - - - - - - - - - - - - - - {(props) => } - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {(props) => } + + + + + + + + + + + ); } diff --git a/apps/example/src/Cube/Cube.tsx b/apps/example/src/Cube/Cube.tsx index 63c49b5ac..0a6d27714 100644 --- a/apps/example/src/Cube/Cube.tsx +++ b/apps/example/src/Cube/Cube.tsx @@ -17,138 +17,139 @@ import { basicVertWGSL, vertexPositionColorWGSL } from "./Shaders"; export function Cube() { const ref = useWebGPU(({ context, device, presentationFormat, canvas }) => { - // Create a vertex buffer from the cube data. - const verticesBuffer = device.createBuffer({ - size: cubeVertexArray.byteLength, - usage: GPUBufferUsage.VERTEX, - mappedAtCreation: true, - }); - new Float32Array(verticesBuffer.getMappedRange()).set(cubeVertexArray); - verticesBuffer.unmap(); - - const pipeline = device.createRenderPipeline({ - layout: "auto", - vertex: { - module: device.createShaderModule({ - code: basicVertWGSL, - }), - buffers: [ + function frame() { + // Create a vertex buffer from the cube data. + const verticesBuffer = device.createBuffer({ + size: cubeVertexArray.byteLength, + usage: GPUBufferUsage.VERTEX, + mappedAtCreation: true, + }); + new Float32Array(verticesBuffer.getMappedRange()).set(cubeVertexArray); + verticesBuffer.unmap(); + + const pipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: basicVertWGSL, + }), + buffers: [ + { + arrayStride: cubeVertexSize, + attributes: [ + { + // position + shaderLocation: 0, + offset: cubePositionOffset, + format: "float32x4", + }, + { + // uv + shaderLocation: 1, + offset: cubeUVOffset, + format: "float32x2", + }, + ], + }, + ], + }, + fragment: { + module: device.createShaderModule({ + code: vertexPositionColorWGSL, + }), + targets: [ + { + format: presentationFormat, + }, + ], + }, + primitive: { + topology: "triangle-list", + + // Backface culling since the cube is solid piece of geometry. + // Faces pointing away from the camera will be occluded by faces + // pointing toward the camera. + cullMode: "back", + }, + + // Enable depth testing so that the fragment closest to the camera + // is rendered in front. + depthStencil: { + depthWriteEnabled: true, + depthCompare: "less", + format: "depth24plus", + }, + }); + + console.log("Size: ", canvas.width, canvas.height); + const depthTexture = device.createTexture({ + size: [canvas.width, canvas.height], + format: "depth24plus", + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + const uniformBufferSize = 4 * 16; // 4x4 matrix + const uniformBuffer = device.createBuffer({ + size: uniformBufferSize, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + const uniformBindGroup = device.createBindGroup({ + layout: pipeline.getBindGroupLayout(0), + entries: [ { - arrayStride: cubeVertexSize, - attributes: [ - { - // position - shaderLocation: 0, - offset: cubePositionOffset, - format: "float32x4", - }, - { - // uv - shaderLocation: 1, - offset: cubeUVOffset, - format: "float32x2", - }, - ], + binding: 0, + resource: { + buffer: uniformBuffer, + }, }, ], - }, - fragment: { - module: device.createShaderModule({ - code: vertexPositionColorWGSL, - }), - targets: [ + }); + + const renderPassDescriptor: GPURenderPassDescriptor = { + // @ts-expect-error + colorAttachments: [ { - format: presentationFormat, + view: undefined, // Assigned later + clearValue: [0, 0, 0, 0], + loadOp: "clear", + storeOp: "store", }, ], - }, - primitive: { - topology: "triangle-list", - - // Backface culling since the cube is solid piece of geometry. - // Faces pointing away from the camera will be occluded by faces - // pointing toward the camera. - cullMode: "back", - }, - - // Enable depth testing so that the fragment closest to the camera - // is rendered in front. - depthStencil: { - depthWriteEnabled: true, - depthCompare: "less", - format: "depth24plus", - }, - }); - - const depthTexture = device.createTexture({ - size: [canvas.width, canvas.height], - format: "depth24plus", - usage: GPUTextureUsage.RENDER_ATTACHMENT, - }); - - const uniformBufferSize = 4 * 16; // 4x4 matrix - const uniformBuffer = device.createBuffer({ - size: uniformBufferSize, - usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, - }); - - const uniformBindGroup = device.createBindGroup({ - layout: pipeline.getBindGroupLayout(0), - entries: [ - { - binding: 0, - resource: { - buffer: uniformBuffer, - }, - }, - ], - }); + depthStencilAttachment: { + view: depthTexture.createView(), - const renderPassDescriptor: GPURenderPassDescriptor = { - // @ts-expect-error - colorAttachments: [ - { - view: undefined, // Assigned later - clearValue: [0, 0, 0, 0], - loadOp: "clear", - storeOp: "store", + depthClearValue: 1.0, + depthLoadOp: "clear", + depthStoreOp: "store", }, - ], - depthStencilAttachment: { - view: depthTexture.createView(), - - depthClearValue: 1.0, - depthLoadOp: "clear", - depthStoreOp: "store", - }, - }; - - const aspect = canvas.width / canvas.height; - const projectionMatrix = mat4.perspective( - (2 * Math.PI) / 5, - aspect, - 1, - 100.0, - ); - const modelViewProjectionMatrix = mat4.create(); - - function getTransformationMatrix() { - const viewMatrix = mat4.identity(); - mat4.translate(viewMatrix, vec3.fromValues(0, 0, -4), viewMatrix); - const now = Date.now() / 1000; - mat4.rotate( - viewMatrix, - vec3.fromValues(Math.sin(now), Math.cos(now), 0), + }; + + const aspect = canvas.width / canvas.height; + const projectionMatrix = mat4.perspective( + (2 * Math.PI) / 5, + aspect, 1, - viewMatrix, + 100.0, ); + const modelViewProjectionMatrix = mat4.create(); - mat4.multiply(projectionMatrix, viewMatrix, modelViewProjectionMatrix); + function getTransformationMatrix() { + const viewMatrix = mat4.identity(); + mat4.translate(viewMatrix, vec3.fromValues(0, 0, -4), viewMatrix); + const now = Date.now() / 1000; + mat4.rotate( + viewMatrix, + vec3.fromValues(Math.sin(now), Math.cos(now), 0), + 1, + viewMatrix, + ); - return modelViewProjectionMatrix; - } + mat4.multiply(projectionMatrix, viewMatrix, modelViewProjectionMatrix); + + return modelViewProjectionMatrix; + } - function frame() { const transformationMatrix = getTransformationMatrix(); device.queue.writeBuffer( uniformBuffer, diff --git a/apps/example/src/Home.tsx b/apps/example/src/Home.tsx index 9272dfec9..25d9cada7 100644 --- a/apps/example/src/Home.tsx +++ b/apps/example/src/Home.tsx @@ -2,6 +2,7 @@ import * as React from "react"; import { Platform, ScrollView, StyleSheet, Text, View } from "react-native"; import { useNavigation } from "@react-navigation/native"; import { RectButton } from "react-native-gesture-handler"; +import { useSafeAreaInsets } from "react-native-safe-area-context"; import type { StackNavigationProp } from "@react-navigation/stack"; import type { Routes } from "./Route"; @@ -127,15 +128,16 @@ export const examples = [ screen: "StorageBufferVertices", title: "πŸ’Ύ Storage Buffer Vertices", }, + { + screen: "MultiContext", + title: "πŸ”² Multi Context", + }, ]; const styles = StyleSheet.create({ container: { flex: 1, }, - content: { - marginBottom: 32, - }, thumbnail: { backgroundColor: "white", padding: 32, @@ -146,8 +148,12 @@ const styles = StyleSheet.create({ export const Home = () => { const { navigate } = useNavigation>(); + const insets = useSafeAreaInsets(); return ( - + {examples.map((thumbnail) => ( (null); + + useEffect(() => { + const context = ref.current?.getContext("webgpu"); + if (!context) { + return; + } + + const { pipeline, verticesBuffer, presentationFormat } = sharedResources; + const canvas = context.canvas as HTMLCanvasElement; + + context.configure({ + device, + format: presentationFormat, + alphaMode: "premultiplied", + }); + + // Per-item uniform buffer: mat4x4 (64 bytes) + vec4 tint (16 bytes) = 80 bytes + const uniformBufferSize = 4 * 16 + 4 * 4; + const uniformBuffer = device.createBuffer({ + size: uniformBufferSize, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + const depthTexture = device.createTexture({ + size: [canvas.width, canvas.height], + format: "depth24plus", + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + const bindGroup = device.createBindGroup({ + layout: pipeline.getBindGroupLayout(0), + entries: [{ binding: 0, resource: { buffer: uniformBuffer } }], + }); + + // Static MVP matrix with fixed rotation based on index + const aspect = canvas.width / canvas.height; + const projectionMatrix = mat4.perspective( + (2 * Math.PI) / 5, + aspect, + 1, + 100.0, + ); + const viewMatrix = mat4.identity(); + mat4.translate(viewMatrix, vec3.fromValues(0, 0, -4), viewMatrix); + const angle = index * 0.4 + 0.5; + mat4.rotate( + viewMatrix, + vec3.fromValues(Math.sin(angle), Math.cos(angle), 0), + 1, + viewMatrix, + ); + const mvp = mat4.multiply(projectionMatrix, viewMatrix); + + // Write MVP matrix + device.queue.writeBuffer( + uniformBuffer, + 0, + mvp.buffer, + mvp.byteOffset, + mvp.byteLength, + ); + + // Write tint color + const [r, g, b] = indexToColor(index); + device.queue.writeBuffer( + uniformBuffer, + 4 * 16, + new Float32Array([r, g, b, 1.0]), + ); + + // Render one frame + const renderPassDescriptor: GPURenderPassDescriptor = { + colorAttachments: [ + { + view: context.getCurrentTexture().createView(), + clearValue: [0.15, 0.15, 0.15, 1], + loadOp: "clear" as const, + storeOp: "store" as const, + }, + ], + depthStencilAttachment: { + view: depthTexture.createView(), + depthClearValue: 1.0, + depthLoadOp: "clear" as const, + depthStoreOp: "store" as const, + }, + }; + + const commandEncoder = device.createCommandEncoder(); + const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor); + passEncoder.setPipeline(pipeline); + passEncoder.setBindGroup(0, bindGroup); + passEncoder.setVertexBuffer(0, verticesBuffer); + passEncoder.draw(cubeVertexCount); + passEncoder.end(); + device.queue.submit([commandEncoder.finish()]); + context.present(); + + return () => { + uniformBuffer.destroy(); + depthTexture.destroy(); + }; + }, [ref, device, index, sharedResources]); + + return ( + + Context #{index} + + + ); +} + +const styles = StyleSheet.create({ + item: { + borderBottomWidth: StyleSheet.hairlineWidth, + borderBottomColor: "#333", + backgroundColor: "#1a1a1a", + }, + label: { + color: "#fff", + fontSize: 12, + paddingHorizontal: 12, + paddingTop: 8, + paddingBottom: 4, + }, + canvas: { + flex: 1, + }, +}); diff --git a/apps/example/src/MultiContext/MultiContext.tsx b/apps/example/src/MultiContext/MultiContext.tsx new file mode 100644 index 000000000..3c6d9f043 --- /dev/null +++ b/apps/example/src/MultiContext/MultiContext.tsx @@ -0,0 +1,122 @@ +import React, { useMemo } from "react"; +import { FlatList, StyleSheet, View } from "react-native"; +import { GPUDeviceProvider, useMainDevice } from "react-native-wgpu"; + +import { + cubePositionOffset, + cubeUVOffset, + cubeVertexArray, + cubeVertexSize, +} from "../components/cube"; + +import { CubeItem } from "./CubeItem"; +import type { SharedResources } from "./CubeItem"; +import { tintedFragWGSL, tintedVertWGSL } from "./Shaders"; + +const NUM_ITEMS = 50; +const ITEM_HEIGHT = 250; + +const data = Array.from({ length: NUM_ITEMS }, (_, i) => ({ id: i })); + +function MultiContextList() { + const { device } = useMainDevice(); + + const sharedResources = useMemo(() => { + if (!device) { + return null; + } + + const presentationFormat = navigator.gpu.getPreferredCanvasFormat(); + + const verticesBuffer = device.createBuffer({ + size: cubeVertexArray.byteLength, + usage: GPUBufferUsage.VERTEX, + mappedAtCreation: true, + }); + new Float32Array(verticesBuffer.getMappedRange()).set(cubeVertexArray); + verticesBuffer.unmap(); + + const pipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ code: tintedVertWGSL }), + buffers: [ + { + arrayStride: cubeVertexSize, + attributes: [ + { + shaderLocation: 0, + offset: cubePositionOffset, + format: "float32x4" as GPUVertexFormat, + }, + { + shaderLocation: 1, + offset: cubeUVOffset, + format: "float32x2" as GPUVertexFormat, + }, + ], + }, + ], + }, + fragment: { + module: device.createShaderModule({ code: tintedFragWGSL }), + targets: [{ format: presentationFormat }], + }, + primitive: { + topology: "triangle-list", + cullMode: "back", + }, + depthStencil: { + depthWriteEnabled: true, + depthCompare: "less", + format: "depth24plus", + }, + }); + + return { pipeline, verticesBuffer, presentationFormat }; + }, [device]); + + if (!sharedResources) { + return null; + } + + return ( + String(item.id)} + renderItem={({ item }) => ( + + )} + getItemLayout={(_, index) => ({ + length: ITEM_HEIGHT, + offset: ITEM_HEIGHT * index, + index, + })} + windowSize={5} + maxToRenderPerBatch={3} + initialNumToRender={6} + /> + ); +} + +export function MultiContext() { + return ( + + + + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: "#111", + }, +}); diff --git a/apps/example/src/MultiContext/Shaders.ts b/apps/example/src/MultiContext/Shaders.ts new file mode 100644 index 000000000..8a8870cb3 --- /dev/null +++ b/apps/example/src/MultiContext/Shaders.ts @@ -0,0 +1,39 @@ +export const tintedVertWGSL = /* wgsl */ `struct Uniforms { + modelViewProjectionMatrix: mat4x4f, + tintColor: vec4f, +} +@binding(0) @group(0) var uniforms: Uniforms; + +struct VertexOutput { + @builtin(position) Position: vec4f, + @location(0) fragUV: vec2f, + @location(1) fragPosition: vec4f, +} + +@vertex +fn main( + @location(0) position: vec4f, + @location(1) uv: vec2f +) -> VertexOutput { + var output: VertexOutput; + output.Position = uniforms.modelViewProjectionMatrix * position; + output.fragUV = uv; + output.fragPosition = 0.5 * (position + vec4(1.0, 1.0, 1.0, 1.0)); + return output; +} +`; + +export const tintedFragWGSL = /* wgsl */ `struct Uniforms { + modelViewProjectionMatrix: mat4x4f, + tintColor: vec4f, +} +@binding(0) @group(0) var uniforms: Uniforms; + +@fragment +fn main( + @location(0) fragUV: vec2f, + @location(1) fragPosition: vec4f +) -> @location(0) vec4f { + return fragPosition * uniforms.tintColor; +} +`; diff --git a/apps/example/src/MultiContext/index.ts b/apps/example/src/MultiContext/index.ts new file mode 100644 index 000000000..303aa986a --- /dev/null +++ b/apps/example/src/MultiContext/index.ts @@ -0,0 +1 @@ +export { MultiContext } from "./MultiContext"; diff --git a/apps/example/src/Resize/Resize.tsx b/apps/example/src/Resize/Resize.tsx index 9a4466fb7..6f83b7f25 100644 --- a/apps/example/src/Resize/Resize.tsx +++ b/apps/example/src/Resize/Resize.tsx @@ -1,5 +1,5 @@ import { useEffect } from "react"; -import { Dimensions, PixelRatio, View } from "react-native"; +import { Dimensions, View } from "react-native"; import { Canvas } from "react-native-wgpu"; import Animated, { cancelAnimation, @@ -67,9 +67,11 @@ export const Resize = () => { }); let renderTarget: GPUTexture | undefined; let renderTargetView: GPUTextureView; + let currentWidth = 0; + let currentHeight = 0; + return () => { - const currentWidth = canvas.clientWidth * PixelRatio.get(); - const currentHeight = canvas.clientHeight * PixelRatio.get(); + ref.current?.measureView(context.canvas); // Update the canvas size // The canvas size is animating via CSS. // When the size changes, we need to reallocate the render target. @@ -78,18 +80,17 @@ export const Resize = () => { (currentWidth !== canvas.width || currentHeight !== canvas.height || !renderTargetView) && - currentWidth && - currentHeight + canvas.width && + canvas.height ) { if (renderTarget !== undefined) { // Destroy the previous render target renderTarget.destroy(); } - // Setting the canvas width and height will automatically resize the textures returned - // when calling getCurrentTexture() on the context. - canvas.width = currentWidth; - canvas.height = currentHeight; + // The renderer fully controls the canvas size, no need to do anything + currentWidth = canvas.width; + currentHeight = canvas.height; // Resize the multisampled render target to match the new canvas size. renderTarget = device.createTexture({ diff --git a/apps/example/src/Route.ts b/apps/example/src/Route.ts index 152923e1e..e472f5b16 100644 --- a/apps/example/src/Route.ts +++ b/apps/example/src/Route.ts @@ -29,4 +29,5 @@ export type Routes = { AsyncStarvation: undefined; DeviceLostHang: undefined; StorageBufferVertices: undefined; + MultiContext: undefined; }; diff --git a/apps/example/src/components/useWebGPU.ts b/apps/example/src/components/useWebGPU.ts index ac8a631ac..fae590813 100644 --- a/apps/example/src/components/useWebGPU.ts +++ b/apps/example/src/components/useWebGPU.ts @@ -56,6 +56,7 @@ export const useWebGPU = (scene: Scene) => { if (typeof renderScene === "function") { const render = () => { const timestamp = Date.now(); + if (!ref.current) return; renderScene(timestamp); context.present(); animationFrameId.current = requestAnimationFrame(render); diff --git a/packages/webgpu/.tool-versions b/packages/webgpu/.tool-versions new file mode 100644 index 000000000..dca845891 --- /dev/null +++ b/packages/webgpu/.tool-versions @@ -0,0 +1,9 @@ +golang 1.23.2 +nodejs 22.4.1 +terraform 1.9.6 +grpcurl 1.9.1 +java openjdk-17.0.2 +ruby 3.3.4 +direnv 2.34.0 +bundler 2.6.6 + diff --git a/packages/webgpu/android/CMakeLists.txt b/packages/webgpu/android/CMakeLists.txt index 6e7488b87..eb4761189 100644 --- a/packages/webgpu/android/CMakeLists.txt +++ b/packages/webgpu/android/CMakeLists.txt @@ -26,6 +26,7 @@ find_package(fbjni REQUIRED CONFIG) add_library(${PACKAGE_NAME} SHARED ./cpp/cpp-adapter.cpp + ./cpp/AndroidSurfaceBridge.cpp ../cpp/rnwgpu/api/GPU.cpp ../cpp/rnwgpu/api/GPUAdapter.cpp ../cpp/rnwgpu/api/GPUSupportedLimits.cpp @@ -66,6 +67,7 @@ target_include_directories( "${NODE_MODULES_DIR}/react-native/ReactAndroid/src/main/java/com/facebook/react/turbomodule/core/jni" "${NODE_MODULES_DIR}/react-native/ReactAndroid/src/main/jni/react/turbomodule" + ./cpp ../cpp ../cpp/rnwgpu ../cpp/rnwgpu/api diff --git a/packages/webgpu/android/cpp/AndroidPlatformContext.h b/packages/webgpu/android/cpp/AndroidPlatformContext.h index 0dfe9a24e..6305077bb 100644 --- a/packages/webgpu/android/cpp/AndroidPlatformContext.h +++ b/packages/webgpu/android/cpp/AndroidPlatformContext.h @@ -71,15 +71,6 @@ class AndroidPlatformContext : public PlatformContext { } } - wgpu::Surface makeSurface(wgpu::Instance instance, void *window, int width, - int height) override { - wgpu::SurfaceSourceAndroidNativeWindow androidSurfaceDesc; - androidSurfaceDesc.window = reinterpret_cast(window); - wgpu::SurfaceDescriptor surfaceDescriptor; - surfaceDescriptor.nextInChain = &androidSurfaceDesc; - return instance.CreateSurface(&surfaceDescriptor); - } - ImageData createImageBitmap(std::string blobId, double offset, double size) override { jni::Environment::ensureCurrentThreadIsAttached(); diff --git a/packages/webgpu/android/cpp/AndroidSurfaceBridge.cpp b/packages/webgpu/android/cpp/AndroidSurfaceBridge.cpp new file mode 100644 index 000000000..a0c352497 --- /dev/null +++ b/packages/webgpu/android/cpp/AndroidSurfaceBridge.cpp @@ -0,0 +1,108 @@ +#include "AndroidSurfaceBridge.h" + +namespace rnwgpu { + +AndroidSurfaceBridge::AndroidSurfaceBridge(GPUWithLock gpu) + : _gpu(std::move(gpu.gpu)), SurfaceBridge(gpu.lock), _width(0), _height(0) { + _config.width = 0; + _config.height = 0; +} + +AndroidSurfaceBridge::~AndroidSurfaceBridge() { _surface = nullptr; } + +// ─── JS thread ─────────────────────────────────────────────────── + +void AndroidSurfaceBridge::configure(wgpu::SurfaceConfiguration &newConfig) { + std::lock_guard lock(_mutex); + + _config = newConfig; + _config.presentMode = wgpu::PresentMode::Fifo; + _config.usage = _config.usage | wgpu::TextureUsage::CopyDst; +} + +wgpu::Texture AndroidSurfaceBridge::getCurrentTexture(int width, int height) { + std::lock_guard lock(_mutex); + _width = width; + _height = height; + + wgpu::TextureDescriptor desc; + desc.format = _config.format; + desc.size.width = width; + desc.size.height = height; + desc.usage = wgpu::TextureUsage::RenderAttachment | + wgpu::TextureUsage::CopySrc | wgpu::TextureUsage::CopyDst | + wgpu::TextureUsage::TextureBinding; + _texture = _config.device.CreateTexture(&desc); + + return _texture; +} + +bool AndroidSurfaceBridge::present() { + std::lock_guard lock(_mutex); + + if (_texture) { + _presentedTexture = _texture; + _texture = nullptr; + } + if (_surface) { + _copyToSurfaceAndPresent(); + } + + return true; +} + +// ─── UI thread (Android-specific) ─────────────────────────────── + +void AndroidSurfaceBridge::switchToOnscreen(ANativeWindow *nativeWindow, + wgpu::Surface surface) { + std::lock_guard gpuLock(_gpuLock->mutex); + std::unique_lock lock(_mutex); + _nativeWindow = nativeWindow; + _surface = std::move(surface); + _copyToSurfaceAndPresent(); +} + +ANativeWindow *AndroidSurfaceBridge::switchToOffscreen() { + std::lock_guard gpuLock(_gpuLock->mutex); + std::unique_lock lock(_mutex); + auto res = _nativeWindow; + if (_surface) { + _surface.Unconfigure(); + } + _surface = nullptr; + _nativeWindow = nullptr; + return res; +} + +NativeInfo AndroidSurfaceBridge::getNativeInfo() { + std::lock_guard lock(_mutex); + return {.nativeSurface = static_cast(_nativeWindow), + .width = _width, + .height = _height}; +} + +void AndroidSurfaceBridge::_copyToSurfaceAndPresent() { + if (!_config.device || !_presentedTexture) { + return; + } + + auto queue = _config.device.GetQueue(); + auto future = queue.OnSubmittedWorkDone( + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + wgpu::FutureWaitInfo waitInfo{future}; + _gpu.WaitAny(1, &waitInfo, UINT64_MAX); + + _config.width = _presentedTexture.GetWidth(); + _config.height = _presentedTexture.GetHeight(); + _surface.Configure(&_config); + + copyTextureToSurfaceAndPresent(_config.device, _presentedTexture, _surface); +} + +// Factory +std::shared_ptr createSurfaceBridge(GPUWithLock gpu) { + return std::make_shared(std::move(gpu)); +} + +} // namespace rnwgpu diff --git a/packages/webgpu/android/cpp/AndroidSurfaceBridge.h b/packages/webgpu/android/cpp/AndroidSurfaceBridge.h new file mode 100644 index 000000000..e5c689cb8 --- /dev/null +++ b/packages/webgpu/android/cpp/AndroidSurfaceBridge.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "rnwgpu/SurfaceBridge.h" + +namespace rnwgpu { + +class AndroidSurfaceBridge : public SurfaceBridge { +public: + AndroidSurfaceBridge(GPUWithLock gpu); + ~AndroidSurfaceBridge() override; + + // JS thread + void configure(wgpu::SurfaceConfiguration &config) override; + wgpu::Texture getCurrentTexture(int width, int height) override; + bool present() override; + + // Android UI thread + void switchToOnscreen(ANativeWindow *nativeWindow, wgpu::Surface surface); + ANativeWindow *switchToOffscreen(); + + // Read-only + NativeInfo getNativeInfo() override; + +private: + void _copyToSurfaceAndPresent(); + + wgpu::Instance _gpu; + wgpu::SurfaceConfiguration _config; + wgpu::Surface _surface = nullptr; + wgpu::Texture _texture = nullptr; + wgpu::Texture _presentedTexture = nullptr; + ANativeWindow *_nativeWindow = nullptr; + + mutable std::mutex _mutex; + int _width; + int _height; +}; + +} // namespace rnwgpu diff --git a/packages/webgpu/android/cpp/cpp-adapter.cpp b/packages/webgpu/android/cpp/cpp-adapter.cpp index 2a441c218..4dc52fdc0 100644 --- a/packages/webgpu/android/cpp/cpp-adapter.cpp +++ b/packages/webgpu/android/cpp/cpp-adapter.cpp @@ -10,13 +10,22 @@ #include #include "AndroidPlatformContext.h" +#include "AndroidSurfaceBridge.h" #include "GPUCanvasContext.h" #include "RNWebGPUManager.h" +#include "WGPULogger.h" #define LOG_TAG "WebGPUModule" std::shared_ptr manager; +// Helper to get the AndroidSurfaceBridge for a given contextId. +static std::shared_ptr getAndroidBridge(int contextId) { + auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); + return std::static_pointer_cast( + registry.getSurfaceInfo(contextId)); +} + extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUModule_initializeNative( JNIEnv *env, jobject /* this */, jlong jsRuntime, jobject jsCallInvokerHolder, jobject blobModule) { @@ -33,39 +42,51 @@ extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUModule_initializeNative( platformContext); } -extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUView_onSurfaceChanged( - JNIEnv *env, jobject thiz, jobject surface, jint contextId, jfloat width, - jfloat height) { - auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - registry.getSurfaceInfo(contextId)->resize(static_cast(width), - static_cast(height)); +static wgpu::Surface makeSurface(wgpu::Instance instance, void *window) { + wgpu::SurfaceSourceAndroidNativeWindow androidSurfaceDesc; + androidSurfaceDesc.window = reinterpret_cast(window); + wgpu::SurfaceDescriptor surfaceDescriptor; + surfaceDescriptor.nextInChain = &androidSurfaceDesc; + return instance.CreateSurface(&surfaceDescriptor); } extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUView_onSurfaceCreate( JNIEnv *env, jobject thiz, jobject jSurface, jint contextId, jfloat width, jfloat height) { - auto window = ANativeWindow_fromSurface(env, jSurface); - // ANativeWindow_acquire(window); auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - auto gpu = manager->_gpu; - auto surface = manager->_platformContext->makeSurface( - gpu, window, static_cast(width), static_cast(height)); - registry - .getSurfaceInfoOrCreate(contextId, gpu, static_cast(width), - static_cast(height)) - ->switchToOnscreen(window, surface); + + auto bridge = std::static_pointer_cast( + registry.getSurfaceInfoOrCreate(contextId, manager->_gpu)); + if (bridge) { + auto old = bridge->switchToOffscreen(); + if (old) ANativeWindow_release(old); + } + + // It runs ANativeWindow_acquire() internally + auto window = ANativeWindow_fromSurface(env, jSurface); + auto surface = makeSurface(manager->_gpu.gpu, window); + bridge->switchToOnscreen(window, surface); } -extern "C" JNIEXPORT void JNICALL -Java_com_webgpu_WebGPUView_switchToOffscreenSurface(JNIEnv *env, jobject thiz, - jint contextId) { - auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - auto nativeSurface = registry.getSurfaceInfo(contextId)->switchToOffscreen(); - // ANativeWindow_release(reinterpret_cast(nativeSurface)); +extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUView_onSurfaceChanged( + JNIEnv *env, jobject thiz, jobject jSurface, jint contextId, jfloat width, + jfloat height) { + // No-op for now } -extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUView_onSurfaceDestroy( +extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUView_switchToOffscreenSurface( + JNIEnv *env, jobject thiz, jint contextId) { + + auto bridge = getAndroidBridge(contextId); + if (bridge) { + auto *window = bridge->switchToOffscreen(); + ANativeWindow_release(window); + } +} + +extern "C" JNIEXPORT void JNICALL Java_com_webgpu_WebGPUView_onViewDetached( JNIEnv *env, jobject thiz, jint contextId) { + // Called from onDropViewInstance when the React component is permanently unmounted. auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); registry.removeSurfaceInfo(contextId); -} \ No newline at end of file +} diff --git a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAHBView.java b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAHBView.java index 46139c84e..e2b03332e 100644 --- a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAHBView.java +++ b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAHBView.java @@ -256,7 +256,7 @@ protected void onDetachedFromWindow() { // Notify WebGPU that surface is being destroyed if (mSurfaceCreated) { - mApi.surfaceDestroyed(); + mApi.surfaceOffscreen(); mSurfaceCreated = false; } diff --git a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAPI.java b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAPI.java index db4e1f861..95911c64b 100644 --- a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAPI.java +++ b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUAPI.java @@ -14,7 +14,5 @@ void surfaceChanged( Surface surface ); - void surfaceDestroyed(); - void surfaceOffscreen(); } diff --git a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUSurfaceView.java b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUSurfaceView.java index 83cf943cd..1c61147db 100644 --- a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUSurfaceView.java +++ b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUSurfaceView.java @@ -18,12 +18,6 @@ public WebGPUSurfaceView(Context context, WebGPUAPI api) { getHolder().addCallback(this); } - @Override - protected void onDetachedFromWindow() { - super.onDetachedFromWindow(); - mApi.surfaceDestroyed(); - } - @Override public void surfaceCreated(@NonNull SurfaceHolder holder) { mApi.surfaceCreated(holder.getSurface()); diff --git a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUTextureView.java b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUTextureView.java index 23d5cb01f..357681769 100644 --- a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUTextureView.java +++ b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUTextureView.java @@ -35,12 +35,13 @@ public void onSurfaceTextureSizeChanged(@NonNull SurfaceTexture surfaceTexture, @Override public boolean onSurfaceTextureDestroyed(@NonNull SurfaceTexture surfaceTexture) { - mApi.surfaceDestroyed(); + mApi.surfaceOffscreen(); return true; } @Override public void onSurfaceTextureUpdated(@NonNull SurfaceTexture surfaceTexture) { // No implementation needed + } } diff --git a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUView.java b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUView.java index 3f73a1066..7fe75bc2f 100644 --- a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUView.java +++ b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUView.java @@ -73,16 +73,16 @@ public void surfaceChanged(Surface surface) { onSurfaceChanged(surface, mContextId, width, height); } - @Override - public void surfaceDestroyed() { - onSurfaceDestroy(mContextId); - } - @Override public void surfaceOffscreen() { switchToOffscreenSurface(mContextId); } + public void destroy() { + // Permanent cleanup β€” called from onDropViewInstance + onViewDetached(mContextId); + } + @DoNotStrip private native void onSurfaceCreate( Surface surface, @@ -100,9 +100,9 @@ private native void onSurfaceChanged( ); @DoNotStrip - private native void onSurfaceDestroy(int contextId); + private native void switchToOffscreenSurface(int contextId); @DoNotStrip - private native void switchToOffscreenSurface(int contextId); + private native void onViewDetached(int contextId); } diff --git a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUViewManager.java b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUViewManager.java index 71041b326..5b14128df 100644 --- a/packages/webgpu/android/src/main/java/com/webgpu/WebGPUViewManager.java +++ b/packages/webgpu/android/src/main/java/com/webgpu/WebGPUViewManager.java @@ -35,4 +35,10 @@ public void setTransparent(WebGPUView view, boolean value) { public void setContextId(WebGPUView view, int value) { view.setContextId(value); } + + @Override + public void onDropViewInstance(@NonNull WebGPUView view) { + view.destroy(); + super.onDropViewInstance(view); + } } diff --git a/packages/webgpu/apple/ApplePlatformContext.h b/packages/webgpu/apple/ApplePlatformContext.h index 24d27c76f..e355da6c6 100644 --- a/packages/webgpu/apple/ApplePlatformContext.h +++ b/packages/webgpu/apple/ApplePlatformContext.h @@ -10,9 +10,6 @@ class ApplePlatformContext : public PlatformContext { ApplePlatformContext(); ~ApplePlatformContext() = default; - wgpu::Surface makeSurface(wgpu::Instance instance, void *surface, int width, - int height) override; - ImageData createImageBitmap(std::string blobId, double offset, double size) override; diff --git a/packages/webgpu/apple/ApplePlatformContext.mm b/packages/webgpu/apple/ApplePlatformContext.mm index 3cf4ac32d..5d77fe415 100644 --- a/packages/webgpu/apple/ApplePlatformContext.mm +++ b/packages/webgpu/apple/ApplePlatformContext.mm @@ -29,16 +29,6 @@ void checkIfUsingSimulatorWithAPIValidation() { checkIfUsingSimulatorWithAPIValidation(); } -wgpu::Surface ApplePlatformContext::makeSurface(wgpu::Instance instance, - void *surface, int width, - int height) { - wgpu::SurfaceSourceMetalLayer metalSurfaceDesc; - metalSurfaceDesc.layer = surface; - wgpu::SurfaceDescriptor surfaceDescriptor; - surfaceDescriptor.nextInChain = &metalSurfaceDesc; - return instance.CreateSurface(&surfaceDescriptor); -} - static std::span nsDataToSpan(NSData *data) { return {static_cast(data.bytes), data.length}; } diff --git a/packages/webgpu/apple/AppleSurfaceBridge.h b/packages/webgpu/apple/AppleSurfaceBridge.h new file mode 100644 index 000000000..15e6b9030 --- /dev/null +++ b/packages/webgpu/apple/AppleSurfaceBridge.h @@ -0,0 +1,45 @@ +#pragma once + +#include "rnwgpu/SurfaceBridge.h" + +#include +#include + +namespace rnwgpu { + +class AppleSurfaceBridge : public SurfaceBridge, public std::enable_shared_from_this { +public: + AppleSurfaceBridge(GPUWithLock gpu); + ~AppleSurfaceBridge() override {}; + + // JS thread + void configure(wgpu::SurfaceConfiguration &config) override; + wgpu::Texture getCurrentTexture(int width, int height) override; + bool present() override; + + // Called by the UI thread once from MetalView when it's ready. + // The UI thread must hold the GPU device lock. + void prepareToDisplay(void *nativeSurface, wgpu::Surface surface); + + NativeInfo getNativeInfo() override; +private: + void _resizeSurface(int width, int height); + void _doSurfaceConfiguration(int width, int height); + + wgpu::Instance _gpu; + wgpu::SurfaceConfiguration _config; + wgpu::Surface _surface = nullptr; + + // It's possible that the JS thread accesses the getCurrentTexture + // before the UI thread attaches the native Metal layer. In this case + // we render to an offscreen texture. + wgpu::Texture _renderTargetTexture = nullptr; + wgpu::Texture _presentedTexture = nullptr; + void *_nativeSurface = nullptr; + + std::mutex _mutex; + int _width; + int _height; +}; + +} // namespace rnwgpu diff --git a/packages/webgpu/apple/AppleSurfaceBridge.mm b/packages/webgpu/apple/AppleSurfaceBridge.mm new file mode 100644 index 000000000..94eb90fda --- /dev/null +++ b/packages/webgpu/apple/AppleSurfaceBridge.mm @@ -0,0 +1,171 @@ +#import +#include "AppleSurfaceBridge.h" +#include "WGPULogger.h" + +namespace dawn::native::metal { +void WaitForCommandsToBeScheduled(WGPUDevice device); +} + +namespace rnwgpu { + +AppleSurfaceBridge::AppleSurfaceBridge(GPUWithLock gpu) + : _gpu(std::move(gpu.gpu)), SurfaceBridge(gpu.lock) { + _config.width = 0; + _config.height = 0; +} + +void AppleSurfaceBridge::configure(wgpu::SurfaceConfiguration &newConfig) { + std::lock_guard lock(_mutex); + // We might need to copy from the offscreen buffer to the surface + // on resize, or if .present() runs before the native layer is ready. + newConfig.usage = newConfig.usage | wgpu::TextureUsage::CopyDst; + newConfig.width = _config.width; + newConfig.height = _config.height; + _config = newConfig; +} + +NativeInfo AppleSurfaceBridge::getNativeInfo() { + std::lock_guard lock(_mutex); + return {.nativeSurface = _nativeSurface, .width = _width, .height = _height}; +} + +wgpu::Texture AppleSurfaceBridge::getCurrentTexture(int width, int height) { + std::lock_guard lock(_mutex); + _width = width; + _height = height; + + if (!_config.device) { + // The user needs to call configure() before calling the getCurrentTexture(). + return nullptr; + } + + if (_surface) { + // If our surface is sized correctly, just use it! + if (_config.width == width && _config.height == height) { + // It's safe to update non-size-related properties on the surface. I think. + // TODO: use other surface properties to determine if we need to reconfigure in the UI thread + _surface.Configure(&_config); + wgpu::SurfaceTexture surfTex; + // It's safe to get the texture without the UI thread roundtrip, only reconfiguration + // needs to be delegated to the UI thread. + _surface.GetCurrentTexture(&surfTex); + return surfTex.texture; + } + _resizeSurface(width, height); // Kick off the surface resize in background + } + + // This can happen if the getCurrentTexture() runs before the UI thread + // calls prepareToDisplay(). + wgpu::TextureDescriptor textureDesc; + textureDesc.format = _config.format; + textureDesc.size.width = width; + textureDesc.size.height = height; + textureDesc.usage = wgpu::TextureUsage::RenderAttachment | + wgpu::TextureUsage::CopySrc | + wgpu::TextureUsage::TextureBinding; + _renderTargetTexture = _config.device.CreateTexture(&textureDesc); + return _renderTargetTexture; +} + +bool AppleSurfaceBridge::present() { + std::lock_guard lock(_mutex); + if (!_config.device) { + return false; + } + dawn::native::metal::WaitForCommandsToBeScheduled(_config.device.Get()); + // Barrier... +// auto queue = _config.device.GetQueue(); +// auto future = queue.OnSubmittedWorkDone( +// wgpu::CallbackMode::WaitAnyOnly, +// [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); +// wgpu::FutureWaitInfo waitInfo{future}; +// _gpu.WaitAny(1, &waitInfo, UINT64_MAX); + if (_renderTargetTexture) { + _presentedTexture = _renderTargetTexture; + _renderTargetTexture = nullptr; + if (_surface) { + // We were rendering into the texture because the surface was not ready yet + // or it needed resizing. Check if the current size is compatible with the direct + // copy. + int textureWidth = _presentedTexture.GetWidth(); + int textureHeight = _presentedTexture.GetHeight(); + if (_config.width == textureWidth && _config.height == textureHeight) { + copyTextureToSurfaceAndPresent(_config.device, _presentedTexture, _surface); + } else { + // Run the texture resizing in the UI thread asynchronously. It will use the + // presented texture's dimensions for the size. + _resizeSurface(0, 0); + } + } + } else if (_surface) { + // Happy path: rendered onto the surface, no need to keep the presented texture anymore + _presentedTexture = nullptr; + _surface.Present(); + } + return true; +} + +void AppleSurfaceBridge::prepareToDisplay(void *nativeSurface, wgpu::Surface surface) { + // Make sure we prevent the JS thread from racing with this method + std::lock_guard gpuLock(_gpuLock->mutex); + std::lock_guard lock(_mutex); + + if (_surface) { + Logger::logToConsole("Surface assigned multiple times, should never happen"); + return; + } + + _nativeSurface = nativeSurface; // For nativeInfo only + _surface = std::move(surface); + if (_presentedTexture) { + _doSurfaceConfiguration(0, 0); // Use the presented texture's dimensions + } +} + +void AppleSurfaceBridge::_doSurfaceConfiguration(int width, int height) { + if (!_config.device || !_surface) { + return; + } + if (_presentedTexture && (width == 0 || height == 0)) { + width = _presentedTexture.GetWidth(); + height = _presentedTexture.GetHeight(); + } + if (width <= 0 || height <= 0) { + // The presented surface has disappeared since the time we were scheduled. + // It's perfectly fine! This means that the bridge switched into backing surface mode. + return; + } + if (_config.width == width && _config.height == height) { + return; + } + + dawn::native::metal::WaitForCommandsToBeScheduled(_config.device.Get()); + + _config.width = width; + _config.height = height; + _surface.Configure(&_config); // We're in the UI thread, it's safe. + if (_presentedTexture && _presentedTexture.GetWidth() == width && + _presentedTexture.GetHeight() == height) { + // We have a compatible texture. So copy it to the surface. + copyTextureToSurfaceAndPresent(_config.device, _presentedTexture, _surface); + // Don't delete the backing texture in case we want to redisplay it + } +} + +void AppleSurfaceBridge::_resizeSurface(int width, int height) { + // Make sure that we live long enough for the dispatch to run + auto self = this->shared_from_this(); + + dispatch_async(dispatch_get_main_queue(), ^{ + std::lock_guard gpuGuard(self->_gpuLock->mutex); + std::lock_guard lock(self->_mutex); + self->_doSurfaceConfiguration(width, height); + }); +} + +// Factory +std::shared_ptr createSurfaceBridge(GPUWithLock gpu) { + return std::make_shared(std::move(gpu)); +} + +} // namespace rnwgpu diff --git a/packages/webgpu/apple/MetalView.h b/packages/webgpu/apple/MetalView.h index a563db974..b825e512f 100644 --- a/packages/webgpu/apple/MetalView.h +++ b/packages/webgpu/apple/MetalView.h @@ -6,6 +6,7 @@ @interface MetalView : RNWGPlatformView @property NSNumber *contextId; +@property BOOL isAttached; - (void)configure; - (void)update; diff --git a/packages/webgpu/apple/MetalView.mm b/packages/webgpu/apple/MetalView.mm index ccff1245c..20b72da24 100644 --- a/packages/webgpu/apple/MetalView.mm +++ b/packages/webgpu/apple/MetalView.mm @@ -1,8 +1,12 @@ #import "MetalView.h" #import "webgpu/webgpu_cpp.h" +#include "AppleSurfaceBridge.h" +#include "SurfaceRegistry.h" + @implementation MetalView { BOOL _isConfigured; + BOOL _isAttached; } #if !TARGET_OS_OSX @@ -21,29 +25,37 @@ - (instancetype)init { #endif // !TARGET_OS_OSX - (void)configure { - auto size = self.frame.size; + // Delay the configuration until we have a valid size std::shared_ptr manager = [WebGPUModule getManager]; - void *nativeSurface = (__bridge void *)self.layer; + auto gpuWithLock = manager->_gpu; + auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - auto gpu = manager->_gpu; - auto surface = manager->_platformContext->makeSurface( - gpu, nativeSurface, size.width, size.height); - registry - .getSurfaceInfoOrCreate([_contextId intValue], gpu, size.width, - size.height) - ->switchToOnscreen(nativeSurface, surface); + + wgpu::SurfaceSourceMetalLayer metalSurfaceDesc; + metalSurfaceDesc.layer = (__bridge void *)self.layer; + wgpu::SurfaceDescriptor surfaceDescriptor; + surfaceDescriptor.nextInChain = &metalSurfaceDesc; + // This is safe to call without holding a GPU lock + wgpu::Surface surface = gpuWithLock.gpu.CreateSurface(&surfaceDescriptor); + + // Get or create the bridge. + int ctxId = [_contextId intValue]; + + // Create the bridge and attach the surface. + // Safe to take the GPU lock here: prepareToDisplay runs on the UI thread + // and never dispatch_sync's back to it, so no deadlock. + auto bridge = std::static_pointer_cast( + registry.getSurfaceInfoOrCreate(ctxId, gpuWithLock)); + + void *nativeSurface = (__bridge void *)self.layer; + bridge->prepareToDisplay(nativeSurface, surface); } - (void)update { - auto size = self.frame.size; - auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - registry.getSurfaceInfo([_contextId intValue]) - ->resize(size.width, size.height); } - (void)dealloc { auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - // Remove the surface info from the registry registry.removeSurfaceInfo([_contextId intValue]); } diff --git a/packages/webgpu/cpp/jsi/NativeObject.h b/packages/webgpu/cpp/jsi/NativeObject.h index a90927721..e14af7a91 100644 --- a/packages/webgpu/cpp/jsi/NativeObject.h +++ b/packages/webgpu/cpp/jsi/NativeObject.h @@ -16,6 +16,7 @@ #include "RuntimeAwareCache.h" #include "WGPULogger.h" +#include "rnwgpu/GPULockInfo.h" // Forward declare to avoid circular dependency namespace rnwgpu { @@ -401,6 +402,17 @@ class NativeObject : public jsi::NativeState, */ jsi::Runtime *getCreationRuntime() const { return _creationRuntime; } + /** + * Set the GPU lock for this object. All JSβ†’native method calls will + * acquire this lock before invoking the native method. + */ + void setGPULock(std::shared_ptr lock) { _gpuLock = std::move(lock); } + + /** + * Get the GPU lock (for propagation to child objects). + */ + std::shared_ptr getGPULock() const { return _gpuLock; } + protected: explicit NativeObject(const char *name) : _name(name) { #if DEBUG && RNF_ENABLE_LOGS @@ -416,6 +428,19 @@ class NativeObject : public jsi::NativeState, const char *_name; jsi::Runtime *_creationRuntime = nullptr; + std::shared_ptr _gpuLock; + + /** + * Acquire the GPU lock if one is set. Returns a unique_lock that + * releases automatically when it goes out of scope. If no lock is + * set, returns a no-op lock (not owning any mutex). + */ + std::unique_lock acquireGPULock() { + if (_gpuLock) { + return std::unique_lock(_gpuLock->mutex); + } + return std::unique_lock(); // no-op + } // ============================================================ // Helper methods for definePrototype() implementations @@ -433,6 +458,7 @@ class NativeObject : public jsi::NativeState, [method](jsi::Runtime &rt, const jsi::Value &thisVal, const jsi::Value *args, size_t count) -> jsi::Value { auto native = Derived::fromValue(rt, thisVal); + auto lockGuard = native->acquireGPULock(); return callMethod(native.get(), method, rt, args, std::index_sequence_for{}, count); }); @@ -452,6 +478,7 @@ class NativeObject : public jsi::NativeState, [getter](jsi::Runtime &rt, const jsi::Value &thisVal, const jsi::Value *args, size_t count) -> jsi::Value { auto native = Derived::fromValue(rt, thisVal); + auto lockGuard = native->acquireGPULock(); if constexpr (std::is_same_v) { (native.get()->*getter)(); return jsi::Value::undefined(); @@ -492,6 +519,7 @@ class NativeObject : public jsi::NativeState, throw jsi::JSError(rt, "Setter requires a value argument"); } auto native = Derived::fromValue(rt, thisVal); + auto lockGuard = native->acquireGPULock(); auto value = rnwgpu::JSIConverter>::fromJSI(rt, args[0], false); (native.get()->*setter)(std::move(value)); @@ -539,6 +567,7 @@ class NativeObject : public jsi::NativeState, [getter](jsi::Runtime &rt, const jsi::Value &thisVal, const jsi::Value *args, size_t count) -> jsi::Value { auto native = Derived::fromValue(rt, thisVal); + auto lockGuard = native->acquireGPULock(); ReturnType result = (native.get()->*getter)(); return rnwgpu::JSIConverter>::toJSI(rt, std::move(result)); @@ -553,6 +582,7 @@ class NativeObject : public jsi::NativeState, throw jsi::JSError(rt, "Setter requires a value argument"); } auto native = Derived::fromValue(rt, thisVal); + auto lockGuard = native->acquireGPULock(); auto value = rnwgpu::JSIConverter>::fromJSI(rt, args[0], false); (native.get()->*setter)(std::move(value)); diff --git a/packages/webgpu/cpp/rnwgpu/GPULockInfo.h b/packages/webgpu/cpp/rnwgpu/GPULockInfo.h new file mode 100644 index 000000000..e61e35d7f --- /dev/null +++ b/packages/webgpu/cpp/rnwgpu/GPULockInfo.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include "webgpu/webgpu_cpp.h" + +namespace rnwgpu { + +/** + * GPU-level lock that serializes all Dawn API calls. + * + * Dawn's wgpu::Device is not thread-safe. All calls that touch the device + * (texture creation, command encoding, queue submit, surface operations) + * must be serialized. This lock is shared by ALL objects derived from the + * same GPU device. + * + * Created when the GPU instance is initialized, propagated to every + * downstream object (adapter, device, textures, buffers, encoders, etc.) + * via constructor arguments. + * + * The NativeObject base class acquires this lock automatically around + * every JSβ†’native method call. The UI thread (SurfaceInfo) also acquires + * it before touching Dawn APIs. + */ +struct GPULockInfo { + // Recursive because JSβ†’native calls can trigger descriptor conversions + // that round-trip back through JS, re-entering native methods on the + // same thread. + std::recursive_mutex mutex; +}; + +// A helper struct to store the Dawn GPU object with a lock +struct GPUWithLock { + wgpu::Instance gpu; + std::shared_ptr lock; +}; + +} // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/PlatformContext.h b/packages/webgpu/cpp/rnwgpu/PlatformContext.h index e7a272476..a1d1b731e 100644 --- a/packages/webgpu/cpp/rnwgpu/PlatformContext.h +++ b/packages/webgpu/cpp/rnwgpu/PlatformContext.h @@ -22,8 +22,6 @@ class PlatformContext { PlatformContext() = default; virtual ~PlatformContext() = default; - virtual wgpu::Surface makeSurface(wgpu::Instance instance, void *surface, - int width, int height) = 0; virtual ImageData createImageBitmap(std::string blobId, double offset, double size) = 0; diff --git a/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h b/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h index 2043c9658..2d5920c9b 100644 --- a/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h +++ b/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h @@ -2,9 +2,8 @@ #include -#include "GPU.h" #include "PlatformContext.h" -#include "SurfaceRegistry.h" +#include "GPULockInfo.h" namespace facebook { namespace jsi { @@ -39,7 +38,7 @@ class RNWebGPUManager { std::shared_ptr _jsCallInvoker; public: - wgpu::Instance _gpu; + GPUWithLock _gpu; std::shared_ptr _platformContext; }; diff --git a/packages/webgpu/cpp/rnwgpu/SurfaceBridge.h b/packages/webgpu/cpp/rnwgpu/SurfaceBridge.h new file mode 100644 index 000000000..770f27932 --- /dev/null +++ b/packages/webgpu/cpp/rnwgpu/SurfaceBridge.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include + +#include "GPULockInfo.h" +#include "WGPULogger.h" +#include "webgpu/webgpu_cpp.h" + +namespace rnwgpu { + +struct Size { + int width; + int height; +}; + +struct NativeInfo { + void *nativeSurface; + int width; + int height; +}; + +/** + * The abstract bridge between the OS-specific surface and GPUCanvasContext. + * It is created initially without the UI-level object in the JS thread, and + * registered in the global SurfaceRegistry. + * + * It's then looked up by its ID by the UI thread and connected to the proper + * system-level surface. + * + * JS-thread methods (called under GPU lock via NativeObject): + * configure, unconfigure, getCurrentTexture, present + * + * UI-thread lifecycle methods are platform-specific and live in the + * concrete subclasses (AndroidSurfaceBridge, AppleSurfaceBridge). + */ +class SurfaceBridge { +public: + SurfaceBridge(std::shared_ptr gpuLock) : _gpuLock(std::move(gpuLock)) {} + virtual ~SurfaceBridge() = default; + + // Accessors + virtual NativeInfo getNativeInfo() = 0; + std::shared_ptr getGPULock() const { return _gpuLock; } + + //////////////////////////////////////////////////////////// + ///// The JS thread operations (called under GPU lock) //// + //////////////////////////////////////////////////////////// + virtual void configure(wgpu::SurfaceConfiguration &config) = 0; + // Returns the texture to render into. Handles resize if width/height + // differ from the current configuration. + // Can return nullptr if the UI object is not yet attached. + virtual wgpu::Texture getCurrentTexture(int width, int height) = 0; + virtual bool present() = 0; + +protected: + + void copyTextureToSurfaceAndPresent(wgpu::Device device, wgpu::Texture texture, + wgpu::Surface surface) { + + if (!device || !texture || !surface) { + return; + } + + wgpu::SurfaceTexture surfTex; + surface.GetCurrentTexture(&surfTex); + if (surfTex.status != wgpu::SurfaceGetCurrentTextureStatus::SuccessOptimal && + surfTex.status != wgpu::SurfaceGetCurrentTextureStatus::SuccessSuboptimal) { + Logger::logToConsole("SurfaceBridge", + "GetCurrentTexture failed: status=%d, src texture=%dx%d", + (int)surfTex.status, texture.GetWidth(), texture.GetHeight()); + return; + } + + // Copy the overlapping region β€” handles mismatched sizes gracefully. + uint32_t copyWidth = std::min(texture.GetWidth(), surfTex.texture.GetWidth()); + uint32_t copyHeight = std::min(texture.GetHeight(), surfTex.texture.GetHeight()); + if (copyWidth == 0 || copyHeight == 0) { + surface.Present(); + return; + } + + wgpu::CommandEncoderDescriptor encDesc; + auto encoder = device.CreateCommandEncoder(&encDesc); + + wgpu::TexelCopyTextureInfo src = {}; + src.texture = texture; + wgpu::TexelCopyTextureInfo dst = {}; + dst.texture = surfTex.texture; + wgpu::Extent3D size = {copyWidth, copyHeight, 1}; + + encoder.CopyTextureToTexture(&src, &dst, &size); + auto cmds = encoder.Finish(); + device.GetQueue().Submit(1, &cmds); + surface.Present(); + } + + std::shared_ptr _gpuLock; +}; + +// Platform-specific factory. Implemented in: +// android/cpp/AndroidSurfaceBridge.cpp +// apple/AppleSurfaceBridge.mm +std::shared_ptr createSurfaceBridge(GPUWithLock gpu); + +} // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/SurfaceRegistry.h b/packages/webgpu/cpp/rnwgpu/SurfaceRegistry.h index 110a45d44..454c2ddf4 100644 --- a/packages/webgpu/cpp/rnwgpu/SurfaceRegistry.h +++ b/packages/webgpu/cpp/rnwgpu/SurfaceRegistry.h @@ -3,180 +3,11 @@ #include #include #include -#include -#include "webgpu/webgpu_cpp.h" +#include "SurfaceBridge.h" namespace rnwgpu { -struct NativeInfo { - void *nativeSurface; - int width; - int height; -}; - -struct Size { - int width; - int height; -}; - -class SurfaceInfo { -public: - SurfaceInfo(wgpu::Instance gpu, int width, int height) - : gpu(std::move(gpu)), width(width), height(height) {} - - ~SurfaceInfo() { surface = nullptr; } - - void reconfigure(int newWidth, int newHeight) { - std::unique_lock lock(_mutex); - config.width = newWidth; - config.height = newHeight; - _configure(); - } - - void configure(wgpu::SurfaceConfiguration &newConfig) { - std::unique_lock lock(_mutex); - config = newConfig; - config.width = width; - config.height = height; - config.presentMode = wgpu::PresentMode::Fifo; - _configure(); - } - - void unconfigure() { - std::unique_lock lock(_mutex); - if (surface) { - surface.Unconfigure(); - } else { - texture = nullptr; - } - } - - void *switchToOffscreen() { - std::unique_lock lock(_mutex); - // We only do this if the onscreen surface is configured. - auto isConfigured = config.device != nullptr; - if (isConfigured) { - wgpu::TextureDescriptor textureDesc; - textureDesc.usage = wgpu::TextureUsage::RenderAttachment | - wgpu::TextureUsage::CopySrc | - wgpu::TextureUsage::TextureBinding; - textureDesc.format = config.format; - textureDesc.size.width = config.width; - textureDesc.size.height = config.height; - texture = config.device.CreateTexture(&textureDesc); - } - surface = nullptr; - return nativeSurface; - } - - void switchToOnscreen(void *newNativeSurface, wgpu::Surface newSurface) { - std::unique_lock lock(_mutex); - nativeSurface = newNativeSurface; - surface = std::move(newSurface); - // If we are comming from an offscreen context, we need to configure the new - // surface - if (texture != nullptr) { - config.usage = config.usage | wgpu::TextureUsage::CopyDst; - _configure(); - // We flush the offscreen texture to the onscreen one - // TODO: there is a faster way to do this without validation? - wgpu::CommandEncoderDescriptor encoderDesc; - auto device = config.device; - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(&encoderDesc); - - wgpu::TexelCopyTextureInfo sourceTexture = {}; - sourceTexture.texture = texture; - - wgpu::TexelCopyTextureInfo destinationTexture = {}; - wgpu::SurfaceTexture surfaceTexture; - surface.GetCurrentTexture(&surfaceTexture); - destinationTexture.texture = surfaceTexture.texture; - - wgpu::Extent3D size = {sourceTexture.texture.GetWidth(), - sourceTexture.texture.GetHeight(), - sourceTexture.texture.GetDepthOrArrayLayers()}; - - encoder.CopyTextureToTexture(&sourceTexture, &destinationTexture, &size); - - wgpu::CommandBuffer commands = encoder.Finish(); - wgpu::Queue queue = device.GetQueue(); - queue.Submit(1, &commands); - surface.Present(); - texture = nullptr; - } - } - - void resize(int newWidth, int newHeight) { - std::unique_lock lock(_mutex); - width = newWidth; - height = newHeight; - } - - void present() { - std::unique_lock lock(_mutex); - if (surface) { - surface.Present(); - } - } - - wgpu::Texture getCurrentTexture() { - std::shared_lock lock(_mutex); - if (surface) { - wgpu::SurfaceTexture surfaceTexture; - surface.GetCurrentTexture(&surfaceTexture); - return surfaceTexture.texture; - } else { - return texture; - } - } - - NativeInfo getNativeInfo() { - std::shared_lock lock(_mutex); - return {.nativeSurface = nativeSurface, .width = width, .height = height}; - } - - Size getSize() { - std::shared_lock lock(_mutex); - return {.width = width, .height = height}; - } - - wgpu::SurfaceConfiguration getConfig() { - std::shared_lock lock(_mutex); - return config; - } - - wgpu::Device getDevice() { - std::shared_lock lock(_mutex); - return config.device; - } - -private: - void _configure() { - if (surface) { - surface.Configure(&config); - } else { - wgpu::TextureDescriptor textureDesc; - textureDesc.format = config.format; - textureDesc.size.width = config.width; - textureDesc.size.height = config.height; - textureDesc.usage = wgpu::TextureUsage::RenderAttachment | - wgpu::TextureUsage::CopySrc | - wgpu::TextureUsage::TextureBinding; - texture = config.device.CreateTexture(&textureDesc); - } - } - - mutable std::shared_mutex _mutex; - void *nativeSurface = nullptr; - wgpu::Surface surface = nullptr; - wgpu::Texture texture = nullptr; - wgpu::Instance gpu; - wgpu::SurfaceConfiguration config; - int width; - int height; -}; - class SurfaceRegistry { public: static SurfaceRegistry &getInstance() { @@ -187,7 +18,7 @@ class SurfaceRegistry { SurfaceRegistry(const SurfaceRegistry &) = delete; SurfaceRegistry &operator=(const SurfaceRegistry &) = delete; - std::shared_ptr getSurfaceInfo(int id) { + std::shared_ptr getSurfaceInfo(int id) { std::shared_lock lock(_mutex); auto it = _registry.find(id); if (it != _registry.end()) { @@ -201,30 +32,23 @@ class SurfaceRegistry { _registry.erase(id); } - std::shared_ptr addSurfaceInfo(int id, wgpu::Instance gpu, - int width, int height) { - std::unique_lock lock(_mutex); - auto info = std::make_shared(gpu, width, height); - _registry[id] = info; - return info; - } + std::shared_ptr + getSurfaceInfoOrCreate(int id, GPUWithLock gpu) { - std::shared_ptr - getSurfaceInfoOrCreate(int id, wgpu::Instance gpu, int width, int height) { std::unique_lock lock(_mutex); auto it = _registry.find(id); if (it != _registry.end()) { return it->second; } - auto info = std::make_shared(gpu, width, height); - _registry[id] = info; - return info; + auto bridge = createSurfaceBridge(gpu); + _registry[id] = bridge; + return bridge; } private: SurfaceRegistry() = default; mutable std::shared_mutex _mutex; - std::unordered_map> _registry; + std::unordered_map> _registry; }; } // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/api/Canvas.h b/packages/webgpu/cpp/rnwgpu/api/Canvas.h index b84ec6929..102c60a86 100644 --- a/packages/webgpu/cpp/rnwgpu/api/Canvas.h +++ b/packages/webgpu/cpp/rnwgpu/api/Canvas.h @@ -17,22 +17,23 @@ class Canvas : public NativeObject { public: static constexpr const char *CLASS_NAME = "Canvas"; - explicit Canvas(void *surface, const int width, const int height) - : NativeObject(CLASS_NAME), _surface(surface), _width(width), - _height(height), _clientWidth(width), _clientHeight(height) {} + explicit Canvas(void *surface, const float width, const float height, + const float pixelRatio) + : NativeObject(CLASS_NAME), _surface(surface), _width(width * pixelRatio), + _height(height * pixelRatio), _clientWidth(width), _clientHeight(height) {} - int getWidth() { return _width; } - int getHeight() { return _height; } + float getWidth() { return _width; } + float getHeight() { return _height; } void setWidth(const int width) { _width = width; } void setHeight(const int height) { _height = height; } - int getClientWidth() { return _clientWidth; } - int getClientHeight() { return _clientHeight; } + float getClientWidth() { return _clientWidth; } + float getClientHeight() { return _clientHeight; } - void setClientWidth(const int width) { _clientWidth = width; } + void setClientWidth(const float width) { _clientWidth = width; } - void setClientHeight(const int height) { _clientHeight = height; } + void setClientHeight(const float height) { _clientHeight = height; } void *getSurface() { return _surface; } @@ -42,16 +43,18 @@ class Canvas : public NativeObject { &Canvas::setWidth); installGetterSetter(runtime, prototype, "height", &Canvas::getHeight, &Canvas::setHeight); - installGetter(runtime, prototype, "clientWidth", &Canvas::getClientWidth); - installGetter(runtime, prototype, "clientHeight", &Canvas::getClientHeight); + installGetterSetter(runtime, prototype, "clientWidth", &Canvas::getClientWidth, + &Canvas::setClientWidth); + installGetterSetter(runtime, prototype, "clientHeight", &Canvas::getClientHeight, + &Canvas::setClientHeight); } private: void *_surface; - int _width; - int _height; - int _clientWidth; - int _clientHeight; + float _width; + float _height; + float _clientWidth; + float _clientHeight; }; } // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/api/GPU.cpp b/packages/webgpu/cpp/rnwgpu/api/GPU.cpp index 764a9aa32..79a43e332 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPU.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPU.cpp @@ -24,6 +24,9 @@ GPU::GPU(jsi::Runtime &runtime) : NativeObject(CLASS_NAME) { auto dispatcher = std::make_shared(runtime); _async = async::AsyncRunner::getOrCreate(runtime, _instance, dispatcher); + + // Create the GPU-level lock that serializes all Dawn API calls + setGPULock(std::make_shared()); } async::AsyncTaskHandle GPU::requestAdapter( @@ -40,11 +43,12 @@ async::AsyncTaskHandle GPU::requestAdapter( #endif aOptions.backendType = kDefaultBackendType; return _async->postTask( - [this, aOptions](const async::AsyncTaskHandle::ResolveFunction &resolve, - const async::AsyncTaskHandle::RejectFunction &reject) { + [this, aOptions, gpuLock = getGPULock()]( + const async::AsyncTaskHandle::ResolveFunction &resolve, + const async::AsyncTaskHandle::RejectFunction &reject) { _instance.RequestAdapter( &aOptions, wgpu::CallbackMode::AllowProcessEvents, - [asyncRunner = _async, resolve, + [asyncRunner = _async, gpuLock, resolve, reject](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message) { if (message.length) { @@ -54,6 +58,7 @@ async::AsyncTaskHandle GPU::requestAdapter( if (status == wgpu::RequestAdapterStatus::Success && adapter) { auto adapterHost = std::make_shared( std::move(adapter), asyncRunner); + adapterHost->setGPULock(gpuLock); auto result = std::variant>( adapterHost); diff --git a/packages/webgpu/cpp/rnwgpu/api/GPU.h b/packages/webgpu/cpp/rnwgpu/api/GPU.h index f6bb4ede3..9259ede2f 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPU.h +++ b/packages/webgpu/cpp/rnwgpu/api/GPU.h @@ -47,7 +47,7 @@ class GPU : public NativeObject { &GPU::getWgslLanguageFeatures); } - inline const wgpu::Instance get() { return _instance; } + inline const GPUWithLock get() { return GPUWithLock{ .gpu = _instance, .lock = _gpuLock }; } inline std::shared_ptr getAsyncRunner() { return _async; } private: diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUAdapter.cpp b/packages/webgpu/cpp/rnwgpu/api/GPUAdapter.cpp index 27bf14f8b..b8c14b443 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUAdapter.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPUAdapter.cpp @@ -76,7 +76,7 @@ async::AsyncTaskHandle GPUAdapter::requestDevice( message.length > 0 ? std::string(message.data, message.length) : ""; std::string fullMessage = msg.length() > 0 ? std::string(errorType) + ": " + msg : "no message"; - fprintf(stderr, "%s\n", fullMessage.c_str()); + Logger::logToConsole("%s\n", fullMessage.c_str()); // Look up the GPUDevice from the registry and notify it if (auto gpuDevice = GPUDevice::lookupDevice(device.Get())) { @@ -89,18 +89,20 @@ async::AsyncTaskHandle GPUAdapter::requestDevice( auto creationRuntime = getCreationRuntime(); return _async->postTask( [this, aDescriptor, descriptor, label = std::move(label), - deviceLostBinding, - creationRuntime](const async::AsyncTaskHandle::ResolveFunction &resolve, - const async::AsyncTaskHandle::RejectFunction &reject) { + deviceLostBinding, creationRuntime, + gpuLock = getGPULock()]( + const async::AsyncTaskHandle::ResolveFunction &resolve, + const async::AsyncTaskHandle::RejectFunction &reject) { (void)descriptor; _instance.RequestDevice( &aDescriptor, wgpu::CallbackMode::AllowProcessEvents, [asyncRunner = _async, resolve, reject, label, creationRuntime, - deviceLostBinding](wgpu::RequestDeviceStatus status, + deviceLostBinding, + gpuLock](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) mutable { if (message.length) { - fprintf(stderr, "%s", message.data); + Logger::logToConsole("%s", message.data); } if (status != wgpu::RequestDeviceStatus::Success || !device) { @@ -144,6 +146,7 @@ async::AsyncTaskHandle GPUAdapter::requestDevice( auto deviceHost = std::make_shared(std::move(device), asyncRunner, label); + deviceHost->setGPULock(gpuLock); *deviceLostBinding = deviceHost; // Register the device in the static registry so the uncaptured diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.cpp b/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.cpp index 15e7f4acd..2b3630aa1 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.cpp @@ -3,18 +3,18 @@ #include "RNWebGPUManager.h" #include -#ifdef __APPLE__ -namespace dawn::native::metal { +namespace rnwgpu { -void WaitForCommandsToBeScheduled(WGPUDevice device); +GPUCanvasContext::GPUCanvasContext(std::shared_ptr gpu, int contextId, + float width, float height, float pixelRatio) + : NativeObject(CLASS_NAME), _gpu(std::move(gpu)), _contextId(contextId) { + _canvas = std::make_shared(nullptr, width, height, pixelRatio); + auto ®istry = SurfaceRegistry::getInstance(); + _bridge = registry.getSurfaceInfoOrCreate(contextId, _gpu->get()); } -#endif - -namespace rnwgpu { -void GPUCanvasContext::configure( - std::shared_ptr configuration) { +void GPUCanvasContext::configure(std::shared_ptr configuration) { Convertor conv; wgpu::SurfaceConfiguration surfaceConfiguration; surfaceConfiguration.device = configuration->device->get(); @@ -34,32 +34,27 @@ void GPUCanvasContext::configure( surfaceConfiguration.alphaMode = configuration->alphaMode; #endif surfaceConfiguration.presentMode = wgpu::PresentMode::Fifo; - _surfaceInfo->configure(surfaceConfiguration); + _bridge->configure(surfaceConfiguration); } void GPUCanvasContext::unconfigure() {} std::shared_ptr GPUCanvasContext::getCurrentTexture() { - auto prevSize = _surfaceInfo->getConfig(); - auto width = _canvas->getWidth(); - auto height = _canvas->getHeight(); - auto sizeHasChanged = prevSize.width != width || prevSize.height != height; - if (sizeHasChanged) { - _surfaceInfo->reconfigure(width, height); + auto texture = _bridge->getCurrentTexture(_canvas->getWidth(), _canvas->getHeight()); + if (!texture) { + return nullptr; } - auto texture = _surfaceInfo->getCurrentTexture(); - return std::make_shared(texture, ""); + auto result = std::make_shared(texture, ""); + result->setGPULock(getGPULock()); + _startedFrame = true; + return result; } void GPUCanvasContext::present() { -#ifdef __APPLE__ - dawn::native::metal::WaitForCommandsToBeScheduled( - _surfaceInfo->getDevice().Get()); -#endif - auto size = _surfaceInfo->getSize(); - _canvas->setClientWidth(size.width); - _canvas->setClientHeight(size.height); - _surfaceInfo->present(); + if (_startedFrame) { + _bridge->present(); + } + _startedFrame = false; } } // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.h b/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.h index 4b97a7887..93daaac05 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.h +++ b/packages/webgpu/cpp/rnwgpu/api/GPUCanvasContext.h @@ -4,6 +4,8 @@ #include #include +#include + #include "Unions.h" #include "webgpu/webgpu_cpp.h" @@ -24,19 +26,15 @@ class GPUCanvasContext : public NativeObject { public: static constexpr const char *CLASS_NAME = "GPUCanvasContext"; - GPUCanvasContext(std::shared_ptr gpu, int contextId, int width, - int height) - : NativeObject(CLASS_NAME), _gpu(std::move(gpu)) { - _canvas = std::make_shared(nullptr, width, height); - auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); - _surfaceInfo = - registry.getSurfaceInfoOrCreate(contextId, _gpu->get(), width, height); - } + GPUCanvasContext(std::shared_ptr gpu, int contextId, + float width, float height, float pixelRatio); public: std::string getBrand() { return CLASS_NAME; } - std::shared_ptr getCanvas() { return _canvas; } + std::shared_ptr getCanvas() { + return _canvas; + } static void definePrototype(jsi::Runtime &runtime, jsi::Object &prototype) { installGetter(runtime, prototype, "__brand", &GPUCanvasContext::getBrand); @@ -58,9 +56,13 @@ class GPUCanvasContext : public NativeObject { void present(); private: - std::shared_ptr _canvas; - std::shared_ptr _surfaceInfo; + int _contextId; + bool _startedFrame = false; + std::shared_ptr _bridge; std::shared_ptr _gpu; + std::shared_ptr _canvas; + std::shared_ptr _measureCallback; + std::shared_ptr _callInvoker; }; } // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUCommandEncoder.cpp b/packages/webgpu/cpp/rnwgpu/api/GPUCommandEncoder.cpp index 7ef1ce064..d2e943d01 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUCommandEncoder.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPUCommandEncoder.cpp @@ -32,9 +32,11 @@ std::shared_ptr GPUCommandEncoder::finish( "GPUCommandEncoder::finish(): error with GPUCommandBufferDescriptor"); } auto commandBuffer = _instance.Finish(&desc); - return std::make_shared( + auto result = std::make_shared( commandBuffer, descriptor.has_value() ? descriptor.value()->label.value_or("") : ""); + result->setGPULock(getGPULock()); + return result; } std::shared_ptr GPUCommandEncoder::beginRenderPass( @@ -57,8 +59,10 @@ std::shared_ptr GPUCommandEncoder::beginRenderPass( "get GPURenderPassDescriptor"); } auto renderPass = _instance.BeginRenderPass(&desc); - return std::make_shared(renderPass, - descriptor->label.value_or("")); + auto result = std::make_shared(renderPass, + descriptor->label.value_or("")); + result->setGPULock(getGPULock()); + return result; } void GPUCommandEncoder::copyTextureToBuffer( @@ -104,9 +108,11 @@ std::shared_ptr GPUCommandEncoder::beginComputePass( "access GPUComputePassDescriptor."); } auto computePass = _instance.BeginComputePass(&desc); - return std::make_shared( + auto result = std::make_shared( computePass, descriptor.has_value() ? descriptor.value()->label.value_or("") : ""); + result->setGPULock(getGPULock()); + return result; } void GPUCommandEncoder::resolveQuerySet(std::shared_ptr querySet, diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUComputePipeline.cpp b/packages/webgpu/cpp/rnwgpu/api/GPUComputePipeline.cpp index 4a8be1d72..7496a88ee 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUComputePipeline.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPUComputePipeline.cpp @@ -6,7 +6,9 @@ namespace rnwgpu { std::shared_ptr GPUComputePipeline::getBindGroupLayout(uint32_t groupIndex) { auto bindGroup = _instance.GetBindGroupLayout(groupIndex); - return std::make_shared(bindGroup, ""); + auto result = std::make_shared(bindGroup, ""); + result->setGPULock(getGPULock()); + return result; } } // namespace rnwgpu \ No newline at end of file diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUDevice.cpp b/packages/webgpu/cpp/rnwgpu/api/GPUDevice.cpp index 909c4555a..57da32dd0 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUDevice.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPUDevice.cpp @@ -52,8 +52,8 @@ GPUDevice::createBuffer(std::shared_ptr descriptor) { "GPUDevice::createBuffer(): Error with GPUBufferDescriptor"); } auto result = _instance.CreateBuffer(&desc); - return std::make_shared(result, _async, - descriptor->label.value_or("")); + return makeChild(result, _async, + descriptor->label.value_or("")); } std::shared_ptr GPUDevice::getLimits() { @@ -66,7 +66,7 @@ std::shared_ptr GPUDevice::getLimits() { std::shared_ptr GPUDevice::getQueue() { auto result = _instance.GetQueue(); - return std::make_shared(result, _async, _label); + return makeChild(result, _async, _label); } std::shared_ptr GPUDevice::createCommandEncoder( @@ -77,7 +77,7 @@ std::shared_ptr GPUDevice::createCommandEncoder( throw std::runtime_error("Error with GPUCommandEncoderDescriptor"); } auto result = _instance.CreateCommandEncoder(&desc); - return std::make_shared( + return makeChild( result, descriptor.has_value() ? descriptor.value()->label.value_or("") : ""); } @@ -95,7 +95,7 @@ GPUDevice::createTexture(std::shared_ptr descriptor) { throw std::runtime_error("Error with GPUTextureDescriptor"); } auto texture = _instance.CreateTexture(&desc); - return std::make_shared(texture, descriptor->label.value_or("")); + return makeChild(texture, descriptor->label.value_or("")); } std::shared_ptr GPUDevice::createShaderModule( @@ -111,11 +111,11 @@ std::shared_ptr GPUDevice::createShaderModule( if (descriptor->code.find('\0') != std::string::npos) { auto mod = _instance.CreateErrorShaderModule( &sm_desc, "The WGSL shader contains an illegal character '\\0'"); - return std::make_shared(mod, _async, sm_desc.label.data); + return makeChild(mod, _async, sm_desc.label.data); } auto module = _instance.CreateShaderModule(&sm_desc); - return std::make_shared(module, _async, - descriptor->label.value_or("")); + return makeChild(module, _async, + descriptor->label.value_or("")); } std::shared_ptr GPUDevice::createRenderPipeline( @@ -127,8 +127,8 @@ std::shared_ptr GPUDevice::createRenderPipeline( } // assert(desc.fragment != nullptr && "Fragment state must not be null"); auto renderPipeline = _instance.CreateRenderPipeline(&desc); - return std::make_shared(renderPipeline, - descriptor->label.value_or("")); + return makeChild(renderPipeline, + descriptor->label.value_or("")); } std::shared_ptr @@ -142,8 +142,8 @@ GPUDevice::createBindGroup(std::shared_ptr descriptor) { "GPUBindGroup::createBindGroup(): Error with GPUBindGroupDescriptor"); } auto bindGroup = _instance.CreateBindGroup(&desc); - return std::make_shared(bindGroup, - descriptor->label.value_or("")); + return makeChild(bindGroup, + descriptor->label.value_or("")); } std::shared_ptr GPUDevice::createSampler( @@ -155,7 +155,7 @@ std::shared_ptr GPUDevice::createSampler( "GPUSamplerDescriptor"); } auto sampler = _instance.CreateSampler(&desc); - return std::make_shared( + return makeChild( sampler, descriptor.has_value() ? descriptor.value()->label.value_or("") : ""); } @@ -169,8 +169,8 @@ std::shared_ptr GPUDevice::createComputePipeline( "GPUComputePipelineDescriptor"); } auto computePipeline = _instance.CreateComputePipeline(&desc); - return std::make_shared(computePipeline, - descriptor->label.value_or("")); + return makeChild(computePipeline, + descriptor->label.value_or("")); } std::shared_ptr @@ -182,8 +182,8 @@ GPUDevice::createQuerySet(std::shared_ptr descriptor) { "GPUQuerySetDescriptor"); } auto querySet = _instance.CreateQuerySet(&desc); - return std::make_shared(querySet, - descriptor->label.value_or("")); + return makeChild(querySet, + descriptor->label.value_or("")); } std::shared_ptr GPUDevice::createRenderBundleEncoder( @@ -200,7 +200,7 @@ std::shared_ptr GPUDevice::createRenderBundleEncoder( !conv(desc.stencilReadOnly, descriptor->stencilReadOnly)) { return {}; } - return std::make_shared( + return makeChild( _instance.CreateRenderBundleEncoder(&desc), descriptor->label.value_or("")); } @@ -214,7 +214,7 @@ std::shared_ptr GPUDevice::createBindGroupLayout( !conv(desc.entries, desc.entryCount, descriptor->entries)) { return {}; } - return std::make_shared( + return makeChild( _instance.CreateBindGroupLayout(&desc), descriptor->label.value_or("")); } @@ -228,7 +228,7 @@ std::shared_ptr GPUDevice::createPipelineLayout( descriptor->bindGroupLayouts)) { return {}; } - return std::make_shared( + return makeChild( _instance.CreatePipelineLayout(&desc), descriptor->label.value_or("")); } @@ -249,7 +249,7 @@ async::AsyncTaskHandle GPUDevice::createComputePipelineAsync( auto label = std::string( descriptor->label.has_value() ? descriptor->label.value() : ""); - auto pipelineHolder = std::make_shared(nullptr, label); + auto pipelineHolder = makeChild(nullptr, label); return _async->postTask([device = _instance, desc, descriptor, pipelineHolder]( @@ -290,7 +290,7 @@ async::AsyncTaskHandle GPUDevice::createRenderPipelineAsync( auto label = std::string( descriptor->label.has_value() ? descriptor->label.value() : ""); - auto pipelineHolder = std::make_shared(nullptr, label); + auto pipelineHolder = makeChild(nullptr, label); return _async->postTask([device = _instance, desc, descriptor, pipelineHolder]( diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUDevice.h b/packages/webgpu/cpp/rnwgpu/api/GPUDevice.h index 9b0681c2f..cd237999a 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUDevice.h +++ b/packages/webgpu/cpp/rnwgpu/api/GPUDevice.h @@ -239,6 +239,16 @@ class GPUDevice : public NativeObject { inline const wgpu::Device get() { return _instance; } + /** + * Create a child NativeObject and propagate the GPU lock to it. + */ + template + std::shared_ptr makeChild(Args&&... args) { + auto child = std::make_shared(std::forward(args)...); + child->setGPULock(getGPULock()); + return child; + } + private: friend class GPUAdapter; diff --git a/packages/webgpu/cpp/rnwgpu/api/GPURenderBundleEncoder.cpp b/packages/webgpu/cpp/rnwgpu/api/GPURenderBundleEncoder.cpp index ece7257ea..94e8833fc 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPURenderBundleEncoder.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPURenderBundleEncoder.cpp @@ -16,9 +16,11 @@ std::shared_ptr GPURenderBundleEncoder::finish( "GPURenderBundleDescriptor"); } auto bundle = _instance.Finish(&desc); - return std::make_shared( + auto result = std::make_shared( bundle, descriptor.has_value() ? descriptor.value()->label.value_or("") : ""); + result->setGPULock(getGPULock()); + return result; } void GPURenderBundleEncoder::setPipeline( diff --git a/packages/webgpu/cpp/rnwgpu/api/GPURenderPipeline.cpp b/packages/webgpu/cpp/rnwgpu/api/GPURenderPipeline.cpp index 5c9bd1153..c4a520311 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPURenderPipeline.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPURenderPipeline.cpp @@ -7,7 +7,9 @@ namespace rnwgpu { std::shared_ptr GPURenderPipeline::getBindGroupLayout(uint32_t groupIndex) { auto bindGroupLayout = _instance.GetBindGroupLayout(groupIndex); - return std::make_shared(bindGroupLayout, ""); + auto result = std::make_shared(bindGroupLayout, ""); + result->setGPULock(getGPULock()); + return result; } } // namespace rnwgpu \ No newline at end of file diff --git a/packages/webgpu/cpp/rnwgpu/api/GPUTexture.cpp b/packages/webgpu/cpp/rnwgpu/api/GPUTexture.cpp index f1d84b99c..ca2eb7822 100644 --- a/packages/webgpu/cpp/rnwgpu/api/GPUTexture.cpp +++ b/packages/webgpu/cpp/rnwgpu/api/GPUTexture.cpp @@ -17,9 +17,11 @@ std::shared_ptr GPUTexture::createView( "GPUTextureViewDescriptor"); } auto view = _instance.CreateView(&desc); - return std::make_shared( + auto result = std::make_shared( view, descriptor.has_value() ? descriptor.value()->label.value_or("") : ""); + result->setGPULock(getGPULock()); + return result; } uint32_t GPUTexture::getWidth() { return _instance.GetWidth(); } diff --git a/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h b/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h index 59fe14bd8..96ba80f14 100644 --- a/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h +++ b/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h @@ -69,9 +69,10 @@ class RNWebGPU : public NativeObject { bool getFabric() { return true; } std::shared_ptr - MakeWebGPUCanvasContext(int contextId, float width, float height) { + MakeWebGPUCanvasContext(int contextId, float width, float height, float pixelRatio) { auto ctx = - std::make_shared(_gpu, contextId, width, height); + std::make_shared(_gpu, contextId, width, height, pixelRatio); + ctx->setGPULock(_gpu->getGPULock()); return ctx; } @@ -176,11 +177,11 @@ class RNWebGPU : public NativeObject { auto ®istry = rnwgpu::SurfaceRegistry::getInstance(); auto info = registry.getSurfaceInfo(contextId); if (info == nullptr) { - return std::make_shared(nullptr, 0, 0); + return std::make_shared(nullptr, 0, 0, 0); } auto nativeInfo = info->getNativeInfo(); return std::make_shared(nativeInfo.nativeSurface, nativeInfo.width, - nativeInfo.height); + nativeInfo.height, 1.f); } static void definePrototype(jsi::Runtime &runtime, jsi::Object &prototype) { diff --git a/packages/webgpu/react-native-wgpu.podspec b/packages/webgpu/react-native-wgpu.podspec index ac01a3b66..f1cf2772b 100644 --- a/packages/webgpu/react-native-wgpu.podspec +++ b/packages/webgpu/react-native-wgpu.podspec @@ -16,13 +16,13 @@ Pod::Spec.new do |s| s.source_files = [ "apple/**/*.{h,c,cc,cpp,m,mm,swift}", - "cpp/**/*.{h,cpp}" + "cpp/**/*.{h,cpp,mm}" ] s.vendored_frameworks = 'libs/apple/libwebgpu_dawn.xcframework' s.pod_target_xcconfig = { - 'HEADER_SEARCH_PATHS' => '$(PODS_TARGET_SRCROOT)/cpp', + 'HEADER_SEARCH_PATHS' => '$(PODS_TARGET_SRCROOT)/cpp $(PODS_TARGET_SRCROOT)/apple', } # Use install_modules_dependencies helper to install the dependencies if React Native version >=0.71.0. diff --git a/packages/webgpu/src/Canvas.tsx b/packages/webgpu/src/Canvas.tsx index 142e5de2c..e56f6b238 100644 --- a/packages/webgpu/src/Canvas.tsx +++ b/packages/webgpu/src/Canvas.tsx @@ -1,6 +1,6 @@ import React, { useImperativeHandle, useRef, useState } from "react"; import type { ViewProps } from "react-native"; -import { View } from "react-native"; +import { PixelRatio, View } from "react-native"; import WebGPUNativeView from "./WebGPUViewNativeComponent"; @@ -18,6 +18,7 @@ declare global { contextId: number, width: number, height: number, + pixelRatio: number, ) => RNCanvasContext; DecodeToUTF8: (buffer: NodeJS.ArrayBufferView | ArrayBuffer) => string; createImageBitmap: typeof createImageBitmap; @@ -26,20 +27,24 @@ declare global { type SurfacePointer = bigint; -export interface NativeCanvas { - surface: SurfacePointer; +export interface CanvasSize { width: number; height: number; clientWidth: number; clientHeight: number; } +export type NativeCanvas = CanvasSize & { + surface: SurfacePointer; +}; + export type RNCanvasContext = GPUCanvasContext & { present: () => void; }; export interface CanvasRef { getContextId: () => number; + measureView: (canvasTarget: unknown) => CanvasSize; getContext(contextName: "webgpu"): RNCanvasContext | null; getNativeSurface: () => NativeCanvas; } @@ -49,38 +54,80 @@ interface CanvasProps extends ViewProps { ref?: React.Ref; } +function getViewSize(view: View): { width: number; height: number } { + // let widthRes = 0, heightRes = 0; + // view.measure((x, y, width, height, pageX, pageY) => { + // widthRes = width; + // heightRes = height; + // }); + // + // console.log(`Size: ${widthRes}x${heightRes}`); + // return { width: widthRes, height: heightRes }; + // getBoundingClientRect became stable in RN 0.83 + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const viewAny = view as any; + const size = + "getBoundingClientRect" in viewAny + ? viewAny.getBoundingClientRect() + : viewAny.unstable_getBoundingClientRect(); + return size; +} + export const Canvas = ({ transparent, ref, ...props }: CanvasProps) => { - const viewRef = useRef(null); + const viewRef = useRef(null); const [contextId, _] = useState(() => generateContextId()); - useImperativeHandle(ref, () => ({ - getContextId: () => contextId, - getNativeSurface: () => { - return RNWebGPU.getNativeSurface(contextId); - }, - getContext(contextName: "webgpu"): RNCanvasContext | null { - if (contextName !== "webgpu") { - throw new Error(`[WebGPU] Unsupported context: ${contextName}`); - } - if (!viewRef.current) { - throw new Error("[WebGPU] Cannot get context before mount"); - } - // getBoundingClientRect became stable in RN 0.83 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const view = viewRef.current as any; - const size = - "getBoundingClientRect" in view - ? view.getBoundingClientRect() - : view.unstable_getBoundingClientRect(); - return RNWebGPU.MakeWebGPUCanvasContext( - contextId, - size.width, - size.height, - ); - }, - })); + useImperativeHandle(ref, () => { + return { + getContextId: () => contextId, + getNativeSurface: () => { + return RNWebGPU.getNativeSurface(contextId); + }, + measureView: (canvasTarget: unknown): CanvasSize => { + if (!viewRef.current) { + throw new Error("[WebGPU] Cannot get context before mount"); + } + + const sz = getViewSize(viewRef.current); + const pixelRatio = PixelRatio.get(); + const res = { + width: sz.width * pixelRatio, + height: sz.height * pixelRatio, + clientWidth: sz.width, + clientHeight: sz.height, + }; + if (canvasTarget) { + const canvas = canvasTarget as NativeCanvas; + canvas.width = res.width; + canvas.height = res.height; + canvas.clientWidth = res.clientWidth; + canvas.clientHeight = res.clientHeight; + } + return res; + }, + getContext(contextName: "webgpu"): RNCanvasContext | null { + if (contextName !== "webgpu") { + throw new Error(`[WebGPU] Unsupported context: ${contextName}`); + } + if (!viewRef.current) { + throw new Error("[WebGPU] Cannot get context before mount"); + } + + const pixelRatio = PixelRatio.get(); + const sz = getViewSize(viewRef.current); + + return RNWebGPU.MakeWebGPUCanvasContext( + contextId, + sz.width, + sz.height, + pixelRatio, + ); + }, + } satisfies CanvasRef; + }); + const withNativeId = { ...props, nativeID: `webgpu-container-${contextId}` }; return ( - +