Skip to content

Commit ee49d2c

Browse files
committed
Improve logging readability
1 parent 05ce507 commit ee49d2c

3 files changed

Lines changed: 35 additions & 23 deletions

File tree

opto/features/async_search/async_priority_search.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ async def _initialize_search_parameters(self, *,
199199
"""
200200
# Validate and adjust num_candidates
201201
if num_candidates < len(self._optimizers):
202-
print(f"Warning: num_candidates {num_candidates} is less than the number of optimizers {len(self._optimizers)}. Setting num_candidates to {len(self._optimizers)}.")
202+
print(f"[AsyncPrioritySearch] Warning: num_candidates {num_candidates} is less than the number of optimizers {len(self._optimizers)}. Setting num_candidates to {len(self._optimizers)}.")
203203
num_candidates = len(self._optimizers)
204204

205205
# Set core parameters
@@ -230,15 +230,15 @@ async def _initialize_search_parameters(self, *,
230230

231231
# Initialize memory structures
232232
if memory_update_frequency is None:
233-
print("AsyncPrioritySearch initialized with only short-term memory.")
233+
print("[AsyncPrioritySearch] AsyncPrioritySearch initialized with only short-term memory.")
234234
assert short_term_memory_size is None or short_term_memory_size > 0, \
235235
"short_term_memory_size must be None or greater than 0 when memory_update_frequency is None."
236236
elif memory_update_frequency == 0:
237-
print("AsyncPrioritySearch initialized with only long-term memory.")
237+
print("[AsyncPrioritySearch] AsyncPrioritySearch initialized with only long-term memory.")
238238
assert long_term_memory_size is None or long_term_memory_size > 0, \
239239
"long_term_memory_size must be None or greater than 0 when memory_update_frequency is 0."
240240
else:
241-
print(f"AsyncPrioritySearch initialized with both short-term and long-term memory. Candidates will be merged into long-term memory every {memory_update_frequency} iterations.")
241+
print(f"[AsyncPrioritySearch] AsyncPrioritySearch initialized with both short-term and long-term memory. Candidates will be merged into long-term memory every {memory_update_frequency} iterations.")
242242

243243
self.long_term_memory = HeapMemory(size=long_term_memory_size, processing_fun=self.compress_candidate_memory)
244244
self.short_term_memory = HeapMemory(size=short_term_memory_size)
@@ -270,7 +270,7 @@ def memory(self):
270270
if self.n_iters % self.memory_update_frequency == 0:
271271
# Merge short-term memory into long-term memory
272272
if len(self.short_term_memory) > 0:
273-
print('Merging short-term memory into long-term memory.')
273+
print('[AsyncPrioritySearch] Merging short-term memory into long-term memory.')
274274
self.long_term_memory.append(self.short_term_memory)
275275
self.short_term_memory.reset()
276276
return self.long_term_memory
@@ -371,7 +371,7 @@ async def propose(self, samples: Samples, exploration_candidates: List[ModuleCan
371371
Returns:
372372
List of proposed ModuleCandidate objects
373373
"""
374-
print("--- Proposing new parameters...") if verbose else None
374+
print("[AsyncPrioritySearch] --- Proposing new parameters...") if verbose else None
375375
assert isinstance(samples, Samples), "samples must be an instance of Samples."
376376
samples_list = samples.samples # list of BatchRollout objects
377377
n_proposals = self.num_proposals
@@ -405,7 +405,7 @@ async def _backward(n):
405405
return optimizer
406406

407407
# Run backward passes concurrently
408-
print(f"Running backward on {n_batches} batches...") if verbose else None
408+
print(f"[AsyncPrioritySearch] Running backward on {n_batches} batches...") if verbose else None
409409
optimizers = await asyncio.gather(*[_backward(n) for n in range(n_batches)])
410410
assert len(optimizers) == n_batches, "Number of optimizers must match number of batch rollouts."
411411

@@ -432,7 +432,7 @@ async def _step(n):
432432
return update_dict
433433

434434
# Run step operations concurrently
435-
print(f"Generating {n_proposals} proposals for each of {n_batches} batches...") if verbose else None
435+
print(f"[AsyncPrioritySearch] Generating {n_proposals} proposals for each of {n_batches} batches...") if verbose else None
436436
update_dicts = await asyncio.gather(*[_step(n) for n in range(n_batches * n_proposals)])
437437

438438
# Clear optimizer state TODO is this for preventing memory leak?
@@ -467,7 +467,7 @@ async def validate(self,
467467
Returns:
468468
Dictionary mapping ModuleCandidate to list of rollouts
469469
"""
470-
print("--- Validating candidates...") if verbose else None
470+
print("[AsyncPrioritySearch] --- Validating candidates...") if verbose else None
471471
assert isinstance(samples, Samples), "samples must be an instance of Samples."
472472
assert exploration_candidates is not None, "exploration_candidate must be set."
473473

@@ -537,7 +537,7 @@ async def update_memory(self, validate_results: Dict[ModuleCandidate, List[Dict[
537537
verbose: Whether to print verbose output
538538
**kwargs: Additional keyword arguments
539539
"""
540-
print("--- Updating memory with validation results...") if verbose else None
540+
print("[AsyncPrioritySearch] --- Updating memory with validation results...") if verbose else None
541541
for candidate, rollouts in validate_results.items():
542542
candidate.add_rollouts(rollouts)
543543
priority = self.compute_exploration_priority(candidate)
@@ -560,7 +560,7 @@ async def explore(self,
560560
priorities: List of priorities for candidates
561561
info_dict: Dictionary with logging information
562562
"""
563-
print(f"--- Generating {min(len(self.memory), num_candidates)} exploration candidates...") if verbose else None
563+
print(f"[AsyncPrioritySearch] --- Generating {min(len(self.memory), num_candidates)} exploration candidates...") if verbose else None
564564

565565
top_candidates = [self._best_candidate] if use_best_candidate_to_explore else []
566566
priorities = [self._best_candidate_priority] if use_best_candidate_to_explore else []
@@ -599,7 +599,7 @@ async def exploit(self, verbose: bool = False, **kwargs) -> Tuple[ModuleCandidat
599599
priority: Priority of best candidate
600600
info_dict: Dictionary with logging information
601601
"""
602-
print("--- Exploiting the best candidate...") if verbose else None
602+
print("[AsyncPrioritySearch] --- Exploiting the best candidate...") if verbose else None
603603
if not self.memory:
604604
raise ValueError("The priority queue is empty. Cannot exploit.")
605605

opto/features/async_search/async_search.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ async def create_task(self) -> Optional[Coroutine[Any, Any, Any]]:
245245
return self.eval_task(task_id=task_id)
246246

247247
# Otherwise, run worker task
248-
print(f"Epoch: {self.n_epochs}. Iteration: {self.n_iters}")
248+
print(f"[AsyncSearch] Epoch: {self.n_epochs}. Iteration: {self.n_iters}")
249249
task_id = WORKER_TASK + ':' + str(self.n_iters) + ':' + uuid.uuid4().hex
250250
proposal, task_state = await self.get_task_state()
251251
self._running_worker_tasks[task_id] = (proposal, task_state)
@@ -283,9 +283,11 @@ async def process_result(self, result: Tuple[str, Any]) -> None:
283283

284284
task_id, result = result
285285

286+
print(f"[AsyncSearch] Processing result for task: {task_id}")
287+
286288
if task_id.startswith(EVAL_TASK):
287289
info_test = result
288-
self.log(info_test, prefix="Test/")
290+
self.log(info_test, prefix="Test/", color='green')
289291
return
290292

291293
# Else, worker task is done
@@ -356,7 +358,7 @@ async def sample(self, agents, verbose=False, **kwargs):
356358
}
357359
# check if the scores are within the score range
358360
if hasattr(self, '_score_range') and not (self.min_score <= log_info['mean_score'] <= self.max_score):
359-
print(f"Warning: Mean score {log_info['mean_score']} is out of the range {self._score_range}.")
361+
print(f"[AsyncSearch] Warning: Mean score {log_info['mean_score']} is out of the range {self._score_range}.")
360362
return samples, log_info
361363

362364

@@ -370,21 +372,21 @@ async def evaluate_agent(self, test_dataset, guide):
370372
test_score = safe_mean(test_scores)
371373
# check if the test_score is within the score range
372374
if hasattr(self, '_score_range') and not (self.min_score <= test_score <= self.max_score):
373-
print(f"Warning: Test score {test_score} is out of the range {self._score_range}.")
375+
print(f"[AsyncSearch] Warning: Test score {test_score} is out of the range {self._score_range}.")
374376
return {'test_score': test_score}
375377

376378

377-
def log(self, info_log: Dict[str, Any], prefix=""):
379+
def log(self, info_log: Dict[str, Any], prefix="", color=None) -> None:
378380
"""Log information from the algorithm."""
379381
for key, value in info_log.items():
380382
if value is not None and self.logger:
381383
try:
382-
self.logger.log(f"{prefix}{key}", value, self.n_iters)
384+
self.logger.log(f"{prefix}{key}", value, self.n_iters, color=color)
383385
except Exception as e:
384-
print(f"Logging failed for key {key}: {e}")
386+
print(f"[AsyncSearch] Logging failed for key {key}: {e}")
385387

386388
def save(self, save_path):
387-
print(f"Saving algorithm state to {save_path} at iteration {self.n_iters}.")
389+
print(f"[AsyncSearch] Saving algorithm state to {save_path} at iteration {self.n_iters}.")
388390
if not os.path.exists(save_path):
389391
os.makedirs(save_path)
390392
obj = copy.deepcopy(self) # to detach nodes from the computation graph
@@ -430,7 +432,7 @@ def resume(self, *,
430432
last_train_kwargs['train_dataset'] = train_dataset
431433
last_train_kwargs['validate_dataset'] = validate_dataset
432434
last_train_kwargs['test_dataset'] = test_dataset
433-
print(f"Resuming training with parameters: {last_train_kwargs}")
435+
print(f"[AsyncSearch] Resuming training with parameters: {last_train_kwargs}")
434436
return self.train(*last_train_args, **last_train_kwargs)
435437

436438

opto/features/async_search/controller.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import time
23
from typing import Any, Coroutine, Optional
34

45

@@ -30,7 +31,8 @@ async def run(self, *args: Any, **kwargs: Any) -> None:
3031
done, pending = await asyncio.wait(worker_tasks, return_when=asyncio.FIRST_COMPLETED)
3132
for task in done:
3233
worker_tasks.remove(task)
33-
result = task.result()
34+
result, elapsed_time = task.result()
35+
print(f"[Controller] Task completed in {elapsed_time:.4f} seconds")
3436
await self.process_result(result)
3537
if self.should_stop():
3638
self.post_process()
@@ -47,7 +49,15 @@ async def _create_new_task(self) -> Optional[asyncio.Task[Any]]:
4749
new_task = await self.create_task()
4850
if new_task is None:
4951
return None
50-
return asyncio.create_task(new_task)
52+
53+
# Wrap the task with timing
54+
async def timed_task() -> tuple[Any, float]:
55+
start_time = time.time()
56+
result = await new_task
57+
elapsed_time = time.time() - start_time
58+
return result, elapsed_time
59+
60+
return asyncio.create_task(timed_task())
5161

5262
#### Methods to be implemented by subclasses ####
5363

0 commit comments

Comments
 (0)