|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.Globalization; |
| 4 | +using Microsoft.Data.Sqlite; |
| 5 | + |
| 6 | +namespace BitNetSharp.Distributed.Coordinator.Persistence; |
| 7 | + |
| 8 | +/// <summary> |
| 9 | +/// SQLite-backed data access layer for the coordinator's <c>workers</c> |
| 10 | +/// table. Opens and owns its own connection against the coordinator |
| 11 | +/// database file so it can coexist with |
| 12 | +/// <see cref="SqliteWorkQueueStore"/>; WAL mode + SQLite's internal |
| 13 | +/// locking keep concurrent writers safe. |
| 14 | +/// </summary> |
| 15 | +public sealed class SqliteWorkerRegistryStore : IDisposable |
| 16 | +{ |
| 17 | + private readonly SqliteConnection _connection; |
| 18 | + private readonly TimeProvider _time; |
| 19 | + private readonly object _writeGate = new(); |
| 20 | + |
| 21 | + public SqliteWorkerRegistryStore(string connectionString, TimeProvider? time = null) |
| 22 | + { |
| 23 | + ArgumentException.ThrowIfNullOrWhiteSpace(connectionString); |
| 24 | + _time = time ?? TimeProvider.System; |
| 25 | + _connection = new SqliteConnection(connectionString); |
| 26 | + _connection.Open(); |
| 27 | + |
| 28 | + ExecuteNonQuery("PRAGMA journal_mode = WAL;"); |
| 29 | + ExecuteNonQuery("PRAGMA synchronous = NORMAL;"); |
| 30 | + ExecuteNonQuery("PRAGMA busy_timeout = 5000;"); |
| 31 | + ExecuteNonQuery("PRAGMA foreign_keys = ON;"); |
| 32 | + |
| 33 | + MigrateSchema(); |
| 34 | + } |
| 35 | + |
| 36 | + private void MigrateSchema() |
| 37 | + { |
| 38 | + ExecuteNonQuery(@" |
| 39 | +CREATE TABLE IF NOT EXISTS workers ( |
| 40 | + worker_id TEXT PRIMARY KEY, |
| 41 | + name TEXT NOT NULL, |
| 42 | + bearer_token_hash TEXT NOT NULL UNIQUE, |
| 43 | + cpu_threads INTEGER NOT NULL, |
| 44 | + tokens_per_sec REAL NOT NULL, |
| 45 | + recommended_tokens_per_task INTEGER NOT NULL, |
| 46 | + process_architecture TEXT, |
| 47 | + os_description TEXT, |
| 48 | + registered_at INTEGER NOT NULL, |
| 49 | + last_heartbeat INTEGER NOT NULL, |
| 50 | + state TEXT NOT NULL |
| 51 | +); |
| 52 | +
|
| 53 | +CREATE INDEX IF NOT EXISTS ix_workers_bearer_hash |
| 54 | + ON workers(bearer_token_hash); |
| 55 | +
|
| 56 | +CREATE INDEX IF NOT EXISTS ix_workers_heartbeat |
| 57 | + ON workers(last_heartbeat); |
| 58 | +
|
| 59 | +CREATE INDEX IF NOT EXISTS ix_workers_state |
| 60 | + ON workers(state); |
| 61 | +"); |
| 62 | + } |
| 63 | + |
| 64 | + /// <summary> |
| 65 | + /// Inserts a newly registered worker. Throws |
| 66 | + /// <see cref="InvalidOperationException"/> if the worker id already |
| 67 | + /// exists — callers must generate a fresh opaque id per registration |
| 68 | + /// so idempotent re-registration is always a new row. |
| 69 | + /// </summary> |
| 70 | + public void Insert(WorkerRecord worker) |
| 71 | + { |
| 72 | + ArgumentNullException.ThrowIfNull(worker); |
| 73 | + |
| 74 | + lock (_writeGate) |
| 75 | + { |
| 76 | + using var cmd = _connection.CreateCommand(); |
| 77 | + cmd.CommandText = @" |
| 78 | +INSERT INTO workers ( |
| 79 | + worker_id, name, bearer_token_hash, cpu_threads, tokens_per_sec, |
| 80 | + recommended_tokens_per_task, process_architecture, os_description, |
| 81 | + registered_at, last_heartbeat, state |
| 82 | +) VALUES ( |
| 83 | + $worker_id, $name, $bearer_token_hash, $cpu_threads, $tokens_per_sec, |
| 84 | + $recommended_tokens_per_task, $process_architecture, $os_description, |
| 85 | + $registered_at, $last_heartbeat, $state |
| 86 | +);"; |
| 87 | + cmd.Parameters.AddWithValue("$worker_id", worker.WorkerId); |
| 88 | + cmd.Parameters.AddWithValue("$name", worker.Name); |
| 89 | + cmd.Parameters.AddWithValue("$bearer_token_hash", worker.BearerTokenHash); |
| 90 | + cmd.Parameters.AddWithValue("$cpu_threads", worker.CpuThreads); |
| 91 | + cmd.Parameters.AddWithValue("$tokens_per_sec", worker.TokensPerSecond); |
| 92 | + cmd.Parameters.AddWithValue("$recommended_tokens_per_task", worker.RecommendedTokensPerTask); |
| 93 | + cmd.Parameters.AddWithValue("$process_architecture", (object?)worker.ProcessArchitecture ?? DBNull.Value); |
| 94 | + cmd.Parameters.AddWithValue("$os_description", (object?)worker.OsDescription ?? DBNull.Value); |
| 95 | + cmd.Parameters.AddWithValue("$registered_at", worker.RegisteredAtUtc.ToUnixTimeSeconds()); |
| 96 | + cmd.Parameters.AddWithValue("$last_heartbeat", worker.LastHeartbeatUtc.ToUnixTimeSeconds()); |
| 97 | + cmd.Parameters.AddWithValue("$state", worker.State.ToString()); |
| 98 | + try |
| 99 | + { |
| 100 | + cmd.ExecuteNonQuery(); |
| 101 | + } |
| 102 | + catch (SqliteException ex) when (ex.SqliteErrorCode == 19 /* SQLITE_CONSTRAINT */) |
| 103 | + { |
| 104 | + throw new InvalidOperationException( |
| 105 | + $"Worker '{worker.WorkerId}' or its bearer-token hash is already registered.", |
| 106 | + ex); |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + /// <summary> |
| 112 | + /// Finds a worker by its opaque id. Returns <c>null</c> if no such |
| 113 | + /// worker exists. |
| 114 | + /// </summary> |
| 115 | + public WorkerRecord? FindById(string workerId) |
| 116 | + { |
| 117 | + ArgumentException.ThrowIfNullOrWhiteSpace(workerId); |
| 118 | + return Load("worker_id = $id", ("$id", workerId)); |
| 119 | + } |
| 120 | + |
| 121 | + /// <summary> |
| 122 | + /// Finds a worker by the SHA-256 hash of its bearer token. Used by |
| 123 | + /// the bearer-auth middleware on every authenticated request. |
| 124 | + /// </summary> |
| 125 | + public WorkerRecord? FindByBearerTokenHash(string bearerTokenHash) |
| 126 | + { |
| 127 | + ArgumentException.ThrowIfNullOrWhiteSpace(bearerTokenHash); |
| 128 | + return Load("bearer_token_hash = $hash", ("$hash", bearerTokenHash)); |
| 129 | + } |
| 130 | + |
| 131 | + /// <summary> |
| 132 | + /// Updates the worker's last-heartbeat timestamp to "now" from the |
| 133 | + /// injected <see cref="TimeProvider"/>. Returns <c>true</c> if a row |
| 134 | + /// was updated, <c>false</c> if the worker id does not exist. |
| 135 | + /// </summary> |
| 136 | + public bool TouchHeartbeat(string workerId) |
| 137 | + { |
| 138 | + ArgumentException.ThrowIfNullOrWhiteSpace(workerId); |
| 139 | + |
| 140 | + lock (_writeGate) |
| 141 | + { |
| 142 | + using var cmd = _connection.CreateCommand(); |
| 143 | + cmd.CommandText = @" |
| 144 | +UPDATE workers |
| 145 | +SET last_heartbeat = $now |
| 146 | +WHERE worker_id = $id;"; |
| 147 | + cmd.Parameters.AddWithValue("$now", _time.GetUtcNow().ToUnixTimeSeconds()); |
| 148 | + cmd.Parameters.AddWithValue("$id", workerId); |
| 149 | + return cmd.ExecuteNonQuery() == 1; |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + /// <summary> |
| 154 | + /// Transitions a worker to <see cref="WorkerState.Draining"/> so |
| 155 | + /// the coordinator stops assigning it new work. Safe to re-run. |
| 156 | + /// </summary> |
| 157 | + public bool MarkDraining(string workerId) => UpdateState(workerId, WorkerState.Draining); |
| 158 | + |
| 159 | + /// <summary> |
| 160 | + /// Transitions a worker to <see cref="WorkerState.Gone"/>. Invoked |
| 161 | + /// when heartbeats go silent past the deadline or when a worker |
| 162 | + /// explicitly deregisters. |
| 163 | + /// </summary> |
| 164 | + public bool MarkGone(string workerId) => UpdateState(workerId, WorkerState.Gone); |
| 165 | + |
| 166 | + /// <summary> |
| 167 | + /// Counts workers currently in the given lifecycle state. Used by |
| 168 | + /// the <c>/status</c> dashboard. |
| 169 | + /// </summary> |
| 170 | + public int CountByState(WorkerState state) |
| 171 | + { |
| 172 | + using var cmd = _connection.CreateCommand(); |
| 173 | + cmd.CommandText = "SELECT COUNT(1) FROM workers WHERE state = $state;"; |
| 174 | + cmd.Parameters.AddWithValue("$state", state.ToString()); |
| 175 | + var result = cmd.ExecuteScalar(); |
| 176 | + return result is null or DBNull ? 0 : Convert.ToInt32(result, CultureInfo.InvariantCulture); |
| 177 | + } |
| 178 | + |
| 179 | + /// <summary> |
| 180 | + /// Finds every <see cref="WorkerState.Active"/> worker whose last |
| 181 | + /// heartbeat is older than <paramref name="staleAfter"/> relative to |
| 182 | + /// the injected clock and transitions them to |
| 183 | + /// <see cref="WorkerState.Gone"/>. Returns the count of transitions. |
| 184 | + /// </summary> |
| 185 | + public int SweepStaleWorkers(TimeSpan staleAfter) |
| 186 | + { |
| 187 | + if (staleAfter <= TimeSpan.Zero) |
| 188 | + { |
| 189 | + throw new ArgumentOutOfRangeException(nameof(staleAfter), "Stale threshold must be positive."); |
| 190 | + } |
| 191 | + |
| 192 | + lock (_writeGate) |
| 193 | + { |
| 194 | + using var cmd = _connection.CreateCommand(); |
| 195 | + cmd.CommandText = @" |
| 196 | +UPDATE workers |
| 197 | +SET state = 'Gone' |
| 198 | +WHERE state = 'Active' |
| 199 | + AND last_heartbeat < $cutoff;"; |
| 200 | + var cutoff = _time.GetUtcNow().Subtract(staleAfter).ToUnixTimeSeconds(); |
| 201 | + cmd.Parameters.AddWithValue("$cutoff", cutoff); |
| 202 | + return cmd.ExecuteNonQuery(); |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + /// <summary> |
| 207 | + /// Enumerates every row in the <c>workers</c> table in registration |
| 208 | + /// order. Kept simple for the v1 status dashboard; larger fleets |
| 209 | + /// will add pagination later. |
| 210 | + /// </summary> |
| 211 | + public IReadOnlyList<WorkerRecord> ListAll() |
| 212 | + { |
| 213 | + var results = new List<WorkerRecord>(); |
| 214 | + using var cmd = _connection.CreateCommand(); |
| 215 | + cmd.CommandText = @" |
| 216 | +SELECT worker_id, name, bearer_token_hash, cpu_threads, tokens_per_sec, |
| 217 | + recommended_tokens_per_task, process_architecture, os_description, |
| 218 | + registered_at, last_heartbeat, state |
| 219 | +FROM workers |
| 220 | +ORDER BY registered_at ASC, worker_id ASC;"; |
| 221 | + using var reader = cmd.ExecuteReader(); |
| 222 | + while (reader.Read()) |
| 223 | + { |
| 224 | + results.Add(MapRow(reader)); |
| 225 | + } |
| 226 | + |
| 227 | + return results; |
| 228 | + } |
| 229 | + |
| 230 | + private bool UpdateState(string workerId, WorkerState state) |
| 231 | + { |
| 232 | + ArgumentException.ThrowIfNullOrWhiteSpace(workerId); |
| 233 | + |
| 234 | + lock (_writeGate) |
| 235 | + { |
| 236 | + using var cmd = _connection.CreateCommand(); |
| 237 | + cmd.CommandText = "UPDATE workers SET state = $state WHERE worker_id = $id;"; |
| 238 | + cmd.Parameters.AddWithValue("$state", state.ToString()); |
| 239 | + cmd.Parameters.AddWithValue("$id", workerId); |
| 240 | + return cmd.ExecuteNonQuery() == 1; |
| 241 | + } |
| 242 | + } |
| 243 | + |
| 244 | + private WorkerRecord? Load(string whereClause, params (string name, object value)[] parameters) |
| 245 | + { |
| 246 | + using var cmd = _connection.CreateCommand(); |
| 247 | + cmd.CommandText = $@" |
| 248 | +SELECT worker_id, name, bearer_token_hash, cpu_threads, tokens_per_sec, |
| 249 | + recommended_tokens_per_task, process_architecture, os_description, |
| 250 | + registered_at, last_heartbeat, state |
| 251 | +FROM workers WHERE {whereClause};"; |
| 252 | + foreach (var (name, value) in parameters) |
| 253 | + { |
| 254 | + cmd.Parameters.AddWithValue(name, value); |
| 255 | + } |
| 256 | + |
| 257 | + using var reader = cmd.ExecuteReader(); |
| 258 | + return reader.Read() ? MapRow(reader) : null; |
| 259 | + } |
| 260 | + |
| 261 | + private static WorkerRecord MapRow(SqliteDataReader reader) |
| 262 | + { |
| 263 | + return new WorkerRecord( |
| 264 | + WorkerId: reader.GetString(0), |
| 265 | + Name: reader.GetString(1), |
| 266 | + BearerTokenHash: reader.GetString(2), |
| 267 | + CpuThreads: reader.GetInt32(3), |
| 268 | + TokensPerSecond: reader.GetDouble(4), |
| 269 | + RecommendedTokensPerTask: reader.GetInt64(5), |
| 270 | + ProcessArchitecture: reader.IsDBNull(6) ? null : reader.GetString(6), |
| 271 | + OsDescription: reader.IsDBNull(7) ? null : reader.GetString(7), |
| 272 | + RegisteredAtUtc: DateTimeOffset.FromUnixTimeSeconds(reader.GetInt64(8)), |
| 273 | + LastHeartbeatUtc: DateTimeOffset.FromUnixTimeSeconds(reader.GetInt64(9)), |
| 274 | + State: Enum.Parse<WorkerState>(reader.GetString(10))); |
| 275 | + } |
| 276 | + |
| 277 | + private void ExecuteNonQuery(string sql) |
| 278 | + { |
| 279 | + using var cmd = _connection.CreateCommand(); |
| 280 | + cmd.CommandText = sql; |
| 281 | + cmd.ExecuteNonQuery(); |
| 282 | + } |
| 283 | + |
| 284 | + public void Dispose() |
| 285 | + { |
| 286 | + _connection.Dispose(); |
| 287 | + } |
| 288 | +} |
0 commit comments