Skip to content

Add Synthetic EHR Generation Support -- GPT Baseline#879

Open
ethanrasmussen wants to merge 24 commits intosunlabuiuc:masterfrom
ethanrasmussen:implement_baseline_final
Open

Add Synthetic EHR Generation Support -- GPT Baseline#879
ethanrasmussen wants to merge 24 commits intosunlabuiuc:masterfrom
ethanrasmussen:implement_baseline_final

Conversation

@ethanrasmussen
Copy link

Overview

This PR introduces comprehensive synthetic EHR generation capabilities to PyHealth, enabling researchers to train generative models that create realistic synthetic patient histories. The implementation follows PyHealth conventions and provides a complete pipeline from data processing to model training and evaluation.

Changes

Core Functionality

1. New Task: SyntheticEHRGenerationMIMIC3 and SyntheticEHRGenerationMIMIC4

File: pyhealth/tasks/synthetic_ehr_generation.py

  • Implements PyHealth BaseTask for preparing patient visit sequences for generative modeling
  • Processes patient records into nested sequences suitable for autoregressive generation
  • Supports configurable visit filtering (min_visits, max_visits)
  • Compatible with both MIMIC-III and MIMIC-IV datasets
  • Creates samples with input/output schema for teacher forcing during training

2. New Model: TransformerEHRGenerator

File: pyhealth/models/synthetic_ehr.py

  • Decoder-only transformer architecture (GPT-style) for autoregressive EHR generation
  • Learns to model patient visit sequences and generate synthetic patient histories
  • Supports sampling with temperature, top-k, and nucleus (top-p) filtering

3. Utility Module: synthetic_ehr_utils

File: pyhealth/synthetic_ehr_utils/synthetic_ehr_utils.py

Provides data conversion utilities for working with different EHR representations:

Core Functions:

  • tabular_to_sequences(): Converts long-form DataFrames to text sequences
  • sequences_to_tabular(): Converts text sequences back to DataFrames
  • nested_codes_to_sequences(): Converts PyHealth nested structure to text
  • sequences_to_nested_codes(): Converts text sequences to nested structure
  • create_flattened_representation(): Creates patient-level count matrices
  • process_mimic_for_generation(): End-to-end MIMIC data processing

Example Scripts

4. Baseline Models Script

File: examples/synthetic_ehr_generation/synthetic_ehr_baselines.py

Demonstrates integration with popular generative model baselines:

Supported Models:

  • GReaT: Language model-based tabular generation
  • CTGAN: Conditional GAN for tabular data
  • TVAE: Variational autoencoder for tabular data
  • TransformerEHRGenerator: PyHealth's transformer-based sequential model

Features:

  • Command-line interface for easy model selection
  • Configurable training parameters (epochs, batch size, learning rate)
  • Automatic model saving and synthetic data generation
  • Handles both flattened (tabular) and sequential (transformer) representations

5. Transformer Example

File: examples/synthetic_ehr_generation/synthetic_ehr_mimic3_transformer.py

Complete end-to-end example demonstrating:

  1. Loading MIMIC-III dataset
  2. Applying SyntheticEHRGenerationMIMIC3 task
  3. Splitting data by patient (prevents data leakage)
  4. Training TransformerEHRGenerator model
  5. Generating synthetic patient histories
  6. Converting outputs to multiple formats (CSV, text sequences)

Configurable Parameters:

  • Model architecture (embedding dim, layers, heads, feedforward dim)
  • Training hyperparameters (epochs, batch size, learning rate)
  • Generation settings (temperature, top-k, top-p sampling)
  • Dataset filtering (min/max visits, CCS code mapping)

Integration Updates

6. Module Exports

Files Modified:

  • pyhealth/models/__init__.py: Added TransformerEHRGenerator import
  • pyhealth/tasks/__init__.py: Added SyntheticEHRGenerationMIMIC3 and SyntheticEHRGenerationMIMIC4 imports
  • pyhealth/synthetic_ehr_utils/__init__.py: Added utility function exports

Examples

Training a Transformer Model

python synthetic_ehr_baselines.py \
    --mimic_root /path/to/mimic3 \
    --output_dir ./output \
    --mode transformer_baseline \
    --epochs 50 \
    --batch_size 64

Generating Synthetic Data

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import SyntheticEHRGenerationMIMIC3
from pyhealth.models import TransformerEHRGenerator

# Load dataset
dataset = MIMIC3Dataset(root="/path/to/mimic3", tables=["DIAGNOSES_ICD"])
task = SyntheticEHRGenerationMIMIC3(min_visits=2)
sample_dataset = dataset.set_task(task)

# Train model
model = TransformerEHRGenerator(dataset=sample_dataset)
# ... training code ...

# Generate synthetic patients
synthetic_codes = model.generate(num_samples=1000, max_visits=10)

Dependencies

New optional dependencies for baseline models:

  • be-great: For GReaT model support
  • sdv: For CTGAN and TVAE model support

Core PyHealth dependencies remain unchanged.

Breaking Changes

None. This PR only adds new functionality.

@ethanrasmussen ethanrasmussen marked this pull request as ready for review March 3, 2026 03:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant