Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions engine/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ dependencies {
implementation(libs.ktor.client.java)
implementation(libs.ktor.client.content.negotiation)
implementation(libs.ktor.serialization.json)
testImplementation(kotlin("test"))
testImplementation(libs.kotlinx.serialization.json)
testImplementation(libs.kotlinx.coroutines.core)
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package de.tuda.stg.securecoder.engine.llm
package de.tuda.stg.securecoder.engine.file

import de.tuda.stg.securecoder.filesystem.FileSystem


object FilesInContextPromptBuilder {
suspend fun build(files: Iterable<FileSystem.File>, edit: Boolean = false) = buildString {
if (files.count() == 0) {
appendLine("You have no files in the context.")
appendLine("If you saw files they are only part of the prompt and dont exists yet!")
if (edit) {
appendLine("You may create new files (keep in mind that searchedText needs to be empty in this case!)")
appendLine("You may create new files (keep in mind that searched text needs to be empty in this case!)")
}
return@buildString
}
Expand All @@ -20,4 +19,4 @@ object FilesInContextPromptBuilder {
appendLine("<<<END FILE>>>")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ import de.tuda.stg.securecoder.engine.llm.ChatMessage.Role
import de.tuda.stg.securecoder.engine.llm.LlmClient
import de.tuda.stg.securecoder.filesystem.FileSystem
import de.tuda.stg.securecoder.engine.llm.ChatExchange
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlin.collections.plusAssign

class EditFilesLlmWrapper(
Expand Down Expand Up @@ -76,7 +73,7 @@ class EditFilesLlmWrapper(
appendLine("It violated the required format.")
appendLine("Errors:")
messages.forEach { appendLine(it) }
appendLine("Respond again with ONLY <EDITN> blocks that strictly follow the rules. Do NOT include prose, markdown, or explanations.")
appendLine("Respond again with ONLY edit blocks that strictly follow the rules. Do NOT include prose, markdown, or explanations.")
appendLine("IMPORTANT: Resend the COMPLETE set of edits you intend to apply from your previous message")
}
}
Expand Down Expand Up @@ -155,12 +152,7 @@ class EditFilesLlmWrapper(
return ParseResult.Err(allErrors)
}

val seen = HashSet<Pair<String, String>>()
val deduped = results.filter { sr ->
seen.add(sr.fileName to sr.searchedText.text)
}

return ParseResult.Ok(Changes(deduped))
return ParseResult.Ok(Changes(results))
}

private fun getTextByXMLTag(container: String, tag: String): String? {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package de.tuda.stg.securecoder.engine.file.edit

import de.tuda.stg.securecoder.engine.file.edit.Changes.SearchedText
import de.tuda.stg.securecoder.engine.llm.ChatMessage
import de.tuda.stg.securecoder.engine.llm.ChatMessage.Role
import de.tuda.stg.securecoder.engine.llm.LlmClient
import de.tuda.stg.securecoder.engine.llm.LLMDescription
import de.tuda.stg.securecoder.engine.llm.chatStructured
import de.tuda.stg.securecoder.filesystem.FileSystem
import de.tuda.stg.securecoder.engine.llm.ChatExchange
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlin.collections.plusAssign

class StructuredEditFilesLlmWrapper(
private val llmClient: LlmClient
) {
//TODO path => **uri** ; EditFilesLlmWrapper should be separate from the filesystem implementation
private val prompt = """
Your task it is to produce code. The agent will just parse the code you produce. So dont do a extensive review in your final answer!

It's acceptable to add multiple *search/REPLACE* sections if you need to change multiple parts of the file.
To create a file: search must be empty and replace must contain the entire file content
Each *search* pattern must match the existing source code exactly once, line for line, character for character, including all comments, docstrings, etc.
Do not use a part of the line as *search* pattern. You must use full lines.
Include enough lines to make code inside *search* pattern uniquely identifiable. A *search* pattern that produces multiple matches in the source code will be rejected as an error.
Do not add backslashes to escape special characters. Write the code exactly as it should appear in the intended programming language.
Do not use git diff style (+ and - at the beginning of the line) for *search/REPLACE* blocks.
Do not use line numbers in *search/REPLACE* blocks. Do not enclose the *search/REPLACE* block or any of its components in triple quotes. Use only tags to separate the parameters.
Do not use the same value for *search* and *REPLACE* parameters, as this will make no changes.

If you need to edit a file again after making changes, use the latest version of the code that includes all your modifications applied during **current session**.
""".trimIndent()


suspend fun chat(
messages: List<ChatMessage>,
fileSystem: FileSystem,
params: LlmClient.GenerationParams = LlmClient.GenerationParams(),
onParseError: suspend (parseErrors: List<String>, llm: ChatExchange) -> Unit = { _, _ -> },
attempts: Int = 3
): ChatResult {
val messages = messages.toMutableList()
appendPromptToLastSystem(messages)
repeat(attempts) {
val llmInput = messages.toList()
val structured = llmClient.chatStructured<StructuredEdits>(llmInput, params)
messages += ChatMessage(Role.Assistant, Json.encodeToString(structured))
when (val result = validateAndConvert(structured, fileSystem)) {
is ParseResult.Ok -> return ChatResult(messages, result.value)
is ParseResult.Err -> {
messages += ChatMessage(Role.User, result.buildMessage())
onParseError(result.messages, ChatExchange(llmInput, messages.last().content))
}
}
}
return ChatResult(messages, null)
}

data class ChatResult(val messages: List<ChatMessage>, val changes: Changes?) {
fun changesMessage() = messages.last { it.role == Role.Assistant }
}

sealed interface ParseResult {
data class Ok(val value: Changes) : ParseResult
data class Err(val messages: List<String>) : ParseResult {
fun buildMessage() = buildString {
appendLine("Your previous output could not be applied.")
appendLine("It violated the required format.")
appendLine("Errors:")
messages.forEach { appendLine(it) }
appendLine("Respond again with ONLY edit blocks that strictly follow the rules. Do NOT include prose, markdown, or explanations.")
appendLine("IMPORTANT: Resend the COMPLETE set of edits you intend to apply from your previous message")
}
}
}

private suspend fun validateAndConvert(structured: StructuredEdits, fileSystem: FileSystem): ParseResult {
val results = mutableListOf<Changes.SearchReplace>()
val allErrors = mutableListOf<String>()
if (structured.edits.isEmpty()) {
allErrors += "No edits provided. Provide at least one edit block."
return ParseResult.Err(allErrors)
}
for (e in structured.edits) {
val file = e.filePath.trim()
val searchPart = e.search
val replacePart = e.replace
if (file.isEmpty()) {
allErrors += "`filePath` should not be empty"
continue
}
if (searchPart == replacePart) {
allErrors += "`search` and `replace` parameters are the same"
continue
}
val replace = Changes.SearchReplace(file, SearchedText(searchPart), replacePart)
val content = fileSystem.getFile(file)?.content()
val match = ApplyChanges.match(content, replace.searchedText)
if (match is Matcher.MatchResult.Error) {
allErrors += ApplyChanges.buildErrorMessage(file, searchPart, match)
continue
}
results += replace
}
if (results.isEmpty()) return ParseResult.Err(allErrors)
return ParseResult.Ok(Changes(results))
}

private fun appendPromptToLastSystem(messages: MutableList<ChatMessage>) {
val lastSystemIndex = messages.indexOfLast { it.role == Role.System }
if (lastSystemIndex >= 0) {
val existing = messages[lastSystemIndex]
val combined = "${existing.content}\n\n$prompt\n\nRespond ONLY with a JSON object that matches the provided schema. Do not include explanations."
messages[lastSystemIndex] = ChatMessage(Role.System, combined)
} else {
messages += ChatMessage(Role.System, "$prompt\n\nRespond ONLY with a JSON object that matches the provided schema. Do not include explanations.")
}
}

@Serializable
data class StructuredEdits(
@LLMDescription("List of edit operations to apply")
val edits: List<EditOperation>
)

@Serializable
data class EditOperation(
@LLMDescription("The full **uri** of the file that will be modified")
val filePath: String,
@LLMDescription("A continuous, yet concise block of lines to search for in the existing source code (*search* pattern). If this section is empty, the lines from `replace` will be added to the end of the file.")
val search: String,
@LLMDescription("The lines to replace the existing code found using `search`. If this section is empty, the lines specified in `search` will be removed.")
val replace: String,
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package de.tuda.stg.securecoder.engine.llm

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.PolymorphicKind
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonArray
import kotlinx.serialization.json.buildJsonObject

@OptIn(ExperimentalSerializationApi::class)
class KxJsonSchemaFormat {
fun <T> format(serializer: KSerializer<T>): JsonObject =
schemaForDescriptor(serializer.descriptor, seen = HashSet())

private fun schemaForDescriptor(desc: SerialDescriptor, seen: MutableSet<String>): JsonObject {
val key = desc.serialName
if (!seen.add(key)) {
throw IllegalStateException("Recursive type detected: $key")
}
val jsonType = when (desc.kind) {
PrimitiveKind.BOOLEAN -> type("boolean")
PrimitiveKind.BYTE, PrimitiveKind.SHORT, PrimitiveKind.INT, PrimitiveKind.LONG -> type("integer")
PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> type("number")
PrimitiveKind.CHAR, PrimitiveKind.STRING -> type("string")
SerialKind.ENUM -> type("string") {
put("enum", buildJsonArray {
for (i in 0 until desc.elementsCount) {
add(JsonPrimitive(desc.getElementName(i)))
}
})
}
StructureKind.LIST -> type("array") {
put("items", schemaForDescriptor(desc.getElementDescriptor(0), seen))
}
StructureKind.MAP -> type("object") {
val keyDesc = desc.getElementDescriptor(0)
if (keyDesc.kind != PrimitiveKind.STRING) {
throw IllegalStateException("Map keys must be strings, but was ${keyDesc.serialName}")
}
put("additionalProperties", schemaForDescriptor(desc.getElementDescriptor(1), seen))
}
StructureKind.CLASS, StructureKind.OBJECT -> type("object") {
put("properties", buildJsonObject {
for (i in 0 until desc.elementsCount) {
val name = desc.getElementName(i)
val childDesc = desc.getElementDescriptor(i)
val childSchema = schemaForDescriptor(childDesc, seen)
val propDesc = getDescription(desc.getElementAnnotations(i))
put(name, if (propDesc != null) addDescription(childSchema, propDesc) else childSchema)
}
})
val required = JsonArray(desc.requiredElements().map { name -> JsonPrimitive(name) })
if (required.isNotEmpty()) put("required", required)
put("additionalProperties", JsonPrimitive(false))
}
PolymorphicKind.SEALED, PolymorphicKind.OPEN, SerialKind.CONTEXTUAL
-> throw IllegalStateException("Polymorphic types are not supported")
}
seen.remove(key)
if (desc.isNullable) {
throw IllegalStateException("Nullable types are not supported")
}
val selfDesc = getDescription(desc.annotations)
return if (selfDesc != null) addDescription(jsonType, selfDesc) else jsonType
}

private fun type(name: String, builderAction: JsonObjectBuilder.() -> Unit = {}): JsonObject =
buildJsonObject {
put("type", JsonPrimitive(name))
builderAction()
}

private fun SerialDescriptor.requiredElements(): List<String> = (0 until elementsCount)
.filter { !isElementOptional(it) }
.map { getElementName(it) }

private fun JsonArray.isNotEmpty(): Boolean = this.size > 0

private fun getDescription(annotations: List<Annotation>): String? =
annotations.filterIsInstance<LLMDescription>().firstOrNull()?.text

private fun addDescription(obj: JsonObject, text: String): JsonObject =
buildJsonObject {
obj.forEach { (k, v) -> put(k, v) }
put("description", JsonPrimitive(text))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package de.tuda.stg.securecoder.engine.llm

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialInfo

@OptIn(ExperimentalSerializationApi::class)
@SerialInfo
@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY)
@Retention(AnnotationRetention.RUNTIME)
annotation class LLMDescription(val text: String)
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
package de.tuda.stg.securecoder.engine.llm

import kotlinx.serialization.KSerializer
import kotlinx.serialization.serializer

interface LlmClient : AutoCloseable {
suspend fun chat(
messages: List<ChatMessage>,
params: GenerationParams = GenerationParams(),
): String

suspend fun <T> chatStructured(
messages: List<ChatMessage>,
serializer: KSerializer<T>,
params: GenerationParams = GenerationParams(),
): T

data class GenerationParams(
val temperature: Double? = null,
val maxTokens: Int? = null
)
}

suspend inline fun <reified T> LlmClient.chatStructured(
messages: List<ChatMessage>,
params: LlmClient.GenerationParams = LlmClient.GenerationParams(),
): T = this.chatStructured(messages, serializer(), params)
Loading