diff --git a/Cargo.lock b/Cargo.lock index 3d4635d..a03ed52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2918,7 +2918,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbiont" -version = "0.4.0" +version = "0.7.0" dependencies = [ "criterion", "libloading", @@ -2933,11 +2933,12 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", + "tracing-test", ] [[package]] name = "symbiont-macros" -version = "0.4.0" +version = "0.7.0" dependencies = [ "prettyplease", "proc-macro2", @@ -3323,6 +3324,27 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tracing-test" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19a4c448db514d4f24c5ddb9f73f2ee71bfb24c526cf0c570ba142d1119e0051" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad06847b7afb65c7866a36664b75c40b895e318cea4f71299f013fb22965329d" +dependencies = [ + "quote", + "syn 2.0.117", +] + [[package]] name = "try-lock" version = "0.2.5" diff --git a/README.md b/README.md index 56e7280..089b659 100644 --- a/README.md +++ b/README.md @@ -226,4 +226,4 @@ Also **checkout** the [TODOs](TODO.md) file for what might come next for `symbio Copyright (C) 2026 MathisWellmann -This project is licensed under the **Mozilla Public License 2.0** — see [LICENSE](LICENSE) for details. +This project is licensed under the **Mozilla Public License 2.0** — see [LICENSE](LICENSE) for details diff --git a/examples/quantize/src/main.rs b/examples/quantize/src/main.rs index 998a610..5d6c403 100644 --- a/examples/quantize/src/main.rs +++ b/examples/quantize/src/main.rs @@ -486,9 +486,7 @@ async fn main() -> symbiont::Result<()> { .await .expect("evolution should succeed"); - prev_code = runtime - .read_clean_code() - .expect("failed to read generated code"); + prev_code = runtime.current_code(); result = evaluate(runtime, &dist_data, distr); println!("{result}"); diff --git a/examples/rastrigin/src/main.rs b/examples/rastrigin/src/main.rs index 873f8b4..1e699eb 100644 --- a/examples/rastrigin/src/main.rs +++ b/examples/rastrigin/src/main.rs @@ -207,9 +207,7 @@ async fn main() -> symbiont::Result<()> { if new_mse < mse_threshold { println!("Exact formula found after {round} evolution round(s)!\n"); - let code = runtime - .read_clean_code() - .expect("failed to read generated code"); + let code = runtime.current_code(); println!("Generated code:\n```rust\n{code}```"); return Ok(()); } diff --git a/examples/sort/src/main.rs b/examples/sort/src/main.rs index b774029..293cb0a 100644 --- a/examples/sort/src/main.rs +++ b/examples/sort/src/main.rs @@ -304,9 +304,7 @@ async fn main() -> symbiont::Result<()> { .await .expect("evolution should succeed"); - prev_code = runtime - .read_clean_code() - .expect("failed to read generated code"); + prev_code = runtime.current_code(); results = run_benchmarks(runtime, &benches); let new_report = format_report(&results); diff --git a/examples/tictactoe/src/main.rs b/examples/tictactoe/src/main.rs index 51500d1..076550d 100644 --- a/examples/tictactoe/src/main.rs +++ b/examples/tictactoe/src/main.rs @@ -385,7 +385,7 @@ async fn main() -> symbiont::Result<()> { // Update best code if this round improved. if score > best_score { best_score = score; - best_code = runtime.read_clean_code().ok(); + best_code = Some(runtime.current_code()); info!("New best score: {:.0}%", best_score * 100.0); } diff --git a/flake.nix b/flake.nix index 7fa5dc4..ef6d217 100644 --- a/flake.nix +++ b/flake.nix @@ -64,8 +64,6 @@ statix # Highlights nix antipatterns ]; in { - nixosModules.zola-serve = import ./symbiont/modules/nixos/zola-serve.nix; - devShells.${system} = { default = pkgs.mkShell { buildInputs = diff --git a/symbiont-macros/Cargo.toml b/symbiont-macros/Cargo.toml index eee5534..52f1f3b 100644 --- a/symbiont-macros/Cargo.toml +++ b/symbiont-macros/Cargo.toml @@ -7,7 +7,7 @@ keywords = ["agent", "dylib", "dynamic", "harness", "llm"] license = "MPL-2.0" name = "symbiont-macros" repository = "https://github.com/MathisWellmann/symbiont" -version = "0.4.0" +version = "0.7.0" [lints] workspace = true diff --git a/symbiont/Cargo.toml b/symbiont/Cargo.toml index 9b2b553..b385f63 100644 --- a/symbiont/Cargo.toml +++ b/symbiont/Cargo.toml @@ -9,13 +9,13 @@ license = "MPL-2.0" name = "symbiont" readme = "README.md" repository = "https://github.com/MathisWellmann/symbiont" -version = "0.4.0" +version = "0.7.0" [lints] workspace = true [dependencies] -symbiont-macros = { version = "0.4.0", path = "../symbiont-macros" } +symbiont-macros = { version = "0.7.0", path = "../symbiont-macros" } rig-core.workspace = true tokio.workspace = true @@ -32,6 +32,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } +tracing-test = "0.2" [[bench]] harness = false diff --git a/symbiont/modules/nixos/zola-serve.nix b/symbiont/modules/nixos/zola-serve.nix deleted file mode 100644 index e5eb83c..0000000 --- a/symbiont/modules/nixos/zola-serve.nix +++ /dev/null @@ -1,52 +0,0 @@ -{ - config, - lib, - pkgs, - ... -}: let - cfg = config.services.zola-serve; -in { - options.services.zola-serve = { - enable = lib.mkEnableOption "Zola static site server"; - - root = lib.mkOption { - type = lib.types.path; - default = "/home/m/MathisWellmann/symbiont/website"; - description = "Path to the Zola website root directory."; - }; - - port = lib.mkOption { - type = lib.types.port; - default = 1111; - description = "Port to serve the site on."; - }; - - hostname = lib.mkOption { - type = lib.types.str; - default = "127.0.0.1"; - description = "Hostname to bind the server to."; - }; - }; - - config = lib.mkIf cfg.enable { - systemd.services.zola-serve = { - description = "Zola static site server"; - wantedBy = ["multi-user.target"]; - wants = ["network-online.target"]; - after = ["network-online.target"]; - - serviceConfig = { - Type = "simple"; - ExecStart = "${pkgs.zola}/bin/zola serve --port ${toString cfg.port} --interface ${cfg.hostname}"; - Restart = "on-failure"; - RestartSec = "5"; - WorkingDirectory = cfg.root; - StandardOutput = "journal"; - StandardError = "journal"; - }; - }; - networking.firewall.allowedUDPPorts = [ - cfg.port - ]; - }; -} diff --git a/symbiont/src/compiler.rs b/symbiont/src/compiler.rs index c6b621a..dddc074 100644 --- a/symbiont/src/compiler.rs +++ b/symbiont/src/compiler.rs @@ -4,10 +4,7 @@ use std::path::Path; use minstant::Instant; use prettyplease::unparse; use tokio::process::Command; -use tracing::{ - debug, - info, -}; +use tracing::info; use crate::{ error::{ @@ -36,9 +33,10 @@ pub enum Profile { impl std::fmt::Display for Profile { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use Profile::*; match self { - Profile::Debug => f.write_str("debug"), - Profile::Release => f.write_str("release"), + Debug => f.write_str("debug"), + Release => f.write_str("release"), } } } @@ -52,15 +50,10 @@ pub(crate) async fn compile_dylib( crate_dir: &Path, profile: Profile, clean_ast: &mut syn::File, + clean_ast_str: &str, ) -> Result<()> { let t0 = Instant::now(); - let clean_code = unparse(clean_ast); - debug!("clean_code: {clean_code}"); - let clean_path = crate_dir.join("src").join("clean.rs"); - std::fs::write(&clean_path, &clean_code) - .map_err(|e| Error::WriteLib(format!("Failed to write clean.rs: {e}")))?; - // Wrap function bodies in catch_unwind so panics stay inside the dylib. wrap_bodies_in_catch_unwind(clean_ast); @@ -85,7 +78,7 @@ pub(crate) async fn compile_dylib( .output() .await .map_err(|e| Error::CompilationFailed { - code: clean_code.clone(), + code: clean_ast_str.to_string(), err: format!("Failed to spawn cargo: {e}"), })?; @@ -98,7 +91,7 @@ pub(crate) async fn compile_dylib( } else { let err = String::from_utf8_lossy(&output.stderr).to_string(); Err(Error::CompilationFailed { - code: clean_code.clone(), + code: clean_ast_str.to_string(), err, }) } diff --git a/symbiont/src/error.rs b/symbiont/src/error.rs index ad253e5..6796d88 100644 --- a/symbiont/src/error.rs +++ b/symbiont/src/error.rs @@ -40,6 +40,12 @@ pub enum Error { #[error("Failed to load dylib: {0}")] DylibLoad(String), + + #[error("Evolution failed after {attempts} attempts. Last error: {last_error}")] + MaxRetriesExceeded { + attempts: usize, + last_error: Box, + }, } /// Result type alias for symbiont operations. diff --git a/symbiont/src/parser.rs b/symbiont/src/parser.rs index 8985354..d5f172f 100644 --- a/symbiont/src/parser.rs +++ b/symbiont/src/parser.rs @@ -128,6 +128,27 @@ fn no_lang_marker(x: i32) -> i32 { x } assert_eq!(code.trim(), "fn no_lang_marker(x: i32) -> i32 { x }"); } + #[test] + fn test_extract_rust_code_with_prefix_and_extra_whitespace() { + // Prefix text ensures `start > 0` and extra whitespace after the fence + // ensures the whitespace count is `> 1`, so that `code_start + count` + // differs from `code_start * count` in a way `trim()` cannot recover. + let prefix = "Here is the code you requested:\n"; + let input = format!("{prefix}```rust\n\n fn foo() -> i32 {{ 42 }}\n```"); + let code = extract_rust_code(&input).expect("can extract"); + assert_eq!(code, "fn foo() -> i32 { 42 }"); + } + + #[test] + fn test_extract_rust_code_generic_fence_with_prefix() { + // Prefix ensures `start > 0` for the generic-fence branch so that + // mutations of `+ start` to `- start` or `* start` produce a wrong + // (or panicking) result. + let input = "Some explanation here:\n```\nfn no_lang(x: i32) -> i32 { x }\n```"; + let code = extract_rust_code(input).expect("can extract"); + assert_eq!(code, "fn no_lang(x: i32) -> i32 { x }"); + } + #[test] fn test_parse_rust_code_from_block() { let input = "```rust diff --git a/symbiont/src/runtime.rs b/symbiont/src/runtime.rs index 99d25b5..e9c62de 100644 --- a/symbiont/src/runtime.rs +++ b/symbiont/src/runtime.rs @@ -24,6 +24,7 @@ use std::{ Ordering, }, }, + time::Duration, }; use libloading::{ @@ -32,14 +33,10 @@ use libloading::{ }; use minstant::Instant; use owo_colors::OwoColorize; -use rig::{ - agent::Agent, - completion::{ - CompletionModel, - Prompt, - }, -}; +use prettyplease::unparse; +use rig::completion::Prompt; use tracing::{ + debug, info, warn, }; @@ -61,6 +58,7 @@ use crate::{ find_so, generate_cargo_toml, generate_lib_rs, + is_transient_http_error, }, validation::validate_generated_ast, }; @@ -133,6 +131,8 @@ pub struct Runtime { decls: &'static [EvolvableDecl], /// Compilation profile (`debug` or `release`). profile: Profile, + /// The currently active AST of the agent code, in String form, to make it `Send` + current_clean_ast: Mutex, } /// Look up all declared symbols in `lib` and store their addresses @@ -165,6 +165,11 @@ unsafe fn update_fn_ptrs(lib: &Library, decls: &[EvolvableDecl]) -> Result<()> { } impl Runtime { + /// Maximum number of attempts [`Runtime::evolve`] will make before giving + /// up and returning [`Error::MaxRetriesExceeded`]. Prevents a misbehaving + /// agent from hanging the runtime indefinitely. + pub const MAX_EVOLVE_ATTEMPTS: usize = 10; + /// Initialize the symbiont runtime. /// /// Creates a temporary dylib crate from the declarations generated by `evolvable!`, @@ -205,7 +210,7 @@ impl Runtime { let mut ast = syn::parse_str(&lib_rs)?; // Compile - compile_dylib(&crate_dir, profile, &mut ast).await?; + compile_dylib(&crate_dir, profile, &mut ast, &lib_rs).await?; // Find and load the .so let so_path = find_so(&crate_dir, profile)?; @@ -226,6 +231,7 @@ impl Runtime { library: Mutex::new(Some(lib)), decls, profile, + current_clean_ast: Mutex::new(lib_rs), }; RUNTIME @@ -234,19 +240,6 @@ impl Runtime { Ok(RUNTIME.get().expect("just set")) } - /// Get the function signature strings for all evolvable functions. - pub fn fn_sigs(&self) -> &[String] { - &self.fn_sigs - } - - /// Get the full function signatures, including doc comments and default function body. - /// - /// Returns each source wrapped in [`FullSource`], which preserves real line - /// breaks when pretty-printed (`{:#?}`) so logs stay readable. - pub fn fn_full_sources(&self) -> Vec> { - Vec::from_iter(self.decls.iter().map(|d| FullSource(d.full_source))) - } - /// Generate LLM response, then parse, validate, compile, and hot-swap. /// It does not catch validation errors and feed it back to the LLM, allowing the user to customize prompting behaviour. /// @@ -255,13 +248,9 @@ impl Runtime { /// All evolvable function calls must have returned before this is called. /// In debug builds this is enforced with an assertion; in release it is /// the caller's responsibility. - async fn evolve_no_backpressure( - &self, - agent: &Agent, - prompt: &str, - ) -> Result<()> + async fn evolve_no_backpressure(&self, agent: &AgentT, prompt: &str) -> Result<()> where - CompletionModelT: CompletionModel + 'static, + AgentT: Prompt, { #[cfg(debug_assertions)] { @@ -287,7 +276,15 @@ impl Runtime { // Recompile let t0 = Instant::now(); - compile_dylib(&self.crate_dir, self.profile, &mut ast).await?; + let clean_ast_str = unparse(&ast); + debug!("clean_ast_str: {clean_ast_str}"); + compile_dylib(&self.crate_dir, self.profile, &mut ast, &clean_ast_str).await?; + { + *self + .current_clean_ast + .lock() + .expect("Can lock the clean ast mutex") = clean_ast_str; + } let compile_time = t0.elapsed().as_millis(); // Copy .so to versioned path to defeat dlopen caching @@ -319,63 +316,117 @@ impl Runtime { Ok(()) } + /// Maximum number of retries for transient HTTP errors (429, 5xx, 529). + /// + /// These are retried with exponential backoff and do not count against + /// [`Self::MAX_EVOLVE_ATTEMPTS`]. + pub const MAX_TRANSIENT_RETRIES: usize = 6; + + /// Exponential backoff (capped at 30s) for transient retry attempt `n`. + fn transient_backoff(n: usize) -> Duration { + let secs = 1u64 << n.min(5); + Duration::from_secs(secs.min(30)) + } + /// Prompt the LLM, validate the response, compile, and hot-swap. /// /// If the constrained generation fails (parse error, signature mismatch, /// compilation failure), the error is appended to the prompt and the LLM - /// retries until it produces valid code. + /// retries until it produces valid code, up to [`Self::MAX_EVOLVE_ATTEMPTS`] + /// attempts. After that, [`Error::MaxRetriesExceeded`] is returned so a + /// misbehaving agent cannot hang the runtime indefinitely. + /// + /// Transient HTTP errors from the LLM provider (HTTP 429, 5xx, 529 + /// "overloaded") are retried separately with exponential backoff up to + /// [`Self::MAX_TRANSIENT_RETRIES`] times, and do not count against the + /// self-healing attempt budget. /// /// # Contract /// /// All evolvable function calls must have returned before this is called. /// This is the natural shape of the feedback loop: run functions, collect /// results, evolve, repeat. - pub async fn evolve( - &self, - agent: &Agent, - base_prompt: &str, - ) -> Result<()> + pub async fn evolve(&self, agent: &AgentT, base_prompt: &str) -> Result<()> where - CompletionModelT: CompletionModel + 'static, + AgentT: Prompt, { let mut prompt = base_prompt.to_string(); - - while let Err(e) = self.evolve_no_backpressure(agent, &prompt).await { - info!("Function evolution error: {e}.\nSelf-healing from error..."); - - prompt = base_prompt.to_string(); - - use Error::*; - match e { - NoRustCode => prompt.push_str( - "Your response did not contain a rust code block. Please try again and make sure its wrapped like this: ```CODE```", - ), - CouldNotParseRust => prompt.push_str( - "Your response did not contain valid Rust code. Please try again", - ), - WriteLib(_) => todo!(), - SignatureMismatch { - code, - expected, - } => write!(prompt, - " Generated function signature miss-match. Expected ```{expected}```, Got Code ```{code}```", - ).expect("Can write to prompt"), - CompilationFailed{code, err} => write!(prompt, - " Your generated code ```{}``` failed to compile. Compiler output:\n```\n{}\n```\nPlease fix the compilation errors.", code.blue(), err.red() - ).expect("Can write to prompt"), - e => { - warn!("Unhandled error: {e}"); - return Err(e) - }, + let mut attempts: usize = 0; + let mut transient_attempts: usize = 0; + + loop { + attempts += 1; + match self.evolve_no_backpressure(agent, &prompt).await { + Ok(()) => return Ok(()), + Err(e) => { + // Transient HTTP errors (rate limits, overload, gateway + // failures) are not the LLM's fault: retry with + // exponential backoff and don't count against the + // self-healing attempt budget. + if is_transient_http_error(&e) { + if transient_attempts >= Self::MAX_TRANSIENT_RETRIES { + warn!( + "Transient HTTP error retry budget exhausted ({transient_attempts}/{}); giving up. Last error: {e}", + Self::MAX_TRANSIENT_RETRIES + ); + return Err(e); + } + let backoff = Self::transient_backoff(transient_attempts); + transient_attempts += 1; + warn!( + "Transient HTTP error from LLM provider (retry {transient_attempts}/{} in {:?}): {e}", + Self::MAX_TRANSIENT_RETRIES, + backoff, + ); + // Don't count this against the self-healing budget. + attempts -= 1; + tokio::time::sleep(backoff).await; + continue; + } + + if attempts >= Self::MAX_EVOLVE_ATTEMPTS { + warn!( + "Evolution failed after {attempts} attempts; giving up. Last error: {e}" + ); + return Err(MaxRetriesExceeded { + attempts, + last_error: Box::new(e), + }); + } + + info!( + "Function evolution error (attempt {attempts}/{}): {e}.\nSelf-healing from error...", + Self::MAX_EVOLVE_ATTEMPTS + ); + + prompt = base_prompt.to_string(); + + use Error::*; + match e { + NoRustCode => prompt.push_str( + "Your response did not contain a rust code block. Please try again and make sure its wrapped like this: ```CODE```", + ), + CouldNotParseRust => prompt.push_str( + "Your response did not contain valid Rust code. Please try again", + ), + WriteLib(_) => todo!(), + SignatureMismatch { + code, + expected, + } => write!(prompt, + " Generated function signature miss-match. Expected ```{expected}```, Got Code ```{code}```", + ).expect("Can write to prompt"), + CompilationFailed{code, err} => write!(prompt, + " Your generated code ```{}``` failed to compile. Compiler output:\n```\n{}\n```\nPlease fix the compilation errors.", code.blue(), err.red() + ).expect("Can write to prompt"), + e => { + warn!("Unhandled error: {e}"); + return Err(e) + }, + } } + } } - - Ok(()) - } - - /// Path to the temporary crate directory. - pub fn crate_dir(&self) -> &Path { - &self.crate_dir } /// Retrieve and clear the last panic message from the loaded dylib. @@ -401,10 +452,40 @@ impl Runtime { } } - /// Read the clean LLM-generated code (without panic-catching wrappers - /// or preamble). Suitable for feeding back into the LLM prompt or - /// displaying to the user. - pub fn read_clean_code(&self) -> std::io::Result { - std::fs::read_to_string(self.crate_dir.join("src").join("clean.rs")) + /// Path to the temporary crate directory. + pub fn crate_dir(&self) -> &Path { + &self.crate_dir + } + + /// Get the function signature strings for all evolvable functions. + pub fn fn_sigs(&self) -> &[String] { + &self.fn_sigs + } + + /// Get the full function signatures, including doc comments and default function body. + /// + /// Returns each source wrapped in [`FullSource`], which preserves real line + /// breaks when pretty-printed (`{:#?}`) so logs stay readable. + pub fn fn_full_sources(&self) -> Vec> { + Vec::from_iter(self.decls.iter().map(|d| FullSource(d.full_source))) + } + + /// Get the current, ,clean LLM-generated code (without panic-catching wrappers or preamble). + /// Suitable for feeding back into the LLM prompt or displaying to the user. + pub fn current_code(&self) -> String { + self.current_clean_ast + .lock() + .expect("Can lock the mutex to get clean AST") + .clone() + } + + /// Return the function signature and body for a single function base on its `fn_name` + pub fn current_function(&self, fn_name: &str) -> Option { + let code = self.current_code(); + let file: syn::File = syn::parse_str(&code).ok()?; + file.items.into_iter().find_map(|item| match item { + syn::Item::Fn(f) if f.sig.ident == fn_name => Some(f), + _ => None, + }) } } diff --git a/symbiont/src/utils.rs b/symbiont/src/utils.rs index d86c42c..d51d874 100644 --- a/symbiont/src/utils.rs +++ b/symbiont/src/utils.rs @@ -95,6 +95,35 @@ pub(crate) fn find_so(crate_dir: &Path, profile: Profile) -> Result { ))) } } + +/// Return `true` for [`Error`] values that represent transient failures of +/// the LLM provider (rate-limits, server overload, gateway errors) and are +/// safe to retry without modifying the prompt. +pub(crate) fn is_transient_http_error(err: &Error) -> bool { + let http_err = match err { + Error::RigPrompt(rig::completion::PromptError::CompletionError( + rig::completion::CompletionError::HttpError(http_err), + )) => http_err, + Error::RigHttp(http_err) => http_err, + _ => return false, + }; + + use rig::http_client::Error::*; + let status = match http_err { + InvalidStatusCode(s) => *s, + InvalidStatusCodeWithMessage(s, _) => *s, + // Connection-level errors (timeouts, resets, DNS, etc.) are also + // transient by nature. + Instance(_) => return true, + _ => return false, + }; + + let code = status.as_u16(); + // 408 Request Timeout, 425 Too Early, 429 Too Many Requests, + // 5xx Server errors (incl. 529 Site Overloaded used by Anthropic). + matches!(code, 408 | 425 | 429 | 500..=599) +} + #[cfg(test)] mod tests { use super::*; diff --git a/symbiont/src/validation.rs b/symbiont/src/validation.rs index 941ff11..a3caa65 100644 --- a/symbiont/src/validation.rs +++ b/symbiont/src/validation.rs @@ -224,6 +224,7 @@ pub fn step(counter: &mut usize) { } #[test] + #[tracing_test::traced_test] fn test_validate_with_return_type() { let input = "```rust #[unsafe(no_mangle)] diff --git a/symbiont/tests/runtime.rs b/symbiont/tests/runtime.rs new file mode 100644 index 0000000..03c6f0a --- /dev/null +++ b/symbiont/tests/runtime.rs @@ -0,0 +1,62 @@ +#![expect( + unused_crate_dependencies, + missing_docs, + reason = "Integration tests don't use them all" +)] + +use rig::completion::Prompt; +use symbiont::{ + Profile, + Runtime, +}; + +#[tokio::test] +#[tracing_test::traced_test] +async fn runtime() { + symbiont::evolvable! { + /// Should increment the counter by a value in the range 5..20 + fn step(counter: &mut usize) { + *counter += 1; + } + }; + let rt = Runtime::init(SYMBIONT_DECLS, Profile::Debug) + .await + .expect("Can init"); + assert_eq!( + &rt.current_code(), + "/// Should increment the counter by a value in the range 5..20\n#[unsafe(no_mangle)]\npub fn step(counter: &mut usize) {\n *counter += 1;\n}\n\n\n" + ); + let mut counter = 0; + step(&mut counter); + assert_eq!(counter, 1); + + let agent = MockAgent; + let prompt = format!("Implement this function in rust: ```{}```", rt.fn_sigs()[0]); + rt.evolve(&agent, &prompt).await.expect("Can evolve"); + assert_eq!( + &rt.current_code(), + "#[unsafe(no_mangle)]\npub fn step(counter: &mut usize) {\n *counter += 5;\n}\n", + "Code has evolved" + ); + step(&mut counter); + assert_eq!(counter, 6); +} + +struct MockAgent; + +impl Prompt for MockAgent { + fn prompt( + &self, + _prompt: impl Into + rig::wasm_compat::WasmCompatSend, + ) -> impl IntoFuture< + Output = Result, + IntoFuture: rig::wasm_compat::WasmCompatSend, + > { + async { + Ok("``` + pub fn step(counter: &mut usize) { *counter += 5; } + ```" + .to_string()) + } + } +}