Conversation
long8v
left a comment
There was a problem hiding this comment.
(24.12.04) torch distributed study PP 쪽 봄
| def pipeline_llama( | ||
| model: nn.Module, | ||
| pp_mesh: DeviceMesh, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| device: DeviceType, | ||
| model_config: ModelArgs, | ||
| loss_fn: Callable[..., torch.Tensor], | ||
| ): | ||
| stages, models = pipeline_llama_manual_split( | ||
| model, pp_mesh, parallel_dims, job_config, device, model_config | ||
| ) | ||
|
|
||
| pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) | ||
|
|
||
| return pp_schedule, models |
There was a problem hiding this comment.
PP recap -- notion
PP 메인 1) pipline_llama_manual_split으로 모델 쪼개주는 것과 2) build_pipeline_schedule로 micro batch 등 pipeline 스케쥴하는 것 두개로 나누어짐.
| def pipeline_llama_manual_split( | ||
| whole_model: nn.Module, | ||
| pp_mesh: DeviceMesh, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| device: DeviceType, | ||
| model_config: ModelArgs, | ||
| ): |
There was a problem hiding this comment.
llama 쪼개는 함수. DeviceMesh는 Torch native고 ParallelDims는 내부 함수
| @dataclass | ||
| class ParallelDims: | ||
| dp_replicate: int | ||
| dp_shard: int | ||
| cp: int | ||
| tp: int | ||
| pp: int | ||
| world_size: int | ||
| enable_loss_parallel: bool |
There was a problem hiding this comment.
dp_replicat, dp_shard, cp, tp, pp 등 정의하는 클래스
| dp = dp_replicate * dp_shard | ||
| if dp < 0: | ||
| dp = self.world_size // (cp * tp * pp) | ||
| self.dp_shard = dp_shard = dp // dp_replicate |
There was a problem hiding this comment.
dp_shrad * dp_replicate는 word_size를 (cp * tp * pp)로 나눈 것과 같아야 함. (model parallel을 하고 남은 차원에서 DP)
| def build_mesh(self, device_type): | ||
| dims = [] | ||
| names = [] | ||
| for d, name in zip( | ||
| [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], | ||
| ["pp", "dp_replicate", "dp_shard", "cp", "tp"], | ||
| ): | ||
| if d > 1: | ||
| dims.append(d) | ||
| if (name == "dp_replicate" and self.dp_shard == 1) or ( | ||
| name == "dp_shard" and self.dp_replicate == 1 | ||
| ): | ||
| names.append("dp") | ||
| else: | ||
| names.append(name) | ||
|
|
||
| logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") | ||
| names = tuple(names) | ||
| mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) |
| # init distributed | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| parallel_dims = ParallelDims( | ||
| dp_shard=job_config.training.data_parallel_shard_degree, | ||
| dp_replicate=job_config.training.data_parallel_replicate_degree, | ||
| cp=job_config.experimental.context_parallel_degree, | ||
| tp=job_config.training.tensor_parallel_degree, | ||
| pp=job_config.experimental.pipeline_parallel_degree, | ||
| world_size=world_size, | ||
| enable_loss_parallel=job_config.training.enable_loss_parallel, | ||
| ) |
| logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") | ||
|
|
||
| # build meshes | ||
| world_mesh = parallel_dims.build_mesh(device_type=device_type) |
| if parallel_dims.pp_enabled: | ||
| # apply PT-D Pipeline Parallel | ||
| pp_schedule, model_parts = models_pipelining_fns[model_name]( | ||
| model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn | ||
| ) | ||
|
|
||
| # For PP with looped schedules, each item in model_parts is one stage-model-chunk. | ||
| # We need to iterate through model_parts to apply SPMD parallelisms, compilation, | ||
| # optimizer, and checkpointing | ||
| for m in model_parts: | ||
| # apply SPMD-style PT-D techniques | ||
| models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) | ||
| m.to_empty(device=init_device) | ||
| m.init_weights(buffer_device=buffer_device) | ||
| m.train() |
There was a problem hiding this comment.
여기서 model_pipelining_fns를 불러오고 call 해줌.
| with maybe_enable_profiling( | ||
| job_config, global_step=train_state.step | ||
| ) as torch_profiler, maybe_enable_memory_snapshot( | ||
| job_config, global_step=train_state.step | ||
| ) as memory_profiler: | ||
| while train_state.step < job_config.training.steps: | ||
| train_state.step += 1 | ||
| gc_handler.run(train_state.step) | ||
|
|
||
| # get batch | ||
| data_load_start = time.perf_counter() | ||
| batch = next(data_iterator) |
| if parallel_dims.pp_enabled: | ||
| # Pipeline Parallel forward / backward inside step() call | ||
| is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 | ||
|
|
||
| with train_context(optional_context_parallel_ctx): | ||
| if pp_mesh.get_local_rank() == 0: | ||
| pp_schedule.step(input_ids) | ||
| elif is_last_stage: | ||
| losses = [] | ||
| pp_schedule.step(target=labels, losses=losses) | ||
| else: | ||
| pp_schedule.step() | ||
|
|
||
| # accumulate losses across pipeline microbatches | ||
| loss = ( | ||
| torch.mean(torch.stack(losses)) | ||
| if is_last_stage | ||
| else torch.Tensor([-1.0]) | ||
| ) |
There was a problem hiding this comment.
pp_schedule.step()으로 micro batch forward 하는듯
| def parallelize_llama( | ||
| model: nn.Module, | ||
| world_mesh: DeviceMesh, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| ): |
There was a problem hiding this comment.
llama에 대한 parallelize 하는 부분은 여기 모아져있음
| and not job_config.training.compile | ||
| ): | ||
| raise RuntimeError("Async TP requires --training.compile") | ||
| apply_tp( |
| ) | ||
|
|
||
| if job_config.activation_checkpoint.mode != "none": | ||
| apply_ac(model, job_config.activation_checkpoint) |
| def parallelize_llama( | ||
| model: nn.Module, | ||
| world_mesh: DeviceMesh, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| ): | ||
| """ | ||
| Apply tensor parallelism, activation checkpointing, torch.compile, and data | ||
| parallelism to the model. | ||
|
|
||
| NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
| the model must fit on GPU or CPU memory. | ||
| """ |
There was a problem hiding this comment.
여기가 llama를 parallelize 하는 부분
| if parallel_dims.tp_enabled: | ||
| if ( | ||
| job_config.experimental.enable_async_tensor_parallel | ||
| and not job_config.training.compile | ||
| ): | ||
| raise RuntimeError("Async TP requires --training.compile") | ||
| apply_tp( |
| module (Union[nn.Module, List[nn.Module]): The module or modules to | ||
| shard with FSDP and group together for communication. | ||
| mesh (Optional[DeviceMesh]): This data parallel mesh defines the | ||
| sharding and device. If 1D, then parameters are fully sharded | ||
| across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D, | ||
| then parameters are sharded across the 1st dim and replicated | ||
| across the 0th dim (HSDP) with ``(Replicate(), Shard(0))`` | ||
| placement. The mesh's device type gives the device type used for | ||
| communication; if a CUDA or CUDA-like device type, then we use the | ||
| current device. |
| mesh = mesh or _init_default_fully_shard_mesh() | ||
| if mesh.ndim not in (1, 2): | ||
| raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") | ||
| elif mesh.ndim == 1: | ||
| mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) | ||
| else: | ||
| if mesh.mesh_dim_names is None: | ||
| raise AssertionError( | ||
| "Please init the 2D mesh for HSDP with mesh_dim_names specified" | ||
| ) | ||
| mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) | ||
| device = _get_device_from_mesh(mesh) | ||
| post_forward_mesh_info = _get_post_forward_mesh_info( | ||
| reshard_after_forward, mesh_info | ||
| ) |
| arg_module = module | ||
| modules = ( | ||
| (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) | ||
| ) | ||
| state = fully_shard.state(modules[0]) | ||
| state.init(modules, device, mp_policy) |
There was a problem hiding this comment.
fsdp state 설정해줌. hook같은거 설정해주고 등등 하는듯
class FSDPState(_State):
def __init__(self) -> None:
super().__init__()
self._fsdp_param_group: Optional[FSDPParamGroup] = None
self._is_root: Optional[bool] = None # root set during lazy init
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []
self._modules_to_run_forward: Set[nn.Module] = set()
# Define a separate init since `__init__` is called in the contract
def init(
self,
modules: Tuple[nn.Module, ...],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
) -> None:
for module in modules:
_insert_module_state(module, self)
self._modules = modules
self._device = device
self._device_handle = _get_device_handle(device.type)
self._mp_policy = mp_policy
if len(modules) == 1:
self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
self._pre_forward, prepend=True, with_kwargs=True
)
self._post_forward_hook_handle = modules[0].register_forward_hook(
self._post_forward, prepend=False
)
else:
hook_handle = _register_group_forward_hooks(
modules,
self._pre_forward,
self._post_forward,
self._modules_to_run_forward,
)
self._pre_forward_hook_handle = hook_handle
self._post_forward_hook_handle = hook_handle| if params: | ||
| state._fsdp_param_group = FSDPParamGroup( | ||
| params, | ||
| modules, | ||
| mesh_info, | ||
| post_forward_mesh_info, | ||
| device, | ||
| shard_placement_fn, | ||
| mp_policy, | ||
| offload_policy, | ||
| ) |
| # Place FSDP leftmost for highest priority in the method resolution order | ||
| for module in modules: | ||
| cls = module.__class__ | ||
| new_cls = cls_to_fsdp_cls.get(cls, None) | ||
| if not new_cls: | ||
| dct = {"__deepcopy__": _unimplemented_deepcopy} | ||
| new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) | ||
| cls_to_fsdp_cls[cls] = new_cls | ||
| module.__class__ = new_cls | ||
| return arg_module |
There was a problem hiding this comment.
cls wrapping해주는 부분 있음. 이 부분은 이전과 비슷한듯함

No description provided.