This repository implements a two-phase deep learning pipeline for modeling longitudinal Electronic Medical Records (EMRs). The architecture combines temporal embeddings, patient context, and Transformer-based sequence modeling to predict or impute patient events over time.
This repo is part of an unpublished thesis and will be finalized post-submission. Please do not reuse without permission.
The results shown here (in evaluation.ipynb) are on random data, as my research dataset is private. This model will be used on actual EMR data, stored in a closed environment. For that, it is organized as a package that can be installed:
transform-emr/
β
βββ transform_emr/ # Core Python package
β βββ config/ # Configuration modules
β β βββ __init__.py
β β βββ tak-repo-portable.json # TAKRepository object from Mediator (see related project)
β β βββ dataset_config.py
β β βββ model_config.py
β βββ __init__.py
β βββ dataset.py # Dataset, DataPreprocess and Tokenizer
β βββ embedder.py # Embedding model (EMREmbedding) + training
β βββ transformer.py # Transformer architecture (GPT) + training
β βββ train.py # Full training pipeline (3-phase)
β βββ inference.py # Inference pipeline
β βββ loss.py # Utility module for special loss criterias
β βββ schedulers.py # Utility module for training schedulers (LR & Aux tasks)
β βββ utils.py # Utility functions for the package (plots + penalties + masks)
β βββ diagnose.py # Debug reports on trained model health
βββ data/ # External data folder (for synthetic or real EMR)
β βββ generate_synthetic_data.ipynb # A notebook that generates synthetic data similar in structure to mediator's output (for tests)
β βββ source/ # Notebook will point here and auto-generate the train-test splits
β βββ train/
β βββ test/
βββ unittests/ # Unit and integration tests (dataset / model / utils)
βββ evaluation.ipynb # Main research and experiments notebook
βββ README.md
βββ .gitignore
βββ requirements.txt
βββ LICENCE
βββ CITATION.cff
βββ setup.py
βββ pyproject.tomlAs noted, this model feeds of the output of the Mediator temporal abstraction engine. It can work with any temporal-interval dataset, but note that the embedding has knowledge-base component, so a tak-repo-portable.json like object is mandatory.
Install the project as an editable package from the root directory:
pip install -e .
# Ensure your working directory is properly set to the root repo of this project
# Be sure to set the path in your local env properly.import pandas as pd
from transform_emr.dataset import EMRDataset
from transform_emr.config.dataset_config import *
from transform_emr.config.model_config import *
# Load data (verify you paths are properly defined)
temporal_df = pd.read_csv(TRAIN_TEMPORAL_DATA_FILE, low_memory=False)
ctx_df = pd.read_csv(TRAIN_CTX_DATA_FILE)
print(f"[Pre-processing]: Building tokenizer...")
processor = DataProcessor(temporal_df, ctx_df, tak_repo_path=TAK_REPO_PATH, scaler=None)
temporal_df, ctx_df = processor.run()
tokenizer = EMRTokenizer.from_processed_df(temporal_df)
train_ds = EMRDataset(train_df, train_ctx, tokenizer=tokenizer)
MODEL_CONFIG['ctx_dim'] = int(train_ds.context_df.shape[1]) # Dinamically updating shapefrom transform_emr.train import run_training
run_training()Model checkpoints are saved under checkpoints/phase1/, checkpoints/phase2/, and checkpoints/phase3/.
You can also run each phase individually by calling prepare_data(), phase_one(), phase_two(), and
phase_three() separately. prepare_data() returns (train_ds, val_ds, tokenizer); run_training()
owns all DataLoader creation. Phase 2 uses a weighted oversampled loader; phases 1 and 3 share a
bucket-batched natural-distribution loader. See train.py for reference.
The primary inference task is complication risk prediction: for each patient, generate a single
free-running trajectory and read the outcome head at every step to produce a probability curve per
complication over time. Use generate with collect_risk_scores=True for this purpose.
import joblib
from pathlib import Path
from transform_emr.embedder import EMREmbedding
from transform_emr.transformer import GPT
from transform_emr.dataset import DataProcessor, EMRTokenizer, EMRDataset
from transform_emr.inference import generate
from transform_emr.config.model_config import *
# Load tokenizer and scaler
tokenizer = EMRTokenizer.load(Path(CHECKPOINT_PATH) / "tokenizer.pt")
scaler = joblib.load(Path(CHECKPOINT_PATH) / "scaler.pkl")
# Preprocess test data, truncated to the same input window used during Phase-3 alignment
processor = DataProcessor(df, ctx_df, scaler=scaler, tak_repo_path=TAK_REPO_PATH, max_input_days=5)
df, ctx_df = processor.run()
dataset_input = EMRDataset(df, ctx_df, tokenizer=tokenizer)
# Load the best available checkpoint (Phase-3 if available, otherwise Phase-2)
embedder_model, *_ = EMREmbedding.load(PHASE1_CHECKPOINT, tokenizer=tokenizer)
p3_ckpt = Path(PHASE3_CHECKPOINT)
p2_ckpt = Path(PHASE2_CHECKPOINT)
ckpt_path = p3_ckpt if p3_ckpt.exists() else p2_ckpt
model, *_ = GPT.load(ckpt_path, embedder=embedder_model)
model.eval()
# Generate risk curves β one row per generated step, P_<outcome> columns per complication
risk_df = generate(model, dataset_input, max_len=500, temperature=1.0, rep_decay=0.6,
collect_risk_scores=True)
# Raw event stream only (no risk scores, faster)
event_df = generate(model, dataset_input, max_len=500, temperature=1.0, rep_decay=0.6)The returned risk_df has columns {PatientId, Step, Token, IsInput, IsOutcome, IsTerminal, TimePoint, P_<outcome_name>, ...}.
Rows with IsInput == 0 are generated steps; the P_* columns hold sigmoid outcome-head probabilities
at that step. Evaluate using time-stratified AUC (see evaluation.ipynb).
Patients that exhaust max_len without generating a terminal token receive a forced DEATH or RELEASE
token (chosen by highest logit), clamped to <= 336 h. The fallback rate is printed after generation.
You can perform local tests (not unit-tests) by activating the .py files, using the module as a package, as long as the file you are activating has main section.
For example, run this from the root:
python -m transform_emr.train
# Or
python -m transform_emr.inference
# Both modules have a __main__ activation to train / infer on a trained modelRun all tests:
Without validation prints:
python -m pytest unittests/With validation prints:
python -m pytest -q -s unittests/To package without data/checkpoints:
# Clean up any existing temp folder
Remove-Item -Recurse -Force .\transform_emr_temp -ErrorAction SilentlyContinue
# Recreate the temp folder
New-Item -ItemType Directory -Path .\transform_emr_temp | Out-Null
# Copy only what's needed
Copy-Item -Path .\transform_emr -Destination .\transform_emr_temp -Recurse
Copy-Item -Path .\setup.py, .\evaluation.ipynb, .\README.md, .\requirements.txt -Destination .\transform_emr_temp
# Remove __pycache__ folders (platform-specific bytecode, not for distribution)
Get-ChildItem -Path .\transform_emr_temp -Filter __pycache__ -Recurse -Directory | Remove-Item -Recurse -Force
# Zip it
Compress-Archive -Path .\transform_emr_temp\* -DestinationPath .\emr_model.zip -Force
# Clean up
Remove-Item -Recurse -Force .\transform_emr_temp- This project uses synthetic EMR data (
data/train/anddata/test/). - For best results, ensure consistent preprocessing when saving/loading models.
Raw EMR Tables
β
βΌ
Per-patient Event Tokenization (with normalized absolute timestamps)
β
βΌ
π§ Phase 1 β Train EMREmbedding (token + time + patient context)
β
βΌ
π Phase 2 β Pretrain a Transformer decoder over learned embeddings (next-token prediction + outcome auxiliary task).
β
βΌ
π― Phase 3 β Outcome Head Fine-tuning: freeze backbone, fine-tune only the outcome head on
natural-distribution batches (oversample=False + pos_weight), analogous to BERT head fine-tuning.
β
βΌ
β Predict next medical events (token + time) and read complication risk curves from the outcome head (in evaluation.ipynb)
| Component | Role |
|---|---|
DataProcessor |
Performs all necessary data processing, from input data to tokens_df. |
EMRTokenizer |
Builds vocabulary and per-outcome prevalence ratios from a processed temporal_df; filters outcomes below OUTCOME_RARE_THRESHOLD_PCT; saves/loads with BucketBatchSampler / WeightedBucketBatchSampler support. |
EMRDataset |
Converts raw EMR tables into per-patient token sequences with relative time. |
| collate_emr() | Pads sequences and returns tensors|
π Why it matters:
Medical data varies in density and structure across patients. This dynamic preprocessing handles irregularity while preserving medically-relevant sequencing via START/END logic and relative timing.
This modules assumes the existance of prepared tak-repo-portable.json file, outputed from the Mediator as a hierarchy mapper of the different concepts.
| Component | Role |
|---|---|
Time2Vec |
Learns periodic + trend encoding from inter-event durations. |
EMREmbedding |
Combines token, time, and patient context embeddings to create token representation. |
train_embedder() |
Trains the embedding model with teacher-forced next-token prediction (temporal BCE), with MSE on time prediction and MLM task as auxilary goal. |
βοΈ Phase 1: Learning Events Representation
Phase 1 learns a robust, patient-aware representation of their event sequences. It isolates the core structure of patient timelines without being confounded by the autoregressive depth of Transformers.
The embedder uses:
- 4 levels of tokens - The event token is seperated to 4 hierarichal components to impose similarity between tokens of the same domain:
GLUCOSE->GLUCOSE_TREND->GLUCOSE_TREND_Inc->GLUCOSE_TREND_Inc_START - 1 level of time - ABS T from ADMISSION, to understand global patterns and relationships between non sequential events.
This architecture constructs event representations by concatenating five hierarchical levels: Raw Concept, Concept, Value, Position, and Absolute Time. This creates a dense vector that captures the intrinsic hierarchy of medical concepts (e.g., Glucose_High is a child of Glucose) while explicitly binding them to their timestamp.
We choose concatenation (Early Fusion) for the temporal component-unlike the standard additive approach to preserve the integrity of the medical signal. By keeping the time dimensions separate from the concept dimensions in the input, the model can clearly distinguish the "what" from the "when". This ensures that the core identity of a pathology (e.g., Hyperglycemia) remains stable and recognizable ("Hyperglycemia is Hyperglycemia") regardless of its timing, while allowing the projection layer to learn how time modifies its clinical significance (e.g., Morning vs. Evening).
Context Handling To condition these embeddings on static patient attributes (e.g., Age, Sex), we project the patient context vector and add it to the event sequence. This acts as a global bias, shifting the entire event manifold into a patient-specific subspace. This ensures that even before the Transformer layers, the event representations are already calibrated to the patient's demographic risk profile. Since the inference output the context projection and event embedding separately, we use context dropout (passing p% of the trajectories with no context) so that the embedder will learn to work with / without it, while still pushing the context projection layer towards the shared latent space.
The training uses next token prediction loss (temporal-window BCE) + time prediction MSE (Ξt) + MLM prediction loss. MLM will avoid masking tokens which will damage the broader meaning like ADMISSION, TERMINAL_OUTCOMES...
| Component | Role |
|---|---|
GPT |
Transformer decoder stack over learned embeddings for next token prediction, with an additional head for delta_t prediction. Model inputs a trained embedder. |
CausalSelfAttention |
Multi-head attention using causal mask to enforce chronology. Uses temporal RoPE to inject absolute time into attention scores. |
MLP |
SwiGLU MLP (SiLU Gating), based on common LLM optimizations. |
AdaLNBlock |
Transformer block with AdaLN-Zero conditioning (adaptive layer norm), to bias prediction based on the patient context. |
pretrain_transformer() |
Complete Phase-2 training logic using legality-masked temporal multi-hot BCE (focal), masked set-CE, Ξt loss, and outcome BCE auxiliary losses. |
finetune_transformer() |
Phase-3 outcome head fine-tuning: freezes the backbone and fine-tunes only the outcome head on natural-distribution batches (oversample=False), so pos_weight in BCEWithLogitsLoss correctly compensates for class imbalance without double-counting. Uses the same soft-label targets as Phase 2 but with gradient isolation on the head only. Saves full-model checkpoints loadable with GPT.load(). |
βοΈ Phase 2: Learning Sequence Dependencies
Once the EMR structure is captured, the transformer learns to model sequential dependencies in event progression:
- What tends to follow a certain event?
- How does timing affect outcomes?
- How does patient context modulate the trajectory?
The training uses next token prediction loss (temporal-window masked BCE Focal Loss + masked CE loss) + time prediction MSE (Ξt) + outcome prediction BCE auxillary task. The training is guided by teacher's forcing, showing the model the correct context at every step (exposing [0, t-1] at step t from T where T is block_size), while also masking logits for illegal predictions based on the true trajectory. As training progress, the model's input ([0, t-1]) is partially masked (CBM) to teach the model to handle inaccuracies in generation, while avoiding masking same tokens as the EMREmbedding + MEAL + _START + _END tokens, to not clash with the legal set of next tokens to model can use.
The training flow uses warmup/curriculum scheduling (LR warmup, BCE-only phase, and staged auxiliary losses). The embedder is trainable during Phase 2, but updated with a lower learning rate than the transformer blocks.
| Component | Role |
|---|---|
generate() |
Primary inference function. Generates one autoregressive trajectory per patient. With collect_risk_scores=True, reads the outcome head at every step and returns per-step complication probabilities (P_* columns). Patients that reach max_len receive a forced terminal token (DEATH or RELEASE by highest logit), clamped to <= 336 h. |
get_token_embedding() |
Returns the embedding vector of a specific token from a trained embedder. |
NOTE: Inference is step-by-step (autoregressive), so it is significantly slower than training. With that being said, model uses batch inference (multiple patients at the same time), KV cache (reduces per-step work from O(TΒ·dΒ²) to O(dΒ²)) and FP16 quantization, all together significantly helps the inference speed.
Runs the full end-to-end evaluation pipeline: data loading, three-phase training, risk-curve generation, and statistical analysis. The primary evaluation metric is time-stratified AUC.
| Component | Role |
|---|---|
extract_ground_truth() |
Builds a {patient_id β {outcome β first_occurrence_hours}} dict from the full (untruncated) test dataset. |
time_stratified_auc() |
At each 24 h window: score = max outcome-head probability in window, label = complication occurred in same window. Computes AUROC and AUPRC per complication, averaged across windows. |
time_accuracy() |
For patients where a complication occurred: MAE between the generated step with peak probability and actual onset time. |
calibrate_temperature() |
Learns a per-outcome temperature scalar via LBFGS (NLL minimisation). Does not affect rank order (AUC unchanged); improves probability calibration for direct interpretation. |
reliability_diagram() |
Plots calibration curves before and after temperature scaling. |
- βοΈ Handles irregular time-series data using relative deltas and Time2Vec.
- βοΈ Captures both short- and long-range dependencies with deep transformer blocks.
- βοΈ Supports variable-length patient histories using custom collate and attention masks.
- βοΈ Imputes and predicts events in structured EMR timelines.
This work builds on and adapts ideas from the following sources:
-
Time2Vec (Kazemi et al., 2019):
The temporal embedding design is adapted from the Time2Vec formulation.
π A. Kazemi, S. Ghamizi, A.-H. Karimi. "Time2Vec: Learning a Vector Representation of Time." NeurIPS 2019 Time Series Workshop.
arXiv:1907.05321 -
nanoGPT (Karpathy, 2023):
The training loop and transformer backbone are adapted from nanoGPT,
with modifications for multi-stream EMR inputs, multiple embeddings, and a k-step prediction loss. -
RoPE / RoFormer (Su et al., 2021):
The attention module uses rotary position embeddings adapted to continuous/absolute timestamps (temporal RoPE) to inject time into Q/K rotations.
π J. Su, Y. Lu, S. Pan, A. Murtadha, B. Wen. "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv:2104.09864.
arXiv:2104.09864 -
AdaLN-Zero (Peebles, W., & Xie, S., 2023):
Inspired by the paper "Scalable diffusion models with transformers", I added a customized block to the transformer designed to allow static context influence all generation steps. The paper uses this method to inform the diffusion model of the label of the image it should generate.
And more...
