diff --git a/runware/base.py b/runware/base.py index 481b273..bd5d0ad 100644 --- a/runware/base.py +++ b/runware/base.py @@ -17,6 +17,8 @@ from .reconnection import ConnectionState, ReconnectionManager from .types import ( Environment, + IInputReference, + IInputs, IImageInference, IPhotoMaker, IImageCaption, @@ -49,6 +51,7 @@ IAudio, IAudioInference, IFrameImage, + IVideoInputs, IAsyncTaskResponse, IVectorize, OperationState, @@ -110,6 +113,21 @@ class RunwareBase: + async def _process_media_list( + self, + items: List[Any], + object_attr: Optional[str] = None, + ) -> List[Any]: + + processed: List[Any] = [] + for item in items: + if object_attr and hasattr(item, object_attr): + setattr(item, object_attr, await process_image(getattr(item, object_attr))) + processed.append(item) + else: + processed.append(await process_image(item)) + return processed + def __init__( self, api_key: str, @@ -704,6 +722,16 @@ async def _imageInference( if requestImage.referenceImages: requestImage.referenceImages = await process_image(requestImage.referenceImages) + if requestImage.inputs: + if isinstance(requestImage.inputs, dict): + requestImage.inputs = IInputs(**requestImage.inputs) + + if requestImage.inputs.referenceImages: + requestImage.inputs.referenceImages = await self._process_media_list( + requestImage.inputs.referenceImages, + object_attr="image", + ) + if requestImage.controlNet: for control_data in requestImage.controlNet: image_uploaded = await self.uploadImage(control_data.guideImage) @@ -2050,7 +2078,41 @@ async def _getResponse( ) async def _requestVideo(self, requestVideo: "IVideoInference") -> "Union[List[IVideo], IAsyncTaskResponse]": - await self._processVideoImages(requestVideo) + if requestVideo.frameImages: + requestVideo.frameImages = await self._process_media_list( + requestVideo.frameImages, + object_attr="inputImage", + ) + + if requestVideo.referenceImages: + requestVideo.referenceImages = await self._process_media_list( + requestVideo.referenceImages, + ) + + if requestVideo.inputs: + inputs = requestVideo.inputs + if isinstance(inputs, dict): + inputs = IVideoInputs(**inputs) + requestVideo.inputs = inputs + + if inputs.image: + inputs.image = await process_image(inputs.image) + + if inputs.images: + inputs.images = await self._process_media_list(inputs.images) + + if inputs.mask: + inputs.mask = await process_image(inputs.mask) + + if inputs.referenceImages: + inputs.referenceImages = await self._process_media_list(inputs.referenceImages) + + if inputs.frameImages: + inputs.frameImages = await self._process_media_list( + inputs.frameImages, + object_attr="image", + ) + requestVideo.taskUUID = requestVideo.taskUUID or getUUID() request_object = self._buildVideoRequest(requestVideo) @@ -2066,39 +2128,6 @@ async def _requestVideo(self, requestVideo: "IVideoInference") -> "Union[List[IV debug_key="video-inference-initial" ) - async def _processVideoImages(self, requestVideo: IVideoInference) -> None: - frame_tasks = [] - reference_tasks = [] - - if requestVideo.frameImages: - frame_tasks = [ - process_image(frame_item.inputImage) - for frame_item in requestVideo.frameImages - if isinstance(frame_item, IFrameImage) - ] - - if requestVideo.referenceImages: - reference_tasks = [ - process_image(reference_item) - for reference_item in requestVideo.referenceImages - ] - - frame_results = await gather(*frame_tasks) if frame_tasks else [] - reference_results = await gather(*reference_tasks) if reference_tasks else [] - - if requestVideo.frameImages and frame_results: - processed_frame_images = [] - result_index = 0 - for frame_item in requestVideo.frameImages: - if isinstance(frame_item, IFrameImage): - frame_item.inputImages = frame_results[result_index] - result_index += 1 - processed_frame_images.append(frame_item) - requestVideo.frameImages = processed_frame_images - - if requestVideo.referenceImages and reference_results: - requestVideo.referenceImages = reference_results - def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]: request_object = { "deliveryMethod": requestVideo.deliveryMethod, @@ -2345,7 +2374,6 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str self._addOptionalField(request_object, requestImage.safety) self._addOptionalField(request_object, requestImage.settings) - return request_object def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: IImageInference, control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> None: diff --git a/runware/types.py b/runware/types.py index 0d59706..7a4c6cd 100644 --- a/runware/types.py +++ b/runware/types.py @@ -870,6 +870,15 @@ class IInputFrame(SerializableMixin): class IInputReference(SerializableMixin): image: Union[str, File] tag: Optional[str] = None + refType: Optional[str] = None + strength: Optional[float] = None + + def serialize(self) -> Dict[str, Any]: + data = super().serialize() + if self.refType is not None: + data["type"] = self.refType + data.pop("refType", None) + return data @dataclass @@ -879,11 +888,11 @@ class IInputs(SerializableMixin): image: Optional[Union[str, File]] = None mask: Optional[Union[str, File]] = None superResolutionReferences: Optional[List[Union[str, File]]] = None - + @property def request_key(self) -> str: return "inputs" - + def __post_init__(self): if self.references: warnings.warn(