-
Notifications
You must be signed in to change notification settings - Fork 6
Multiple forward per backward #81
Copy link
Copy link
Open
Labels
coreImproves core model while keeping core idea intactImproves core model while keeping core idea intactengineeringSoftware-engineering problems that don't require ML-ExpertiseSoftware-engineering problems that don't require ML-ExpertiseresearchCreative project that might fail but could give high returnsCreative project that might fail but could give high returns
Metadata
Metadata
Assignees
Labels
coreImproves core model while keeping core idea intactImproves core model while keeping core idea intactengineeringSoftware-engineering problems that don't require ML-ExpertiseSoftware-engineering problems that don't require ML-ExpertiseresearchCreative project that might fail but could give high returnsCreative project that might fail but could give high returns
Currently, our model does one forward pass and uses the intermediate states to do one backward pass. However, a backward pass is over 3x as expensive as a forward pass, so we could change the ratio of forward to backward passes to speed up the model.
One such approach would be MESA, which adds
KL(model(x), ema_model(x)). Another method is RHO-Loss, which prioritizes some samples over others, by running(model(x) - oracle(x)).topk(). Both of these methods claim to improve sample efficiency by up to 18x.