Conversation
… to plot (simple) interactive graph
…yers. heatmaps for other metrics.
There was a problem hiding this comment.
This file should probably be in mergekit/scripts.
There was a problem hiding this comment.
Also would be good to use click to turn the hardcoded values into arguments.
| from typing import List, Dict, Optional, Any, Tuple | ||
| from mergekit.graph import Task | ||
| import networkx as nx | ||
| import plotly.graph_objects as go |
There was a problem hiding this comment.
We should capture these new dependencies in pyproject.toml. Probably under a feature, so headless installs don't need to bring them in.
| ) | ||
| tensor_task | ||
| ) | ||
| finalize = FinalizeModel( |
There was a problem hiding this comment.
Totally fine to not do the finalize task when we're doing metrics, but this is needed for merges - I think as is this makes merges not write out correctly.
| **_kwargs, | ||
| ) -> Task: | ||
|
|
||
| if 'self_attn' in output_weight.name: |
There was a problem hiding this comment.
Down the line we probably want this split to be done based on new fields in ArchitectureInfo but this is good for now!
|
|
||
| res = {} | ||
|
|
||
| scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) |
There was a problem hiding this comment.
Should we be doing something here to guard against dividing by zero?
There was a problem hiding this comment.
yep - norms are non-negative so adding small epsilon will be fine
| aliases: Optional[Tuple[str, ...]] = None | ||
| force_dtype: Optional[str] = None | ||
|
|
||
| GQA_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA |
| num_heads=32 # hard-coded for now | ||
| ) | ||
| self.block_count += 1 | ||
| return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) |
There was a problem hiding this comment.
Does this end up creating N AttnTasks for each block? I don't think it's actually a problem as the tasks will be deduplicated downstream - should be fine
There was a problem hiding this comment.
Should only be one AttnTask for each block - the if statement on line 351 is only satisfied once all the tensors (K,Q,V,O) have been collected. Then self.attn_weight_dict is reset to {} and the (one) AttnTask is created. I might also add individual tensor metrics for comparing just the Qs, Vs etc, which would be simpler.
| self._method = merge_methods.get(config.merge_method) | ||
| if getattr(config, "merge_method", None): | ||
| self._method = merge_methods.get(config.merge_method) | ||
| elif getattr(config, "metric_method", None): |
There was a problem hiding this comment.
Would be good to add a validator to MergeConfig that checks that exactly one of these fields is set.
| ) | ||
|
|
||
| res = [] | ||
| for _task, value in exec.run(quiet=options.quiet): |
There was a problem hiding this comment.
Looking this over, I kinda think we might not need a separate file here - maybe it should just early out in merge.py if there's a metric_method set instead of merge_method?
| Abstract base class representing a task in a computational graph. | ||
|
|
||
| This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. | ||
| Note that PyDantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. |
There was a problem hiding this comment.
Super nitpick here: I think the official capitalization is Pydantic, not PyDantic.
…rge OR Metri, not both.
… to separate case
…eralised substitute function in architecture
Implemented:
Not Implemented: