diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index aa4c6d29e..ea47ebdb3 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -59,6 +59,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} > + null, + title: 'Main Menu', + drawerItemStyle: { display: 'none' }, + }} + /> - null, - title: 'Main Menu', - drawerItemStyle: { display: 'none' }, - }} - /> ); diff --git a/apps/computer-vision/app/classification/index.tsx b/apps/computer-vision/app/classification/index.tsx index 4a1e473d8..d45e0ddbf 100644 --- a/apps/computer-vision/app/classification/index.tsx +++ b/apps/computer-vision/app/classification/index.tsx @@ -2,8 +2,16 @@ import Spinner from '../../components/Spinner'; import { getImage } from '../../utils'; import { useClassification, + EFFICIENTNET_V2_S, EFFICIENTNET_V2_S_QUANTIZED, + ClassificationModelSources, } from 'react-native-executorch'; +import { ModelPicker, ModelOption } from '../../components/ModelPicker'; + +const MODELS: ModelOption[] = [ + { label: 'EfficientNet V2 S Quantized', value: EFFICIENTNET_V2_S_QUANTIZED }, + { label: 'EfficientNet V2 S', value: EFFICIENTNET_V2_S }, +]; import { View, StyleSheet, Image, Text, ScrollView } from 'react-native'; import { BottomBar } from '../../components/BottomBar'; import React, { useContext, useEffect, useState } from 'react'; @@ -13,6 +21,8 @@ import { StatsBar } from '../../components/StatsBar'; import ErrorBanner from '../../components/ErrorBanner'; export default function ClassificationScreen() { + const [selectedModel, setSelectedModel] = + useState(EFFICIENTNET_V2_S_QUANTIZED); const [results, setResults] = useState<{ label: string; score: number }[]>( [] ); @@ -21,7 +31,7 @@ export default function ClassificationScreen() { const [error, setError] = useState(null); - const model = useClassification({ model: EFFICIENTNET_V2_S_QUANTIZED }); + const model = useClassification({ model: selectedModel }); const { setGlobalGenerating } = useContext(GeneratingContext); useEffect(() => { @@ -82,6 +92,16 @@ export default function ClassificationScreen() { : require('../../assets/icons/executorch_logo.png') } /> + {!imageUri && ( + + Image Classification + + This model analyzes an image and returns the top 10 most likely + labels with confidence scores. Use the gallery or camera icons + below to pick an image, then tap the button to run the model. + + + )} {results.length > 0 && ( Results Top 10 @@ -96,10 +116,21 @@ export default function ClassificationScreen() { )} + { + setSelectedModel(m); + setResults([]); + }} + /> ); @@ -141,4 +172,20 @@ const styles = StyleSheet.create({ flex: 1, marginRight: 4, }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/instance_segmentation/index.tsx b/apps/computer-vision/app/instance_segmentation/index.tsx index f833ffa24..dba53875e 100644 --- a/apps/computer-vision/app/instance_segmentation/index.tsx +++ b/apps/computer-vision/app/instance_segmentation/index.tsx @@ -70,9 +70,19 @@ export default function InstanceSegmentationScreen() { // Set default input size when model is ready useEffect(() => { - if (isReady && availableInputSizes && availableInputSizes.length > 0) { - setSelectedInputSize(availableInputSizes[0]); + if (!isReady) return; + + if (availableInputSizes && availableInputSizes.length > 0) { + setSelectedInputSize((prev) => { + if (typeof prev === 'number' && availableInputSizes.includes(prev)) { + return prev; + } + return availableInputSizes[0]; + }); + return; } + + setSelectedInputSize(null); }, [isReady, availableInputSizes]); const handleCameraPress = async (isCamera: boolean) => { @@ -90,6 +100,13 @@ export default function InstanceSegmentationScreen() { const runForward = async () => { if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return; + const inputSize = + availableInputSizes && + typeof selectedInputSize === 'number' && + availableInputSizes.includes(selectedInputSize) + ? selectedInputSize + : undefined; + try { const start = Date.now(); const output = await forward(imageUri, { @@ -97,7 +114,7 @@ export default function InstanceSegmentationScreen() { iouThreshold: 0.55, maxInstances: 20, returnMaskAtOriginalResolution: true, - inputSize: selectedInputSize ?? undefined, + inputSize, }); setInferenceTime(Date.now() - start); @@ -144,6 +161,16 @@ export default function InstanceSegmentationScreen() { imageWidth={imageSize.width} imageHeight={imageSize.height} /> + {!imageUri && ( + + Instance Segmentation + + This model detects individual objects and draws a precise mask + over each one. Pick an image from your gallery or take one with + your camera to get started. + + + )} {imageUri && availableInputSizes && availableInputSizes.length > 0 && ( @@ -215,6 +242,8 @@ export default function InstanceSegmentationScreen() { ); @@ -318,4 +347,20 @@ const styles = StyleSheet.create({ color: '#999', fontFamily: 'Courier', }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 2d03c34ce..ea4a9fc7b 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -14,7 +14,7 @@ import { YOLO26X, ObjectDetectionModelSources, } from 'react-native-executorch'; -import { View, StyleSheet, Image } from 'react-native'; +import { View, StyleSheet, Image, Text } from 'react-native'; import ImageWithBboxes from '../../components/ImageWithBboxes'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; @@ -112,6 +112,16 @@ export default function ObjectDetectionScreen() { /> )} + {!imageUri && ( + + Object Detection + + This model detects objects in an image and draws bounding boxes + around them with class labels and confidence scores. Pick an image + from your gallery or take one with your camera to get started. + + + )} ); @@ -149,4 +161,20 @@ const styles = StyleSheet.create({ width: '100%', height: '100%', }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/ocr/index.tsx b/apps/computer-vision/app/ocr/index.tsx index e46828798..de4abcd40 100644 --- a/apps/computer-vision/app/ocr/index.tsx +++ b/apps/computer-vision/app/ocr/index.tsx @@ -112,6 +112,17 @@ export default function OCRScreen() { /> )} + {!imageUri && ( + + OCR + + This model reads and extracts text from images, returning each + detected text region with its bounding box and confidence score. + Pick an image from your gallery or take one with your camera to + get started. + + + )} {results.length > 0 && ( Results @@ -142,6 +153,8 @@ export default function OCRScreen() { ); @@ -187,4 +200,20 @@ const styles = StyleSheet.create({ flex: 1, marginRight: 4, }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/ocr_vertical/index.tsx b/apps/computer-vision/app/ocr_vertical/index.tsx index 90d052d8b..8754ee8b5 100644 --- a/apps/computer-vision/app/ocr_vertical/index.tsx +++ b/apps/computer-vision/app/ocr_vertical/index.tsx @@ -92,6 +92,26 @@ export default function VerticalOCRScreen() { /> )} + {!imageUri && ( + + Vertical OCR + + This model reads vertical text (e.g. Japanese, Korean, Chinese + columns) from images, returning each detected text region with its + bounding box and confidence score. Pick an image from your gallery + or take one with your camera to get started. + + + )} + {imageUri && inferenceTime !== null && results.length === 0 && ( + + No text detected + + The model did not find any vertical text in this image. Try an + image containing vertical Japanese, Korean, or Chinese text. + + + )} {results.length > 0 && ( Results @@ -113,6 +133,8 @@ export default function VerticalOCRScreen() { ); @@ -158,4 +180,20 @@ const styles = StyleSheet.create({ flex: 1, marginRight: 4, }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/semantic_segmentation/index.tsx b/apps/computer-vision/app/semantic_segmentation/index.tsx index 2e743174f..ba998bca2 100644 --- a/apps/computer-vision/app/semantic_segmentation/index.tsx +++ b/apps/computer-vision/app/semantic_segmentation/index.tsx @@ -9,6 +9,7 @@ import { LRASPP_MOBILENET_V3_LARGE_QUANTIZED, FCN_RESNET50_QUANTIZED, FCN_RESNET101_QUANTIZED, + SELFIE_SEGMENTATION, useSemanticSegmentation, SemanticSegmentationModelSources, } from 'react-native-executorch'; @@ -20,7 +21,7 @@ import { ColorType, SkImage, } from '@shopify/react-native-skia'; -import { View, StyleSheet, Image } from 'react-native'; +import { View, StyleSheet, Image, Text } from 'react-native'; import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; @@ -61,6 +62,7 @@ const MODELS: ModelOption[] = [ { label: 'LRASPP MobileNet', value: LRASPP_MOBILENET_V3_LARGE_QUANTIZED }, { label: 'FCN ResNet50', value: FCN_RESNET50_QUANTIZED }, { label: 'FCN ResNet101', value: FCN_RESNET101_QUANTIZED }, + { label: 'Selfie Segmentation', value: SELFIE_SEGMENTATION }, ]; export default function SemanticSegmentationScreen() { @@ -70,8 +72,13 @@ export default function SemanticSegmentationScreen() { DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED ); - const { isReady, isGenerating, downloadProgress, forward, error: modelError } = - useSemanticSegmentation({ model: selectedModel }); + const { + isReady, + isGenerating, + downloadProgress, + forward, + error: modelError, + } = useSemanticSegmentation({ model: selectedModel }); const [imageUri, setImageUri] = useState(''); const [imageSize, setImageSize] = useState({ width: 0, height: 0 }); @@ -158,6 +165,16 @@ export default function SemanticSegmentationScreen() { : require('../../assets/icons/executorch_logo.png') } /> + {!imageUri && ( + + Semantic Segmentation + + This model assigns a class label to every pixel in an image, + painting each region with a distinct color. Pick an image from + your gallery or take one with your camera to get started. + + + )} {segImage && ( ); @@ -212,4 +231,20 @@ const styles = StyleSheet.create({ padding: 4, }, canvas: { width: '100%', height: '100%' }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index d1ac25112..877008706 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -12,7 +12,8 @@ import { ResourceSource, } from 'react-native-executorch'; -import { View, StyleSheet, Image } from 'react-native'; +import { View, StyleSheet, Image, Text } from 'react-native'; + import React, { useContext, useEffect, useState } from 'react'; import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; @@ -98,6 +99,16 @@ export default function StyleTransferScreen() { : require('../../assets/icons/executorch_logo.png') } /> + {!imageUri && ( + + Style Transfer + + This model applies artistic styles to your images, transforming + them to look like famous paintings. Pick an image from your + gallery or take one with your camera to get started. + + + )} ); @@ -120,4 +133,20 @@ export default function StyleTransferScreen() { const styles = StyleSheet.create({ imageContainer: { flex: 6, width: '100%', padding: 16 }, image: { flex: 1, borderRadius: 8, width: '100%' }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/text_to_image/index.tsx b/apps/computer-vision/app/text_to_image/index.tsx index f8aa179c7..6af71227e 100644 --- a/apps/computer-vision/app/text_to_image/index.tsx +++ b/apps/computer-vision/app/text_to_image/index.tsx @@ -133,16 +133,21 @@ export default function TextToImageScreen() { Generating... - ) : ( + ) : image?.length ? ( + ) : ( + + Text to Image + + This model generates images from text descriptions using a + diffusion process. Type a prompt below and tap the send button + to generate an image. + + )} @@ -194,7 +199,10 @@ export default function TextToImageScreen() { ) : ( @@ -296,4 +304,23 @@ const styles = StyleSheet.create({ alignItems: 'center', justifyContent: 'center', }, + sendButtonDisabled: { + backgroundColor: '#888', + }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, }); diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index b7e4de008..dd14be000 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -25,6 +25,8 @@ import { } from 'react-native-vision-camera'; import { createSynchronizable } from 'react-native-worklets'; import Svg, { Path, Polygon } from 'react-native-svg'; +import { useRouter } from 'expo-router'; +import { Ionicons } from '@expo/vector-icons'; import { GeneratingContext } from '../../context'; import Spinner from '../../components/Spinner'; import ColorPalette from '../../colors'; @@ -121,6 +123,7 @@ const cameraPositionSync = createSynchronizable<'front' | 'back'>('back'); export default function VisionCameraScreen() { const insets = useSafeAreaInsets(); + const router = useRouter(); const [activeTask, setActiveTask] = useState('classification'); const [activeModel, setActiveModel] = useState('classification'); const [canvasSize, setCanvasSize] = useState({ width: 1, height: 1 }); @@ -310,6 +313,13 @@ export default function VisionCameraScreen() { style={[styles.topOverlay, { paddingTop: insets.top + 8 }]} pointerEvents="box-none" > + router.navigate('/')} + > + + + {activeVariantLabel} @@ -525,4 +535,17 @@ const styles = StyleSheet.create({ borderWidth: 1.5, borderColor: 'rgba(255,255,255,0.4)', }, + backButton: { + position: 'absolute', + left: 12, + width: 40, + height: 40, + borderRadius: 20, + backgroundColor: 'rgba(0,0,0,0.45)', + justifyContent: 'center', + alignItems: 'center', + borderWidth: 1, + borderColor: 'rgba(255,255,255,0.25)', + zIndex: 10, + }, }); diff --git a/apps/computer-vision/components/BottomBar.tsx b/apps/computer-vision/components/BottomBar.tsx index c6a6e33fe..b77fdce4c 100644 --- a/apps/computer-vision/components/BottomBar.tsx +++ b/apps/computer-vision/components/BottomBar.tsx @@ -7,11 +7,16 @@ import { useSafeAreaInsets } from 'react-native-safe-area-context'; export const BottomBar = ({ handleCameraPress, runForward, + hasImage = true, + isGenerating = false, }: { handleCameraPress: (isCamera: boolean) => void; runForward: () => void; + hasImage?: boolean; + isGenerating?: boolean; }) => { const { bottom } = useSafeAreaInsets(); + const disabled = !hasImage || isGenerating; return ( @@ -31,8 +36,18 @@ export const BottomBar = ({ /> - - Run model + + + {isGenerating + ? 'Running...' + : hasImage + ? 'Run model' + : 'Pick an image to run the model'} + ); @@ -60,6 +75,9 @@ const styles = StyleSheet.create({ color: '#fff', borderRadius: 8, }, + buttonDisabled: { + backgroundColor: '#888', + }, buttonText: { color: '#fff', fontSize: 16, diff --git a/apps/computer-vision/components/ModelPicker.tsx b/apps/computer-vision/components/ModelPicker.tsx index 9c2deab13..94a848596 100644 --- a/apps/computer-vision/components/ModelPicker.tsx +++ b/apps/computer-vision/components/ModelPicker.tsx @@ -1,6 +1,7 @@ import React, { useEffect, useRef, useState } from 'react'; import { Dimensions, + Modal, ScrollView, StyleSheet, Text, @@ -33,6 +34,7 @@ export function ModelPicker({ const [open, setOpen] = useState(false); const [triggerHeight, setTriggerHeight] = useState(0); const [expandUp, setExpandUp] = useState(false); + const [dropdownTop, setDropdownTop] = useState(0); const triggerRef = useRef>(null); const selected = models.find((m) => m.value === selectedModel); @@ -40,7 +42,12 @@ export function ModelPicker({ if (disabled) setOpen(false); }, [disabled]); - const handleLayout = () => { + const handlePress = () => { + if (disabled) return; + if (open) { + setOpen(false); + return; + } triggerRef.current?.measure( ( _x: number, @@ -53,59 +60,85 @@ export function ModelPicker({ setTriggerHeight(height); const spaceBelow = Dimensions.get('window').height - (pageY + height); setExpandUp(spaceBelow < DROPDOWN_MAX_HEIGHT); + setDropdownTop(pageY); + setOpen(true); } ); }; - const dropdownPosition = expandUp - ? { bottom: triggerHeight + 2 } - : { top: triggerHeight + 2 }; + const dropdownStylePosition = expandUp + ? { + bottom: Dimensions.get('window').height - dropdownTop, + left: 12, + right: 12, + } + : { + top: dropdownTop + triggerHeight + 2, + left: 12, + right: 12, + }; return ( - - !disabled && setOpen((v) => !v)} - activeOpacity={disabled ? 1 : 0.7} - onLayout={handleLayout} - > - {label && {label}} - {selected?.label ?? '—'} - {open ? '▲' : '▼'} - + <> + + + {label && {label}} + {selected?.label ?? '—'} + {open ? '▲' : '▼'} + + {open && ( - setOpen(false)} + animationType="none" > - {models.map((item) => { - const isSelected = item.value === selectedModel; - return ( - { - onSelect(item.value); - setOpen(false); - }} - > - - {item.label} - - - ); - })} - + setOpen(false)} + /> + + + {models.map((item) => { + const isSelected = item.value === selectedModel; + return ( + { + onSelect(item.value); + setOpen(false); + }} + activeOpacity={0.7} + > + + {item.label} + + + ); + })} + + + )} - + ); } @@ -145,21 +178,23 @@ const styles = StyleSheet.create({ color: '#888', marginLeft: 6, }, + modalBackdrop: { + flex: 1, + backgroundColor: 'rgba(0, 0, 0, 0.3)', + }, dropdown: { position: 'absolute', - left: 0, - right: 0, borderWidth: 1, borderColor: '#C1C6E5', borderRadius: 8, backgroundColor: '#fff', maxHeight: DROPDOWN_MAX_HEIGHT, - zIndex: 100, - elevation: 4, + zIndex: 1000, + elevation: 5, shadowColor: '#000', - shadowOffset: { width: 0, height: 2 }, - shadowOpacity: 0.1, - shadowRadius: 4, + shadowOffset: { width: 0, height: 4 }, + shadowOpacity: 0.15, + shadowRadius: 6, }, option: { paddingHorizontal: 12, diff --git a/apps/llm/app/_layout.tsx b/apps/llm/app/_layout.tsx index f04edee9a..f2f49b534 100644 --- a/apps/llm/app/_layout.tsx +++ b/apps/llm/app/_layout.tsx @@ -61,6 +61,14 @@ export default function _layout() { ), }} > + null, + title: 'Main Menu', + drawerItemStyle: { display: 'none' }, + }} + /> - null, - title: 'Main Menu', - drawerItemStyle: { display: 'none' }, - }} - /> ); diff --git a/apps/llm/app/llm/index.tsx b/apps/llm/app/llm/index.tsx index 1e1ce1c54..901b74de4 100644 --- a/apps/llm/app/llm/index.tsx +++ b/apps/llm/app/llm/index.tsx @@ -84,9 +84,7 @@ function LLMScreen() { ) : ( {/* Image picker button */} Hello! 👋 - What can I help you with? + Tap the mic and speak to me. I'll transcribe your voice and + respond using a language model — all on-device. )} @@ -196,7 +197,15 @@ function VoiceChatScreen() { onSelect={(m) => setSelectedSTT(m)} /> - + {DeviceInfo.isEmulatorSync() ? ( @@ -267,7 +276,7 @@ const styles = StyleSheet.create({ color: ColorPalette.primary, }, bottomContainer: { - minHeight: 100, + height: 100, width: '100%', justifyContent: 'center', alignItems: 'center', diff --git a/apps/llm/components/ModelPicker.tsx b/apps/llm/components/ModelPicker.tsx index 9c2deab13..94a848596 100644 --- a/apps/llm/components/ModelPicker.tsx +++ b/apps/llm/components/ModelPicker.tsx @@ -1,6 +1,7 @@ import React, { useEffect, useRef, useState } from 'react'; import { Dimensions, + Modal, ScrollView, StyleSheet, Text, @@ -33,6 +34,7 @@ export function ModelPicker({ const [open, setOpen] = useState(false); const [triggerHeight, setTriggerHeight] = useState(0); const [expandUp, setExpandUp] = useState(false); + const [dropdownTop, setDropdownTop] = useState(0); const triggerRef = useRef>(null); const selected = models.find((m) => m.value === selectedModel); @@ -40,7 +42,12 @@ export function ModelPicker({ if (disabled) setOpen(false); }, [disabled]); - const handleLayout = () => { + const handlePress = () => { + if (disabled) return; + if (open) { + setOpen(false); + return; + } triggerRef.current?.measure( ( _x: number, @@ -53,59 +60,85 @@ export function ModelPicker({ setTriggerHeight(height); const spaceBelow = Dimensions.get('window').height - (pageY + height); setExpandUp(spaceBelow < DROPDOWN_MAX_HEIGHT); + setDropdownTop(pageY); + setOpen(true); } ); }; - const dropdownPosition = expandUp - ? { bottom: triggerHeight + 2 } - : { top: triggerHeight + 2 }; + const dropdownStylePosition = expandUp + ? { + bottom: Dimensions.get('window').height - dropdownTop, + left: 12, + right: 12, + } + : { + top: dropdownTop + triggerHeight + 2, + left: 12, + right: 12, + }; return ( - - !disabled && setOpen((v) => !v)} - activeOpacity={disabled ? 1 : 0.7} - onLayout={handleLayout} - > - {label && {label}} - {selected?.label ?? '—'} - {open ? '▲' : '▼'} - + <> + + + {label && {label}} + {selected?.label ?? '—'} + {open ? '▲' : '▼'} + + {open && ( - setOpen(false)} + animationType="none" > - {models.map((item) => { - const isSelected = item.value === selectedModel; - return ( - { - onSelect(item.value); - setOpen(false); - }} - > - - {item.label} - - - ); - })} - + setOpen(false)} + /> + + + {models.map((item) => { + const isSelected = item.value === selectedModel; + return ( + { + onSelect(item.value); + setOpen(false); + }} + activeOpacity={0.7} + > + + {item.label} + + + ); + })} + + + )} - + ); } @@ -145,21 +178,23 @@ const styles = StyleSheet.create({ color: '#888', marginLeft: 6, }, + modalBackdrop: { + flex: 1, + backgroundColor: 'rgba(0, 0, 0, 0.3)', + }, dropdown: { position: 'absolute', - left: 0, - right: 0, borderWidth: 1, borderColor: '#C1C6E5', borderRadius: 8, backgroundColor: '#fff', maxHeight: DROPDOWN_MAX_HEIGHT, - zIndex: 100, - elevation: 4, + zIndex: 1000, + elevation: 5, shadowColor: '#000', - shadowOffset: { width: 0, height: 2 }, - shadowOpacity: 0.1, - shadowRadius: 4, + shadowOffset: { width: 0, height: 4 }, + shadowOpacity: 0.15, + shadowRadius: 6, }, option: { paddingHorizontal: 12, diff --git a/apps/speech/components/ModelPicker.tsx b/apps/speech/components/ModelPicker.tsx index 9c2deab13..5e8284ee9 100644 --- a/apps/speech/components/ModelPicker.tsx +++ b/apps/speech/components/ModelPicker.tsx @@ -40,7 +40,12 @@ export function ModelPicker({ if (disabled) setOpen(false); }, [disabled]); - const handleLayout = () => { + const handlePress = () => { + if (disabled) return; + if (open) { + setOpen(false); + return; + } triggerRef.current?.measure( ( _x: number, @@ -53,6 +58,7 @@ export function ModelPicker({ setTriggerHeight(height); const spaceBelow = Dimensions.get('window').height - (pageY + height); setExpandUp(spaceBelow < DROPDOWN_MAX_HEIGHT); + setOpen(true); } ); }; @@ -66,9 +72,8 @@ export function ModelPicker({ !disabled && setOpen((v) => !v)} + onPress={handlePress} activeOpacity={disabled ? 1 : 0.7} - onLayout={handleLayout} > {label && {label}} {selected?.label ?? '—'} diff --git a/apps/text-embeddings/app/_layout.tsx b/apps/text-embeddings/app/_layout.tsx index 24952df29..886f158b3 100644 --- a/apps/text-embeddings/app/_layout.tsx +++ b/apps/text-embeddings/app/_layout.tsx @@ -58,6 +58,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} > + null, + title: 'Main Menu', + drawerItemStyle: { display: 'none' }, + }} + /> - null, - title: 'Main Menu', - drawerItemStyle: { display: 'none' }, - }} - /> ); diff --git a/apps/text-embeddings/app/clip-embeddings/index.tsx b/apps/text-embeddings/app/clip-embeddings/index.tsx index e9831c6be..c88220eb4 100644 --- a/apps/text-embeddings/app/clip-embeddings/index.tsx +++ b/apps/text-embeddings/app/clip-embeddings/index.tsx @@ -1,4 +1,4 @@ -import { useEffect, useState } from 'react'; +import { useState } from 'react'; import { StyleSheet, Text, @@ -7,6 +7,7 @@ import { View, SafeAreaView, ScrollView, + Image, KeyboardAvoidingView, Platform, } from 'react-native'; @@ -15,11 +16,29 @@ import { useTextEmbeddings, useImageEmbeddings, CLIP_VIT_BASE_PATCH32_TEXT, + CLIP_VIT_BASE_PATCH32_IMAGE, CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED, + ImageEmbeddingsProps, } from 'react-native-executorch'; + +type ImageEmbeddingModel = ImageEmbeddingsProps['model']; + +const IMAGE_MODELS: { label: string; value: ImageEmbeddingModel }[] = [ + { label: 'ViT-B/32 Quantized', value: CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED }, + { label: 'ViT-B/32 FP32', value: CLIP_VIT_BASE_PATCH32_IMAGE }, +]; import { launchImageLibrary } from 'react-native-image-picker'; import { useIsFocused } from '@react-navigation/native'; import { dotProduct } from '../../utils/math'; +import { ModelPicker } from '../../components/ModelPicker'; + +const DEFAULT_LABELS = [ + 'a photo of a dog', + 'a photo of a cat', + 'a landscape photo', + 'a photo of food', + 'a photo of people', +]; export default function ClipEmbeddingsScreenWrapper() { const isFocused = useIsFocused(); @@ -28,283 +47,228 @@ export default function ClipEmbeddingsScreenWrapper() { } function ClipEmbeddingsScreen() { + const [selectedImageModel, setSelectedImageModel] = + useState(CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED); + const textModel = useTextEmbeddings({ model: CLIP_VIT_BASE_PATCH32_TEXT }); - const imageModel = useImageEmbeddings({ - model: CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED, - }); + const imageModel = useImageEmbeddings({ model: selectedImageModel }); - const [inputSentence, setInputSentence] = useState(''); - const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState< - { sentence: string; embedding: Float32Array }[] + const [imageUri, setImageUri] = useState(null); + const [newLabel, setNewLabel] = useState(''); + const [labels, setLabels] = useState(DEFAULT_LABELS); + const [results, setResults] = useState< + { label: string; similarity: number }[] >([]); - const [topMatches, setTopMatches] = useState< - { sentence: string; similarity: number }[] - >([]); - - const [textEmbeddingTime, setTextEmbeddingTime] = useState( - null - ); const [imageEmbeddingTime, setImageEmbeddingTime] = useState( null ); - - useEffect( - () => { - const computeEmbeddings = async () => { - if (!textModel.isReady) return; - - const sentences = [ - 'The weather is lovely today.', - 'Night party pictures', - 'Cute animals.', - 'Bike club photos', - ]; - - try { - const start = Date.now(); - const embeddings = []; - - for (const sentence of sentences) { - const embedding = await textModel.forward(sentence); - embeddings.push({ sentence, embedding }); - } - - setTextEmbeddingTime(Date.now() - start); - setSentencesWithEmbeddings(embeddings); - } catch (error) { - console.error('Error generating embeddings:', error); - } - }; - - computeEmbeddings(); - }, - // eslint-disable-next-line react-hooks/exhaustive-deps - [textModel.isReady] + const [textEmbeddingTime, setTextEmbeddingTime] = useState( + null ); - const checkSimilarities = async () => { - if (!textModel.isReady || !inputSentence.trim()) return; - - try { - const start = Date.now(); - const inputEmbedding = await textModel.forward(inputSentence); - setTextEmbeddingTime(Date.now() - start); + const getModelStatusText = (model: typeof textModel | typeof imageModel) => { + if (model.error) return `Oops! ${model.error}`; + if (!model.isReady) + return `Loading ${(model.downloadProgress * 100).toFixed(0)}%`; + return model.isGenerating ? 'Generating…' : 'Ready'; + }; - const matches = sentencesWithEmbeddings.map( - ({ sentence, embedding }) => ({ - sentence, - similarity: dotProduct(inputEmbedding, embedding), - }) - ); - matches.sort((a, b) => b.similarity - a.similarity); - setTopMatches(matches.slice(0, 3)); - } catch (error) { - console.error('Error generating embedding:', error); - } + const pickImage = async () => { + const output = await launchImageLibrary({ mediaType: 'photo' }); + if (!output.assets?.[0]?.uri) return; + setImageUri(output.assets[0].uri); + setResults([]); }; - const addToSentences = async () => { - if (!textModel.isReady || !inputSentence.trim()) return; + const classify = async () => { + if (!imageUri || !imageModel.isReady || !textModel.isReady) return; try { - const start = Date.now(); - const embedding = await textModel.forward(inputSentence); - setTextEmbeddingTime(Date.now() - start); + const imgStart = Date.now(); + const imageEmbedding = await imageModel.forward(imageUri); + setImageEmbeddingTime(Date.now() - imgStart); - setSentencesWithEmbeddings((prev) => [ - ...prev, - { sentence: inputSentence, embedding }, - ]); - } catch (error) { - console.error('Error generating embedding:', error); - } + const txtStart = Date.now(); + const scored: { label: string; similarity: number }[] = []; + for (const label of labels) { + const textEmbedding = await textModel.forward(label); + scored.push({ + label, + similarity: dotProduct(imageEmbedding, textEmbedding), + }); + } + setTextEmbeddingTime(Math.round((Date.now() - txtStart) / labels.length)); - setInputSentence(''); - setTopMatches([]); - }; - - const clearList = async () => { - if (!textModel.isReady) return; - try { - setSentencesWithEmbeddings([]); - } catch (error) { - console.error('Error clearing the list:', error); + scored.sort((a, b) => b.similarity - a.similarity); + setResults(scored); + } catch (e) { + console.error('Error during classification:', e); } }; - const checkImage = async () => { - if (!imageModel.isReady) return; - - const output = await launchImageLibrary({ mediaType: 'photo' }); - - if (!output.assets || output.assets.length === 0 || !output.assets[0].uri) - return; - - try { - const start = Date.now(); - const inputImageEmbedding = await imageModel.forward( - output.assets[0].uri - ); - setImageEmbeddingTime(Date.now() - start); - - const matches = sentencesWithEmbeddings.map( - ({ sentence, embedding }) => ({ - sentence, - similarity: dotProduct(inputImageEmbedding, embedding), - }) - ); - matches.sort((a, b) => b.similarity - a.similarity); - setTopMatches(matches.slice(0, 3)); - } catch (error) { - console.error('Error generating embedding:', error); - } + const addLabel = () => { + const trimmed = newLabel.trim(); + if (!trimmed || labels.includes(trimmed)) return; + setLabels((prev) => [...prev, trimmed]); + setNewLabel(''); + setResults([]); }; - const getModelStatusText = (model: typeof textModel | typeof imageModel) => { - if (model.error) { - return `Oops! ${model.error}`; - } - if (!model.isReady) { - return `Loading model ${(model.downloadProgress * 100).toFixed(2)}%`; - } - return model.isGenerating ? 'Generating...' : 'Model is ready'; + const removeLabel = (label: string) => { + setLabels((prev) => prev.filter((l) => l !== label)); + setResults((prev) => prev.filter((r) => r.label !== label)); }; + const modelsReady = textModel.isReady && imageModel.isReady; + return ( - - Text Embeddings Playground - - Text Model: {getModelStatusText(textModel)} - - - Image Model: {getModelStatusText(imageModel)} - - - List of Existing Sentences - {sentencesWithEmbeddings.map((item, index) => ( - - - {item.sentence} - - ))} + + CLIP Image Embeddings + + + + Text model: {getModelStatusText(textModel)} + + + Image model: {getModelStatusText(imageModel)} + - - Try Your Sentence - - - - - - Find Similar + + { + setSelectedImageModel(m); + setResults([]); + }} + /> + + {/* Image picker */} + + {imageUri ? ( + + ) : ( + + + + Tap to pick an image - - - - + + )} + + + {/* Classify button */} + + + + {!imageUri ? 'Pick an image first' : 'Find best matching label'} + + + + {/* Results */} + {results.length > 0 && ( + + Results + {results.map((item, index) => ( + - Add to List + {index === 0 ? '🥇 ' : ''} + {item.label} - - - - Compare sentences to image + + {item.similarity.toFixed(3)} - - + + ))} + {(imageEmbeddingTime !== null || textEmbeddingTime !== null) && ( + + {imageEmbeddingTime !== null && ( + + Image embedding: {imageEmbeddingTime} ms + + )} + {textEmbeddingTime !== null && ( + + Text embeddings: {textEmbeddingTime} ms + + )} + + )} + + )} + + {/* Labels */} + + Text Labels + {labels.map((label) => ( + + {label} + removeLabel(label)}> - - Clear List - + ))} + + + + + - - {textEmbeddingTime !== null && ( - - Text Embedding time: {textEmbeddingTime} ms - - )} - {imageEmbeddingTime !== null && ( - - Image Embedding time: {imageEmbeddingTime} ms - - )} - - {topMatches.length > 0 && ( - - Top Matches - {topMatches.map((item, index) => ( - - {item.sentence} ({item.similarity.toFixed(2)}) - - ))} - - )} @@ -314,13 +278,38 @@ function ClipEmbeddingsScreen() { const styles = StyleSheet.create({ container: { flex: 1, backgroundColor: '#F8FAFC' }, + flex: { flex: 1 }, scrollContainer: { padding: 20, alignItems: 'center', flexGrow: 1 }, heading: { fontSize: 24, fontWeight: '500', - marginBottom: 20, + marginBottom: 12, color: '#0F172A', }, + statusRow: { + width: '100%', + marginBottom: 16, + gap: 2, + }, + statusText: { fontSize: 13, color: '#64748B' }, + imagePicker: { + width: '100%', + height: 220, + borderRadius: 16, + overflow: 'hidden', + marginBottom: 20, + borderWidth: 2, + borderColor: '#E2E8F0', + backgroundColor: '#fff', + }, + image: { width: '100%', height: '100%' }, + imagePlaceholder: { + flex: 1, + alignItems: 'center', + justifyContent: 'center', + gap: 8, + }, + imagePlaceholderText: { fontSize: 14, color: '#94A3B8' }, card: { backgroundColor: '#FFFFFF', width: '100%', @@ -331,58 +320,67 @@ const styles = StyleSheet.create({ marginBottom: 20, }, sectionTitle: { - fontSize: 18, - fontWeight: '500', + fontSize: 16, + fontWeight: '600', marginBottom: 12, color: '#1E293B', }, - sentenceText: { fontSize: 14, marginBottom: 6, color: '#334155' }, + labelRow: { + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'space-between', + paddingVertical: 6, + borderBottomWidth: 1, + borderBottomColor: '#F1F5F9', + }, + labelText: { fontSize: 14, color: '#334155', flex: 1 }, + addLabelRow: { + flexDirection: 'row', + alignItems: 'center', + gap: 8, + marginTop: 12, + }, input: { + flex: 1, backgroundColor: '#F1F5F9', borderRadius: 10, padding: 10, - marginBottom: 10, - fontSize: 16, + fontSize: 14, color: '#0F172A', - minHeight: 40, - textAlignVertical: 'top', - }, - buttonContainer: { width: '100%', gap: 10 }, - buttonGroup: { - flexDirection: 'row', - justifyContent: 'space-between', - gap: 10, }, - buttonPrimary: { - flex: 1, + addButton: { backgroundColor: 'navy', - padding: 12, borderRadius: 10, - flexDirection: 'row', + width: 40, + height: 40, alignItems: 'center', justifyContent: 'center', }, - buttonSecondary: { - flex: 1, - backgroundColor: 'transparent', - borderWidth: 2, - borderColor: 'navy', - padding: 12, - borderRadius: 10, + classifyButton: { + width: '100%', + backgroundColor: 'navy', + padding: 14, + borderRadius: 12, flexDirection: 'row', alignItems: 'center', justifyContent: 'center', + gap: 8, + marginBottom: 20, }, + classifyButtonText: { color: 'white', fontWeight: '600', fontSize: 15 }, buttonDisabled: { backgroundColor: '#f0f0f0', borderColor: '#d3d3d3' }, - buttonText: { color: 'white', textAlign: 'center', fontWeight: '500' }, - buttonTextOutline: { color: 'navy', textAlign: 'center', fontWeight: '500' }, buttonTextDisabled: { color: 'gray' }, - topMatchesContainer: { marginTop: 20 }, - statsText: { - fontSize: 13, - color: '#64748B', - marginTop: 8, - textAlign: 'center', + resultRow: { + flexDirection: 'row', + justifyContent: 'space-between', + alignItems: 'center', + paddingVertical: 6, + borderBottomWidth: 1, + borderBottomColor: '#F1F5F9', }, - flexContainer: { flex: 1 }, + resultLabel: { fontSize: 14, color: '#334155', flex: 1 }, + topResultLabel: { fontWeight: '700', color: '#0F172A' }, + resultScore: { fontSize: 13, color: '#64748B', marginLeft: 8 }, + statsContainer: { marginTop: 12, gap: 2 }, + statsText: { fontSize: 12, color: '#94A3B8' }, }); diff --git a/apps/text-embeddings/app/text-embeddings/index.tsx b/apps/text-embeddings/app/text-embeddings/index.tsx index de0634964..e31097940 100644 --- a/apps/text-embeddings/app/text-embeddings/index.tsx +++ b/apps/text-embeddings/app/text-embeddings/index.tsx @@ -11,7 +11,24 @@ import { Platform, } from 'react-native'; import { Ionicons } from '@expo/vector-icons'; -import { useTextEmbeddings, ALL_MINILM_L6_V2 } from 'react-native-executorch'; +import { ModelPicker } from '../../components/ModelPicker'; +import { + useTextEmbeddings, + ALL_MINILM_L6_V2, + ALL_MPNET_BASE_V2, + MULTI_QA_MINILM_L6_COS_V1, + MULTI_QA_MPNET_BASE_DOT_V1, + TextEmbeddingsProps, +} from 'react-native-executorch'; + +type TextEmbeddingModel = TextEmbeddingsProps['model']; + +const MODELS: { label: string; value: TextEmbeddingModel }[] = [ + { label: 'MiniLM L6', value: ALL_MINILM_L6_V2 }, + { label: 'MPNet Base', value: ALL_MPNET_BASE_V2 }, + { label: 'MultiQA MiniLM', value: MULTI_QA_MINILM_L6_COS_V1 }, + { label: 'MultiQA MPNet', value: MULTI_QA_MPNET_BASE_DOT_V1 }, +]; import { useIsFocused } from '@react-navigation/native'; import { dotProduct } from '../../utils/math'; import ErrorBanner from '../../components/ErrorBanner'; @@ -23,7 +40,9 @@ export default function TextEmbeddingsScreenWrapper() { } function TextEmbeddingsScreen() { - const model = useTextEmbeddings({ model: ALL_MINILM_L6_V2 }); + const [selectedModel, setSelectedModel] = + useState(ALL_MINILM_L6_V2); + const model = useTextEmbeddings({ model: selectedModel }); const [error, setError] = useState(null); const [inputSentence, setInputSentence] = useState(''); @@ -132,6 +151,15 @@ function TextEmbeddingsScreen() { Text Embeddings Playground {getModelStatusText()} + { + setSelectedModel(m); + setSentencesWithEmbeddings([]); + setTopMatches([]); + }} + /> setError(null)} /> diff --git a/apps/text-embeddings/components/ModelPicker.tsx b/apps/text-embeddings/components/ModelPicker.tsx new file mode 100644 index 000000000..94a848596 --- /dev/null +++ b/apps/text-embeddings/components/ModelPicker.tsx @@ -0,0 +1,216 @@ +import React, { useEffect, useRef, useState } from 'react'; +import { + Dimensions, + Modal, + ScrollView, + StyleSheet, + Text, + TouchableOpacity, + View, +} from 'react-native'; + +export type ModelOption = { + label: string; + value: T; +}; + +type Props = { + models: ModelOption[]; + selectedModel: T; + onSelect: (model: T) => void; + label?: string; + disabled?: boolean; +}; + +const DROPDOWN_MAX_HEIGHT = 200; + +export function ModelPicker({ + models, + selectedModel, + onSelect, + label, + disabled, +}: Props) { + const [open, setOpen] = useState(false); + const [triggerHeight, setTriggerHeight] = useState(0); + const [expandUp, setExpandUp] = useState(false); + const [dropdownTop, setDropdownTop] = useState(0); + const triggerRef = useRef>(null); + const selected = models.find((m) => m.value === selectedModel); + + useEffect(() => { + if (disabled) setOpen(false); + }, [disabled]); + + const handlePress = () => { + if (disabled) return; + if (open) { + setOpen(false); + return; + } + triggerRef.current?.measure( + ( + _x: number, + _y: number, + _width: number, + height: number, + _pageX: number, + pageY: number + ) => { + setTriggerHeight(height); + const spaceBelow = Dimensions.get('window').height - (pageY + height); + setExpandUp(spaceBelow < DROPDOWN_MAX_HEIGHT); + setDropdownTop(pageY); + setOpen(true); + } + ); + }; + + const dropdownStylePosition = expandUp + ? { + bottom: Dimensions.get('window').height - dropdownTop, + left: 12, + right: 12, + } + : { + top: dropdownTop + triggerHeight + 2, + left: 12, + right: 12, + }; + + return ( + <> + + + {label && {label}} + {selected?.label ?? '—'} + {open ? '▲' : '▼'} + + + + {open && ( + setOpen(false)} + animationType="none" + > + setOpen(false)} + /> + + + {models.map((item) => { + const isSelected = item.value === selectedModel; + return ( + { + onSelect(item.value); + setOpen(false); + }} + activeOpacity={0.7} + > + + {item.label} + + + ); + })} + + + + )} + + ); +} + +const styles = StyleSheet.create({ + container: { + marginHorizontal: 12, + marginVertical: 4, + alignSelf: 'stretch', + zIndex: 100, + }, + trigger: { + flexDirection: 'row', + alignItems: 'center', + borderWidth: 1, + borderColor: '#C1C6E5', + borderRadius: 8, + paddingHorizontal: 12, + paddingVertical: 10, + backgroundColor: '#f5f5f5', + }, + triggerDisabled: { + opacity: 0.4, + }, + label: { + fontSize: 12, + color: '#888', + marginRight: 6, + }, + triggerText: { + flex: 1, + fontSize: 14, + color: '#001A72', + fontWeight: '500', + }, + chevron: { + fontSize: 10, + color: '#888', + marginLeft: 6, + }, + modalBackdrop: { + flex: 1, + backgroundColor: 'rgba(0, 0, 0, 0.3)', + }, + dropdown: { + position: 'absolute', + borderWidth: 1, + borderColor: '#C1C6E5', + borderRadius: 8, + backgroundColor: '#fff', + maxHeight: DROPDOWN_MAX_HEIGHT, + zIndex: 1000, + elevation: 5, + shadowColor: '#000', + shadowOffset: { width: 0, height: 4 }, + shadowOpacity: 0.15, + shadowRadius: 6, + }, + option: { + paddingHorizontal: 12, + paddingVertical: 10, + borderBottomWidth: 1, + borderBottomColor: '#f0f0f0', + }, + optionSelected: { + backgroundColor: '#e8ecf8', + }, + optionText: { + fontSize: 14, + color: '#333', + }, + optionTextSelected: { + color: '#001A72', + fontWeight: '600', + }, +});