diff --git a/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/StructureGenerator.java b/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/StructureGenerator.java index 386aefcba..79f238578 100644 --- a/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/StructureGenerator.java +++ b/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/StructureGenerator.java @@ -5,6 +5,7 @@ package software.amazon.smithy.java.codegen.generators; +import java.io.Closeable; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -18,6 +19,7 @@ import java.util.List; import java.util.Objects; import java.util.function.Consumer; +import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.codegen.core.directed.ShapeDirective; import software.amazon.smithy.java.codegen.CodeGenerationContext; @@ -35,6 +37,7 @@ import software.amazon.smithy.java.core.serde.document.Document; import software.amazon.smithy.java.io.datastream.DataStream; import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.OperationIndex; import software.amazon.smithy.model.node.ArrayNode; import software.amazon.smithy.model.node.BooleanNode; import software.amazon.smithy.model.node.Node; @@ -80,6 +83,8 @@ public final class StructureGenerator< T extends ShapeDirective> implements Consumer { + private MemberShape streamingMember; + @Override public void accept(T directive) { if (directive.shape().hasTrait(UnitTypeTrait.class) || directive.symbol() @@ -89,11 +94,12 @@ public void accept(T directive) { return; } var shape = directive.shape(); + setStreamingMember(directive.model(), shape); directive.context().writerDelegator().useShapeWriter(shape, writer -> { writer.pushState(new ClassSection(shape)); var template = """ - public final class ${shape:T} ${^isError}implements ${serializableStruct:T}${/isError}${?isError}extends ${sdkException:T}${/isError} { + public final class ${shape:T}${classModifiers:C} { ${schemas:C|} @@ -115,17 +121,25 @@ public final class ${shape:T} ${^isError}implements ${serializableStruct:T}${/is ${getMemberValue:C|} + ${?isCloseable}${close:C|}${/isCloseable} + ${toBuilder:C|} ${builder:C|} } """; + var sdkError = CodegenUtils.tryGetServiceProperty(directive, SymbolProperties.SERVICE_EXCEPTION); + var isCloseable = streamingMember != null; + writer.putContext("classModifiers", + new ClassModifiers(writer, + shape, + sdkError, + directive.symbolProvider(), + directive.model(), + isCloseable)); writer.putContext("isError", shape.hasTrait(ErrorTrait.class)); + writer.putContext("isCloseable", isCloseable); writer.putContext("shape", directive.symbol()); - writer.putContext("serializableStruct", SerializableStruct.class); - - var sdkError = CodegenUtils.tryGetServiceProperty(directive, SymbolProperties.SERVICE_EXCEPTION); - writer.putContext("sdkException", sdkError == null ? ModeledException.class : sdkError); writer.putContext("id", new IdStringGenerator(writer, shape)); writer.putContext( @@ -150,6 +164,7 @@ public final class ${shape:T} ${^isError}implements ${serializableStruct:T}${/is "hashCode", new HashCodeGenerator(writer, shape, directive.symbolProvider(), directive.model())); writer.putContext("toString", new ToStringGenerator(writer)); + writer.putContext("close", new CloseGenerator(writer, shape, directive.symbolProvider(), streamingMember)); writer.putContext( "serializer", new StructureSerializerGenerator( @@ -175,6 +190,47 @@ public final class ${shape:T} ${^isError}implements ${serializableStruct:T}${/is }); } + private void setStreamingMember(Model model, StructureShape shape) { + if (!model.isTraitApplied(StreamingTrait.class)) { + return; + } + var operationIndex = OperationIndex.of(model); + var isInputStructure = operationIndex.isInputStructure(shape); + var isOutputStructure = operationIndex.isOutputStructure(shape); + if (!isInputStructure && !isOutputStructure) { + return; + } + for (var member : shape.members()) { + var target = model.expectShape(member.getTarget()); + if (target.hasTrait(StreamingTrait.class)) { + streamingMember = member; + } + } + } + + private record ClassModifiers( + JavaWriter writer, + Shape shape, + Symbol sdkError, + SymbolProvider symbolProvider, + Model model, + boolean isCloseable) + implements + Runnable { + @Override + public void run() { + if (shape.hasTrait(ErrorTrait.class)) { + writer.writeInline(" extends $T", sdkError == null ? ModeledException.class : sdkError); + return; + } + if (isCloseable) { + writer.writeInline(" implements $T, $T", SerializableStruct.class, Closeable.class); + } else { + writer.writeInline(" implements $T", SerializableStruct.class); + } + } + } + private record PropertyGenerator(JavaWriter writer, Shape shape, SymbolProvider symbolProvider, Model model) implements Runnable { @@ -484,6 +540,33 @@ private void writeMemberHash(JavaWriter writer, MemberShape member) { } } + private record CloseGenerator( + JavaWriter writer, + StructureShape shape, + SymbolProvider symbolProvider, + MemberShape streamingMember) implements Runnable { + @Override + public void run() { + var memberName = symbolProvider.toMemberName(streamingMember); + + writer.pushState(); + writer.putContext("memberName", memberName); + writer.write( + """ + /** + * Closes the underlying stream. + */ + @Override + public void close() { + if (${memberName:L} != null) { + ${memberName:L}.close(); + } + } + """); + writer.popState(); + } + } + private record ToBuilderGenerator( JavaWriter writer, StructureShape shape, diff --git a/fuzz-test-harness/smithy-build.json b/fuzz-test-harness/smithy-build.json index 1310ce1be..02ec88aac 100644 --- a/fuzz-test-harness/smithy-build.json +++ b/fuzz-test-harness/smithy-build.json @@ -3,7 +3,6 @@ "plugins": { "java-codegen": { "namespace": "software.smithy.fuzz.test", - "useExternalTypes": true, "modes": ["types"] } }