Skip to content

fix: Prevent in-place mutation in collective ops and simplify metric sync#3607

Closed
aaishwarymishra wants to merge 1 commit intopytorch:masterfrom
aaishwarymishra:fix/avoid-inplace-all-reduce
Closed

fix: Prevent in-place mutation in collective ops and simplify metric sync#3607
aaishwarymishra wants to merge 1 commit intopytorch:masterfrom
aaishwarymishra:fix/avoid-inplace-all-reduce

Conversation

@aaishwarymishra
Copy link
Collaborator

@aaishwarymishra aaishwarymishra commented Feb 26, 2026

This pull request improves the handling of tensors in distributed operations to ensure that backend implementations do not inadvertently modify the original tensors passed by callers. The changes focus on preventing in-place mutations and standardizing the behavior across distributed backends.

Distributed tensor mutation prevention:

  • In ignite/distributed/comp_models/base.py, the tensor is now cloned before applying backend operations, ensuring that any in-place modifications by backend implementations do not affect the original tensor.

Standardization of distributed reduction behavior:

  • In ignite/metrics/metric.py, the wrapper no longer clones tensors before passing them to idist.all_reduce, relying on the distributed layer to avoid in-place mutations and always return the reduced value. This change aligns the behavior across backends and simplifies the code.…tric sync_all_reduce

@github-actions github-actions bot added module: metrics Metrics module module: distributed Distributed module labels Feb 26, 2026
Comment on lines +858 to +861
# Backends should not mutate the caller's tensor; idist implementations
# are normalized to operate on a cloned tensor. We can therefore pass
# the original tensor and rely on the distributed layer to avoid
# in-place modifications and always return the reduced value.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have done that we can remove the comment, IMO


# Work on a cloned tensor so backend implementations that perform in-place
# operations do not mutate the original tensor passed by the caller.
tensor = tensor.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is in collective ops like all_reduce and all_gather. All gather does not do in-place op, so it is unnecessary to copy the input. The copy should be only done for all_reduce, IMO.

Comment on lines +179 to +180
if fn is self._do_all_reduce:
tensor = tensor.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more subtle than how you do that: 1) idist.all_reduce can work on float and torch.Tensor. if input is float then a new tensor is created from it => copy is redundant,
2) idist.all_reduce works with various backends: pytorch nccl, gloo, horovod, xla. Pytorch distributed and XLA are only doing an in-place all reduce op, Horovod allreduce is not an in-place op: https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.allreduce

Copy link
Collaborator Author

@aaishwarymishra aaishwarymishra Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get it I can just apply this clone logic in reduce_all method itself if its a tensor clone it if, its float do nothing.

As you mentioned a certain backends are doing the inplace ops, I can just add an isinstance(self,"_XlaDistModel") check to clone only tensors which are of these classes only, does it sound good?

Copy link
Collaborator

@vfdev-5 vfdev-5 Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is more appropriate to clone the tensor inside the dispatched _do_all_reduce methods of Native and XLA backends.

def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:

Copy link
Collaborator Author

@aaishwarymishra aaishwarymishra Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but _do_all_reduce tensor could be the original tensor and if the float was sent to reduce_all it gets converted to tensor, so cloning inside _do_all_reduce can create redundant tensor too if the input was float.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also to be sure comment said only ```reduce_all`` can modify the tensor in-place right on native and xla?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will be BC breaking as originally we were sending cloned tensor to all_reduce so all the in place operations were always happening in the cloned tensor, I could be wrong though.

If this is the case, why we want to have this PR then?
Please try locally and see. This is easily emulated with CPU and for example using scripts like:
https://colab.research.google.com/github/pytorch-ignite/playground/blob/main/index.ipynb#scrollTo=ddMXyp4ek1DO

Copy link
Collaborator Author

@aaishwarymishra aaishwarymishra Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed/comp_models/horovod.py in line 195 hvd.allreduce returns a new tensor , while dist.all_reduce in native.py and xm.all_reduce in xla.py does the in-place operation, that is dual behaviour so the main objective here is to either mutate the tensor or return new tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but _do_all_reduce tensor could be the original tensor and if the float was sent to reduce_all it gets converted to tensor, so cloning inside _do_all_reduce can create redundant tensor too if the input was float.

Yes, you are right about that, my bad.

Also to be sure comment said only ```reduce_all`` can modify the tensor in-place right on native and xla?

there is no reduce_all method, but all_reduce is the one inplace op for the backends native and xla.

On the other hand I'm thinking if we were fixing this behavior, it will be BC-breaking for people relying on that inplace behavior for native pytorch distributed backend. We should put a note about that in the docs.

You are right, we also have to refactor code in the project itself which uses all_reduce well the main question is should we change the behaviour? If yes whether to return a new tensor or mutate in place.

Sorry I think we were not on same page before

Copy link
Collaborator

@vfdev-5 vfdev-5 Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I wont change the behavior for several reasons:

  1. if we adopt one of the two options is will be BC breaking.
  2. among the supported backends torch native is the most useful thing as horovod and xla are mostly sunsetting. I prefer not to change the working part

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood

@aaishwarymishra aaishwarymishra force-pushed the fix/avoid-inplace-all-reduce branch 3 times, most recently from 4ad94cf to 048b5c9 Compare March 2, 2026 19:11
@aaishwarymishra aaishwarymishra force-pushed the fix/avoid-inplace-all-reduce branch from b3739c8 to 4bb7da1 Compare March 3, 2026 04:19
…effects in native and XLA distributed models
@aaishwarymishra aaishwarymishra marked this pull request as draft March 3, 2026 05:17
@aaishwarymishra aaishwarymishra deleted the fix/avoid-inplace-all-reduce branch March 3, 2026 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: distributed Distributed module module: metrics Metrics module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants