Skip to content

Pre-trained HF Aggregation: Matching model classes (Conditions A/B/D/E/F) #5

@jlunder00

Description

@jlunder00

Task 3: Matching Model Classes

Depends on: #3 (PretrainedTransformerAggregator), #4 (embedding models for reference)

Files to Create

1. Tree_Matching_Networks/LinguisticTrees/models/pretrained_tree_matching.py (Conditions D/E/F)

Follows TreeMatchingNet pattern exactly — see models/tree_matching.py.

  • TreeMatchingNet does NOT extend GraphEmbeddingNet — it wraps GraphMatchingNet via self.gmn
  • Copy the config parsing from TreeMatchingNet.__init__() — same keys
  • Create TreeEncoder + PretrainedTransformerAggregator
  • Pass to GraphMatchingNet(encoder=encoder, aggregator=aggregator, ..., prop_type='matching')
  • freeze_propagation(): freeze self.gmn._encoder + self.gmn._prop_layers
  • freeze_transformer(): self.gmn._aggregator.freeze_transformer()
  • get_parameter_groups(): check 'gmn._aggregator.encoder.' in name for pretrained params
  • forward(): just return self.gmn(...)

Key difference from embedding: Uses GraphMatchingNet (has cross-graph attention in propagation) instead of GraphEmbeddingNet.

2. Tree_Matching_Networks/LinguisticTrees/models/pretrained_text_matching.py (Condition A)

Follows BertMatchingNet pattern — see models/bert_matching.py.

  • Loads AutoModel.from_pretrained(model_name)
  • Manual forward pass through encoder layers with interleaved BertGMNStyleCrossAttention
  • Import BertGMNStyleCrossAttention from bert_matching.py (no learnable params, just similarity + softmax)
  • Cross-attention at every layer (matching existing behavior)
  • Need to handle different model families: BERT has model.encoder.layer, DistilBERT has model.transformer.layer
  • Projection: Linear(hf_hidden_dim, graph_rep_dim)

Note: The manual forward pass intercepts between encoder layers. For BERT-style:

hidden_states = self.transformer.embeddings(input_ids=..., token_type_ids=...)
extended_mask = self.transformer.get_extended_attention_mask(attention_mask, input_ids.shape)
for i, layer in enumerate(self.transformer.encoder.layer):
    layer_outputs = layer(hidden_states, extended_mask)
    hidden_states = layer_outputs[0]
    # Insert cross-attention here

3. Tree_Matching_Networks/LinguisticTrees/models/pretrained_noprop_matching.py (Condition B)

  • Same as pretrained_noprop_embedding.py but note: without GNN propagation there's no cross-graph attention during propagation
  • The aggregator processes each graph independently, comparison happens in the loss function
  • This is acceptable for the experimental design (Condition B tests pretrained aggregation on raw features, matching paradigm comparison happens at the loss level)

Verification

  • Tree matching: needs even number of graphs (pairs). Test with 4 graphs → expect [4, 2048]
  • Text matching: tokenize 4 sentences → expect [4, 2048], verify batch_size % 2 == 0 check

See docs/plans/2026-03-09-pretrained-transformer-aggregation.md Task 5 for complete code.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions