Extend CheckpointFunction to track all tensor input/output#1148
Extend CheckpointFunction to track all tensor input/output#1148000Justin000 wants to merge 3 commits intofacebookresearch:mainfrom
Conversation
|
Hi @000Justin000! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
What does this PR do?
The current activation checkpointing implementation would require the input/output argument to be a Tensor to be properly tracked by the auto grad, however for pyspeech nn layers we often use aux_input as a dict and output state which is a list.
This diff enables serialization of a python container: given an input that could be any python "container" (tuple, list, dict), perform a (depth first search) DFS to extract the pytorch tensors from the container and serialize the output to a tuple of tensors. At the original location replace with a index to the serialized list of tensors. As such, the original input can be easily reconstructed.
Before checkpointed_forward, the serialization happens and the tuple of tensors is use as input to forward (thus tracked); during checkpointed_forward, the original input is reconstructed by deserialization and pass in the original forward; the output of the original forward is serialized in the same manner and returned (so that the output is also tracked). After checkpointed_forward, the serialized output is deserialized to the desired format.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.