Task 3: Matching Model Classes
Depends on: #3 (PretrainedTransformerAggregator), #4 (embedding models for reference)
Files to Create
1. Tree_Matching_Networks/LinguisticTrees/models/pretrained_tree_matching.py (Conditions D/E/F)
Follows TreeMatchingNet pattern exactly — see models/tree_matching.py.
TreeMatchingNet does NOT extend GraphEmbeddingNet — it wraps GraphMatchingNet via self.gmn
- Copy the config parsing from
TreeMatchingNet.__init__() — same keys
- Create
TreeEncoder + PretrainedTransformerAggregator
- Pass to
GraphMatchingNet(encoder=encoder, aggregator=aggregator, ..., prop_type='matching')
freeze_propagation(): freeze self.gmn._encoder + self.gmn._prop_layers
freeze_transformer(): self.gmn._aggregator.freeze_transformer()
get_parameter_groups(): check 'gmn._aggregator.encoder.' in name for pretrained params
forward(): just return self.gmn(...)
Key difference from embedding: Uses GraphMatchingNet (has cross-graph attention in propagation) instead of GraphEmbeddingNet.
2. Tree_Matching_Networks/LinguisticTrees/models/pretrained_text_matching.py (Condition A)
Follows BertMatchingNet pattern — see models/bert_matching.py.
- Loads
AutoModel.from_pretrained(model_name)
- Manual forward pass through encoder layers with interleaved
BertGMNStyleCrossAttention
- Import
BertGMNStyleCrossAttention from bert_matching.py (no learnable params, just similarity + softmax)
- Cross-attention at every layer (matching existing behavior)
- Need to handle different model families: BERT has
model.encoder.layer, DistilBERT has model.transformer.layer
- Projection:
Linear(hf_hidden_dim, graph_rep_dim)
Note: The manual forward pass intercepts between encoder layers. For BERT-style:
hidden_states = self.transformer.embeddings(input_ids=..., token_type_ids=...)
extended_mask = self.transformer.get_extended_attention_mask(attention_mask, input_ids.shape)
for i, layer in enumerate(self.transformer.encoder.layer):
layer_outputs = layer(hidden_states, extended_mask)
hidden_states = layer_outputs[0]
# Insert cross-attention here
3. Tree_Matching_Networks/LinguisticTrees/models/pretrained_noprop_matching.py (Condition B)
- Same as
pretrained_noprop_embedding.py but note: without GNN propagation there's no cross-graph attention during propagation
- The aggregator processes each graph independently, comparison happens in the loss function
- This is acceptable for the experimental design (Condition B tests pretrained aggregation on raw features, matching paradigm comparison happens at the loss level)
Verification
- Tree matching: needs even number of graphs (pairs). Test with 4 graphs → expect
[4, 2048]
- Text matching: tokenize 4 sentences → expect
[4, 2048], verify batch_size % 2 == 0 check
See docs/plans/2026-03-09-pretrained-transformer-aggregation.md Task 5 for complete code.
Task 3: Matching Model Classes
Depends on: #3 (PretrainedTransformerAggregator), #4 (embedding models for reference)
Files to Create
1.
Tree_Matching_Networks/LinguisticTrees/models/pretrained_tree_matching.py(Conditions D/E/F)Follows
TreeMatchingNetpattern exactly — seemodels/tree_matching.py.TreeMatchingNetdoes NOT extendGraphEmbeddingNet— it wrapsGraphMatchingNetviaself.gmnTreeMatchingNet.__init__()— same keysTreeEncoder+PretrainedTransformerAggregatorGraphMatchingNet(encoder=encoder, aggregator=aggregator, ..., prop_type='matching')freeze_propagation(): freezeself.gmn._encoder+self.gmn._prop_layersfreeze_transformer():self.gmn._aggregator.freeze_transformer()get_parameter_groups(): check'gmn._aggregator.encoder.'in name for pretrained paramsforward(): justreturn self.gmn(...)Key difference from embedding: Uses
GraphMatchingNet(has cross-graph attention in propagation) instead ofGraphEmbeddingNet.2.
Tree_Matching_Networks/LinguisticTrees/models/pretrained_text_matching.py(Condition A)Follows
BertMatchingNetpattern — seemodels/bert_matching.py.AutoModel.from_pretrained(model_name)BertGMNStyleCrossAttentionBertGMNStyleCrossAttentionfrombert_matching.py(no learnable params, just similarity + softmax)model.encoder.layer, DistilBERT hasmodel.transformer.layerLinear(hf_hidden_dim, graph_rep_dim)Note: The manual forward pass intercepts between encoder layers. For BERT-style:
3.
Tree_Matching_Networks/LinguisticTrees/models/pretrained_noprop_matching.py(Condition B)pretrained_noprop_embedding.pybut note: without GNN propagation there's no cross-graph attention during propagationVerification
[4, 2048][4, 2048], verify batch_size % 2 == 0 checkSee
docs/plans/2026-03-09-pretrained-transformer-aggregation.mdTask 5 for complete code.