Skip to content

Pre-trained HF Aggregation: Embedding model classes (Conditions A/B/D/E/F) #4

@jlunder00

Description

@jlunder00

Task 2: Embedding Model Classes

Depends on: #3 (PretrainedTransformerAggregator)

Files to Create

1. Tree_Matching_Networks/LinguisticTrees/models/pretrained_noprop_embedding.py (Condition B)

Simplest model — start here for testing.

  • Raw 804-dim node features → PretrainedTransformerAggregator → graph embedding
  • No TreeEncoder, no GNN propagation
  • Same forward signature as GraphEmbeddingNet: forward(node_features, edge_features, from_idx, to_idx, graph_idx, n_graphs)
  • Ignores edge_features, passes from_idx/to_idx to aggregator for tree positional encoding
  • Delegates freeze_transformer(), get_parameter_groups() to aggregator

Config reads from: config['model']['graph']['pretrained_transformer']

2. Tree_Matching_Networks/LinguisticTrees/models/pretrained_tree_embedding.py (Conditions D/E/F)

Follows TreeEmbeddingNet pattern exactly — see models/tree_embedding.py.

  • Extends GraphEmbeddingNet (same parent class as TreeEmbeddingNet)
  • Creates TreeEncoder for node encoding (804 → node_state_dim)
  • Creates PretrainedTransformerAggregator instead of TransformerTreeAggregator
  • Passes both to super().__init__() along with all the GNN config params
  • Copy the config parsing from TreeEmbeddingNet.__init__() verbatim — same keys: node_hidden_sizes, edge_hidden_sizes, n_prop_layers, share_prop_params, etc.
  • Add freeze_propagation(): freezes self._encoder + self._prop_layers (Condition F)
  • Add freeze_transformer(): delegates to self._aggregator.freeze_transformer() (Condition E)
  • Add get_parameter_groups(): encoder params at base_lr * pretrained_lr_scale, rest at base_lr

Config reads from: config['model']['graph'] for GNN, config['model']['graph']['pretrained_transformer'] for aggregator

3. Tree_Matching_Networks/LinguisticTrees/models/pretrained_text_embedding.py (Condition A)

Follows BertEmbeddingNet pattern — see models/bert_embedding.py.

  • Loads AutoModel.from_pretrained(model_name) (NOT from_config — real pretrained weights)
  • Same forward interface as BertEmbeddingNet: forward(batch_encoding=None, input_ids=None, attention_mask=None, ...)
  • Uses mean pooling over tokens (not CLS-only — standard for sentence-transformers)
  • Projection: Linear(hf_hidden_dim, graph_rep_dim)
  • freeze_transformer(), get_parameter_groups() for controlling pre-trained layers

Config reads from: config['model']['pretrained']['model_name']

Key difference from BertEmbeddingNet: Uses from_pretrained() not from_config(), and uses its own tokenizer (loaded by train_unified.py from the same model_name).

Verification

Test each model with dummy tensors:

  • NoProp: torch.randn(25, 804) as node_features → expect [2, 2048] output
  • Tree: torch.randn(25, 804) as node_features → expect [2, 2048] output, test freeze_propagation() freezes encoder params
  • Text: tokenize two sentences → expect [2, 2048] output

See docs/plans/2026-03-09-pretrained-transformer-aggregation.md Tasks 2, 3, 4 for complete code.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions