diff --git a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/DemoEngine.kt b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/DemoEngine.kt new file mode 100644 index 0000000..91707ff --- /dev/null +++ b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/DemoEngine.kt @@ -0,0 +1,141 @@ +package de.tuda.stg.securecoder.plugin.engine + +import de.tuda.stg.securecoder.engine.Engine +import de.tuda.stg.securecoder.engine.file.edit.Changes +import de.tuda.stg.securecoder.engine.file.edit.Changes.SearchReplace +import de.tuda.stg.securecoder.engine.file.edit.Changes.SearchedText +import de.tuda.stg.securecoder.engine.stream.ProposalId +import de.tuda.stg.securecoder.engine.stream.StreamEvent +import de.tuda.stg.securecoder.engine.workflow.GuardianExecutor.GuardianResult +import de.tuda.stg.securecoder.filesystem.FileSystem +import de.tuda.stg.securecoder.guardian.File +import de.tuda.stg.securecoder.guardian.Location +import de.tuda.stg.securecoder.guardian.RuleRef +import de.tuda.stg.securecoder.guardian.Violation +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.toList + +class DemoEngine : Engine { + + override suspend fun run( + prompt: String, + filesystem: FileSystem, + onEvent: suspend (StreamEvent) -> Unit, + context: Engine.Context? + ): Engine.EngineResult { + print(filesystem.allFiles().toList().map { it.name() }) + + val targetFile = "file:///Users/david/IdeaProjects/untitled/src/ArchiveUtils.kt" + val unsafeProposalId = ProposalId.newId() + + val unsafeCode = """ + fun unzip(zipFile: File, destDir: File) { + ZipInputStream(FileInputStream(zipFile)).use { zis -> + var entry = zis.nextEntry + while (entry != null) { + val newFile = File(destDir, entry.name) + if (entry.isDirectory) { + newFile.mkdirs() + } else { + newFile.parentFile.mkdirs() + FileOutputStream(newFile).use { fos -> + zis.copyTo(fos) + } + } + entry = zis.nextEntry + } + } + } + """.trimIndent() + + val unsafeChanges = Changes( + searchReplaces = listOf( + SearchReplace( + fileName = targetFile, + searchedText = SearchedText.append(), + replaceText = unsafeCode + ) + ) + ) + delay(4800) + + onEvent(StreamEvent.ProposedEdits(unsafeProposalId, unsafeChanges)) + onEvent(StreamEvent.ValidationStarted(unsafeProposalId)) + + delay(2800) + + val zipSlipViolation = Violation( + rule = RuleRef( + id = "S6096", + name = "Zip Slip Vulnerability", + description = "Extracting archives without validating the destination path can allow arbitrary file overwrite.", + cwe = "CWE-22", + owasp = "A01:2021-Broken Access Control" + ), + message = "Unsafe zip extraction. The code uses 'entry.name' directly without verifying the resulting path is within 'destDir'.", + location = Location(targetFile, 5, 6), + hardReject = true, + raw = "val newFile = File(destDir, entry.name)" + ) + + val guardianResult = GuardianResult( + violations = listOf(zipSlipViolation), + files = listOf(File(targetFile, unsafeCode)) + ) + + onEvent(StreamEvent.GuardianWarning(unsafeProposalId, guardianResult)) + + // Simulate LLM "thinking" about the fix + delay(800) + + // --- 3. Generate Second Proposal (Safe Fix) --- + val safeProposalId = ProposalId.newId() + + val safeCode = """ + fun unzip(zipFile: File, destDir: File) { + val destDirPath = destDir.canonicalPath + ZipInputStream(FileInputStream(zipFile)).use { zis -> + var entry = zis.nextEntry + while (entry != null) { + val newFile = File(destDir, entry.name) + + if (!newFile.canonicalPath.startsWith(destDirPath)) { + throw SecurityException("Zip entry is outside of the target dir: " + entry.name) + } + + if (entry.isDirectory) { + newFile.mkdirs() + } else { + newFile.parentFile.mkdirs() + FileOutputStream(newFile).use { fos -> + zis.copyTo(fos) + } + } + entry = zis.nextEntry + } + } + } + """.trimIndent() + + val safeChanges = Changes( + searchReplaces = listOf( + SearchReplace( + fileName = targetFile, + searchedText = SearchedText.append(), + replaceText = safeCode + ) + ) + ) + delay(3200) + + onEvent(StreamEvent.ProposedEdits(safeProposalId, safeChanges)) + onEvent(StreamEvent.ValidationStarted(safeProposalId)) + + // Simulate successful validation + delay(1200) + onEvent(StreamEvent.ValidationSucceeded(safeProposalId)) + + // --- 4. Return Final Result --- + return Engine.EngineResult.Success(safeChanges) + } +} \ No newline at end of file diff --git a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/EngineRunnerService.kt b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/EngineRunnerService.kt index fc6aca7..5c309a8 100644 --- a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/EngineRunnerService.kt +++ b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/EngineRunnerService.kt @@ -7,6 +7,8 @@ import com.intellij.openapi.fileEditor.FileEditorManager import com.intellij.openapi.project.Project import com.intellij.platform.ide.progress.withBackgroundProgress import de.tuda.stg.securecoder.engine.Engine +import de.tuda.stg.securecoder.engine.guardian.LlmGuardian +import de.tuda.stg.securecoder.engine.llm.LlmClient import de.tuda.stg.securecoder.engine.llm.OllamaClient import de.tuda.stg.securecoder.engine.llm.OpenRouterClient import de.tuda.stg.securecoder.engine.workflow.WorkflowEngine @@ -18,6 +20,7 @@ import de.tuda.stg.securecoder.plugin.engine.event.EngineResultMapper import de.tuda.stg.securecoder.plugin.engine.event.StreamEventMapper import de.tuda.stg.securecoder.plugin.engine.event.UiStreamEvent import de.tuda.stg.securecoder.plugin.settings.SecureCoderSettingsState +import de.tuda.stg.securecoder.plugin.settings.SecureCoderSettingsState.LlmConfig import de.tuda.stg.securecoder.plugin.settings.SecureCoderSettingsState.LlmProvider import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -35,16 +38,24 @@ class EngineRunnerService( val close: () -> Unit, ) - private fun buildEngine(): EngineHandle { - val settings = settings.state - val llm = when (settings.llmProvider) { + private fun buildLlmClient( + config: LlmConfig, + clientName: String + ): LlmClient { + return when (config.provider) { LlmProvider.OPENROUTER -> OpenRouterClient( - settings.openrouterApiKey, - settings.openrouterModel, - "securecoder" + config.openrouterApiKey, + config.openrouterModel, + clientName ) - LlmProvider.OLLAMA -> OllamaClient(settings.ollamaModel) + LlmProvider.OLLAMA -> OllamaClient(config.ollamaModel) } + } + + private fun buildEngine(): EngineHandle { + val settings = settings.state + val llm = buildLlmClient(settings.mainLlm, "securecoder") + val guardianLlmConfig = if (settings.useMainLlmForGuardian) settings.mainLlm else settings.guardianLlm val enricher = if (settings.enablePromptEnriching) { EnricherClient(settings.enricherUrl) @@ -53,10 +64,11 @@ class EngineRunnerService( } val guardians = listOfNotNull( if (settings.enableDummyGuardian) DummyGuardian(sleepMillis = 2000) else null, - if (settings.enableCodeQLGuardian) CodeQLGuardian(settings.codeqlBinary) else null + if (settings.enableCodeQLGuardian) CodeQLGuardian(settings.codeqlBinary) else null, + if (settings.enableLlmGuardian) LlmGuardian(buildLlmClient(guardianLlmConfig, "securecoder guardian")) else null ) - //return EngineHandle(DummyAgentStreamer(), {}) + //return EngineHandle(DemoEngine(), {}) return EngineHandle( WorkflowEngine(enricher, llm, guardians), { diff --git a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/IntelliJProjectFileSystem.kt b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/IntelliJProjectFileSystem.kt index 10a5781..54b69d6 100644 --- a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/IntelliJProjectFileSystem.kt +++ b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/engine/IntelliJProjectFileSystem.kt @@ -61,6 +61,9 @@ class IntelliJProjectFileSystem( return@writeAction } val parentVf = VfsUtil.createDirectories(parentDirPath) + if (parentVf == null) { + throw IOException("Could not create parent directory for $parentDirPath") + } parentVf.refresh(false, true) vf = parentVf.findChild(ioFile.name) ?: parentVf.createChildData(this, ioFile.name) } else if (vf.isDirectory) { diff --git a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsConfigurable.kt b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsConfigurable.kt index 7bc59a8..98e151a 100644 --- a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsConfigurable.kt +++ b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsConfigurable.kt @@ -17,10 +17,13 @@ import com.intellij.ui.EnumComboBoxModel import com.intellij.ui.awt.RelativePoint import com.intellij.ui.components.JBCheckBox import com.intellij.ui.dsl.builder.* +import com.intellij.ui.layout.and +import com.intellij.ui.layout.not import com.intellij.ui.layout.selected import com.intellij.ui.layout.selectedValueMatches import de.tuda.stg.securecoder.guardian.CodeQLRunner import de.tuda.stg.securecoder.plugin.SecureCoderBundle +import de.tuda.stg.securecoder.plugin.settings.SecureCoderSettingsState.LlmConfig import de.tuda.stg.securecoder.plugin.settings.SecureCoderSettingsState.LlmProvider import java.io.IOException import java.nio.file.Path @@ -32,142 +35,209 @@ class SecureCoderSettingsConfigurable : BoundConfigurable(SecureCoderBundle.mess private val settings = service() override fun createPanel() = panel { - group(SecureCoderBundle.message("settings.group.llmProvider")) { + createLlmConfigSection( + SecureCoderBundle.message("settings.group.llmProvider"), + settings.state.mainLlm + ).bottomGap(BottomGap.MEDIUM) + + group(SecureCoderBundle.message("settings.group.security")) { + group(SecureCoderBundle.message("settings.group.enricher")) { + val enricher = JBCheckBox(SecureCoderBundle.message("settings.enricher.enabled")) + row { + cell(enricher).bindSelected(settings.state::enablePromptEnriching) + } + row(SecureCoderBundle.message("settings.enricher.url")) { + textField() + .bindText(settings.state::enricherUrl) + .columns(COLUMNS_MEDIUM) + }.enabledIf(enricher.selected) + }.bottomGap(BottomGap.MEDIUM) + group(SecureCoderBundle.message("settings.group.guardians")) { + group(SecureCoderBundle.message("settings.group.guardian.dummy")) { + row { + checkBox(SecureCoderBundle.message("settings.guardian.dummy")) + .bindSelected(settings.state::enableDummyGuardian) + } + }.bottomGap(BottomGap.MEDIUM) + group(SecureCoderBundle.message("settings.group.guardian.codeql")) { + val codeql = JBCheckBox(SecureCoderBundle.message("settings.guardian.codeql.enable")) + row { + cell(codeql).bindSelected(settings.state::enableCodeQLGuardian) + } + row(SecureCoderBundle.message("settings.codeql.binary")) { + val codeqlPathCell = textFieldWithBrowseButton( + FileChooserDescriptorFactory.singleFile() + ) + .bindText(settings.state::codeqlBinary) + .columns(COLUMNS_MEDIUM) + + val codeqlPathField = codeqlPathCell.component + + button(SecureCoderBundle.message("settings.codeql.test")) { event -> + val loadingBalloon = JBPopupFactory.getInstance() + .createHtmlTextBalloonBuilder( + SecureCoderBundle.message("settings.codeql.checking"), + AnimatedIcon.Default.INSTANCE, + null, null, null + ) + .createBalloon() + + loadingBalloon.show( + RelativePoint.getSouthOf(event.source as JComponent), + Balloon.Position.below + ) + + ApplicationManager.getApplication().executeOnPooledThread { + val bin = settings.state.codeqlBinary.ifBlank { "codeql" } + val (message, type) = try { + SecureCoderBundle.message( + "settings.codeql.found", + CodeQLRunner(bin).getToolVersion() + ) to MessageType.INFO + } catch (e: Exception) { + SecureCoderBundle.message( + "settings.codeql.error", + (e.message ?: e.toString()) + ) to MessageType.ERROR + } + + ApplicationManager.getApplication().invokeLater( + { + loadingBalloon.hide() + JBPopupFactory.getInstance() + .createHtmlTextBalloonBuilder(message, type, null) + .createBalloon() + .show( + RelativePoint.getSouthOf(event.source as JComponent), + Balloon.Position.below + ) + }, + ModalityState.any() + ) + } + } + + button(SecureCoderBundle.message("settings.codeql.download")) { event -> + val button = event.source as JButton + button.isEnabled = false + + ProgressManager.getInstance().run(object : Task.Backgroundable( + null, + SecureCoderBundle.message("settings.codeql.installing"), + true + ) { + private var resultPath: Path? = null + private var exception: Exception? = null + + override fun run(indicator: ProgressIndicator) { + val installer = CodeQLInstaller() + try { + resultPath = installer.getOrInstallCodeQL(indicator) + } catch (e: IOException) { + exception = e + } + } + + override fun onSuccess() { + button.isEnabled = true + when { + exception != null -> { + JBPopupFactory.getInstance() + .createHtmlTextBalloonBuilder( + SecureCoderBundle.message( + "settings.codeql.install.failed", + exception!!.message ?: exception!!.toString() + ), + MessageType.ERROR, + null + ) + .createBalloon() + .show(RelativePoint.getSouthOf(button), Balloon.Position.below) + } + resultPath != null -> { + val path = resultPath.toString() + settings.state.codeqlBinary = path + codeqlPathField.text = path + JBPopupFactory.getInstance() + .createHtmlTextBalloonBuilder( + SecureCoderBundle.message("settings.codeql.downloaded"), + MessageType.INFO, + null + ) + .createBalloon() + .show(RelativePoint.getSouthOf(button), Balloon.Position.below) + } + } + } + + override fun onCancel() { + button.isEnabled = true + } + }) + } + }.enabledIf(codeql.selected) + }.bottomGap(BottomGap.MEDIUM) + + group(SecureCoderBundle.message("settings.group.guardian.llm")) { + val llmGuardian = JBCheckBox(SecureCoderBundle.message("settings.guardian.llm.enable")) + val useMainLlmForGuardian = JBCheckBox(SecureCoderBundle.message("settings.guardian.llm.use.main")) + row { + cell(llmGuardian).bindSelected(settings.state::enableLlmGuardian) + } + row { + cell(useMainLlmForGuardian).bindSelected(settings.state::useMainLlmForGuardian) + }.enabledIf(llmGuardian.selected) + + createLlmConfigSection( + SecureCoderBundle.message("settings.group.llmGuardian"), + settings.state.guardianLlm + ).enabledIf(llmGuardian.selected.and(useMainLlmForGuardian.selected.not())) + } + } + } + } + + + override fun apply() { + super.apply() + ApplicationManager.getApplication() + .messageBus + .syncPublisher(SecureCoderSettingsState.topic) + .settingsChanged(settings.state) + } + + private fun Panel.createLlmConfigSection( + title: String, + config: LlmConfig, + ): Row { + return group(title) { val providerBox = ComboBox(EnumComboBoxModel(LlmProvider::class.java)) row(SecureCoderBundle.message("settings.provider")) { val providerBinding: MutableProperty = MutableProperty( - { settings.state.llmProvider }, - { settings.state.llmProvider = it ?: LlmProvider.OLLAMA } + { config.provider }, + { config.provider = it ?: LlmProvider.OLLAMA } ) cell(providerBox).bindItem(providerBinding) } rowsRange { row(SecureCoderBundle.message("settings.ollama.model")) { textField() - .bindText(settings.state::ollamaModel) + .bindText(config::ollamaModel) .columns(COLUMNS_MEDIUM) } }.visibleIf(providerBox.selectedValueMatches { it == LlmProvider.OLLAMA }) rowsRange { row(SecureCoderBundle.message("settings.openrouter.api.key")) { passwordField() - .bindText(settings.state::openrouterApiKey) + .bindText(config::openrouterApiKey) .columns(COLUMNS_MEDIUM) } row(SecureCoderBundle.message("settings.openrouter.model")) { textField() - .bindText(settings.state::openrouterModel) + .bindText(config ::openrouterModel) .columns(COLUMNS_MEDIUM) } }.visibleIf(providerBox.selectedValueMatches { it == LlmProvider.OPENROUTER }) } - group(SecureCoderBundle.message("settings.group.security")) { - val enricher = JBCheckBox(SecureCoderBundle.message("settings.enricher.enabled")) - row { - cell(enricher).bindSelected(settings.state::enablePromptEnriching) - } - row(SecureCoderBundle.message("settings.enricher.url")) { - textField() - .bindText(settings.state::enricherUrl) - .columns(COLUMNS_MEDIUM) - }.enabledIf(enricher.selected).bottomGap(BottomGap.SMALL) - row { - checkBox(SecureCoderBundle.message("settings.guardian.dummy")).bindSelected(settings.state::enableDummyGuardian) - } - val codeql = JBCheckBox(SecureCoderBundle.message("settings.guardian.codeql.enable")) - row { - cell(codeql).bindSelected(settings.state::enableCodeQLGuardian) - } - row(SecureCoderBundle.message("settings.codeql.binary")) { - val codeqlPathCell = textFieldWithBrowseButton( - FileChooserDescriptorFactory.singleFile() - ) - .bindText(settings.state::codeqlBinary) - .columns(COLUMNS_MEDIUM) - val codeqlPathField = codeqlPathCell.component - button(SecureCoderBundle.message("settings.codeql.test")) { event -> - val loadingBalloon = JBPopupFactory.getInstance() - .createHtmlTextBalloonBuilder(SecureCoderBundle.message("settings.codeql.checking"), AnimatedIcon.Default.INSTANCE, null, null, null) - .createBalloon() - - loadingBalloon.show( - RelativePoint.getSouthOf(event.source as JComponent), - Balloon.Position.below - ) - ApplicationManager.getApplication().executeOnPooledThread { - val bin = settings.state.codeqlBinary.ifBlank { "codeql" } - val (message, type) = try { - SecureCoderBundle.message("settings.codeql.found", CodeQLRunner(bin).getToolVersion()) to MessageType.INFO - } catch (e: Exception) { - SecureCoderBundle.message("settings.codeql.error", (e.message ?: e.toString())) to MessageType.ERROR - } - ApplicationManager.getApplication().invokeLater( - { - loadingBalloon.hide() - val balloon = JBPopupFactory.getInstance() - .createHtmlTextBalloonBuilder(message, type, null) - .createBalloon() - balloon.show( - RelativePoint.getSouthOf(event.source as JComponent), - Balloon.Position.below - ) - }, - ModalityState.any() - ) - } - } - button(SecureCoderBundle.message("settings.codeql.download")) { event -> - val button = event.source as JButton - button.setEnabled(false) - ProgressManager.getInstance().run(object : Task.Backgroundable(null, SecureCoderBundle.message("settings.codeql.installing"), true) { - private var resultPath: Path? = null - private var exception: Exception? = null - - override fun run(indicator: ProgressIndicator) { - val installer = CodeQLInstaller() - try { - resultPath = installer.getOrInstallCodeQL(indicator) - } catch (e: IOException) { - exception = e - } - } - - override fun onSuccess() { - button.setEnabled(true) - if (exception != null) { - JBPopupFactory.getInstance() - .createHtmlTextBalloonBuilder( - SecureCoderBundle.message("settings.codeql.install.failed", exception!!.message ?: exception!!.toString()), - MessageType.ERROR, - null - ) - .createBalloon() - .show(RelativePoint.getSouthOf(button), Balloon.Position.below) - } else if (resultPath != null) { - val path = resultPath.toString() - settings.state.codeqlBinary = path - codeqlPathField.text = path - JBPopupFactory.getInstance() - .createHtmlTextBalloonBuilder(SecureCoderBundle.message("settings.codeql.downloaded"), MessageType.INFO, null) - .createBalloon() - .show(RelativePoint.getSouthOf(button), Balloon.Position.below) - } - } - - override fun onCancel() { - button.setEnabled(true) - } - }) - } - }.enabledIf(codeql.selected) - } - } - - override fun apply() { - super.apply() - ApplicationManager.getApplication() - .messageBus - .syncPublisher(SecureCoderSettingsState.topic) - .settingsChanged(settings.state) } } diff --git a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsState.kt b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsState.kt index b6c4516..2aaea1c 100644 --- a/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsState.kt +++ b/app/intellij-plugin/src/main/java/de/tuda/stg/securecoder/plugin/settings/SecureCoderSettingsState.kt @@ -21,18 +21,27 @@ class SecureCoderSettingsState : PersistentStateComponent) : Error } sealed interface Success : MatchResult { diff --git a/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/guardian/LLMModels.kt b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/guardian/LLMModels.kt new file mode 100644 index 0000000..92ad0c1 --- /dev/null +++ b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/guardian/LLMModels.kt @@ -0,0 +1,49 @@ +package de.tuda.stg.securecoder.engine.guardian + +import de.tuda.stg.securecoder.engine.llm.LLMDescription +import de.tuda.stg.securecoder.guardian.AnalyzeResponse +import de.tuda.stg.securecoder.guardian.Location +import de.tuda.stg.securecoder.guardian.RuleRef +import de.tuda.stg.securecoder.guardian.Violation +import kotlinx.serialization.Serializable + +@Serializable +@LLMDescription("Response containing security analysis results") +data class LlmAnalyzeResponse( + @LLMDescription("List of security findings discovered during analysis") + val findings: List = emptyList() +) { + fun toApi(): AnalyzeResponse = AnalyzeResponse( + violations = findings.map { it.toApi() } + ) + + @Serializable + @LLMDescription("Details of a single security finding") + data class Finding( + @LLMDescription("Line number where the issue starts, null if not applicable") + val shortName: String, + + @LLMDescription("Brief description of the security issue") + val description: String, + + @LLMDescription("The name of the file where the issue was found") + val fileName: String, + + @LLMDescription("Line number where the issue starts, null if not applicable") + val line: Int? = null, + + @LLMDescription("Indicates whether this finding make it impossible to apply the changes even with manuel approval") + val hardReject: Boolean, + + @LLMDescription("The estimated likelihood that this finding is a true positive (e.g., High, Medium, Low)") + val confidence: String? + ) { + fun toApi(): Violation = Violation( + rule = RuleRef("llm", shortName ), + message = description, + location = Location(fileName, line), + hardReject = hardReject, + confidence = confidence + ) + } +} \ No newline at end of file diff --git a/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/guardian/LlmGuardian.kt b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/guardian/LlmGuardian.kt new file mode 100644 index 0000000..7e60d5b --- /dev/null +++ b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/guardian/LlmGuardian.kt @@ -0,0 +1,61 @@ +package de.tuda.stg.securecoder.engine.guardian + +import de.tuda.stg.securecoder.engine.llm.ChatMessage +import de.tuda.stg.securecoder.engine.llm.LlmClient +import de.tuda.stg.securecoder.engine.llm.chatStructured +import de.tuda.stg.securecoder.guardian.AnalyzeRequest +import de.tuda.stg.securecoder.guardian.AnalyzeResponse +import de.tuda.stg.securecoder.guardian.Guardian + +class LlmGuardian( + private val client: LlmClient, + private val systemPrompt: String = DEFAULT_SYSTEM_PROMPT, +) : Guardian { + override suspend fun run(req: AnalyzeRequest): AnalyzeResponse { + val messages = buildMessages(req) + val llmResp = client.chatStructured( + messages = messages, + params = LlmClient.GenerationParams( + temperature = 0.0 + ) + ) + return llmResp.toApi() + } + + private fun buildMessages(req: AnalyzeRequest): List { + return listOf( + ChatMessage(ChatMessage.Role.System, systemPrompt), + ChatMessage(ChatMessage.Role.User, buildString { + appendLine("You are given a set of source files to analyze for security issues.") + appendLine("Only consider the provided files; do not assume hidden context.") + appendLine() + req.files.forEach { file -> + appendLine("===== FILE: ${file.name} =====") + appendLine(withLineNumbers(file.content)) + appendLine("===== END FILE: ${file.name} =====") + appendLine() + } + appendLine("Return your analysis strictly using the structured schema provided by the tool.") + }) + ) + } + + private fun withLineNumbers(text: String): String = buildString { + text.lineSequence().forEachIndexed { idx, line -> + append(idx + 1) + append(": ") + append(line) + append('\n') + } + }.removeSuffix("\n") + + companion object { + private const val DEFAULT_SYSTEM_PROMPT: String = + """ + You are SecureCoder Guardian. Analyze code for security vulnerabilities. + Use conservative judgment; highlight clear issues or suspicious patterns. + Provide precise file and line locations when possible. If unsure, leave + optional fields null. Do not include any prose outside the structured result. + """ + } +} diff --git a/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/KxJsonSchemaFormat.kt b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/KxJsonSchemaFormat.kt index b71f180..636c378 100644 --- a/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/KxJsonSchemaFormat.kt +++ b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/KxJsonSchemaFormat.kt @@ -13,6 +13,7 @@ import kotlinx.serialization.json.JsonObjectBuilder import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonArray import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.putJsonArray @OptIn(ExperimentalSerializationApi::class) class KxJsonSchemaFormat { @@ -24,7 +25,7 @@ class KxJsonSchemaFormat { if (!seen.add(key)) { throw IllegalStateException("Recursive type detected: $key") } - val jsonType = when (desc.kind) { + var jsonType = when (desc.kind) { PrimitiveKind.BOOLEAN -> type("boolean") PrimitiveKind.BYTE, PrimitiveKind.SHORT, PrimitiveKind.INT, PrimitiveKind.LONG -> type("integer") PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> type("number") @@ -65,12 +66,31 @@ class KxJsonSchemaFormat { } seen.remove(key) if (desc.isNullable) { - throw IllegalStateException("Nullable types are not supported") + jsonType = makeNullable(jsonType) } val selfDesc = getDescription(desc.annotations) return if (selfDesc != null) addDescription(jsonType, selfDesc) else jsonType } + private fun makeNullable(schema: JsonObject): JsonObject { + val type = schema["type"] + if (type is JsonPrimitive && type.isString) { + return buildJsonObject { + schema.forEach(::put) + putJsonArray("type") { + add(type) + add(JsonPrimitive("null")) + } + } + } + return buildJsonObject { + put("anyOf", buildJsonArray { + add(schema) + add(type("null")) + }) + } + } + private fun type(name: String, builderAction: JsonObjectBuilder.() -> Unit = {}): JsonObject = buildJsonObject { put("type", JsonPrimitive(name)) @@ -88,7 +108,7 @@ class KxJsonSchemaFormat { private fun addDescription(obj: JsonObject, text: String): JsonObject = buildJsonObject { - obj.forEach { (k, v) -> put(k, v) } + obj.forEach(::put) put("description", JsonPrimitive(text)) } } diff --git a/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/OpenRouterClient.kt b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/OpenRouterClient.kt index 4149a82..e685e5d 100644 --- a/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/OpenRouterClient.kt +++ b/engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/OpenRouterClient.kt @@ -21,6 +21,8 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonArray import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import org.slf4j.LoggerFactory @@ -54,6 +56,7 @@ class OpenRouterClient ( val stream: Boolean = false, val metadata: JsonObject = buildJsonObject {}, @SerialName("response_format") val responseFormat: JsonObject? = null, + val provider: JsonObject? = null, ) @Serializable @@ -138,7 +141,10 @@ class OpenRouterClient ( messages = mapped, temperature = params.temperature, maxTokens = params.maxTokens, - responseFormat = responseFormat + responseFormat = responseFormat, + provider = buildJsonObject { + put("require_parameters", JsonPrimitive(true)) + } ) val obj = performRequest(req) val content = obj.choices.firstOrNull()?.message?.content @@ -146,7 +152,7 @@ class OpenRouterClient ( return try { json.decodeFromString(serializer, content) } catch (e: Exception) { - throw RuntimeException("Failed to decode OpenRouter structured content into ${'$'}{serializer.descriptor.serialName}. Content: ${'$'}content", e) + throw RuntimeException("Failed to decode OpenRouter structured content into ${serializer.descriptor.serialName}. Content: $content", e) } }