Skip to content

Pre-trained HF Aggregation: Update train_unified.py model factory, optimizer, checkpoint loading #6

@jlunder00

Description

@jlunder00

Task 4: Training Pipeline Integration

Depends on: #4, #5 (all model classes created)

Modify: Tree_Matching_Networks/LinguisticTrees/experiments/train_unified.py

Three changes needed, all backward-compatible with existing models.

Change 1: Model Factory (~line 240)

Add imports at top of file:

from ..models.pretrained_text_embedding import PretrainedTextEmbeddingNet
from ..models.pretrained_text_matching import PretrainedTextMatchingNet
from ..models.pretrained_tree_embedding import PretrainedTreeEmbeddingNet
from ..models.pretrained_tree_matching import PretrainedTreeMatchingNet
from ..models.pretrained_noprop_embedding import PretrainedNoPropEmbeddingNet
from ..models.pretrained_noprop_matching import PretrainedNoPropMatchingNet

Replace the model creation block (currently ~lines 241-264) with expanded routing:

model_name = config['model'].get('name', '')
model_type = config['model'].get('model_type', 'matching')

if model_name == 'pretrained_text':
    # Condition A
    from transformers import AutoTokenizer
    hf_model_name = config['model']['pretrained']['model_name']
    tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
    if model_type == 'embedding':
        model = PretrainedTextEmbeddingNet(config, tokenizer).to(config['device'])
    else:
        model = PretrainedTextMatchingNet(config, tokenizer).to(config['device'])
    config['model_name'] = 'bert'  # text data pipeline

elif model_name == 'pretrained_noprop':
    # Condition B
    if model_type == 'embedding':
        model = PretrainedNoPropEmbeddingNet(config).to(config['device'])
    else:
        model = PretrainedNoPropMatchingNet(config).to(config['device'])
    config['model_name'] = 'graph'

elif model_name == 'pretrained_tree':
    # Conditions D, E, F
    if model_type == 'embedding':
        model = PretrainedTreeEmbeddingNet(config).to(config['device'])
    else:
        model = PretrainedTreeMatchingNet(config).to(config['device'])
    config['model_name'] = 'graph'

    # Apply freeze config
    freeze_config = config['model'].get('freeze', {})
    if freeze_config.get('freeze_propagation', False):
        model.freeze_propagation()
        logger.info("Froze GNN propagation (Condition F)")
    if freeze_config.get('freeze_transformer', False):
        model.freeze_transformer()
        logger.info("Froze pretrained transformer (Condition E)")

elif text_mode:
    # EXISTING: unchanged
    ...
else:
    # EXISTING: unchanged
    ...

Critical: The existing elif text_mode: and else: blocks must remain unchanged. The new blocks go BEFORE them as if/elif checks on model_name.

Change 2: Optimizer (~line 267)

Replace:

optimizer = torch.optim.Adam(model.parameters(), lr=..., weight_decay=...)

With:

base_lr = config['train']['learning_rate']
pretrained_lr_scale = config['train'].get('pretrained_lr_scale', 1.0)

if hasattr(model, 'get_parameter_groups'):
    param_groups = model.get_parameter_groups(base_lr, pretrained_lr_scale)
    optimizer = torch.optim.Adam(param_groups, weight_decay=config['train']['weight_decay'])
    logger.info(f"Optimizer: {len(param_groups)} param groups, pretrained_lr_scale={pretrained_lr_scale}")
else:
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=base_lr,
        weight_decay=config['train']['weight_decay']
    )

Note: The filter(lambda p: p.requires_grad, ...) is backward-compatible — existing models have all params trainable so it's equivalent to model.parameters().

Change 3: Checkpoint Loading (~line 273)

Replace:

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

With:

strict_load = config['model'].get('strict_checkpoint_load', True)
missing, unexpected = model.load_state_dict(checkpoint['model_state_dict'], strict=strict_load)
if missing:
    logger.info(f"Missing keys (expected for pretrained components): {len(missing)}")
if unexpected:
    logger.info(f"Unexpected keys (from previous aggregator): {len(unexpected)}")

if not missing and not unexpected:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    logger.info("Skipping optimizer state load due to architecture mismatch")

Why: Condition F loads a GNN checkpoint into a model with a different aggregator. The GNN weights match, the old aggregator keys are "unexpected", the new pretrained aggregator keys are "missing" (loaded from HuggingFace in constructor). strict=False allows this.

Also: Update models/__init__.py

Add exports for all 6 new model classes.

Verification

Run existing model creation to confirm backward compatibility:

cd /home/jlunder/research/Tree-Matching-Networks
python -c "
import yaml
with open('Tree_Matching_Networks/LinguisticTrees/configs/experiment_configs/ttn_embedding_prop_heavy_contrastive_config.yaml') as f:
    config = yaml.safe_load(f)
config['device'] = 'cpu'
from Tree_Matching_Networks.LinguisticTrees.models.tree_embedding import TreeEmbeddingNet
model = TreeEmbeddingNet(config)
print('Existing model still works: PASS')
"

See docs/plans/2026-03-09-pretrained-transformer-aggregation.md Tasks 6, 7 for details.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions