Open
Conversation
Benchmark comparison (main → HEAD)Comparing
|
Collaborator
Author
|
I wasn't yet able to create a good benchmark for a model where batching actually helps. It needs a model that is sufficiently complex, so the compiler can't optimize it well, but it needs to still be runnable somewhat quickly. If using batches actually helps seems to really depend on the model. With Marvins retirement model it nearly cut the memory usage in half, for the Mahler & Yum Model it does very little. I also fixed an error in the MY model input creation and removed one of the tests, because |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Sometimes running a model is not possible because of memory restrictions. The nested
vmapscan lead to JAX creating large arrays for intermediate results, that can be temporarily saved in the GPU memory. Usually these arrays have the dimensions of the State-Action-Space of the model, so looping over batches of half the grid size along one of its dimensions can already halve the peak memory usage. The batching comes at a cost though, the execution time will get progressively worse the smaller the batch size. For big batches the drop in speed is bigger than I would have expected, given that not all the computations can happen at the same time anyways.New feature
This PR implements a batched version of
productmap. The user can for each grid specify the batch size for each states grid. Instead of usingvmapto map theQ_and_F_Functionalong this grid,jax.lax.mapwill be used, which will then either loop over the batches of gridpoints or ifbatch_size=0, work likevmap. The batched version will only be used during the solution, as the State-Action-Space for the simulation is already much smaller, as it only depends on the number of simulated subjects.Tasks
_base_productmapare needed