From 0a032a299cd7f1ce5d66c3831fd17ab68dd8439f Mon Sep 17 00:00:00 2001 From: Thijs Vogels Date: Thu, 19 Mar 2026 10:50:22 +0000 Subject: [PATCH] Fix OneDFT MPI double-counting of EXC energy In the OneDFT integrator path, only rank 0 computes the XC energy from the neural network model. The energy is then allreduced with Sum across all MPI ranks. On repeated calls, non-rank-0 ranks still hold the correct energy value from the previous allreduce, causing the Sum to yield 2x the correct value. Fix by zeroing EXC on non-rank-0 processes before the allreduce, so only rank 0's contribution is summed. This affects both the host and device integrator paths. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../device/incore_replicated_xc_device_integrator_onedft.hpp | 2 ++ .../host/reference_replicated_xc_host_integrator_onedft.hpp | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp index b04dbf5b..b742baa3 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp @@ -220,6 +220,8 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, c10::cuda::CUDACachingAllocator::emptyCache(); EXC[0] = exc.item(); // std::cout << "EXC: " << EXC[0] << std::endl; + } else { + EXC[0] = 0.0; } if ( world_size == 1 ) { diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp index 3fea8596..498f852d 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp @@ -129,8 +129,9 @@ void ReferenceReplicatedXCHostIntegrator:: auto exc = (exc_on_grid * features_dict.at(feat_map.at(ONEDFT_FEATURE::WEIGHTS))).sum(); exc.backward(); EXC[0] = exc.item().to(); + } else { + EXC[0] = 0.0; } - // MPI_Bcast(EXC, 1, MPI_DOUBLE, 0, rt.comm()); // TODO: stop here if only exc send_buffer_onedft_outputs(2/*ndm*/, features_dict, tasks, rt, sendcounts, displs);