Skip to content

Swap over to unified producer from graph store mode#546

Draft
kmontemayor2-sc wants to merge 1 commit intomainfrom
kmonte/unified-gs-producer
Draft

Swap over to unified producer from graph store mode#546
kmontemayor2-sc wants to merge 1 commit intomainfrom
kmonte/unified-gs-producer

Conversation

@kmontemayor2-sc
Copy link
Collaborator

Scope of work done

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: NO

@kmontemayor2-sc
Copy link
Collaborator Author

/all_test

@github-actions
Copy link
Contributor

github-actions bot commented Mar 13, 2026

GiGL Automation

@ 24:39:01UTC : 🔄 Lint Test started.

@ 24:45:54UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 13, 2026

GiGL Automation

@ 24:39:02UTC : 🔄 Python Unit Test started.

@ 01:53:40UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 13, 2026

GiGL Automation

@ 24:39:03UTC : 🔄 Scala Unit Test started.

@ 24:47:59UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 13, 2026

GiGL Automation

@ 24:39:03UTC : 🔄 Integration Test started.

@ 01:48:07UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 13, 2026

GiGL Automation

@ 24:39:04UTC : 🔄 E2E Test started.

@ 01:59:21UTC : ✅ Workflow completed successfully.

)
_flush()
all_channel_ids: list[list[int]] = [[] for _ in range(runtime.world_size)]
torch.distributed.all_gather_object(all_channel_ids, channel_id_list)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep identified an issue in your code:

Using torch.distributed.all_gather_object() to serialize and collect channel_id_list enables arbitrary code execution if an attacker can inject malicious pickle payloads from any rank in the distributed system.

More details about this

The code uses torch.distributed.all_gather_object() to collect channel_id_list from all ranks and broadcast it across the distributed system. This function relies on Python's pickle module under the hood to serialize and deserialize objects across processes.

Exploit scenario:
If an attacker can compromise any rank in the distributed training job (through code injection, environment variable manipulation, or supply chain attack), they can craft a malicious pickled object. When all_gather_object() deserializes this object on other ranks, it will execute arbitrary code embedded in the pickle payload.

For example:

  1. Attacker compromises a worker rank and injects malicious code that modifies channel_id_list before the gather call
  2. The attacker crafts a pickle payload that executes shell commands when deserialized (e.g., using __reduce__ to run os.system())
  3. When rank 0 (or another rank) calls all_gather_object(), it unpickles this malicious object
  4. The arbitrary code executes with the privileges of the process running the distributed training job
  5. Attacker gains access to model weights, training data, credentials, or can pivot to other systems

The vulnerability exists because pickle.loads() (used internally by PyTorch's serialization) treats untrusted data as code.

To resolve this comment:

✨ Commit Assistant Fix Suggestion
  1. Replace torch.distributed.all_gather_object(all_channel_ids, channel_id_list) with a safer alternative that does not use object serialization with pickle. You can use torch.distributed.all_gather if the data can be represented as a tensor.
  2. Convert channel_id_list to a tensor using torch.tensor(channel_id_list, dtype=torch.int64) before gathering if all elements are integers.
  3. Create a tensor all_channel_ids_tensor = torch.zeros([runtime.world_size, len(channel_id_list)], dtype=torch.int64, device=...) to hold the results, specifying the correct device if needed.
  4. Call torch.distributed.all_gather to gather the tensors from all processes, for example: torch.distributed.all_gather(list_of_output_tensors, channel_id_tensor). Adjust the output gathering as required.
  5. If you require the results as lists, convert the gathered tensors back to lists with .tolist().

Alternatively, if you must share richer objects across processes, serialize them manually using a safe method such as JSON, and gather strings or encoded byte tensors instead. This approach avoids the use of pickle, which can introduce security risks by allowing arbitrary code execution.

💬 Ignore this finding

Reply with Semgrep commands to ignore this finding.

  • /fp <comment> for false positive
  • /ar <comment> for acceptable risk
  • /other <comment> for all other reasons

Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.

You can view more details about this finding in the Semgrep AppSec Platform.

torch.distributed.all_gather_object(all_producer_ids, producer_id_list)
self._producer_id_list = all_producer_ids[leader_rank]
all_backend_ids: list[list[int]] = [[] for _ in range(runtime.world_size)]
torch.distributed.all_gather_object(all_backend_ids, backend_id_list)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep identified an issue in your code:

The torch.distributed.all_gather_object() call deserializes untrusted data from remote ranks using pickle, allowing arbitrary code execution if an attacker controls any rank in the distributed system.

More details about this

The torch.distributed.all_gather_object() call is using pickle deserialization under the hood to share producer_id_list across all ranks in the distributed system. This means untrusted data from other processes gets automatically deserialized and executed.

An attacker with access to any rank in the distributed training job could craft a malicious pickled object and send it during the all-gather phase. When producer_id_list (or other ranks' data) gets deserialized on your rank, the attacker's code runs immediately with the same privileges as your training process.

For example:

  1. Attacker gains write access to rank 1's memory or intercepts its state
  2. They insert a pickled Python object that executes shell commands when unpickled (e.g., os.system('steal_data.sh'))
  3. Your rank calls torch.distributed.all_gather_object(all_producer_ids, producer_id_list)
  4. PyTorch's pickle unpickles all ranks' data, triggering the attacker's code during deserialization on your process
  5. The shell commands run with your process's credentials, potentially exfiltrating model weights or training data

To resolve this comment:

✨ Commit Assistant Fix Suggestion
  1. Avoid using torch.distributed.all_gather_object, as this relies on Python pickling, which may lead to arbitrary code execution if untrusted data is ever deserialized.
  2. Replace the use of all_gather_object with a tensor-based collective, such as torch.distributed.all_gather, by converting your data to a tensor (for example, use torch.tensor(producer_id_list)).
  3. Predefine the size and type of the tensor holding the gathered data, for example: all_producer_ids = torch.empty(runtime.world_size * num_producers, dtype=torch.long), where num_producers is the expected length of producer_id_list for each rank.
  4. Call torch.distributed.all_gather([all_producer_ids], producer_id_list_tensor), where producer_id_list_tensor = torch.tensor(producer_id_list, dtype=torch.long) for each rank.
  5. After gathering, reconstruct the original data structure as needed from all_producer_ids. For example, split the combined tensor into per-rank lists using slicing.
  6. If the number of producers may vary, agree on a fixed length and pad with a sentinel value such as -1 so tensors can be safely communicated.

This change ensures only primitive tensor data is shared between ranks, eliminating pickle-related risks.

💬 Ignore this finding

Reply with Semgrep commands to ignore this finding.

  • /fp <comment> for false positive
  • /ar <comment> for acceptable risk
  • /other <comment> for all other reasons

Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.

You can view more details about this finding in the Semgrep AppSec Platform.

) -> list[int]:
backend_key = self._build_graph_store_backend_key()
all_backend_keys: list[Optional[str]] = [None] * runtime.world_size
torch.distributed.all_gather_object(all_backend_keys, backend_key)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep identified an issue in your code:

Using torch.distributed.all_gather_object() to share my_worker_key creates an arbitrary code execution risk, since pickle deserialization can execute attacker-controlled code if a compromised worker sends malicious data.

More details about this

The torch.distributed.all_gather_object() call on this line uses pickle to serialize and deserialize the my_worker_key string across distributed processes. An attacker who can control the data sent from any worker process could craft a malicious pickle payload that executes arbitrary code when deserialized.

Exploit scenario:

  1. An attacker compromises or spoofs one of the worker processes in the distributed training cluster
  2. They set my_worker_key to a malicious pickle-serialized object instead of a normal string
  3. When this line executes, PyTorch deserializes the pickle object on all receiving ranks
  4. The malicious pickle payload executes arbitrary code on those processes with the same privileges as the training job, potentially stealing model weights, injecting backdoors, or exfiltrating data

To resolve this comment:

✨ Commit Assistant Fix Suggestion
  1. Avoid using torch.distributed.all_gather_object, as it uses pickle internally and can allow arbitrary code execution if untrusted data is deserialized.
  2. If exchanging strings across ranks, switch to using torch.distributed.all_gather, which is safe for tensors. You can do this by encoding your strings to byte tensors before gathering and decoding after.
  3. Replace the vulnerable line with logic similar to:
    • Convert the local string to bytes: my_worker_key_bytes = my_worker_key.encode('utf-8')
    • Find the maximum length of all keys to ensure tensors are the same size across ranks. This usually requires an all_reduce to get the max. For example:
      key_len_tensor = torch.tensor([len(my_worker_key_bytes)], device='cpu')
      max_len_tensor = key_len_tensor.clone()
      torch.distributed.all_reduce(max_len_tensor, op=torch.distributed.ReduceOp.MAX)
    • Pad your byte string to max_len_tensor.item(): padded_bytes = my_worker_key_bytes.ljust(max_len_tensor.item(), b'\x00')
    • Create a tensor: my_worker_key_tensor = torch.ByteTensor(list(padded_bytes))
    • Prepare a gather tensor: all_worker_key_tensors = [torch.empty_like(my_worker_key_tensor) for _ in range(runtime.world_size)]
    • Call torch.distributed.all_gather(all_worker_key_tensors, my_worker_key_tensor)
    • After gathering, decode each tensor: [bytes(t.tolist()).rstrip(b'\x00').decode('utf-8') for t in all_worker_key_tensors]
  4. Replace all uses of all_worker_keys with the decoded string list.

Alternatively, if all ranks already know or can deterministically construct the set of worker keys, you can avoid broadcasting entirely by constructing the list locally.

Using tensors for communication prevents vulnerabilities from deserialization attacks, as tensor operations do not use pickle.

💬 Ignore this finding

Reply with Semgrep commands to ignore this finding.

  • /fp <comment> for false positive
  • /ar <comment> for acceptable risk
  • /other <comment> for all other reasons

Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.

You can view more details about this finding in the Semgrep AppSec Platform.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants