Add reprompter#30
Merged
Merged
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a "Prompt Assistant" (re-prompter) feature that uses a Qwen LLM to classify a user's free-form input and rewrite it into a richer prompt tailored to one of four track types (Music / Instrument / One-shot / SFX), and wires it into the Gradio sampling UI.
Changes:
- Introduces
stable_audio_3/interface/reprompt.pywith hard-coded system prompts, an HF model loader/cache helper, an artifact-detection + retry loop, and arepromptentry point. - Adds a "Prompt Assistant" / "Download Prompt Assistant" button to the sampling UI that lazily downloads the model on first use, runs the reprompter, and extracts a
Length: N secondsvalue into the seconds-total slider. - Tightens
seconds_total_slidermax from512tosample_size // sample_rate, and strips two blank lines from theverbose.pymodule docstring.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| stable_audio_3/interface/reprompt.py | New module implementing the LLM-based prompt rewriter, classifier, model cache, and artifact-filter retry loop. |
| stable_audio_3/interface/diffusion_cond.py | Wires the reprompter into the Gradio UI: new button + handler, length extraction, and a reduced seconds_total upper bound. |
| stable_audio_3/verbose.py | Cosmetic: removes blank lines inside the module docstring. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| with gr.Row(visible = True): | ||
| # Timing controls | ||
| seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) | ||
| seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) |
Comment on lines
+539
to
+552
| def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)): | ||
| if not _reprompt_is_model_cached(_reprompt_model_id): | ||
| progress(0.0, desc="Downloading prompt assistant model…") | ||
| _reprompt_get_model(_reprompt_model_id) | ||
| return text, gr.update(), gr.update(value="Prompt Assistant") | ||
| _, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11) | ||
| m = _LENGTH_EXTRACT_RE.search(result) | ||
| if m: | ||
| max_seconds = sample_size // sample_rate | ||
| seconds = min(int(m.group(1)), max_seconds) | ||
| result = result[:m.start()] | ||
| else: | ||
| seconds = gr.update() | ||
| return result, seconds, gr.update() |
Comment on lines
+344
to
+350
| def is_model_cached(model_id): | ||
| if _cache["model_id"] == model_id: | ||
| return True | ||
| try: | ||
| return any(repo.repo_id == model_id for repo in scan_cache_dir().repos) | ||
| except Exception: | ||
| return False |
Comment on lines
+367
to
+377
| def _postprocess(prompt, raw, tags_mode=True): | ||
| source = raw.split("</think>", 1)[1].strip() if "</think>" in raw else raw | ||
| if source.startswith("- "): | ||
| source = source[2:] | ||
|
|
||
| result = source.strip() | ||
|
|
||
| result = re.sub(r'((?:^|(?<=\.\s))\w)', lambda m: m.group().upper(), result) | ||
| if result: | ||
| result = result[0].upper() + result[1:] | ||
| return raw, result |
Comment on lines
+392
to
+396
| def _has_artifacts(text: str) -> bool: | ||
| return ( | ||
| bool(_ARTIFACT_RE.search(text)) | ||
| or bool(_VOCAL_RE.search(text)) | ||
| or not bool(_LENGTH_RE.search(text)) |
Comment on lines
+539
to
+556
| def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)): | ||
| _reprompt_get_model(_reprompt_model_id) | ||
| _, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11) | ||
| m = _LENGTH_EXTRACT_RE.search(result) | ||
| if m: | ||
| max_seconds = sample_size // sample_rate | ||
| seconds = min(int(m.group(1)), max_seconds) | ||
| result = result[:m.start()] | ||
| else: | ||
| seconds = gr.update() | ||
| return result, seconds, gr.update(value="Prompt Assistant") | ||
|
|
||
| prompt_assistant_button.click( | ||
| fn=_prompt_assistant_or_download, | ||
| inputs=[prompt], | ||
| outputs=[prompt, seconds_total_slider, prompt_assistant_button], | ||
| concurrency_limit=1, | ||
| ) |
| with gr.Row(visible = True): | ||
| # Timing controls | ||
| seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) | ||
| seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) |
Comment on lines
+380
to
+384
| result = source.strip() | ||
|
|
||
| result = re.sub(r'((?:^|(?<=\.\s))\w)', lambda m: m.group().upper(), result) | ||
| if result: | ||
| result = result[0].upper() + result[1:] |
Comment on lines
+433
to
+438
| pool = _extract_examples(SYSTEM_PROMPTS["Music"]) | ||
| if not pool: | ||
| return "", "", None | ||
| example = random.choice(pool) | ||
| print(f"[Random] picked: {example}") | ||
| category = "music" |
Comment on lines
+346
to
+358
| def is_model_cached(model_id): | ||
| if _cache["model_id"] == model_id: | ||
| return True | ||
| try: | ||
| for repo in scan_cache_dir().repos: | ||
| if repo.repo_id != model_id: | ||
| continue | ||
| for rev in repo.revisions: | ||
| if any(f.file_name.endswith(_WEIGHT_SUFFIXES) for f in rev.files): | ||
| return True | ||
| except Exception: | ||
| pass | ||
| return False |
| r'\b(vocals?|singing|singer|female|male|voice|voices|chorus|rap|rapper|chant(ing)?|lyrics?)\b', | ||
| re.IGNORECASE, | ||
| ) | ||
| _LENGTH_RE = re.compile(r'\. Length: \d+ seconds\.?\s*$') |
Comment on lines
+375
to
+385
| def _postprocess(prompt, raw, tags_mode=True): | ||
| source = raw.split("</think>", 1)[1].strip() if "</think>" in raw else raw | ||
| if source.startswith("- "): | ||
| source = source[2:] | ||
|
|
||
| result = source.strip() | ||
|
|
||
| result = re.sub(r'((?:^|(?<=\.\s))\w)', lambda m: m.group().upper(), result) | ||
| if result: | ||
| result = result[0].upper() + result[1:] | ||
| return raw, result |
Comment on lines
+539
to
+551
| def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)): | ||
| if not _reprompt_is_model_cached(_reprompt_model_id): | ||
| _reprompt_get_model(_reprompt_model_id) | ||
| return text, gr.update(), gr.update(value="Prompt Assistant") | ||
| _, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11) | ||
| m = _LENGTH_EXTRACT_RE.search(result) | ||
| if m: | ||
| max_seconds = sample_size // sample_rate | ||
| seconds = min(int(m.group(1)), max_seconds) | ||
| result = result[:m.start()] | ||
| else: | ||
| seconds = gr.update() | ||
| return result, seconds, gr.update() |
Comment on lines
+346
to
+358
| def is_model_cached(model_id): | ||
| if _cache["model_id"] == model_id: | ||
| return True | ||
| try: | ||
| for repo in scan_cache_dir().repos: | ||
| if repo.repo_id != model_id: | ||
| continue | ||
| for rev in repo.revisions: | ||
| if any(f.file_name.endswith(_WEIGHT_SUFFIXES) for f in rev.files): | ||
| return True | ||
| except Exception: | ||
| pass | ||
| return False |
Comment on lines
+298
to
+302
| prompt_assistant_button = gr.Button( | ||
| "Prompt Assistant" if _reprompt_cached else "Download Prompt Assistant (~4.2 GB)", | ||
| scale=1 | ||
| ) | ||
| generate_button = gr.Button("Generate", variant='primary', scale=1) |
Comment on lines
+298
to
+302
| prompt_assistant_button = gr.Button( | ||
| "Prompt Assistant" if _reprompt_cached else "Download Prompt Assistant (~4.2 GB)", | ||
| scale=1 | ||
| ) | ||
| generate_button = gr.Button("Generate", variant='primary', scale=1) |
| with gr.Row(visible = True): | ||
| # Timing controls | ||
| seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) | ||
| seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.