Skip to content

Commit f961f3e

Browse files
sharpninjaclaude
andcommitted
Add /work /heartbeat /gradient /weights endpoints + weight store
Phase D-1 step 2d. Lights up the rest of the worker-facing REST surface the distributed training loop needs, plus the filesystem weight-versioned blob store that backs /weights. FileSystemWeightStore: - Flat directory with v{version:D10}.bin + v{version:D10}.bin.sha256 sidecar files. SaveVersion writes atomically through a .tmp staging file + rename so a crash mid-write cannot expose a half-written blob. TryGetManifest + TryOpenReadStream serve workers without any in-memory caching. ListVersions returns versions in ascending numeric order; GetLatestVersion returns the highest. Refuses to overwrite an existing version — coordinator weight versions are immutable by contract. - 9 xunit tests locking empty-store behavior, save + round trip, overwrite rejection, manifest consistency, stream contents, ordering, and auto-directory-creation. Worker API endpoints (all JWT-guarded by BitNetWorkerPolicy): - GET /work Reads client_id from JWT, calls SqliteWorkQueueStore .TryClaimNextPending with a lease = 2x the target task duration (safety margin over the 10-minute goal), and maps the WorkTaskRecord to a WorkTaskAssignment DTO with a fully-qualified {BaseUrl}/weights/{version} WeightUrl so the worker knows where to download the weight blob. Returns 204 No Content when the queue is empty so workers can long-poll without allocating an object per empty poll. - POST /heartbeat Validates body + client_id, calls SqliteWorkerRegistryStore.TouchHeartbeat. Returns HeartbeatResponse with ShouldDrain=false for now; the coordinator dashboard will set it later when the operator requests a rolling drain. Returns 410 Gone for an unknown worker so the worker knows to re-/register. - POST /gradient D-1 stub: validates that the submission's worker_id matches the JWT client_id, then calls SqliteWorkQueueStore .MarkCompleted to flip the task Done. Logs the gradient format + payload size + tokens seen + loss + staleness for visibility. Does NOT yet decode the gradient or apply it to the global weight copy — that lands in the D-4 commit that introduces the int8-ef gradient decoder. - GET /weights/{version:long} Uses FileSystemWeightStore.TryOpenReadStream to stream the blob with Results.File + enableRangeProcessing so workers over ngrok can resume partial downloads. Program.cs wires FileSystemWeightStore as a DI singleton whose root directory is derived from the database path (sibling `weights/` directory). Fast-lane regression 232/232 passing on net10 slice; no new tests for the endpoints yet — they need a TestServer harness that covers the full JWT issuance flow end-to-end, which lands together with the Worker HttpClient wire-up. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9371dc6 commit f961f3e

3 files changed

Lines changed: 516 additions & 1 deletion

File tree

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Security.Cryptography;
6+
using System.Text;
7+
8+
namespace BitNetSharp.Distributed.Coordinator.Persistence;
9+
10+
/// <summary>
11+
/// Filesystem-backed versioned blob store for the coordinator's
12+
/// global training weights. The work queue stores
13+
/// <c>weight_version</c> integers only; actual ternary-packed model
14+
/// bytes live on disk under a flat directory so workers can stream
15+
/// them via <c>GET /weights/{version}</c> with range support.
16+
///
17+
/// <para>
18+
/// Layout under <c>{rootDirectory}</c>:
19+
/// <code>
20+
/// v0000000001.bin
21+
/// v0000000001.sha256
22+
/// v0000000002.bin
23+
/// v0000000002.sha256
24+
/// ...
25+
/// </code>
26+
/// Each <c>.bin</c> file is the immutable blob payload; the matching
27+
/// <c>.sha256</c> file holds the hex digest the coordinator computed
28+
/// at save time so later readers can verify integrity without
29+
/// rehashing.
30+
/// </para>
31+
/// </summary>
32+
public sealed class FileSystemWeightStore
33+
{
34+
private readonly string _rootDirectory;
35+
private readonly object _writeGate = new();
36+
37+
public FileSystemWeightStore(string rootDirectory)
38+
{
39+
ArgumentException.ThrowIfNullOrWhiteSpace(rootDirectory);
40+
_rootDirectory = rootDirectory;
41+
Directory.CreateDirectory(_rootDirectory);
42+
}
43+
44+
/// <summary>
45+
/// Persists the given blob as the contents of the specified
46+
/// weight version. Overwriting an existing version is disallowed
47+
/// — the whole point of versioning is that each
48+
/// <c>weight_version</c> is immutable.
49+
/// </summary>
50+
public WeightVersionManifest SaveVersion(long version, ReadOnlySpan<byte> payload)
51+
{
52+
if (version <= 0)
53+
{
54+
throw new ArgumentOutOfRangeException(nameof(version), "Weight version must be positive.");
55+
}
56+
57+
var binPath = PathForVersion(version);
58+
var shaPath = binPath + ".sha256";
59+
60+
lock (_writeGate)
61+
{
62+
if (File.Exists(binPath))
63+
{
64+
throw new InvalidOperationException(
65+
$"Weight version {version} already exists at {binPath}. Coordinator weight versions are immutable.");
66+
}
67+
68+
// Write payload atomically via a staging file then rename
69+
// so a crash mid-write cannot leave a half-written version
70+
// visible to /weights/{version} readers.
71+
var stagingPath = binPath + ".tmp";
72+
try
73+
{
74+
using (var stream = File.Create(stagingPath))
75+
{
76+
stream.Write(payload);
77+
}
78+
79+
File.Move(stagingPath, binPath);
80+
81+
var hash = Sha256Hex(payload);
82+
File.WriteAllText(shaPath, hash, Encoding.ASCII);
83+
return new WeightVersionManifest(version, binPath, payload.Length, hash);
84+
}
85+
catch
86+
{
87+
if (File.Exists(stagingPath))
88+
{
89+
File.Delete(stagingPath);
90+
}
91+
92+
throw;
93+
}
94+
}
95+
}
96+
97+
/// <summary>
98+
/// Returns the manifest (path, size, hash) for the requested
99+
/// version if it exists on disk, or <c>null</c> otherwise.
100+
/// </summary>
101+
public WeightVersionManifest? TryGetManifest(long version)
102+
{
103+
if (version <= 0)
104+
{
105+
return null;
106+
}
107+
108+
var binPath = PathForVersion(version);
109+
var shaPath = binPath + ".sha256";
110+
if (!File.Exists(binPath) || !File.Exists(shaPath))
111+
{
112+
return null;
113+
}
114+
115+
var length = new FileInfo(binPath).Length;
116+
var hash = File.ReadAllText(shaPath, Encoding.ASCII).Trim();
117+
return new WeightVersionManifest(version, binPath, length, hash);
118+
}
119+
120+
/// <summary>
121+
/// Opens a read-only stream over the requested version's blob,
122+
/// or <c>null</c> if the version does not exist on disk. Caller
123+
/// owns the returned stream and must dispose it.
124+
/// </summary>
125+
public Stream? TryOpenReadStream(long version)
126+
{
127+
var manifest = TryGetManifest(version);
128+
if (manifest is null)
129+
{
130+
return null;
131+
}
132+
133+
return new FileStream(
134+
manifest.PhysicalPath,
135+
FileMode.Open,
136+
FileAccess.Read,
137+
FileShare.Read,
138+
bufferSize: 64 * 1024,
139+
useAsync: true);
140+
}
141+
142+
/// <summary>
143+
/// Enumerates every weight version currently on disk in ascending
144+
/// numeric order. Handy for the /status dashboard.
145+
/// </summary>
146+
public IReadOnlyList<long> ListVersions()
147+
{
148+
return Directory
149+
.EnumerateFiles(_rootDirectory, "v*.bin")
150+
.Select(path => Path.GetFileNameWithoutExtension(path))
151+
.Select(name => long.TryParse(name.TrimStart('v'), out var parsed) ? parsed : 0L)
152+
.Where(v => v > 0)
153+
.OrderBy(v => v)
154+
.ToList();
155+
}
156+
157+
/// <summary>
158+
/// Returns the highest-numbered version currently on disk, or
159+
/// <c>null</c> if the store is empty.
160+
/// </summary>
161+
public long? GetLatestVersion()
162+
{
163+
var versions = ListVersions();
164+
return versions.Count == 0 ? null : versions[^1];
165+
}
166+
167+
/// <summary>
168+
/// Computes the filesystem path for the given version using a
169+
/// 10-digit zero-padded name so alphabetical enumeration matches
170+
/// numerical order for the first ten billion versions.
171+
/// </summary>
172+
private string PathForVersion(long version) =>
173+
Path.Combine(_rootDirectory, $"v{version:D10}.bin");
174+
175+
private static string Sha256Hex(ReadOnlySpan<byte> payload)
176+
{
177+
Span<byte> hash = stackalloc byte[32];
178+
SHA256.HashData(payload, hash);
179+
var sb = new StringBuilder(hash.Length * 2);
180+
foreach (var b in hash)
181+
{
182+
sb.Append(b.ToString("x2", System.Globalization.CultureInfo.InvariantCulture));
183+
}
184+
return sb.ToString();
185+
}
186+
}
187+
188+
/// <summary>
189+
/// Summary of a weight version on disk. Returned by
190+
/// <see cref="FileSystemWeightStore.SaveVersion"/> and
191+
/// <see cref="FileSystemWeightStore.TryGetManifest"/>.
192+
/// </summary>
193+
public sealed record WeightVersionManifest(
194+
long Version,
195+
string PhysicalPath,
196+
long SizeBytes,
197+
string Sha256Hex);

src/BitNetSharp.Distributed.Coordinator/Program.cs

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using System.Linq;
45
using System.Security.Claims;
56
using BitNetSharp.Distributed.Contracts;
@@ -79,6 +80,14 @@
7980
var time = sp.GetRequiredService<TimeProvider>();
8081
return new SqliteClientRevocationStore(BuildConnectionString(coord), time);
8182
});
83+
builder.Services.AddSingleton(sp =>
84+
{
85+
var coord = sp.GetRequiredService<CoordinatorOptions>();
86+
var weightsDir = Path.Combine(
87+
Path.GetDirectoryName(Path.GetFullPath(coord.DatabasePath)) ?? ".",
88+
"weights");
89+
return new FileSystemWeightStore(weightsDir);
90+
});
8291

8392
// ── Worker client registry + Duende IdentityServer ────────────────
8493
var workerRegistry = new WorkerClientRegistry();
@@ -200,10 +209,11 @@
200209

201210
var app = builder.Build();
202211

203-
// Ensure all three stores create their schema on startup.
212+
// Ensure all stores create their schema / directories on startup.
204213
_ = app.Services.GetRequiredService<SqliteWorkQueueStore>();
205214
_ = app.Services.GetRequiredService<SqliteWorkerRegistryStore>();
206215
_ = app.Services.GetRequiredService<SqliteClientRevocationStore>();
216+
_ = app.Services.GetRequiredService<FileSystemWeightStore>();
207217

208218
app.UseAuthentication();
209219
app.UseAuthorization();
@@ -295,6 +305,159 @@
295305
return Results.Ok(response);
296306
}).RequireAuthorization(IdentityServerResources.WorkerPolicyName);
297307

308+
// ── /work — claim the next pending task for this worker ──────────
309+
app.MapGet("/work", (
310+
HttpContext http,
311+
SqliteWorkQueueStore workQueue,
312+
CoordinatorOptions options) =>
313+
{
314+
var clientId = http.User.FindFirst("client_id")?.Value;
315+
if (string.IsNullOrWhiteSpace(clientId))
316+
{
317+
return Results.Json(
318+
new ErrorResponse("unknown_client", "JWT did not carry a client_id claim."),
319+
statusCode: StatusCodes.Status401Unauthorized);
320+
}
321+
322+
var leaseDuration = TimeSpan.FromSeconds(options.TargetTaskDurationSeconds * 2);
323+
var claimed = workQueue.TryClaimNextPending(clientId, leaseDuration);
324+
if (claimed is null)
325+
{
326+
return Results.StatusCode(StatusCodes.Status204NoContent);
327+
}
328+
329+
var baseUrl = options.BaseUrl.TrimEnd('/');
330+
var assignment = new WorkTaskAssignment(
331+
TaskId: claimed.TaskId,
332+
WeightVersion: claimed.WeightVersion,
333+
WeightUrl: $"{baseUrl}/weights/{claimed.WeightVersion}",
334+
ShardId: claimed.ShardId,
335+
ShardOffset: claimed.ShardOffset,
336+
ShardLength: claimed.ShardLength,
337+
TokensPerTask: claimed.TokensPerTask,
338+
KLocalSteps: claimed.KLocalSteps,
339+
HyperparametersJson: claimed.HyperparametersJson,
340+
DeadlineUtc: claimed.DeadlineUtc ?? DateTimeOffset.UtcNow.Add(leaseDuration));
341+
342+
return Results.Ok(assignment);
343+
}).RequireAuthorization(IdentityServerResources.WorkerPolicyName);
344+
345+
// ── /heartbeat — worker pings the coordinator periodically ───────
346+
app.MapPost("/heartbeat", (
347+
[FromBody] HeartbeatRequest request,
348+
HttpContext http,
349+
SqliteWorkerRegistryStore workerStore,
350+
CoordinatorOptions options,
351+
TimeProvider time) =>
352+
{
353+
if (request is null)
354+
{
355+
return Results.Json(
356+
new ErrorResponse("invalid_request", "Heartbeat body is missing."),
357+
statusCode: StatusCodes.Status400BadRequest);
358+
}
359+
360+
var clientId = http.User.FindFirst("client_id")?.Value;
361+
if (string.IsNullOrWhiteSpace(clientId))
362+
{
363+
return Results.Json(
364+
new ErrorResponse("unknown_client", "JWT did not carry a client_id claim."),
365+
statusCode: StatusCodes.Status401Unauthorized);
366+
}
367+
368+
var touched = workerStore.TouchHeartbeat(clientId);
369+
if (!touched)
370+
{
371+
return Results.Json(
372+
new ErrorResponse("unregistered", "Worker must POST /register before heartbeating."),
373+
statusCode: StatusCodes.Status410Gone);
374+
}
375+
376+
return Results.Ok(new HeartbeatResponse(
377+
ShouldDrain: false,
378+
RecommendedTokensPerTaskOverride: null,
379+
ServerTime: time.GetUtcNow()));
380+
}).RequireAuthorization(IdentityServerResources.WorkerPolicyName);
381+
382+
// ── /gradient — worker reports task completion ──────────────────
383+
// D-1 stub: validates ownership, marks the task Done, does NOT yet
384+
// apply the gradient to the global weights. Phase D-4 introduces the
385+
// gradient decoder + weight updater.
386+
app.MapPost("/gradient", (
387+
[FromBody] GradientSubmission submission,
388+
HttpContext http,
389+
SqliteWorkQueueStore workQueue,
390+
ILogger<Program> logger) =>
391+
{
392+
if (submission is null)
393+
{
394+
return Results.Json(
395+
new ErrorResponse("invalid_request", "Gradient body is missing."),
396+
statusCode: StatusCodes.Status400BadRequest);
397+
}
398+
399+
var clientId = http.User.FindFirst("client_id")?.Value;
400+
if (string.IsNullOrWhiteSpace(clientId) || clientId != submission.WorkerId)
401+
{
402+
return Results.Json(
403+
new ErrorResponse("worker_mismatch", "Gradient workerId must match the JWT client_id."),
404+
statusCode: StatusCodes.Status403Forbidden);
405+
}
406+
407+
var completed = workQueue.MarkCompleted(submission.TaskId, clientId);
408+
if (!completed)
409+
{
410+
return Results.Json(
411+
new ErrorResponse("task_not_assigned", "Task is not currently assigned to this worker."),
412+
statusCode: StatusCodes.Status409Conflict);
413+
}
414+
415+
logger.LogInformation(
416+
"Accepted gradient for task {TaskId} from worker {ClientId}: format={Format}, bytes={Size}, tokens={Tokens}, loss={Loss}, staleness={Staleness}",
417+
submission.TaskId,
418+
clientId,
419+
submission.GradientFormat,
420+
submission.GradientPayload?.Length ?? 0,
421+
submission.TokensSeen,
422+
submission.LossAfter,
423+
0);
424+
425+
return Results.Ok(new
426+
{
427+
accepted = true,
428+
task_id = submission.TaskId,
429+
worker_id = clientId
430+
});
431+
}).RequireAuthorization(IdentityServerResources.WorkerPolicyName);
432+
433+
// ── /weights/{version} — streams a weight blob to the worker ─────
434+
app.MapGet("/weights/{version:long}", (
435+
long version,
436+
FileSystemWeightStore weights) =>
437+
{
438+
var manifest = weights.TryGetManifest(version);
439+
if (manifest is null)
440+
{
441+
return Results.Json(
442+
new ErrorResponse("unknown_version", $"Weight version {version} is not available."),
443+
statusCode: StatusCodes.Status404NotFound);
444+
}
445+
446+
var stream = weights.TryOpenReadStream(version);
447+
if (stream is null)
448+
{
449+
return Results.Json(
450+
new ErrorResponse("unknown_version", $"Weight version {version} is not available."),
451+
statusCode: StatusCodes.Status404NotFound);
452+
}
453+
454+
return Results.File(
455+
fileStream: stream,
456+
contentType: "application/octet-stream",
457+
fileDownloadName: $"bitnet-weights-v{version}.bin",
458+
enableRangeProcessing: true);
459+
}).RequireAuthorization(IdentityServerResources.WorkerPolicyName);
460+
298461
// ── Admin Blazor UI (cookie + OIDC) ───────────────────────────────
299462
app.MapRazorComponents<App>();
300463

0 commit comments

Comments
 (0)