Skip to content

Add reprompter#30

Merged
mhrice merged 8 commits into
mainfrom
reprompter
May 19, 2026
Merged

Add reprompter#30
mhrice merged 8 commits into
mainfrom
reprompter

Conversation

@mhrice
Copy link
Copy Markdown
Collaborator

@mhrice mhrice commented May 19, 2026

No description provided.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a "Prompt Assistant" (re-prompter) feature that uses a Qwen LLM to classify a user's free-form input and rewrite it into a richer prompt tailored to one of four track types (Music / Instrument / One-shot / SFX), and wires it into the Gradio sampling UI.

Changes:

  • Introduces stable_audio_3/interface/reprompt.py with hard-coded system prompts, an HF model loader/cache helper, an artifact-detection + retry loop, and a reprompt entry point.
  • Adds a "Prompt Assistant" / "Download Prompt Assistant" button to the sampling UI that lazily downloads the model on first use, runs the reprompter, and extracts a Length: N seconds value into the seconds-total slider.
  • Tightens seconds_total_slider max from 512 to sample_size // sample_rate, and strips two blank lines from the verbose.py module docstring.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
stable_audio_3/interface/reprompt.py New module implementing the LLM-based prompt rewriter, classifier, model cache, and artifact-filter retry loop.
stable_audio_3/interface/diffusion_cond.py Wires the reprompter into the Gradio UI: new button + handler, length extraction, and a reduced seconds_total upper bound.
stable_audio_3/verbose.py Cosmetic: removes blank lines inside the module docstring.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

with gr.Row(visible = True):
# Timing controls
seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
Comment on lines +539 to +552
def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)):
if not _reprompt_is_model_cached(_reprompt_model_id):
progress(0.0, desc="Downloading prompt assistant model…")
_reprompt_get_model(_reprompt_model_id)
return text, gr.update(), gr.update(value="Prompt Assistant")
_, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11)
m = _LENGTH_EXTRACT_RE.search(result)
if m:
max_seconds = sample_size // sample_rate
seconds = min(int(m.group(1)), max_seconds)
result = result[:m.start()]
else:
seconds = gr.update()
return result, seconds, gr.update()
Comment thread stable_audio_3/interface/reprompt.py Outdated
Comment on lines +344 to +350
def is_model_cached(model_id):
if _cache["model_id"] == model_id:
return True
try:
return any(repo.repo_id == model_id for repo in scan_cache_dir().repos)
except Exception:
return False
Comment on lines +367 to +377
def _postprocess(prompt, raw, tags_mode=True):
source = raw.split("</think>", 1)[1].strip() if "</think>" in raw else raw
if source.startswith("- "):
source = source[2:]

result = source.strip()

result = re.sub(r'((?:^|(?<=\.\s))\w)', lambda m: m.group().upper(), result)
if result:
result = result[0].upper() + result[1:]
return raw, result
Comment on lines +392 to +396
def _has_artifacts(text: str) -> bool:
return (
bool(_ARTIFACT_RE.search(text))
or bool(_VOCAL_RE.search(text))
or not bool(_LENGTH_RE.search(text))
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Comment on lines +539 to +556
def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)):
_reprompt_get_model(_reprompt_model_id)
_, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11)
m = _LENGTH_EXTRACT_RE.search(result)
if m:
max_seconds = sample_size // sample_rate
seconds = min(int(m.group(1)), max_seconds)
result = result[:m.start()]
else:
seconds = gr.update()
return result, seconds, gr.update(value="Prompt Assistant")

prompt_assistant_button.click(
fn=_prompt_assistant_or_download,
inputs=[prompt],
outputs=[prompt, seconds_total_slider, prompt_assistant_button],
concurrency_limit=1,
)
with gr.Row(visible = True):
# Timing controls
seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
Comment on lines +380 to +384
result = source.strip()

result = re.sub(r'((?:^|(?<=\.\s))\w)', lambda m: m.group().upper(), result)
if result:
result = result[0].upper() + result[1:]
Comment on lines +433 to +438
pool = _extract_examples(SYSTEM_PROMPTS["Music"])
if not pool:
return "", "", None
example = random.choice(pool)
print(f"[Random] picked: {example}")
category = "music"
Comment on lines +346 to +358
def is_model_cached(model_id):
if _cache["model_id"] == model_id:
return True
try:
for repo in scan_cache_dir().repos:
if repo.repo_id != model_id:
continue
for rev in repo.revisions:
if any(f.file_name.endswith(_WEIGHT_SUFFIXES) for f in rev.files):
return True
except Exception:
pass
return False
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.

r'\b(vocals?|singing|singer|female|male|voice|voices|chorus|rap|rapper|chant(ing)?|lyrics?)\b',
re.IGNORECASE,
)
_LENGTH_RE = re.compile(r'\. Length: \d+ seconds\.?\s*$')
Comment on lines +375 to +385
def _postprocess(prompt, raw, tags_mode=True):
source = raw.split("</think>", 1)[1].strip() if "</think>" in raw else raw
if source.startswith("- "):
source = source[2:]

result = source.strip()

result = re.sub(r'((?:^|(?<=\.\s))\w)', lambda m: m.group().upper(), result)
if result:
result = result[0].upper() + result[1:]
return raw, result
Comment on lines +539 to +551
def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)):
if not _reprompt_is_model_cached(_reprompt_model_id):
_reprompt_get_model(_reprompt_model_id)
return text, gr.update(), gr.update(value="Prompt Assistant")
_, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11)
m = _LENGTH_EXTRACT_RE.search(result)
if m:
max_seconds = sample_size // sample_rate
seconds = min(int(m.group(1)), max_seconds)
result = result[:m.start()]
else:
seconds = gr.update()
return result, seconds, gr.update()
Comment on lines +346 to +358
def is_model_cached(model_id):
if _cache["model_id"] == model_id:
return True
try:
for repo in scan_cache_dir().repos:
if repo.repo_id != model_id:
continue
for rev in repo.revisions:
if any(f.file_name.endswith(_WEIGHT_SUFFIXES) for f in rev.files):
return True
except Exception:
pass
return False
Comment on lines +298 to +302
prompt_assistant_button = gr.Button(
"Prompt Assistant" if _reprompt_cached else "Download Prompt Assistant (~4.2 GB)",
scale=1
)
generate_button = gr.Button("Generate", variant='primary', scale=1)
Comment on lines +298 to +302
prompt_assistant_button = gr.Button(
"Prompt Assistant" if _reprompt_cached else "Download Prompt Assistant (~4.2 GB)",
scale=1
)
generate_button = gr.Button("Generate", variant='primary', scale=1)
with gr.Row(visible = True):
# Timing controls
seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
@mhrice mhrice merged commit c72fba8 into main May 19, 2026
5 checks passed
@mhrice mhrice deleted the reprompter branch May 19, 2026 23:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants