Adds image_features parameter to predict for pre-computed embeddings#169
Adds image_features parameter to predict for pre-computed embeddings#169NetZissou wants to merge 1 commit intoImageomics:mainfrom
image_features parameter to predict for pre-computed embeddings#169Conversation
Allows passing pre-computed image embeddings directly to predict() on TreeOfLifeClassifier, CustomLabelsClassifier, and CustomLabelsBinningClassifier, avoiding redundant image encoding when embeddings are already available. Validates input: checks tensor is 2D, embedding_dim matches the model's expected dimension (model.visual.output_dim), and normalizes via L2 norm only if not already normalized to avoid floating point drift. Closes Imageomics#167 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
hlapp
left a comment
There was a problem hiding this comment.
Thanks @NetZissou. The part of the implementation approach that I don't like here is that now creating probabilities is taking place redundantly in two different functions. This also creates more code noise than I think should be needed in the predict() method.
Instead, shouldn't the clean way to handle this in the predict() method be to see whether image_features are already provided. If they are, apply basic checks like correct dimensions etc. If they are not, create them (like the are being created now from images). Then proceed with creating probabilities from image_features.
|
@NetZissou just FYI, it might be advisable to rebase on main to bring in the changes from #179. It's well possible your changes so far are not in conflict at all, but some stuff did get moved around. |
Disclaimer: This PR was developed with assistance from
Claude Opus 4.6 (1M context). The author has reviewed all code changes and test additions. CI has been executed successfully in the forked repo. Opening this PR to request review from the package maintainers for further feedback and iteration.Summary
This PR adds an optional
image_featuresparameter topredict()onTreeOfLifeClassifierandCustomLabelsClassifier(CustomLabelsBinningClassifierinherits this throughCustomLabelsClassifier). When provided, the method skips image encoding and computes classification directly from pre-computed embeddings.Embedding validation
The method validates input embeddings before classification:
(N, embedding_dim)embedding_dimmatches the model's expected dimension (model.visual.output_dim)Test plan
New tests in
TestPredictFromEmbeddings:file_name)CustomLabelsClassifierandCustomLabelsBinningClassifierCloses #167