Skip to content

Commit c3a826c

Browse files
authored
Merge pull request #134 from Runware/fix-exits
fixed exits and small refactor of server
2 parents 00dea27 + aa29d39 commit c3a826c

2 files changed

Lines changed: 60 additions & 84 deletions

File tree

runware/base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,8 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool:
235235

236236
except Exception as e:
237237
if retry_count >= 2:
238-
self.logger.error(f"Error in photoMaker request: {e}")
239-
exit()
240-
return self.handle_incomplete_images(task_uuids=task_uuids, error=e)
238+
self.logger.error(f"Error in photoMaker request:", exc_info=e)
239+
raise RunwareAPIError({"message": f"PhotoMaker failed after retries: {str(e)}"})
241240
else:
242241
raise e
243242

@@ -464,9 +463,8 @@ async def imageInference(
464463
)
465464
except Exception as e:
466465
if retry_count >= 2:
467-
self.logger.error(f"Error in requestImages: {e}")
468-
exit()
469-
return self.handle_incomplete_images(task_uuids=task_uuids, error=e)
466+
self.logger.error(f"Error in requestImages:", exc_info=e)
467+
raise RunwareAPIError({"message": f"Image inference failed after retries: {str(e)}"})
470468
else:
471469
raise e
472470

runware/server.py

Lines changed: 56 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
import logging
44
import websockets
55
from websockets.protocol import State
6-
import inspect
7-
import pprint
8-
from typing import Any, Callable, Dict, List, Union, Optional, TypeVar
6+
from typing import Any, Dict, Optional
97

108

11-
from .types import RunwareBaseType, SdkType
9+
from .types import SdkType
1210
from .utils import (
13-
delay,
14-
getUUID,
15-
removeListener,
1611
BASE_RUNWARE_URLS,
1712
PING_INTERVAL,
1813
PING_TIMEOUT_DURATION,
@@ -21,12 +16,7 @@
2116
from .base import RunwareBase
2217
from .types import (
2318
Environment,
24-
EPreProcessor,
25-
EPreProcessorGroup,
2619
ListenerType,
27-
IControlNet,
28-
File,
29-
GetWithPromiseCallBackType,
3020
)
3121

3222
from .logging_config import configure_logging
@@ -50,6 +40,7 @@ def __init__(
5040
self._apiKey: str = api_key
5141
self._message_handler_task: Optional[asyncio.Task] = None
5242
self._last_pong_time: float = 0.0
43+
self._is_shutting_down: bool = False
5344

5445
# Configure logging
5546
configure_logging(log_level)
@@ -160,29 +151,19 @@ def pong_lis(m):
160151
async def on_message(self, ws, message):
161152
if not message:
162153
return
163-
m = json.loads(message)
164-
# print(
165-
# f"\n\n\n================================================ Received message ============================================================"
166-
# )
167-
# print(f"{m}")
168-
169-
# print(f"Listenerse:")
170-
# for lis in self._listeners:
171-
# print(lis, "\n")
172-
# print(
173-
# f"============================================= End received message ============================================================\n\n\n"
174-
# )
154+
155+
try:
156+
m = json.loads(message)
157+
except json.JSONDecodeError as e:
158+
self.logger.error(f"Failed to parse JSON message:", exc_info=e)
159+
return
160+
175161
for lis in self._listeners:
176162
try:
177-
# result = True
178163
result = lis.listener(m)
179164
except Exception as e:
180-
print(f"Unexpected error in on_message: {e}")
181-
print(dir(lis))
182-
print(f"Listeners: {self._listeners}")
183-
for lis in self._listeners:
184-
print(dir(lis), "\n")
185-
return
165+
self.logger.error(f"Error in listener {lis.key}:", exc_info=e)
166+
continue
186167
if result:
187168
return
188169

@@ -192,31 +173,25 @@ async def _handle_messages(self):
192173
f"Starting message handler task {self._message_handler_task}"
193174
)
194175
async for message in self._ws:
176+
if self._is_shutting_down:
177+
break
195178
try:
196179
await self.on_message(self._ws, message)
197180
except Exception as e:
198-
print(f"Unexpected error in async loop: {e}")
199-
print(self.on_message)
200-
exit()
181+
self.logger.error(f"Error in on_message:", exc_info=e)
182+
continue
201183
except websockets.exceptions.ConnectionClosedError as e:
202-
self.logger.error(f"Connection Closed Error: {e}")
203-
await self.handleClose()
184+
if not self._is_shutting_down:
185+
self.logger.error(f"Connection Closed Error:", exc_info=e)
186+
await self.handleClose()
204187
except Exception as e:
205-
print(f"Unexpected error in _handle_messages: {e}")
206-
print(self.on_message)
207-
exit()
208-
await self._ws.close()
188+
self.logger.error(f"Critical error in _handle_messages:", exc_info=e)
189+
if not self._is_shutting_down:
190+
await self.handleClose()
209191

210192
async def send(self, msg: Dict[str, Any]):
211193
self.logger.debug(f"Sending message: {msg}")
212-
# print(
213-
# f"\n\n\n================================================= Sending message ================================================================="
214-
# )
215-
# print(f"{msg}")
216-
# print(
217-
# f"=============================================== End sending message ===============================================================\n\n\n"
218-
# )
219-
if self._ws and self._ws.state is State.OPEN:
194+
if self._ws and self._ws.state is State.OPEN and not self._is_shutting_down:
220195
await self._ws.send(json.dumps(msg))
221196

222197
def _get_task_by_name(self, name):
@@ -240,7 +215,7 @@ async def handleClose(self):
240215
try:
241216
reconnecting_task.cancel()
242217
except Exception as e:
243-
self.logger.error(f"Error while cancelling Task_Reconnecting: {e}")
218+
self.logger.error(f"Error while cancelling Task_Reconnecting:", exc_info=e)
244219

245220
message_handler_task = self._get_task_by_name("Task_Message_Handler")
246221
if message_handler_task is not None:
@@ -252,7 +227,7 @@ async def handleClose(self):
252227
message_handler_task.cancel()
253228
except Exception as e:
254229
self.logger.error(
255-
f"Error while cancelling Task_Message_Handler: {e}"
230+
f"Error while cancelling Task_Message_Handler:", exc_info=e
256231
)
257232

258233
heartbeat_task = self._get_task_by_name("Task_Heartbeat")
@@ -262,12 +237,15 @@ async def handleClose(self):
262237
try:
263238
heartbeat_task.cancel()
264239
except Exception as e:
265-
self.logger.error(f"Error while cancelling Task_Heartbeat: {e}")
240+
self.logger.error(f"Error while cancelling Task_Heartbeat:", exc_info=e)
266241

267242
async def reconnect():
268-
while True:
269-
self.logger.info("Reconnecting...")
270-
await asyncio.sleep(1)
243+
reconnect_attempts = 0
244+
max_reconnect_attempts = 5
245+
246+
while reconnect_attempts < max_reconnect_attempts and not self._is_shutting_down:
247+
self.logger.info(f"Reconnecting... (attempt {reconnect_attempts + 1})")
248+
await asyncio.sleep(min(reconnect_attempts * 2 + 1, 10))
271249
try:
272250
await self.connect()
273251
if self.isWebsocketReadyState():
@@ -278,43 +256,43 @@ async def reconnect():
278256
"WebSocket connection is not in a ready state after reconnecting"
279257
)
280258
except Exception as e:
281-
self.logger.error(f"Error while reconnecting: {e}")
259+
self.logger.error(f"Error while reconnecting:", exc_info=e)
260+
261+
reconnect_attempts += 1
262+
263+
if reconnect_attempts >= max_reconnect_attempts:
264+
self.logger.error("Max reconnection attempts reached. Giving up.")
265+
self._is_shutting_down = True
282266

283-
# TODO: I don't need to close self._ws here, as it will be cleaned by sockets library based on it's interrnal ping mechanism
284267
# Attempting to reconnect...
285-
self._reconnecting_task = asyncio.create_task(
286-
reconnect(), name="Task_Reconnecting"
287-
)
268+
if not self._is_shutting_down:
269+
self._reconnecting_task = asyncio.create_task(
270+
reconnect(), name="Task_Reconnecting"
271+
)
288272

289273
async def heartBeat(self):
290-
# TODO: Not sure if we need this, as the websocket server responds to default PING messages
291-
# 2024-04-29 10:46:23,193 - websockets.client - DEBUG - % sending keepalive ping
292-
# 2024-04-29 10:46:23,194 - websockets.client - DEBUG - > PING f2 0b eb 3d [binary, 4 bytes]
293-
# 2024-04-29 10:46:23,197 - runware.server - DEBUG - Sending ping
294-
# 2024-04-29 10:46:23,197 - runware.server - DEBUG - Sending message: {'ping': True}
295-
# 2024-04-29 10:46:23,197 - websockets.client - DEBUG - > TEXT '{"ping": true}' [14 bytes]
296-
# 2024-04-29 10:46:23,241 - websockets.client - DEBUG - < PONG f2 0b eb 3d [binary, 4 bytes]
297-
# 2024-04-29 10:46:23,241 - websockets.client - DEBUG - % received keepalive pong
298-
# 2024-04-29 10:46:23,244 - websockets.client - DEBUG - < TEXT '{"pong":true}' [13 bytes]
299-
while True:
274+
while not self._is_shutting_down:
300275
if self.isWebsocketReadyState():
301276
self.logger.debug("Sending ping")
302277
try:
303278
await self.send([{"taskType": "ping", "ping": True}])
304279
except websockets.exceptions.ConnectionClosedError as e:
305280
self.logger.error(
306-
f"Error sending ping: {e}. Connection likely closed."
281+
f"Error sending ping. Connection likely closed.", exc_info=e
307282
)
308-
# Potentially handle reconnection here
309-
except Exception as e: # Catch other potential exceptions
310-
self.logger.error(f"Unexpected error sending ping: {e}")
311-
# Handle unexpected errors appropriately
283+
break
284+
except Exception as e:
285+
self.logger.error(f"Unexpected error sending ping", exc_info=e)
286+
break
287+
312288
await asyncio.sleep(PING_INTERVAL / 1000)
313289

314290
if (
315-
asyncio.get_event_loop().time() - self._last_pong_time
316-
> PING_TIMEOUT_DURATION / 1000
317-
): # No pong received within the timeout period
291+
asyncio.get_event_loop().time() - self._last_pong_time
292+
> PING_TIMEOUT_DURATION / 1000
293+
):
318294
self.logger.warning("No pong received. Connection may be lost.")
319-
# Initiate a reconnection
320295
await self.handleClose()
296+
break
297+
else:
298+
break

0 commit comments

Comments
 (0)