diff --git a/Cargo.lock b/Cargo.lock index 373a807cc20..ce5252d0d2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9850,6 +9850,7 @@ dependencies = [ "tracing-subscriber", "url", "vortex", + "vortex-parquet-variant", ] [[package]] diff --git a/java/vortex-jni/build.gradle.kts b/java/vortex-jni/build.gradle.kts index 09892fa6f44..c8f0192babc 100644 --- a/java/vortex-jni/build.gradle.kts +++ b/java/vortex-jni/build.gradle.kts @@ -48,7 +48,9 @@ mavenPublishing { coordinates(groupId = "dev.vortex", artifactId = "vortex-jni", version = "${rootProject.version}") publishToMavenCentral() - signAllPublications() + if (!project.hasProperty("skip.signing")) { + signAllPublications() + } pom { name = "vortex-jni" diff --git a/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java b/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java index 58cdbbf5315..03869ae8598 100644 --- a/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java +++ b/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java @@ -4,6 +4,7 @@ package dev.vortex.jni; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -26,17 +27,25 @@ import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; public final class JNIWriterTest { + private static final String ARROW_EXTENSION_NAME = "ARROW:extension:name"; + private static final String PARQUET_VARIANT_EXTENSION_NAME = "arrow.parquet.variant"; + private static final byte[] VARIANT_METADATA = new byte[] {0x01, 0x00}; + private static final byte[] VARIANT_INT8_42 = new byte[] {0x0c, 0x2a}; + private static final byte[] VARIANT_TRUE = new byte[] {0x04}; @TempDir Path tempDir; @@ -52,6 +61,45 @@ private static Schema personSchema() { Field.notNullable("age", new ArrowType.Int(32, true)))); } + private static Schema parquetVariantSchema() { + Field variant = new Field( + "variant", + new FieldType( + true, + ArrowType.Struct.INSTANCE, + null, + Map.of(ARROW_EXTENSION_NAME, PARQUET_VARIANT_EXTENSION_NAME)), + List.of( + Field.notNullable("metadata", new ArrowType.Binary()), + Field.nullable("value", new ArrowType.Binary()))); + return new Schema(List.of(variant)); + } + + private static void populateParquetVariantRoot(VectorSchemaRoot root) { + StructVector variant = (StructVector) root.getVector("variant"); + VarBinaryVector metadata = variant.getChild("metadata", VarBinaryVector.class); + VarBinaryVector value = variant.getChild("value", VarBinaryVector.class); + + variant.allocateNew(); + metadata.allocateNew(3); + value.allocateNew(3); + + metadata.setSafe(0, VARIANT_METADATA); + metadata.setSafe(1, VARIANT_METADATA); + metadata.setSafe(2, VARIANT_METADATA); + value.setSafe(0, VARIANT_INT8_42); + value.setSafe(1, VARIANT_TRUE); + value.setNull(2); + variant.setIndexDefined(0); + variant.setIndexDefined(1); + variant.setNull(2); + + metadata.setValueCount(3); + value.setValueCount(3); + variant.setValueCount(3); + root.setRowCount(3); + } + @Test public void testCreateWriter() throws IOException { Path outputPath = tempDir.resolve("test_create.vortex"); @@ -155,4 +203,53 @@ public void testWriteBatch() throws IOException { } } } + + @Test + public void testParquetVariantRoundTrip() throws IOException { + Path outputPath = tempDir.resolve("test_parquet_variant.vortex"); + String writePath = outputPath.toAbsolutePath().toUri().toString(); + + BufferAllocator allocator = ArrowAllocation.rootAllocator(); + Schema schema = parquetVariantSchema(); + + Session session = Session.create(); + try (VortexWriter writer = VortexWriter.create(session, writePath, schema, new HashMap<>(), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + populateParquetVariantRoot(root); + + try (ArrowArray arrowArray = ArrowArray.allocateNew(allocator); + ArrowSchema arrowSchemaFfi = ArrowSchema.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchemaFfi); + writer.writeBatch(arrowArray.memoryAddress(), arrowSchemaFfi.memoryAddress()); + } + } + + assertTrue(Files.exists(outputPath), "output file should exist"); + + DataSource ds = DataSource.open(session, writePath); + Field dataSourceField = ds.arrowSchema(allocator).findField("variant"); + assertEquals( + PARQUET_VARIANT_EXTENSION_NAME, dataSourceField.getMetadata().get(ARROW_EXTENSION_NAME)); + + Scan scan = ds.scan(ScanOptions.of()); + Field scanField = scan.arrowSchema(allocator).findField("variant"); + assertEquals(PARQUET_VARIANT_EXTENSION_NAME, scanField.getMetadata().get(ARROW_EXTENSION_NAME)); + + while (scan.hasNext()) { + Partition p = scan.next(); + try (ArrowReader reader = p.scanArrow(allocator)) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot resultRoot = reader.getVectorSchemaRoot(); + StructVector variant = (StructVector) resultRoot.getVector("variant"); + VarBinaryVector metadata = variant.getChild("metadata", VarBinaryVector.class); + VarBinaryVector value = variant.getChild("value", VarBinaryVector.class); + + assertArrayEquals(VARIANT_METADATA, metadata.get(0)); + assertArrayEquals(VARIANT_INT8_42, value.get(0)); + assertArrayEquals(VARIANT_METADATA, metadata.get(1)); + assertArrayEquals(VARIANT_TRUE, value.get(1)); + assertTrue(variant.isNull(2)); + } + } + } } diff --git a/vortex-jni/Cargo.toml b/vortex-jni/Cargo.toml index 628c2c89d43..62b89e61fc8 100644 --- a/vortex-jni/Cargo.toml +++ b/vortex-jni/Cargo.toml @@ -29,6 +29,7 @@ tracing = { workspace = true, features = ["std", "log"] } tracing-subscriber = { workspace = true, features = ["env-filter"] } url = { workspace = true } vortex = { workspace = true, features = ["object_store", "files"] } +vortex-parquet-variant = { workspace = true } [dev-dependencies] jni = { workspace = true, features = ["invocation"] } diff --git a/vortex-jni/src/dtype.rs b/vortex-jni/src/dtype.rs index f748135ade1..5a0b54d2c14 100644 --- a/vortex-jni/src/dtype.rs +++ b/vortex-jni/src/dtype.rs @@ -10,8 +10,8 @@ use arrow_array::ffi::FFI_ArrowSchema; use arrow_schema::DataType; use arrow_schema::FieldRef; use arrow_schema::Fields; +use arrow_schema::Schema; use vortex::dtype::DType; -use vortex::dtype::arrow::FromArrowType; use vortex::error::VortexResult; /// Export a Vortex [`DType`] to the Arrow C Data Interface struct at `schema_addr`. Views @@ -24,7 +24,7 @@ pub(crate) fn export_dtype_to_arrow(dtype: &DType, schema_addr: i64) -> VortexRe DataType::Struct(fields) => fields, _ => unreachable!("Vortex DType always exports as a struct"), }; - let schema = arrow_schema::Schema::new(fields); + let schema = Schema::new(fields); let ffi_schema = FFI_ArrowSchema::try_from(&schema)?; unsafe { ptr::write(schema_addr as *mut FFI_ArrowSchema, ffi_schema); @@ -70,9 +70,8 @@ pub(crate) fn strip_views(data_type: DataType) -> DataType { } } -/// Decode an [`FFI_ArrowSchema`] pointed to by `schema_addr` into a Vortex [`DType`]. -pub(crate) fn import_dtype_from_arrow(schema_addr: i64) -> VortexResult { +/// Decode an [`FFI_ArrowSchema`] pointed to by `schema_addr` into an Arrow [`Schema`]. +pub(crate) fn import_arrow_schema(schema_addr: i64) -> VortexResult { let ffi_schema = unsafe { &*(schema_addr as *const FFI_ArrowSchema) }; - let arrow_schema = arrow_schema::Schema::try_from(ffi_schema)?; - Ok(DType::from_arrow(&arrow_schema)) + Ok(Schema::try_from(ffi_schema)?) } diff --git a/vortex-jni/src/session.rs b/vortex-jni/src/session.rs index 9f75d24f564..9adaf544431 100644 --- a/vortex-jni/src/session.rs +++ b/vortex-jni/src/session.rs @@ -16,7 +16,9 @@ use crate::RUNTIME; /// Constructs a fresh [`VortexSession`] bound to the JNI-shared tokio runtime and returns /// an opaque pointer that Java must pass to [`Java_dev_vortex_jni_NativeSession_free`]. pub(crate) fn new_session() -> Box { - Box::new(VortexSession::default().with_handle(RUNTIME.handle())) + let session = VortexSession::default().with_handle(RUNTIME.handle()); + vortex_parquet_variant::initialize(&session); + Box::new(session) } /// SAFETY: caller must pass a pointer previously returned by [`new_session`]. diff --git a/vortex-jni/src/writer.rs b/vortex-jni/src/writer.rs index da30952d87b..a7d1e6432f2 100644 --- a/vortex-jni/src/writer.rs +++ b/vortex-jni/src/writer.rs @@ -13,6 +13,7 @@ use arrow_array::RecordBatch; use arrow_array::StructArray; use arrow_array::ffi::FFI_ArrowArray; use arrow_array::ffi::FFI_ArrowSchema; +use arrow_schema::SchemaRef; use async_fs::File; use futures::SinkExt; use futures::channel::mpsc; @@ -28,12 +29,16 @@ use object_store::ObjectStore; use object_store::path::Path as ObjectStorePath; use url::Url; use vortex::array::ArrayRef; -use vortex::array::arrow::FromArrowArray; +use vortex::array::VTable; +use vortex::array::arrow::ArrowSessionExt; use vortex::array::stream::ArrayStreamAdapter; use vortex::dtype::DType; +use vortex::dtype::Field as DTypeField; +use vortex::dtype::FieldPath; use vortex::error::VortexResult; use vortex::error::vortex_err; use vortex::file::WriteOptionsSessionExt; +use vortex::file::WriteStrategyBuilder; use vortex::file::WriteSummary; use vortex::io::VortexWrite; use vortex::io::compat::Compat; @@ -41,10 +46,14 @@ use vortex::io::object_store::ObjectStoreWrite; use vortex::io::runtime::BlockingRuntime; use vortex::io::runtime::Task; use vortex::io::session::RuntimeSessionExt; +use vortex::layout::LayoutStrategy; +use vortex::layout::layouts::flat::writer::FlatLayoutStrategy; +use vortex::session::VortexSession; use vortex::utils::aliases::hash_map::HashMap; +use vortex_parquet_variant::ParquetVariant; use crate::RUNTIME; -use crate::dtype::import_dtype_from_arrow; +use crate::dtype::import_arrow_schema; use crate::errors::JNIError; use crate::errors::try_or_throw; use crate::file::extract_properties; @@ -81,21 +90,71 @@ fn resolve_store( } } +fn write_options_for_schema( + session: &VortexSession, + write_schema: &DType, +) -> vortex::file::VortexWriteOptions { + let variant_paths = variant_field_paths(write_schema); + if variant_paths.is_empty() { + return session.write_options(); + } + + let mut allowed = vortex::file::ALLOWED_ENCODINGS.clone(); + allowed.insert(ParquetVariant.id()); + let flat: Arc = + Arc::new(FlatLayoutStrategy::default().with_allow_encodings(allowed)); + + let mut strategy = WriteStrategyBuilder::default(); + for path in variant_paths { + strategy = strategy.with_field_writer(path, Arc::clone(&flat)); + } + + session.write_options().with_strategy(strategy.build()) +} + +fn variant_field_paths(dtype: &DType) -> Vec { + let mut paths = Vec::new(); + collect_variant_field_paths(dtype, FieldPath::root(), &mut paths); + paths +} + +fn collect_variant_field_paths(dtype: &DType, path: FieldPath, paths: &mut Vec) { + match dtype { + DType::Variant(_) => paths.push(path), + DType::Struct(fields, _) => { + for (name, field_dtype) in fields.names().iter().zip(fields.fields()) { + collect_variant_field_paths( + &field_dtype, + path.clone().push(DTypeField::from(name.clone())), + paths, + ); + } + } + _ => {} + } +} + /// Native writer holding a write-task handle and a sender that Java pushes batches into. pub struct NativeWriter { handle: Option>>, + session: VortexSession, + arrow_schema: SchemaRef, write_schema: DType, sender: mpsc::Sender>, } impl NativeWriter { pub fn new( + session: VortexSession, + arrow_schema: SchemaRef, write_schema: DType, handle: Task>, sender: mpsc::Sender>, ) -> Self { Self { handle: Some(handle), + session, + arrow_schema, write_schema, sender, } @@ -117,7 +176,10 @@ impl NativeWriter { } fn write_record_batch(&self, batch: RecordBatch) -> VortexResult<()> { - let vortex_batch = ArrayRef::from_arrow(batch, false)?; + let vortex_batch = self + .session + .arrow() + .from_arrow_record_batch(batch, self.arrow_schema.as_ref())?; if !vortex_batch.dtype().eq(&self.write_schema) { return Err(vortex_err!( "write schema mismatch: expected {}, got {}", @@ -162,13 +224,15 @@ pub extern "system" fn Java_dev_vortex_jni_NativeWriter_create( } let session = unsafe { session_ref(session_ptr) }; - let write_schema = import_dtype_from_arrow(arrow_schema_addr)?; + let arrow_schema = Arc::new(import_arrow_schema(arrow_schema_addr)?); + let write_schema = session.arrow().from_arrow_schema(arrow_schema.as_ref())?; let file_path: String = uri.try_to_string(env)?; let properties: HashMap = extract_properties(env, &options)?; let resolved = resolve_store(&file_path, &properties)?; let (tx, rx) = mpsc::channel(WRITE_CHANNEL_CAPACITY); let stream = ArrayStreamAdapter::new(write_schema.clone(), rx); + let write_options = write_options_for_schema(session, &write_schema); let handle = session.handle().spawn(async move { match resolved { @@ -177,21 +241,28 @@ pub extern "system" fn Java_dev_vortex_jni_NativeWriter_create( async_fs::create_dir_all(parent).await?; } let mut file = File::create(path).await?; - let summary = session.write_options().write(&mut file, stream).await?; + let summary = write_options.write(&mut file, stream).await?; file.shutdown().await?; Ok(summary) } ResolvedStore::ObjectStore(store, path) => { let mut write = ObjectStoreWrite::new(Arc::new(Compat::new(store)), &path).await?; - let summary = session.write_options().write(&mut write, stream).await?; + let summary = write_options.write(&mut write, stream).await?; write.shutdown().await?; Ok(summary) } } }); - Ok(Box::new(NativeWriter::new(write_schema, handle, tx)).into_raw()) + Ok(Box::new(NativeWriter::new( + session.clone(), + arrow_schema, + write_schema, + handle, + tx, + )) + .into_raw()) }) }