Task 2: Embedding Model Classes
Depends on: #3 (PretrainedTransformerAggregator)
Files to Create
1. Tree_Matching_Networks/LinguisticTrees/models/pretrained_noprop_embedding.py (Condition B)
Simplest model — start here for testing.
- Raw 804-dim node features →
PretrainedTransformerAggregator → graph embedding
- No TreeEncoder, no GNN propagation
- Same forward signature as
GraphEmbeddingNet: forward(node_features, edge_features, from_idx, to_idx, graph_idx, n_graphs)
- Ignores
edge_features, passes from_idx/to_idx to aggregator for tree positional encoding
- Delegates
freeze_transformer(), get_parameter_groups() to aggregator
Config reads from: config['model']['graph']['pretrained_transformer']
2. Tree_Matching_Networks/LinguisticTrees/models/pretrained_tree_embedding.py (Conditions D/E/F)
Follows TreeEmbeddingNet pattern exactly — see models/tree_embedding.py.
- Extends
GraphEmbeddingNet (same parent class as TreeEmbeddingNet)
- Creates
TreeEncoder for node encoding (804 → node_state_dim)
- Creates
PretrainedTransformerAggregator instead of TransformerTreeAggregator
- Passes both to
super().__init__() along with all the GNN config params
- Copy the config parsing from
TreeEmbeddingNet.__init__() verbatim — same keys: node_hidden_sizes, edge_hidden_sizes, n_prop_layers, share_prop_params, etc.
- Add
freeze_propagation(): freezes self._encoder + self._prop_layers (Condition F)
- Add
freeze_transformer(): delegates to self._aggregator.freeze_transformer() (Condition E)
- Add
get_parameter_groups(): encoder params at base_lr * pretrained_lr_scale, rest at base_lr
Config reads from: config['model']['graph'] for GNN, config['model']['graph']['pretrained_transformer'] for aggregator
3. Tree_Matching_Networks/LinguisticTrees/models/pretrained_text_embedding.py (Condition A)
Follows BertEmbeddingNet pattern — see models/bert_embedding.py.
- Loads
AutoModel.from_pretrained(model_name) (NOT from_config — real pretrained weights)
- Same forward interface as
BertEmbeddingNet: forward(batch_encoding=None, input_ids=None, attention_mask=None, ...)
- Uses mean pooling over tokens (not CLS-only — standard for sentence-transformers)
- Projection:
Linear(hf_hidden_dim, graph_rep_dim)
freeze_transformer(), get_parameter_groups() for controlling pre-trained layers
Config reads from: config['model']['pretrained']['model_name']
Key difference from BertEmbeddingNet: Uses from_pretrained() not from_config(), and uses its own tokenizer (loaded by train_unified.py from the same model_name).
Verification
Test each model with dummy tensors:
- NoProp:
torch.randn(25, 804) as node_features → expect [2, 2048] output
- Tree:
torch.randn(25, 804) as node_features → expect [2, 2048] output, test freeze_propagation() freezes encoder params
- Text: tokenize two sentences → expect
[2, 2048] output
See docs/plans/2026-03-09-pretrained-transformer-aggregation.md Tasks 2, 3, 4 for complete code.
Task 2: Embedding Model Classes
Depends on: #3 (PretrainedTransformerAggregator)
Files to Create
1.
Tree_Matching_Networks/LinguisticTrees/models/pretrained_noprop_embedding.py(Condition B)Simplest model — start here for testing.
PretrainedTransformerAggregator→ graph embeddingGraphEmbeddingNet:forward(node_features, edge_features, from_idx, to_idx, graph_idx, n_graphs)edge_features, passesfrom_idx/to_idxto aggregator for tree positional encodingfreeze_transformer(),get_parameter_groups()to aggregatorConfig reads from:
config['model']['graph']['pretrained_transformer']2.
Tree_Matching_Networks/LinguisticTrees/models/pretrained_tree_embedding.py(Conditions D/E/F)Follows
TreeEmbeddingNetpattern exactly — seemodels/tree_embedding.py.GraphEmbeddingNet(same parent class asTreeEmbeddingNet)TreeEncoderfor node encoding (804 → node_state_dim)PretrainedTransformerAggregatorinstead ofTransformerTreeAggregatorsuper().__init__()along with all the GNN config paramsTreeEmbeddingNet.__init__()verbatim — same keys:node_hidden_sizes,edge_hidden_sizes,n_prop_layers,share_prop_params, etc.freeze_propagation(): freezesself._encoder+self._prop_layers(Condition F)freeze_transformer(): delegates toself._aggregator.freeze_transformer()(Condition E)get_parameter_groups(): encoder params atbase_lr * pretrained_lr_scale, rest atbase_lrConfig reads from:
config['model']['graph']for GNN,config['model']['graph']['pretrained_transformer']for aggregator3.
Tree_Matching_Networks/LinguisticTrees/models/pretrained_text_embedding.py(Condition A)Follows
BertEmbeddingNetpattern — seemodels/bert_embedding.py.AutoModel.from_pretrained(model_name)(NOTfrom_config— real pretrained weights)BertEmbeddingNet:forward(batch_encoding=None, input_ids=None, attention_mask=None, ...)Linear(hf_hidden_dim, graph_rep_dim)freeze_transformer(),get_parameter_groups()for controlling pre-trained layersConfig reads from:
config['model']['pretrained']['model_name']Key difference from BertEmbeddingNet: Uses
from_pretrained()notfrom_config(), and uses its own tokenizer (loaded by train_unified.py from the same model_name).Verification
Test each model with dummy tensors:
torch.randn(25, 804)as node_features → expect[2, 2048]outputtorch.randn(25, 804)as node_features → expect[2, 2048]output, testfreeze_propagation()freezes encoder params[2, 2048]outputSee
docs/plans/2026-03-09-pretrained-transformer-aggregation.mdTasks 2, 3, 4 for complete code.