33using DatabaseBenchmark . Generators . Options ;
44using DatabaseBenchmark . Plugins . Interfaces ;
55using DatabaseBenchmark . Plugins . TextEmbedding ;
6+ using System . Collections . Concurrent ;
67
78namespace DatabaseBenchmark . Generators
89{
@@ -11,6 +12,8 @@ public class EmbeddingGenerator : IGenerator
1112 private readonly IGenerator _sourceGenerator ;
1213 private readonly ITextEmbeddingModel _embeddingModel ;
1314 private readonly int ? _dimensions ;
15+ private readonly bool _cache ;
16+ private readonly ConcurrentDictionary < string , float [ ] > _embeddingsCache ;
1417
1518 public object Current { get ; private set ; }
1619
@@ -30,6 +33,12 @@ public EmbeddingGenerator(EmbeddingGeneratorOptions options, IGenerator sourceGe
3033 //TODO: refactor resolution and interfaces when non-text embedding models are added
3134 _embeddingModel = pluginRepository . GetPlugin < ITextEmbeddingModel > ( options . ModelName , PluginType . TextEmbeddingModel ) ;
3235 _dimensions = options . Dimensions ;
36+ _cache = options . Cache ;
37+
38+ if ( _cache )
39+ {
40+ _embeddingsCache = new ConcurrentDictionary < string , float [ ] > ( ) ;
41+ }
3342 }
3443
3544 public bool Next ( )
@@ -45,8 +54,7 @@ public bool Next()
4554 {
4655 //TODO: refactor when non-text embedding models are added
4756 var text = ( string ) sourceValue ;
48- var embedding = _embeddingModel . GenerateEmbedding ( text , _dimensions ) ;
49- Current = embedding ;
57+ Current = GetOrComputeEmbedding ( text ) ;
5058 }
5159 else
5260 {
@@ -55,5 +63,10 @@ public bool Next()
5563
5664 return true ;
5765 }
66+
67+ private float [ ] GetOrComputeEmbedding ( string text ) =>
68+ _cache
69+ ? _embeddingsCache . GetOrAdd ( text , t => _embeddingModel . GenerateEmbedding ( t , _dimensions ) )
70+ : _embeddingModel . GenerateEmbedding ( text , _dimensions ) ;
5871 }
5972}
0 commit comments