Swap over to unified producer from graph store mode#546
Swap over to unified producer from graph store mode#546kmontemayor2-sc wants to merge 1 commit intomainfrom
Conversation
|
/all_test |
GiGL Automation@ 24:39:01UTC : 🔄 @ 24:45:54UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 24:39:02UTC : 🔄 @ 01:53:40UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 24:39:03UTC : 🔄 @ 24:47:59UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 24:39:03UTC : 🔄 @ 01:48:07UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 24:39:04UTC : 🔄 @ 01:59:21UTC : ✅ Workflow completed successfully. |
| ) | ||
| _flush() | ||
| all_channel_ids: list[list[int]] = [[] for _ in range(runtime.world_size)] | ||
| torch.distributed.all_gather_object(all_channel_ids, channel_id_list) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
Using torch.distributed.all_gather_object() to serialize and collect channel_id_list enables arbitrary code execution if an attacker can inject malicious pickle payloads from any rank in the distributed system.
More details about this
The code uses torch.distributed.all_gather_object() to collect channel_id_list from all ranks and broadcast it across the distributed system. This function relies on Python's pickle module under the hood to serialize and deserialize objects across processes.
Exploit scenario:
If an attacker can compromise any rank in the distributed training job (through code injection, environment variable manipulation, or supply chain attack), they can craft a malicious pickled object. When all_gather_object() deserializes this object on other ranks, it will execute arbitrary code embedded in the pickle payload.
For example:
- Attacker compromises a worker rank and injects malicious code that modifies
channel_id_listbefore the gather call - The attacker crafts a pickle payload that executes shell commands when deserialized (e.g., using
__reduce__to runos.system()) - When rank 0 (or another rank) calls
all_gather_object(), it unpickles this malicious object - The arbitrary code executes with the privileges of the process running the distributed training job
- Attacker gains access to model weights, training data, credentials, or can pivot to other systems
The vulnerability exists because pickle.loads() (used internally by PyTorch's serialization) treats untrusted data as code.
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Replace
torch.distributed.all_gather_object(all_channel_ids, channel_id_list)with a safer alternative that does not use object serialization with pickle. You can usetorch.distributed.all_gatherif the data can be represented as a tensor. - Convert
channel_id_listto a tensor usingtorch.tensor(channel_id_list, dtype=torch.int64)before gathering if all elements are integers. - Create a tensor
all_channel_ids_tensor = torch.zeros([runtime.world_size, len(channel_id_list)], dtype=torch.int64, device=...)to hold the results, specifying the correct device if needed. - Call
torch.distributed.all_gatherto gather the tensors from all processes, for example:torch.distributed.all_gather(list_of_output_tensors, channel_id_tensor). Adjust the output gathering as required. - If you require the results as lists, convert the gathered tensors back to lists with
.tolist().
Alternatively, if you must share richer objects across processes, serialize them manually using a safe method such as JSON, and gather strings or encoded byte tensors instead. This approach avoids the use of pickle, which can introduce security risks by allowing arbitrary code execution.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
| torch.distributed.all_gather_object(all_producer_ids, producer_id_list) | ||
| self._producer_id_list = all_producer_ids[leader_rank] | ||
| all_backend_ids: list[list[int]] = [[] for _ in range(runtime.world_size)] | ||
| torch.distributed.all_gather_object(all_backend_ids, backend_id_list) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
The torch.distributed.all_gather_object() call deserializes untrusted data from remote ranks using pickle, allowing arbitrary code execution if an attacker controls any rank in the distributed system.
More details about this
The torch.distributed.all_gather_object() call is using pickle deserialization under the hood to share producer_id_list across all ranks in the distributed system. This means untrusted data from other processes gets automatically deserialized and executed.
An attacker with access to any rank in the distributed training job could craft a malicious pickled object and send it during the all-gather phase. When producer_id_list (or other ranks' data) gets deserialized on your rank, the attacker's code runs immediately with the same privileges as your training process.
For example:
- Attacker gains write access to rank 1's memory or intercepts its state
- They insert a pickled Python object that executes shell commands when unpickled (e.g.,
os.system('steal_data.sh')) - Your rank calls
torch.distributed.all_gather_object(all_producer_ids, producer_id_list) - PyTorch's pickle unpickles all ranks' data, triggering the attacker's code during deserialization on your process
- The shell commands run with your process's credentials, potentially exfiltrating model weights or training data
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Avoid using
torch.distributed.all_gather_object, as this relies on Python pickling, which may lead to arbitrary code execution if untrusted data is ever deserialized. - Replace the use of
all_gather_objectwith a tensor-based collective, such astorch.distributed.all_gather, by converting your data to a tensor (for example, usetorch.tensor(producer_id_list)). - Predefine the size and type of the tensor holding the gathered data, for example:
all_producer_ids = torch.empty(runtime.world_size * num_producers, dtype=torch.long), wherenum_producersis the expected length ofproducer_id_listfor each rank. - Call
torch.distributed.all_gather([all_producer_ids], producer_id_list_tensor), whereproducer_id_list_tensor = torch.tensor(producer_id_list, dtype=torch.long)for each rank. - After gathering, reconstruct the original data structure as needed from
all_producer_ids. For example, split the combined tensor into per-rank lists using slicing. - If the number of producers may vary, agree on a fixed length and pad with a sentinel value such as
-1so tensors can be safely communicated.
This change ensures only primitive tensor data is shared between ranks, eliminating pickle-related risks.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
| ) -> list[int]: | ||
| backend_key = self._build_graph_store_backend_key() | ||
| all_backend_keys: list[Optional[str]] = [None] * runtime.world_size | ||
| torch.distributed.all_gather_object(all_backend_keys, backend_key) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
Using torch.distributed.all_gather_object() to share my_worker_key creates an arbitrary code execution risk, since pickle deserialization can execute attacker-controlled code if a compromised worker sends malicious data.
More details about this
The torch.distributed.all_gather_object() call on this line uses pickle to serialize and deserialize the my_worker_key string across distributed processes. An attacker who can control the data sent from any worker process could craft a malicious pickle payload that executes arbitrary code when deserialized.
Exploit scenario:
- An attacker compromises or spoofs one of the worker processes in the distributed training cluster
- They set
my_worker_keyto a malicious pickle-serialized object instead of a normal string - When this line executes, PyTorch deserializes the pickle object on all receiving ranks
- The malicious pickle payload executes arbitrary code on those processes with the same privileges as the training job, potentially stealing model weights, injecting backdoors, or exfiltrating data
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Avoid using
torch.distributed.all_gather_object, as it uses pickle internally and can allow arbitrary code execution if untrusted data is deserialized. - If exchanging strings across ranks, switch to using
torch.distributed.all_gather, which is safe for tensors. You can do this by encoding your strings to byte tensors before gathering and decoding after. - Replace the vulnerable line with logic similar to:
- Convert the local string to bytes:
my_worker_key_bytes = my_worker_key.encode('utf-8') - Find the maximum length of all keys to ensure tensors are the same size across ranks. This usually requires an
all_reduceto get the max. For example:
key_len_tensor = torch.tensor([len(my_worker_key_bytes)], device='cpu')
max_len_tensor = key_len_tensor.clone()
torch.distributed.all_reduce(max_len_tensor, op=torch.distributed.ReduceOp.MAX) - Pad your byte string to
max_len_tensor.item():padded_bytes = my_worker_key_bytes.ljust(max_len_tensor.item(), b'\x00') - Create a tensor:
my_worker_key_tensor = torch.ByteTensor(list(padded_bytes)) - Prepare a gather tensor:
all_worker_key_tensors = [torch.empty_like(my_worker_key_tensor) for _ in range(runtime.world_size)] - Call
torch.distributed.all_gather(all_worker_key_tensors, my_worker_key_tensor) - After gathering, decode each tensor:
[bytes(t.tolist()).rstrip(b'\x00').decode('utf-8') for t in all_worker_key_tensors]
- Convert the local string to bytes:
- Replace all uses of
all_worker_keyswith the decoded string list.
Alternatively, if all ranks already know or can deterministically construct the set of worker keys, you can avoid broadcasting entirely by constructing the list locally.
Using tensors for communication prevents vulnerabilities from deserialization attacks, as tensor operations do not use pickle.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
Scope of work done
Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: NO