From e4d3a0941c3b946db24769499f274c63803d775a Mon Sep 17 00:00:00 2001 From: lucasven Date: Sun, 22 Mar 2026 14:56:18 -0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(memory):=20add=20post-conversa?= =?UTF-8?q?tion=20extraction=20and=20relevance=20decay?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 --- Lis.Agent/AgentSetup.cs | 3 + Lis.Agent/ConversationService.cs | 13 + Lis.Agent/MemoryExtractionService.cs | 162 +++ Lis.Api/Program.cs | 4 + Lis.Persistence/Entities/MemoryEntity.cs | 13 +- ...00_add_memory_relevance_fields.Designer.cs | 1009 +++++++++++++++++ ...60322150000_add_memory_relevance_fields.cs | 41 + .../Migrations/LisDbContextModelSnapshot.cs | 12 + Lis.Tests/Agent/MemoryExtractionTests.cs | 295 +++++ Lis.Tools/MemoryPlugin.cs | 22 +- global.json | 3 +- 11 files changed, 1569 insertions(+), 8 deletions(-) create mode 100644 Lis.Agent/MemoryExtractionService.cs create mode 100644 Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.Designer.cs create mode 100644 Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.cs create mode 100644 Lis.Tests/Agent/MemoryExtractionTests.cs diff --git a/Lis.Agent/AgentSetup.cs b/Lis.Agent/AgentSetup.cs index 97782da..37dc62c 100644 --- a/Lis.Agent/AgentSetup.cs +++ b/Lis.Agent/AgentSetup.cs @@ -73,6 +73,9 @@ public static IServiceCollection AddLisAgent(this IServiceCollection services) { // Compaction services.AddSingleton(); + // Memory extraction + services.AddSingleton(); + return services; } } diff --git a/Lis.Agent/ConversationService.cs b/Lis.Agent/ConversationService.cs index 719c23b..f974fe3 100644 --- a/Lis.Agent/ConversationService.cs +++ b/Lis.Agent/ConversationService.cs @@ -31,6 +31,7 @@ public sealed class ConversationService( IMediaProcessor mediaProcessor, IApprovalService approvalService, ToolPolicyService toolPolicyService, + IMemoryExtractionService memoryExtraction, IOptions lisOptions, ILogger logger, ITokenCounter? tokenCounter = null) : IConversationService { @@ -297,6 +298,18 @@ public async Task RespondAsync(IncomingMessage message, CancellationToken ct) { await db.SaveChangesAsync(ct); await this.CheckCompactionTriggersAsync(db, session, agent, lastUsage, message.ChatId, ct); + + // Fire-and-forget memory extraction from conversation + List conversationForExtraction = recentMessages + .Select(m => $"{(m.IsFromMe ? "Assistant" : "User")}: {m.Body ?? "[media]"}") + .ToList(); + _ = Task.Run(async () => { + try { + await memoryExtraction.ExtractAsync(conversationForExtraction, CancellationToken.None); + } catch (Exception ex) { + logger.LogWarning(ex, "Memory extraction failed"); + } + }, CancellationToken.None); } } diff --git a/Lis.Agent/MemoryExtractionService.cs b/Lis.Agent/MemoryExtractionService.cs new file mode 100644 index 0000000..208d863 --- /dev/null +++ b/Lis.Agent/MemoryExtractionService.cs @@ -0,0 +1,162 @@ +using System.Text.Json; +using System.Text.RegularExpressions; + +using Lis.Core.Util; +using Lis.Persistence; +using Lis.Persistence.Entities; + +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +using Pgvector; + +namespace Lis.Agent; + +public interface IMemoryExtractionService { + Task ExtractAsync(List conversationMessages, CancellationToken ct); +} + +public sealed class MemoryExtractionService( + [FromKeyedServices("extraction")] IChatClient extractionClient, + IServiceScopeFactory scopeFactory, + ILogger logger) : IMemoryExtractionService { + + private const int MaxMemoriesPerExtraction = 5; + + private const string ExtractionPrompt = """ + Analyze the following conversation and extract 0-5 factual memories worth remembering long-term. + Focus on: personal preferences, facts about people, decisions made, important dates, commitments. + Skip: greetings, small talk, transient information, tool call details. + + Return a JSON array of objects with: + - "content": the fact to remember (concise, standalone sentence) + - "contact_name": person's name this is about (optional, null if general) + + If nothing worth remembering, return an empty array: [] + + Conversation: + """; + + [Trace("MemoryExtractionService > ExtractAsync")] + public async Task ExtractAsync(List conversationMessages, CancellationToken ct) { + try { + string conversation = string.Join("\n", conversationMessages); + string prompt = ExtractionPrompt + conversation; + + ChatOptions options = new() { MaxOutputTokens = 512, Temperature = 0.1f }; + ChatResponse response = await extractionClient.GetResponseAsync( + [new ChatMessage(ChatRole.User, prompt)], options, ct); + + string? text = response.Text?.Trim(); + if (string.IsNullOrWhiteSpace(text)) return; + + // Strip markdown code fences if present + text = StripCodeFences(text); + + List? memories; + try { + memories = JsonSerializer.Deserialize>(text, JsonOpts); + } catch (JsonException ex) { + logger.LogWarning(ex, "Failed to parse extraction response: {Text}", text[..Math.Min(text.Length, 200)]); + return; + } + + if (memories is null || memories.Count == 0) return; + + // Cap at max + if (memories.Count > MaxMemoriesPerExtraction) + memories = memories.Take(MaxMemoriesPerExtraction).ToList(); + + // Filter out empty/null content + memories = memories.Where(m => !string.IsNullOrWhiteSpace(m.Content)).ToList(); + if (memories.Count == 0) return; + + using IServiceScope scope = scopeFactory.CreateScope(); + LisDbContext db = scope.ServiceProvider.GetRequiredService(); + IEmbeddingGenerator>? embeddingGen = + scope.ServiceProvider.GetService>>(); + + foreach (ExtractedMemory mem in memories) { + long? contactId = await ResolveOrCreateContactAsync(db, mem.ContactName); + Vector? embedding = await GenerateEmbeddingAsync(embeddingGen, mem.Content!); + + MemoryEntity entity = new() { + Content = mem.Content!.Trim(), + ContactId = contactId, + Embedding = embedding, + RelevanceScore = 1.0f, + CreatedAt = DateTimeOffset.UtcNow, + UpdatedAt = DateTimeOffset.UtcNow, + }; + + db.Memories.Add(entity); + } + + await db.SaveChangesAsync(ct); + + if (logger.IsEnabled(LogLevel.Information)) + logger.LogInformation("Extracted {Count} memories from conversation", memories.Count); + } catch (Exception ex) { + logger.LogWarning(ex, "Memory extraction failed"); + } + } + + /// + /// Calculates relevance score based on last access time. + /// Formula: max(0.1, 1.0 - (days_since_access / 30.0) * 0.5) + /// + public static float CalculateRelevanceScore(DateTimeOffset? lastAccessedAt) { + if (lastAccessedAt is null) return 1.0f; + + double daysSinceAccess = (DateTimeOffset.UtcNow - lastAccessedAt.Value).TotalDays; + float score = (float)(1.0 - daysSinceAccess / 30.0 * 0.5); + return Math.Max(0.1f, score); + } + + private static string StripCodeFences(string text) { + // Remove ```json ... ``` or ``` ... ``` + Match match = Regex.Match(text, @"```(?:json)?\s*([\s\S]*?)\s*```", RegexOptions.IgnoreCase); + return match.Success ? match.Groups[1].Value.Trim() : text; + } + + private static async Task ResolveOrCreateContactAsync(LisDbContext db, string? contactName) { + if (string.IsNullOrWhiteSpace(contactName)) return null; + + ContactEntity? contact = await db.Contacts + .FirstOrDefaultAsync(c => c.Name != null + && c.Name.Equals(contactName.Trim(), StringComparison.OrdinalIgnoreCase)); + + if (contact is not null) return contact.Id; + + contact = new ContactEntity { + Name = contactName.Trim(), + CreatedAt = DateTimeOffset.UtcNow, + UpdatedAt = DateTimeOffset.UtcNow, + }; + + db.Contacts.Add(contact); + await db.SaveChangesAsync(); + + return contact.Id; + } + + private static async Task GenerateEmbeddingAsync( + IEmbeddingGenerator>? embeddingGen, string content) { + if (embeddingGen is null) return null; + + GeneratedEmbeddings> result = await embeddingGen.GenerateAsync([content]); + return new Vector(result[0].Vector); + } + + private static readonly JsonSerializerOptions JsonOpts = new() { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + PropertyNameCaseInsensitive = true, + }; + + private sealed class ExtractedMemory { + public string? Content { get; set; } + public string? ContactName { get; set; } + } +} diff --git a/Lis.Api/Program.cs b/Lis.Api/Program.cs index a951b07..f1ad709 100644 --- a/Lis.Api/Program.cs +++ b/Lis.Api/Program.cs @@ -101,6 +101,10 @@ (sp, _) => sp.GetRequiredService()); } +// Extraction client (keyed IChatClient for memory extraction — reuses compaction client) +builder.Services.AddKeyedSingleton("extraction", + (sp, _) => sp.GetRequiredKeyedService("compaction")); + // Embedding (optional — enables vector search for memories) if (Env("MEMORIES_EMBEDDING_ENABLED") == "true") builder.Services.AddEmbedding(); diff --git a/Lis.Persistence/Entities/MemoryEntity.cs b/Lis.Persistence/Entities/MemoryEntity.cs index 2a31aa8..c460889 100644 --- a/Lis.Persistence/Entities/MemoryEntity.cs +++ b/Lis.Persistence/Entities/MemoryEntity.cs @@ -38,9 +38,17 @@ public sealed class MemoryEntity { [Column("updated_at")] [JsonPropertyName("updated_at")] public DateTimeOffset UpdatedAt { get; set; } + + [Column("last_accessed_at")] + [JsonPropertyName("last_accessed_at")] + public DateTimeOffset? LastAccessedAt { get; set; } + + [Column("relevance_score")] + [JsonPropertyName("relevance_score")] + public float RelevanceScore { get; set; } = 1.0f; } -public class MemoryEntityConfiguration :IEntityTypeConfiguration { +public class MemoryEntityConfiguration : IEntityTypeConfiguration { public void Configure(EntityTypeBuilder builder) { builder.HasIndex(e => e.ContactId); @@ -52,5 +60,8 @@ public void Configure(EntityTypeBuilder builder) { builder.HasIndex(e => e.Embedding) .HasMethod("hnsw") .HasOperators("vector_cosine_ops"); + + builder.Property(e => e.RelevanceScore) + .HasDefaultValue(1f); } } diff --git a/Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.Designer.cs b/Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.Designer.cs new file mode 100644 index 0000000..ac72702 --- /dev/null +++ b/Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.Designer.cs @@ -0,0 +1,1009 @@ +// +using System; +using Lis.Persistence; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; +using Pgvector; + +#nullable disable + +namespace Lis.Persistence.Migrations +{ + [DbContext(typeof(LisDbContext))] + [Migration("20260322150000_add_memory_relevance_fields")] + partial class add_memory_relevance_fields + { + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { +#pragma warning disable 612, 618 + modelBuilder + .HasAnnotation("ProductVersion", "10.0.3") + .HasAnnotation("Relational:MaxIdentifierLength", 63); + + NpgsqlModelBuilderExtensions.HasPostgresExtension(modelBuilder, "vector"); + NpgsqlModelBuilderExtensions.UseIdentityByDefaultColumns(modelBuilder); + + modelBuilder.Entity("Lis.Persistence.Entities.AgentEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("CompactionThreshold") + .HasColumnType("integer") + .HasColumnName("compaction_threshold") + .HasJsonPropertyName("compaction_threshold"); + + b.Property("ContextBudget") + .HasColumnType("integer") + .HasColumnName("context_budget") + .HasJsonPropertyName("context_budget"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("DisplayName") + .HasMaxLength(128) + .HasColumnType("varchar(128)") + .HasColumnName("display_name") + .HasJsonPropertyName("display_name"); + + b.Property("ExecSecurity") + .IsRequired() + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("exec_security") + .HasJsonPropertyName("exec_security"); + + b.Property("ExecTimeoutSeconds") + .HasColumnType("integer") + .HasColumnName("exec_timeout_seconds") + .HasJsonPropertyName("exec_timeout_seconds"); + + b.Property("GroupContextPrompt") + .HasColumnType("text") + .HasColumnName("group_context_prompt") + .HasJsonPropertyName("group_context_prompt"); + + b.Property("IsDefault") + .HasColumnType("boolean") + .HasColumnName("is_default") + .HasJsonPropertyName("is_default"); + + b.Property("KeepRecentTokens") + .HasColumnType("integer") + .HasColumnName("keep_recent_tokens") + .HasJsonPropertyName("keep_recent_tokens"); + + b.Property("MaxTokens") + .HasColumnType("integer") + .HasColumnName("max_tokens") + .HasJsonPropertyName("max_tokens"); + + b.Property("MentionTriggers") + .HasMaxLength(256) + .HasColumnType("varchar(256)") + .HasColumnName("mention_triggers") + .HasJsonPropertyName("mention_triggers"); + + b.Property("Model") + .IsRequired() + .HasMaxLength(128) + .HasColumnType("varchar(128)") + .HasColumnName("model") + .HasJsonPropertyName("model"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("name") + .HasJsonPropertyName("name"); + + b.Property("ThinkingEffort") + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("thinking_effort") + .HasJsonPropertyName("thinking_effort"); + + b.Property("ToolKeepThreshold") + .HasColumnType("integer") + .HasColumnName("tool_keep_threshold") + .HasJsonPropertyName("tool_keep_threshold"); + + b.Property("ToolNotifications") + .HasColumnType("boolean") + .HasColumnName("tool_notifications") + .HasJsonPropertyName("tool_notifications"); + + b.Property("ToolProfile") + .HasMaxLength(32) + .HasColumnType("varchar(32)") + .HasColumnName("tool_profile") + .HasJsonPropertyName("tool_profile"); + + b.Property("ToolPruneThreshold") + .HasColumnType("integer") + .HasColumnName("tool_prune_threshold") + .HasJsonPropertyName("tool_prune_threshold"); + + b.Property("ToolSummarizationPolicy") + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("tool_summarization_policy") + .HasJsonPropertyName("tool_summarization_policy"); + + b.Property("ToolsAllow") + .HasColumnType("text") + .HasColumnName("tools_allow") + .HasJsonPropertyName("tools_allow"); + + b.Property("ToolsDeny") + .HasColumnType("text") + .HasColumnName("tools_deny") + .HasJsonPropertyName("tools_deny"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at") + .HasJsonPropertyName("updated_at"); + + b.Property("WorkspacePath") + .HasColumnType("text") + .HasColumnName("workspace_path") + .HasJsonPropertyName("workspace_path"); + + b.HasKey("Id"); + + b.HasIndex("Name") + .IsUnique(); + + b.ToTable("agent"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ChatAllowedSenderEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("ChatId") + .HasColumnType("bigint") + .HasColumnName("chat_id") + .HasJsonPropertyName("chat_id"); + + b.Property("SenderId") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("sender_id") + .HasJsonPropertyName("sender_id"); + + b.HasKey("Id"); + + b.HasIndex("ChatId"); + + b.HasIndex("ChatId", "SenderId") + .IsUnique(); + + b.ToTable("chat_allowed_sender"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ChatEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("AgentId") + .HasColumnType("bigint") + .HasColumnName("agent_id") + .HasJsonPropertyName("agent_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("CurrentSessionId") + .HasColumnType("bigint") + .HasColumnName("current_session_id") + .HasJsonPropertyName("current_session_id"); + + b.Property("DebounceMs") + .HasColumnType("integer") + .HasColumnName("debounce_ms") + .HasJsonPropertyName("debounce_ms"); + + b.Property("Enabled") + .HasColumnType("boolean") + .HasColumnName("enabled") + .HasJsonPropertyName("enabled"); + + b.Property("ExternalId") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("external_id") + .HasJsonPropertyName("external_id"); + + b.Property("GroupContextMessages") + .HasColumnType("integer") + .HasColumnName("group_context_messages") + .HasJsonPropertyName("group_context_messages"); + + b.Property("GroupTopic") + .HasMaxLength(512) + .HasColumnType("varchar(512)") + .HasColumnName("group_topic") + .HasJsonPropertyName("group_topic"); + + b.Property("IsGroup") + .HasColumnType("boolean") + .HasColumnName("is_group") + .HasJsonPropertyName("is_group"); + + b.Property("Name") + .HasMaxLength(256) + .HasColumnType("varchar(256)") + .HasColumnName("name") + .HasJsonPropertyName("name"); + + b.Property("OpenGroup") + .HasColumnType("boolean") + .HasColumnName("open_group") + .HasJsonPropertyName("open_group"); + + b.Property("RequireMention") + .HasColumnType("boolean") + .HasColumnName("require_mention") + .HasJsonPropertyName("require_mention"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at") + .HasJsonPropertyName("updated_at"); + + b.HasKey("Id"); + + b.HasIndex("AgentId"); + + b.HasIndex("CurrentSessionId"); + + b.HasIndex("ExternalId") + .IsUnique(); + + b.ToTable("chat"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ContactEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("Name") + .HasMaxLength(256) + .HasColumnType("varchar(256)") + .HasColumnName("name") + .HasJsonPropertyName("name"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at") + .HasJsonPropertyName("updated_at"); + + b.HasKey("Id"); + + b.ToTable("contact"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ContactIdentifierEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Channel") + .IsRequired() + .HasMaxLength(32) + .HasColumnType("varchar(32)") + .HasColumnName("channel") + .HasJsonPropertyName("channel"); + + b.Property("ContactId") + .HasColumnType("bigint") + .HasColumnName("contact_id") + .HasJsonPropertyName("contact_id"); + + b.Property("ExternalId") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("external_id") + .HasJsonPropertyName("external_id"); + + b.HasKey("Id"); + + b.HasIndex("ContactId"); + + b.HasIndex("Channel", "ExternalId") + .IsUnique(); + + b.ToTable("contact_identifier"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ExecAllowlistEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("AgentId") + .HasColumnType("bigint") + .HasColumnName("agent_id") + .HasJsonPropertyName("agent_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("LastCommand") + .HasColumnType("text") + .HasColumnName("last_command") + .HasJsonPropertyName("last_command"); + + b.Property("LastUsedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_used_at") + .HasJsonPropertyName("last_used_at"); + + b.Property("Pattern") + .IsRequired() + .HasColumnType("text") + .HasColumnName("pattern") + .HasJsonPropertyName("pattern"); + + b.HasKey("Id"); + + b.HasIndex("AgentId", "Pattern") + .IsUnique(); + + b.ToTable("exec_allowlist"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ExecApprovalEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("AgentId") + .HasColumnType("bigint") + .HasColumnName("agent_id") + .HasJsonPropertyName("agent_id"); + + b.Property("ApprovalId") + .IsRequired() + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("approval_id") + .HasJsonPropertyName("approval_id"); + + b.Property("ChatId") + .HasColumnType("bigint") + .HasColumnName("chat_id") + .HasJsonPropertyName("chat_id"); + + b.Property("Command") + .IsRequired() + .HasColumnType("text") + .HasColumnName("command") + .HasJsonPropertyName("command"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("Cwd") + .HasColumnType("text") + .HasColumnName("cwd") + .HasJsonPropertyName("cwd"); + + b.Property("Decision") + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("decision") + .HasJsonPropertyName("decision"); + + b.Property("ExpiresAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("expires_at") + .HasJsonPropertyName("expires_at"); + + b.Property("MessageExternalId") + .HasMaxLength(128) + .HasColumnType("varchar(128)") + .HasColumnName("message_external_id") + .HasJsonPropertyName("message_external_id"); + + b.Property("ResolvedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("resolved_at") + .HasJsonPropertyName("resolved_at"); + + b.Property("ResolvedBy") + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("resolved_by") + .HasJsonPropertyName("resolved_by"); + + b.Property("Status") + .IsRequired() + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("status") + .HasJsonPropertyName("status"); + + b.HasKey("Id"); + + b.HasIndex("AgentId"); + + b.HasIndex("ApprovalId") + .IsUnique(); + + b.HasIndex("ChatId"); + + b.HasIndex("MessageExternalId"); + + b.HasIndex("Status") + .HasFilter("status = 'pending'"); + + b.ToTable("exec_approval"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.MemoryEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("ContactId") + .HasColumnType("bigint") + .HasColumnName("contact_id") + .HasJsonPropertyName("contact_id"); + + b.Property("Content") + .IsRequired() + .HasColumnType("text") + .HasColumnName("content") + .HasJsonPropertyName("content"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("Embedding") + .HasColumnType("vector(1536)") + .HasColumnName("embedding"); + + b.Property("LastAccessedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_accessed_at") + .HasJsonPropertyName("last_accessed_at"); + + b.Property("RelevanceScore") + .ValueGeneratedOnAdd() + .HasColumnType("real") + .HasDefaultValue(1f) + .HasColumnName("relevance_score") + .HasJsonPropertyName("relevance_score"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at") + .HasJsonPropertyName("updated_at"); + + b.HasKey("Id"); + + b.HasIndex("ContactId"); + + b.HasIndex("Embedding"); + + NpgsqlIndexBuilderExtensions.HasMethod(b.HasIndex("Embedding"), "hnsw"); + NpgsqlIndexBuilderExtensions.HasOperators(b.HasIndex("Embedding"), new[] { "vector_cosine_ops" }); + + b.ToTable("memory"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.MessageEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Body") + .HasColumnType("text") + .HasColumnName("body") + .HasJsonPropertyName("body"); + + b.Property("CacheCreationTokens") + .HasColumnType("integer") + .HasColumnName("cache_creation_tokens") + .HasJsonPropertyName("cache_creation_tokens"); + + b.Property("CacheReadTokens") + .HasColumnType("integer") + .HasColumnName("cache_read_tokens") + .HasJsonPropertyName("cache_read_tokens"); + + b.Property("ChatId") + .HasColumnType("bigint") + .HasColumnName("chat_id") + .HasJsonPropertyName("chat_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("ExternalId") + .HasMaxLength(128) + .HasColumnType("varchar(128)") + .HasColumnName("external_id") + .HasJsonPropertyName("external_id"); + + b.Property("InputTokens") + .HasColumnType("integer") + .HasColumnName("input_tokens") + .HasJsonPropertyName("input_tokens"); + + b.Property("IsFromMe") + .HasColumnType("boolean") + .HasColumnName("is_from_me") + .HasJsonPropertyName("is_from_me"); + + b.Property("MediaCaption") + .HasColumnType("text") + .HasColumnName("media_caption") + .HasJsonPropertyName("media_caption"); + + b.Property("MediaData") + .HasColumnType("bytea") + .HasColumnName("media_data"); + + b.Property("MediaMimeType") + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("media_mime_type") + .HasJsonPropertyName("media_mime_type"); + + b.Property("MediaType") + .HasMaxLength(32) + .HasColumnType("varchar(32)") + .HasColumnName("media_type") + .HasJsonPropertyName("media_type"); + + b.Property("OutputTokens") + .HasColumnType("integer") + .HasColumnName("output_tokens") + .HasJsonPropertyName("output_tokens"); + + b.Property("Queued") + .HasColumnType("boolean") + .HasColumnName("queued") + .HasJsonPropertyName("queued"); + + b.Property("ReplyContent") + .HasColumnType("text") + .HasColumnName("reply_content") + .HasJsonPropertyName("reply_content"); + + b.Property("ReplyToId") + .HasMaxLength(128) + .HasColumnType("varchar(128)") + .HasColumnName("reply_to_id") + .HasJsonPropertyName("reply_to_id"); + + b.Property("Role") + .HasMaxLength(16) + .HasColumnType("varchar(16)") + .HasColumnName("role") + .HasJsonPropertyName("role"); + + b.Property("SenderId") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("varchar(64)") + .HasColumnName("sender_id") + .HasJsonPropertyName("sender_id"); + + b.Property("SenderName") + .HasMaxLength(256) + .HasColumnType("varchar(256)") + .HasColumnName("sender_name") + .HasJsonPropertyName("sender_name"); + + b.Property("SessionId") + .HasColumnType("bigint") + .HasColumnName("session_id") + .HasJsonPropertyName("session_id"); + + b.Property("SkContent") + .HasColumnType("jsonb") + .HasColumnName("sk_content") + .HasJsonPropertyName("sk_content"); + + b.Property("ThinkingTokens") + .HasColumnType("integer") + .HasColumnName("thinking_tokens") + .HasJsonPropertyName("thinking_tokens"); + + b.Property("Timestamp") + .HasColumnType("timestamp with time zone") + .HasColumnName("timestamp") + .HasJsonPropertyName("timestamp"); + + b.HasKey("Id"); + + b.HasIndex("ExternalId") + .IsUnique() + .HasFilter("\"external_id\" IS NOT NULL"); + + b.HasIndex("ChatId", "Timestamp"); + + b.HasIndex("SenderId", "Id"); + + b.HasIndex("SenderName", "Id"); + + b.HasIndex("SessionId", "Timestamp"); + + b.ToTable("message"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.PromptSectionEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("AgentId") + .HasColumnType("bigint") + .HasColumnName("agent_id") + .HasJsonPropertyName("agent_id"); + + b.Property("Content") + .IsRequired() + .HasColumnType("text") + .HasColumnName("content") + .HasJsonPropertyName("content"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("IsEnabled") + .HasColumnType("boolean") + .HasColumnName("is_enabled") + .HasJsonPropertyName("is_enabled"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(50) + .HasColumnType("varchar(50)") + .HasColumnName("name") + .HasJsonPropertyName("name"); + + b.Property("SortOrder") + .HasColumnType("integer") + .HasColumnName("sort_order") + .HasJsonPropertyName("sort_order"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at") + .HasJsonPropertyName("updated_at"); + + b.HasKey("Id"); + + b.HasIndex("AgentId", "Name") + .IsUnique(); + + b.ToTable("prompt_section"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.SessionEntity", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id") + .HasJsonPropertyName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("ChatId") + .HasColumnType("bigint") + .HasColumnName("chat_id") + .HasJsonPropertyName("chat_id"); + + b.Property("ContextTokens") + .HasColumnType("bigint") + .HasColumnName("context_tokens") + .HasJsonPropertyName("context_tokens"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at") + .HasJsonPropertyName("created_at"); + + b.Property("IsCompacting") + .HasColumnType("boolean") + .HasColumnName("is_compacting") + .HasJsonPropertyName("is_compacting"); + + b.Property("ParentSessionId") + .HasColumnType("bigint") + .HasColumnName("parent_session_id") + .HasJsonPropertyName("parent_session_id"); + + b.Property("Summary") + .HasColumnType("text") + .HasColumnName("summary") + .HasJsonPropertyName("summary"); + + b.Property("SummaryEmbedding") + .HasColumnType("vector(1536)") + .HasColumnName("summary_embedding") + .HasJsonPropertyName("summary_embedding"); + + b.Property("ToolsPrunedThroughId") + .HasColumnType("bigint") + .HasColumnName("tools_pruned_through_id") + .HasJsonPropertyName("tools_pruned_through_id"); + + b.Property("TotalCacheCreationTokens") + .HasColumnType("bigint") + .HasColumnName("total_cache_creation_tokens") + .HasJsonPropertyName("total_cache_creation_tokens"); + + b.Property("TotalCacheReadTokens") + .HasColumnType("bigint") + .HasColumnName("total_cache_read_tokens") + .HasJsonPropertyName("total_cache_read_tokens"); + + b.Property("TotalInputTokens") + .HasColumnType("bigint") + .HasColumnName("total_input_tokens") + .HasJsonPropertyName("total_input_tokens"); + + b.Property("TotalOutputTokens") + .HasColumnType("bigint") + .HasColumnName("total_output_tokens") + .HasJsonPropertyName("total_output_tokens"); + + b.Property("TotalThinkingTokens") + .HasColumnType("bigint") + .HasColumnName("total_thinking_tokens") + .HasJsonPropertyName("total_thinking_tokens"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at") + .HasJsonPropertyName("updated_at"); + + b.HasKey("Id"); + + b.HasIndex("ChatId"); + + b.HasIndex("ParentSessionId"); + + b.HasIndex("SummaryEmbedding"); + + NpgsqlIndexBuilderExtensions.HasMethod(b.HasIndex("SummaryEmbedding"), "hnsw"); + NpgsqlIndexBuilderExtensions.HasOperators(b.HasIndex("SummaryEmbedding"), new[] { "vector_cosine_ops" }); + + b.ToTable("session"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ChatAllowedSenderEntity", b => + { + b.HasOne("Lis.Persistence.Entities.ChatEntity", "Chat") + .WithMany("AllowedSenders") + .HasForeignKey("ChatId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Chat"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ChatEntity", b => + { + b.HasOne("Lis.Persistence.Entities.AgentEntity", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.SetNull); + + b.HasOne("Lis.Persistence.Entities.SessionEntity", "CurrentSession") + .WithMany() + .HasForeignKey("CurrentSessionId") + .OnDelete(DeleteBehavior.SetNull); + + b.Navigation("Agent"); + + b.Navigation("CurrentSession"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ContactIdentifierEntity", b => + { + b.HasOne("Lis.Persistence.Entities.ContactEntity", "Contact") + .WithMany("Identifiers") + .HasForeignKey("ContactId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Contact"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ExecAllowlistEntity", b => + { + b.HasOne("Lis.Persistence.Entities.AgentEntity", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.SetNull); + + b.Navigation("Agent"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ExecApprovalEntity", b => + { + b.HasOne("Lis.Persistence.Entities.AgentEntity", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.SetNull); + + b.HasOne("Lis.Persistence.Entities.ChatEntity", "Chat") + .WithMany() + .HasForeignKey("ChatId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Agent"); + + b.Navigation("Chat"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.MemoryEntity", b => + { + b.HasOne("Lis.Persistence.Entities.ContactEntity", "Contact") + .WithMany("Memories") + .HasForeignKey("ContactId") + .OnDelete(DeleteBehavior.SetNull); + + b.Navigation("Contact"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.MessageEntity", b => + { + b.HasOne("Lis.Persistence.Entities.ChatEntity", "Chat") + .WithMany("Messages") + .HasForeignKey("ChatId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.HasOne("Lis.Persistence.Entities.SessionEntity", "Session") + .WithMany() + .HasForeignKey("SessionId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Chat"); + + b.Navigation("Session"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.PromptSectionEntity", b => + { + b.HasOne("Lis.Persistence.Entities.AgentEntity", "Agent") + .WithMany("PromptSections") + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Agent"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.SessionEntity", b => + { + b.HasOne("Lis.Persistence.Entities.ChatEntity", "Chat") + .WithMany() + .HasForeignKey("ChatId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.HasOne("Lis.Persistence.Entities.SessionEntity", "ParentSession") + .WithMany() + .HasForeignKey("ParentSessionId") + .OnDelete(DeleteBehavior.SetNull); + + b.Navigation("Chat"); + + b.Navigation("ParentSession"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.AgentEntity", b => + { + b.Navigation("PromptSections"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ChatEntity", b => + { + b.Navigation("AllowedSenders"); + + b.Navigation("Messages"); + }); + + modelBuilder.Entity("Lis.Persistence.Entities.ContactEntity", b => + { + b.Navigation("Identifiers"); + + b.Navigation("Memories"); + }); +#pragma warning restore 612, 618 + } + } +} diff --git a/Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.cs b/Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.cs new file mode 100644 index 0000000..77e8d18 --- /dev/null +++ b/Lis.Persistence/Migrations/20260322150000_add_memory_relevance_fields.cs @@ -0,0 +1,41 @@ +using System; + +using Microsoft.EntityFrameworkCore.Migrations; + +#nullable disable + +namespace Lis.Persistence.Migrations +{ + /// + public partial class add_memory_relevance_fields : Migration + { + /// + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "last_accessed_at", + table: "memory", + type: "timestamp with time zone", + nullable: true); + + migrationBuilder.AddColumn( + name: "relevance_score", + table: "memory", + type: "real", + nullable: false, + defaultValue: 1f); + } + + /// + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "last_accessed_at", + table: "memory"); + + migrationBuilder.DropColumn( + name: "relevance_score", + table: "memory"); + } + } +} diff --git a/Lis.Persistence/Migrations/LisDbContextModelSnapshot.cs b/Lis.Persistence/Migrations/LisDbContextModelSnapshot.cs index de02528..4dc0a00 100644 --- a/Lis.Persistence/Migrations/LisDbContextModelSnapshot.cs +++ b/Lis.Persistence/Migrations/LisDbContextModelSnapshot.cs @@ -530,6 +530,18 @@ protected override void BuildModel(ModelBuilder modelBuilder) .HasColumnType("vector(1536)") .HasColumnName("embedding"); + b.Property("LastAccessedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_accessed_at") + .HasJsonPropertyName("last_accessed_at"); + + b.Property("RelevanceScore") + .ValueGeneratedOnAdd() + .HasColumnType("real") + .HasDefaultValue(1f) + .HasColumnName("relevance_score") + .HasJsonPropertyName("relevance_score"); + b.Property("UpdatedAt") .HasColumnType("timestamp with time zone") .HasColumnName("updated_at") diff --git a/Lis.Tests/Agent/MemoryExtractionTests.cs b/Lis.Tests/Agent/MemoryExtractionTests.cs new file mode 100644 index 0000000..440a6e6 --- /dev/null +++ b/Lis.Tests/Agent/MemoryExtractionTests.cs @@ -0,0 +1,295 @@ +using Lis.Agent; +using Lis.Persistence; +using Lis.Persistence.Entities; + +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Abstractions; + +using Moq; + +namespace Lis.Tests.Agent; + +public class MemoryExtractionTests : IDisposable { + private readonly LisDbContext _db; + private readonly Mock _chatClientMock; + + public MemoryExtractionTests() { + DbContextOptions options = new DbContextOptionsBuilder() + .UseInMemoryDatabase(databaseName: Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Ignore(InMemoryEventId.TransactionIgnoredWarning)) + .Options; + this._db = new TestDbContext(options); + this._chatClientMock = new Mock(); + } + + /// Ignores pgvector-specific model config that InMemory doesn't support. + private sealed class TestDbContext(DbContextOptions options) : LisDbContext(options) { + protected override void OnModelCreating(ModelBuilder modelBuilder) { + base.OnModelCreating(modelBuilder); + modelBuilder.Entity().Ignore(e => e.Embedding); + modelBuilder.Entity().Ignore(e => e.SummaryEmbedding); + } + } + + public void Dispose() { + this._db.Dispose(); + GC.SuppressFinalize(this); + } + + private MemoryExtractionService CreateService(IChatClient? chatClient = null) { + ServiceCollection services = new(); + services.AddSingleton(this._db); + services.AddScoped(sp => sp.GetRequiredService()); + ServiceProvider sp = services.BuildServiceProvider(); + + Mock scopeFactoryMock = new(); + Mock scopeMock = new(); + Mock spMock = new(); + spMock.Setup(x => x.GetService(typeof(LisDbContext))).Returns(this._db); + spMock.Setup(x => x.GetService(typeof(IEmbeddingGenerator>))).Returns((object?)null); + scopeMock.Setup(x => x.ServiceProvider).Returns(spMock.Object); + scopeFactoryMock.Setup(x => x.CreateScope()).Returns(scopeMock.Object); + + return new MemoryExtractionService( + chatClient ?? this._chatClientMock.Object, + scopeFactoryMock.Object, + NullLogger.Instance + ); + } + + private void SetupChatClientResponse(string responseText) { + ChatResponse response = new(new ChatMessage(ChatRole.Assistant, responseText)); + this._chatClientMock + .Setup(c => c.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(response); + } + + // ── Extraction tests ──────────────────────────────────────────── + + [Fact] + public async Task ExtractAsync_ValidJsonArray_CreatesMemories() { + this.SetupChatClientResponse("""[{"content": "Lucas likes coffee", "contact_name": "Lucas"}]"""); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: I really love coffee", "Assistant: Noted!"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Single(memories); + Assert.Equal("Lucas likes coffee", memories[0].Content); + Assert.Equal(1.0f, memories[0].RelevanceScore); + Assert.NotNull(memories[0].ContactId); + } + + [Fact] + public async Task ExtractAsync_MultipleMemories_CreatesAll() { + this.SetupChatClientResponse(""" + [ + {"content": "Lucas works at Acme"}, + {"content": "Meeting scheduled for Friday", "contact_name": "Bob"} + ] + """); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: I work at Acme, meeting with Bob on Friday"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Equal(2, memories.Count); + } + + [Fact] + public async Task ExtractAsync_EmptyArray_CreatesNoMemories() { + this.SetupChatClientResponse("[]"); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: hello", "Assistant: hi"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Empty(memories); + } + + [Fact] + public async Task ExtractAsync_MalformedJson_DoesNotCrash() { + this.SetupChatClientResponse("this is not json at all"); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: whatever"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Empty(memories); + } + + [Fact] + public async Task ExtractAsync_JsonInMarkdownFences_ParsesCorrectly() { + this.SetupChatClientResponse(""" + ```json + [{"content": "Prefers dark mode"}] + ``` + """); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: I prefer dark mode"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Single(memories); + Assert.Equal("Prefers dark mode", memories[0].Content); + } + + [Fact] + public async Task ExtractAsync_NullContent_SkipsEntry() { + this.SetupChatClientResponse("""[{"content": null}, {"content": "Valid memory"}]"""); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: test"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Single(memories); + Assert.Equal("Valid memory", memories[0].Content); + } + + [Fact] + public async Task ExtractAsync_EmptyContent_SkipsEntry() { + this.SetupChatClientResponse("""[{"content": ""}, {"content": " "}, {"content": "Valid"}]"""); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: test"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Single(memories); + Assert.Equal("Valid", memories[0].Content); + } + + [Fact] + public async Task ExtractAsync_LlmThrows_DoesNotCrash() { + this._chatClientMock + .Setup(c => c.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new HttpRequestException("API down")); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: test"]; + + // Should not throw + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Empty(memories); + } + + [Fact] + public async Task ExtractAsync_ContactCreatedWhenMissing() { + this.SetupChatClientResponse("""[{"content": "Ana likes tea", "contact_name": "Ana"}]"""); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: Ana told me she likes tea"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + ContactEntity? contact = await this._db.Contacts.FirstOrDefaultAsync(c => c.Name == "Ana"); + Assert.NotNull(contact); + + MemoryEntity memory = await this._db.Memories.Include(m => m.Contact).FirstAsync(); + Assert.Equal("Ana", memory.Contact!.Name); + } + + [Fact] + public async Task ExtractAsync_ExistingContactReused() { + this._db.Contacts.Add(new ContactEntity { + Name = "Bob", + CreatedAt = DateTimeOffset.UtcNow, + UpdatedAt = DateTimeOffset.UtcNow + }); + await this._db.SaveChangesAsync(); + + this.SetupChatClientResponse("""[{"content": "Bob runs daily", "contact_name": "Bob"}]"""); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: Bob told me he runs every day"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + Assert.Equal(1, await this._db.Contacts.CountAsync()); + MemoryEntity memory = await this._db.Memories.FirstAsync(); + Assert.NotNull(memory.ContactId); + } + + [Fact] + public async Task ExtractAsync_MoreThan5_CapsAt5() { + string json = """ + [ + {"content": "Fact 1"}, + {"content": "Fact 2"}, + {"content": "Fact 3"}, + {"content": "Fact 4"}, + {"content": "Fact 5"}, + {"content": "Fact 6"}, + {"content": "Fact 7"} + ] + """; + this.SetupChatClientResponse(json); + + MemoryExtractionService sut = this.CreateService(); + List messages = ["User: lots of facts"]; + + await sut.ExtractAsync(messages, CancellationToken.None); + + List memories = await this._db.Memories.ToListAsync(); + Assert.Equal(5, memories.Count); + } + + // ── Relevance decay tests ─────────────────────────────────────── + + [Fact] + public void CalculateRelevanceScore_JustAccessed_Returns1() { + float score = MemoryExtractionService.CalculateRelevanceScore(DateTimeOffset.UtcNow); + Assert.Equal(1.0f, score); + } + + [Fact] + public void CalculateRelevanceScore_15DaysAgo_ReturnsHalf() { + DateTimeOffset accessed = DateTimeOffset.UtcNow.AddDays(-15); + float score = MemoryExtractionService.CalculateRelevanceScore(accessed); + Assert.Equal(0.75f, score, 0.01f); + } + + [Fact] + public void CalculateRelevanceScore_30DaysAgo_Returns05() { + DateTimeOffset accessed = DateTimeOffset.UtcNow.AddDays(-30); + float score = MemoryExtractionService.CalculateRelevanceScore(accessed); + Assert.Equal(0.5f, score, 0.01f); + } + + [Fact] + public void CalculateRelevanceScore_60DaysAgo_ClampedToMin() { + DateTimeOffset accessed = DateTimeOffset.UtcNow.AddDays(-60); + float score = MemoryExtractionService.CalculateRelevanceScore(accessed); + Assert.Equal(0.1f, score, 0.01f); + } + + [Fact] + public void CalculateRelevanceScore_NeverAccessed_Returns1() { + float score = MemoryExtractionService.CalculateRelevanceScore(null); + Assert.Equal(1.0f, score); + } +} diff --git a/Lis.Tools/MemoryPlugin.cs b/Lis.Tools/MemoryPlugin.cs index e9f90a4..23b857c 100644 --- a/Lis.Tools/MemoryPlugin.cs +++ b/Lis.Tools/MemoryPlugin.cs @@ -32,11 +32,12 @@ public async Task CreateMemoryAsync( Vector? embedding = await GenerateEmbeddingAsync(scope.ServiceProvider, content); MemoryEntity memory = new() { - Content = content.Trim(), - ContactId = contactId, - Embedding = embedding, - CreatedAt = DateTimeOffset.UtcNow, - UpdatedAt = DateTimeOffset.UtcNow, + Content = content.Trim(), + ContactId = contactId, + Embedding = embedding, + RelevanceScore = 1.0f, + CreatedAt = DateTimeOffset.UtcNow, + UpdatedAt = DateTimeOffset.UtcNow, }; db.Memories.Add(memory); @@ -78,6 +79,13 @@ public async Task SearchMemoriesAsync( if (results.Count == 0) return "No memories found."; + // Update last_accessed_at for returned results + DateTimeOffset now = DateTimeOffset.UtcNow; + foreach (MemoryEntity mem in results) { + mem.LastAccessedAt = now; + } + await db.SaveChangesAsync(); + StringBuilder sb = new(); foreach (MemoryEntity mem in results) { string prefix = mem.Contact is not null ? $"[{mem.Contact.Name}] " : ""; @@ -174,8 +182,10 @@ private static async Task> VectorSearchAsync( if (contactId is not null) q = q.Where(m => m.ContactId == contactId); + // Apply relevance decay as sort boost: cosine_distance * (2 - relevance_score) + // Higher relevance_score → lower multiplier → ranked higher return await q - .OrderBy(m => m.Embedding!.CosineDistance(queryVector)) + .OrderBy(m => m.Embedding!.CosineDistance(queryVector) * (2.0 - m.RelevanceScore)) .Take(10) .ToListAsync(); } diff --git a/global.json b/global.json index 058bafa..816dd81 100644 --- a/global.json +++ b/global.json @@ -1,5 +1,6 @@ { "sdk": { - "version": "10.0.103" + "version": "10.0.103", + "rollForward": "latestMinor" } }