Task 4: Training Pipeline Integration
Depends on: #4, #5 (all model classes created)
Modify: Tree_Matching_Networks/LinguisticTrees/experiments/train_unified.py
Three changes needed, all backward-compatible with existing models.
Change 1: Model Factory (~line 240)
Add imports at top of file:
from ..models.pretrained_text_embedding import PretrainedTextEmbeddingNet
from ..models.pretrained_text_matching import PretrainedTextMatchingNet
from ..models.pretrained_tree_embedding import PretrainedTreeEmbeddingNet
from ..models.pretrained_tree_matching import PretrainedTreeMatchingNet
from ..models.pretrained_noprop_embedding import PretrainedNoPropEmbeddingNet
from ..models.pretrained_noprop_matching import PretrainedNoPropMatchingNet
Replace the model creation block (currently ~lines 241-264) with expanded routing:
model_name = config['model'].get('name', '')
model_type = config['model'].get('model_type', 'matching')
if model_name == 'pretrained_text':
# Condition A
from transformers import AutoTokenizer
hf_model_name = config['model']['pretrained']['model_name']
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
if model_type == 'embedding':
model = PretrainedTextEmbeddingNet(config, tokenizer).to(config['device'])
else:
model = PretrainedTextMatchingNet(config, tokenizer).to(config['device'])
config['model_name'] = 'bert' # text data pipeline
elif model_name == 'pretrained_noprop':
# Condition B
if model_type == 'embedding':
model = PretrainedNoPropEmbeddingNet(config).to(config['device'])
else:
model = PretrainedNoPropMatchingNet(config).to(config['device'])
config['model_name'] = 'graph'
elif model_name == 'pretrained_tree':
# Conditions D, E, F
if model_type == 'embedding':
model = PretrainedTreeEmbeddingNet(config).to(config['device'])
else:
model = PretrainedTreeMatchingNet(config).to(config['device'])
config['model_name'] = 'graph'
# Apply freeze config
freeze_config = config['model'].get('freeze', {})
if freeze_config.get('freeze_propagation', False):
model.freeze_propagation()
logger.info("Froze GNN propagation (Condition F)")
if freeze_config.get('freeze_transformer', False):
model.freeze_transformer()
logger.info("Froze pretrained transformer (Condition E)")
elif text_mode:
# EXISTING: unchanged
...
else:
# EXISTING: unchanged
...
Critical: The existing elif text_mode: and else: blocks must remain unchanged. The new blocks go BEFORE them as if/elif checks on model_name.
Change 2: Optimizer (~line 267)
Replace:
optimizer = torch.optim.Adam(model.parameters(), lr=..., weight_decay=...)
With:
base_lr = config['train']['learning_rate']
pretrained_lr_scale = config['train'].get('pretrained_lr_scale', 1.0)
if hasattr(model, 'get_parameter_groups'):
param_groups = model.get_parameter_groups(base_lr, pretrained_lr_scale)
optimizer = torch.optim.Adam(param_groups, weight_decay=config['train']['weight_decay'])
logger.info(f"Optimizer: {len(param_groups)} param groups, pretrained_lr_scale={pretrained_lr_scale}")
else:
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=base_lr,
weight_decay=config['train']['weight_decay']
)
Note: The filter(lambda p: p.requires_grad, ...) is backward-compatible — existing models have all params trainable so it's equivalent to model.parameters().
Change 3: Checkpoint Loading (~line 273)
Replace:
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
With:
strict_load = config['model'].get('strict_checkpoint_load', True)
missing, unexpected = model.load_state_dict(checkpoint['model_state_dict'], strict=strict_load)
if missing:
logger.info(f"Missing keys (expected for pretrained components): {len(missing)}")
if unexpected:
logger.info(f"Unexpected keys (from previous aggregator): {len(unexpected)}")
if not missing and not unexpected:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
logger.info("Skipping optimizer state load due to architecture mismatch")
Why: Condition F loads a GNN checkpoint into a model with a different aggregator. The GNN weights match, the old aggregator keys are "unexpected", the new pretrained aggregator keys are "missing" (loaded from HuggingFace in constructor). strict=False allows this.
Also: Update models/__init__.py
Add exports for all 6 new model classes.
Verification
Run existing model creation to confirm backward compatibility:
cd /home/jlunder/research/Tree-Matching-Networks
python -c "
import yaml
with open('Tree_Matching_Networks/LinguisticTrees/configs/experiment_configs/ttn_embedding_prop_heavy_contrastive_config.yaml') as f:
config = yaml.safe_load(f)
config['device'] = 'cpu'
from Tree_Matching_Networks.LinguisticTrees.models.tree_embedding import TreeEmbeddingNet
model = TreeEmbeddingNet(config)
print('Existing model still works: PASS')
"
See docs/plans/2026-03-09-pretrained-transformer-aggregation.md Tasks 6, 7 for details.
Task 4: Training Pipeline Integration
Depends on: #4, #5 (all model classes created)
Modify:
Tree_Matching_Networks/LinguisticTrees/experiments/train_unified.pyThree changes needed, all backward-compatible with existing models.
Change 1: Model Factory (~line 240)
Add imports at top of file:
Replace the model creation block (currently ~lines 241-264) with expanded routing:
Critical: The existing
elif text_mode:andelse:blocks must remain unchanged. The new blocks go BEFORE them asif/elifchecks onmodel_name.Change 2: Optimizer (~line 267)
Replace:
With:
Note: The
filter(lambda p: p.requires_grad, ...)is backward-compatible — existing models have all params trainable so it's equivalent tomodel.parameters().Change 3: Checkpoint Loading (~line 273)
Replace:
With:
Why: Condition F loads a GNN checkpoint into a model with a different aggregator. The GNN weights match, the old aggregator keys are "unexpected", the new pretrained aggregator keys are "missing" (loaded from HuggingFace in constructor).
strict=Falseallows this.Also: Update
models/__init__.pyAdd exports for all 6 new model classes.
Verification
Run existing model creation to confirm backward compatibility:
See
docs/plans/2026-03-09-pretrained-transformer-aggregation.mdTasks 6, 7 for details.