Skip to content

Commit e04fd72

Browse files
committed
Added an option to cache embedding calculation results for the Embedding Generator
1 parent 168043a commit e04fd72

2 files changed

Lines changed: 17 additions & 2 deletions

File tree

src/DatabaseBenchmark/Generators/EmbeddingGenerator.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using DatabaseBenchmark.Generators.Options;
44
using DatabaseBenchmark.Plugins.Interfaces;
55
using DatabaseBenchmark.Plugins.TextEmbedding;
6+
using System.Collections.Concurrent;
67

78
namespace DatabaseBenchmark.Generators
89
{
@@ -11,6 +12,8 @@ public class EmbeddingGenerator : IGenerator
1112
private readonly IGenerator _sourceGenerator;
1213
private readonly ITextEmbeddingModel _embeddingModel;
1314
private readonly int? _dimensions;
15+
private readonly bool _cache;
16+
private readonly ConcurrentDictionary<string, float[]> _embeddingsCache;
1417

1518
public object Current { get; private set; }
1619

@@ -30,6 +33,12 @@ public EmbeddingGenerator(EmbeddingGeneratorOptions options, IGenerator sourceGe
3033
//TODO: refactor resolution and interfaces when non-text embedding models are added
3134
_embeddingModel = pluginRepository.GetPlugin<ITextEmbeddingModel>(options.ModelName, PluginType.TextEmbeddingModel);
3235
_dimensions = options.Dimensions;
36+
_cache = options.Cache;
37+
38+
if (_cache)
39+
{
40+
_embeddingsCache = new ConcurrentDictionary<string, float[]>();
41+
}
3342
}
3443

3544
public bool Next()
@@ -45,8 +54,7 @@ public bool Next()
4554
{
4655
//TODO: refactor when non-text embedding models are added
4756
var text = (string)sourceValue;
48-
var embedding = _embeddingModel.GenerateEmbedding(text, _dimensions);
49-
Current = embedding;
57+
Current = GetOrComputeEmbedding(text);
5058
}
5159
else
5260
{
@@ -55,5 +63,10 @@ public bool Next()
5563

5664
return true;
5765
}
66+
67+
private float[] GetOrComputeEmbedding(string text) =>
68+
_cache
69+
? _embeddingsCache.GetOrAdd(text, t => _embeddingModel.GenerateEmbedding(t, _dimensions))
70+
: _embeddingModel.GenerateEmbedding(text, _dimensions);
5871
}
5972
}

src/DatabaseBenchmark/Generators/Options/EmbeddingGeneratorOptions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ public class EmbeddingGeneratorOptions : GeneratorOptionsBase
1414

1515
public int? Dimensions { get; set; }
1616

17+
public bool Cache { get; set; } = false;
18+
1719
public IGeneratorOptions SourceGeneratorOptions { get; set; }
1820

1921
public enum GeneratorKind

0 commit comments

Comments
 (0)