Skip to content

[draft] sp for dflash#507

Draft
uygnef wants to merge 10 commits intosgl-project:mainfrom
uygnef:sp/dflash
Draft

[draft] sp for dflash#507
uygnef wants to merge 10 commits intosgl-project:mainfrom
uygnef:sp/dflash

Conversation

@uygnef
Copy link
Copy Markdown
Collaborator

@uygnef uygnef commented Mar 19, 2026

Motivation

SP for dflash (Ulysses only). Work in progress; please do not merge yet.

Modifications

Related Issues

Accuracy Test

image

Benchmark & Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant enhancements for DFlash model training and hidden state preparation. It enables the generation of DFlash-specific hidden states and supports an offline training mode, which leverages these pre-computed states. A major focus is the integration of Sequence Parallelism (USP) into the DFlash attention mechanism, allowing for more scalable and efficient training on distributed systems. The changes span across data preprocessing, model architecture, and training scripts, ensuring a robust and flexible framework for DFlash development.

Highlights

  • DFlash Model Type Support: Added a new --model-type dflash argument to scripts/prepare_hidden_states.py to enable generating hidden states specifically for DFlash models, alongside the existing Eagle3 support.
  • Offline DFlash Training: Introduced an offline training mode for DFlash in scripts/train_dflash.py, allowing the model to be trained using pre-computed hidden states, which can significantly reduce training time and resource usage.
  • Sequence Parallelism (USP) Integration: Implemented Sequence Parallelism (USP) for DFlash training, enabling distributed processing of sequences across multiple devices to handle longer contexts more efficiently. This includes new distributed arguments and attention backend choices.
  • Unified Hidden States Generator: Refactored the HiddenStatesGenerator to support both Eagle3 and DFlash model types, including DFlash-specific arguments like --num-draft-layers and --block-size, and filtering logic for samples with insufficient loss tokens.
  • Data Loading and Preprocessing Enhancements: Updated data loading mechanisms to support offline DFlash datasets and integrated USP-specific preprocessing within the OfflineDFlashDataset and DataCollatorWithPadding to handle sharded inputs and global value collection.
  • DFlash Attention Mechanism Update: Modified the DFlash attention mechanism to incorporate USP, including scattering and gathering operations for query, key, and value tensors across distributed ranks, ensuring correctness in a parallel environment.
  • Comprehensive Testing for DFlash Core and USP Parity: Added new unit tests to validate the core functionalities of DFlash, including noise block construction, attention mask generation, and label/weight mask logic. Crucially, new tests confirm parity between USP and non-USP DFlash implementations for loss and accuracy calculations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new capabilities by adding support for the DFlash model, including both online and offline training modes, as well as Ulysses Sequence Parallelism (USP) for distributed training. The changes are extensive and well-structured, touching upon data preparation, model definition, core training logic, and data loading. The introduction of dflash as a model-type is handled consistently across the scripts. The implementation of offline training via pre-computed hidden states is a great feature for efficiency. The support for USP is complex but appears correct, with necessary logic for distributed sampling, communication, and loss calculation. The addition of comprehensive unit and parity tests for these new features is commendable and crucial for ensuring correctness. I have a couple of minor suggestions for improving code clarity and removing a redundant check, but overall this is a solid contribution.

Comment on lines +331 to +335
self.min_loss_tokens = (
min_loss_tokens
if min_loss_tokens is not None
else (2 * block_size if block_size is not None else None)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested ternary operator for initializing self.min_loss_tokens is a bit dense and can be hard to read. Refactoring it into a more explicit if/elif/else structure would improve clarity and maintainability.

Suggested change
self.min_loss_tokens = (
min_loss_tokens
if min_loss_tokens is not None
else (2 * block_size if block_size is not None else None)
)
if min_loss_tokens is not None:
self.min_loss_tokens = min_loss_tokens
elif block_size is not None:
self.min_loss_tokens = 2 * block_size
else:
self.min_loss_tokens = None

Comment on lines +387 to +388
if dist.get_world_size(draft_sp_group) <= 1:
raise ValueError("Offline DFlash USP requires draft SP world size > 1")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for the draft_sp_group world size appears to be redundant. The size of this group is determined by sp_ulysses_size * sp_ring_size. Given that sp_ring_size is already asserted to be 1, the group size is equal to sp_ulysses_size. Another check args.sp_ulysses_size <= 1 already exists earlier in this function, making this one unnecessary. Removing this redundant check will simplify the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant