-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Summary
The calibration pipeline currently runs all major steps sequentially. Since the Modal container has 8 CPUs available, we can use ProcessPoolExecutor within the single container to parallelize the three most compute-heavy stages, reducing total pipeline time from ~30-60 min to ~10-20 min (~3x speedup).
Current Bottlenecks
| Component | Current (sequential) | Time |
|---|---|---|
| Per-state precompute | 51 Microsimulation calls, one at a time |
~5-10 min |
| Per-county precompute | 51 state groups, one at a time | ~10-20 min |
| Clone loop | 436 clones, one at a time | ~7-15 min |
| L0 optimization | Torch training (not parallelizable) | ~5-10 min |
| Total | ~30-60 min |
All three parallelizable stages are perfectly independent — states don't depend on each other, county groups don't depend on each other, and clones don't depend on each other.
Proposed Approach: ProcessPoolExecutor (8 workers)
Use Python's ProcessPoolExecutor within the single Modal function (which already has 8 CPUs). This avoids Modal Volume sync overhead and keeps results in memory.
Why ProcessPoolExecutor and not threads or Modal-level parallelism?
- Not threads:
Microsimulation.calculate()is CPU-bound Python/NumPy — the GIL serializes threads.ProcessPoolExecutorspawns separate OS processes with independent GILs. - Not separate Modal containers: Per-state tasks only take ~5-10 seconds each, so Modal's cold start + serialization overhead (~5-10s) would negate the gains. ProcessPoolExecutor within one container is simpler and has no sync overhead.
Code Changes Required
1. Per-state precompute (_build_state_values)
Extract the loop body into a top-level standalone function (not a method — Python can't pickle bound methods across processes):
def _compute_single_state(state, dataset_path, time_period,
target_vars, constraint_vars, n_hh,
rerandomize_takeup, affected_targets):
"""Runs in a worker process."""
state_sim = Microsimulation(dataset=dataset_path)
# ... exact same logic as current loop body ...
return state, {"hh": hh, "person": person, "entity": entity_vals}
# In _build_state_values:
with ProcessPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(_compute_single_state, state, ...)
for state in unique_states]
for future in as_completed(futures):
state, vals = future.result()
state_values[state] = vals2. Per-county precompute (_build_county_values)
Same pattern — extract per-state-group body into a top-level function.
3. Clone loop (436 clones)
Each clone already saves a clone_XXXX.npz file, so the disk cache pattern naturally supports parallelization. The challenge is that state_values is a large dict (~2GB). Two options:
- Option A (preferred on Modal/Linux): Use
multiprocessing.forkstart method. Child processes inherit parent memory via copy-on-write —state_valuesis shared without copying. - Option B (portable): Serialize
state_valuesto disk (one.npzper state), each worker loads only the states it needs.
4. Add workers parameter
Add a workers parameter to build_matrix() and run_calibration(), defaulting to 1 (preserving current sequential behavior). The Modal container would pass workers=8.
Estimated Speedup
| Component | Sequential | 8 Workers | Speedup |
|---|---|---|---|
| Per-state precompute | ~5-10 min | ~1-2 min | ~6x |
| Per-county precompute | ~10-20 min | ~2-4 min | ~5x |
| Clone loop | ~7-15 min | ~1-3 min | ~6x |
| L0 optimization | ~5-10 min | ~5-10 min | 1x |
| Total | ~30-60 min | ~10-20 min | ~3x |
Speedup isn't a full 8x because of worker startup overhead (~2-3s per process for importing policyengine_us), unbalanced state sizes (California >> Wyoming), and the non-parallelizable optimization step.
Files to Modify
policyengine_us_data/calibration/unified_matrix_builder.py— states, counties, clone looppolicyengine_us_data/calibration/unified_calibration.py—workersparameter, pass-throughpolicyengine_us_data/storage/upload_completed_datasets.py— passworkersfrom Modal config
Notes
- This does not change any outputs — same math, same results, just faster
- Default
workers=1means no change for local development or CI - All 121 existing tests should pass without modification