diff --git a/parquet/pqarrow/file_writer_test.go b/parquet/pqarrow/file_writer_test.go index d8503ccb5..32713cd2f 100644 --- a/parquet/pqarrow/file_writer_test.go +++ b/parquet/pqarrow/file_writer_test.go @@ -172,7 +172,7 @@ func TestFileWriterTotalBytes(t *testing.T) { // Verify total bytes & compressed bytes are correct assert.Equal(t, int64(408), writer.TotalCompressedBytes()) - assert.Equal(t, int64(912), writer.TotalBytesWritten()) + assert.Equal(t, int64(910), writer.TotalBytesWritten()) } func TestFileWriterTotalBytesBuffered(t *testing.T) { @@ -206,5 +206,5 @@ func TestFileWriterTotalBytesBuffered(t *testing.T) { // Verify total bytes & compressed bytes are correct assert.Equal(t, int64(596), writer.TotalCompressedBytes()) - assert.Equal(t, int64(1308), writer.TotalBytesWritten()) + assert.Equal(t, int64(1306), writer.TotalBytesWritten()) } diff --git a/parquet/schema/schema.go b/parquet/schema/schema.go index 6d124eb17..3ff376890 100644 --- a/parquet/schema/schema.go +++ b/parquet/schema/schema.go @@ -272,6 +272,7 @@ func (t *toThriftVisitor) VisitPost(Node) {} func ToThrift(schema *GroupNode) []*format.SchemaElement { t := &toThriftVisitor{make([]*format.SchemaElement, 0)} schema.Visit(t) + t.elements[0].RepetitionType = nil return t.elements } diff --git a/parquet/schema/schema_flatten_test.go b/parquet/schema/schema_flatten_test.go index ecbb431c2..a39391610 100644 --- a/parquet/schema/schema_flatten_test.go +++ b/parquet/schema/schema_flatten_test.go @@ -92,8 +92,10 @@ func (s *SchemaFlattenSuite) TestDecimalMetadata() { func (s *SchemaFlattenSuite) TestNestedExample() { elements := make([]*format.SchemaElement, 0) + root := NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 /* numChildren */, 0 /* fieldID */) + root.RepetitionType = nil elements = append(elements, - NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 /* numChildren */, 0 /* fieldID */), + root, NewPrimitive("a" /* name */, format.FieldRepetitionType_REQUIRED, format.Type_INT32, 1 /* fieldID */), NewGroup("bag" /* name */, format.FieldRepetitionType_OPTIONAL, 1 /* numChildren */, 2 /* fieldID */)) @@ -120,6 +122,23 @@ func TestSchemaFlatten(t *testing.T) { suite.Run(t, new(SchemaFlattenSuite)) } +func TestToThriftRootRepetitionStripped(t *testing.T) { + for _, rep := range []parquet.Repetition{ + parquet.Repetitions.Repeated, + parquet.Repetitions.Required, + parquet.Repetitions.Optional, + } { + group := MustGroup(NewGroupNode("schema", rep, FieldList{ + NewInt32Node("a", parquet.Repetitions.Required, -1), + }, -1)) + elements := ToThrift(group) + assert.False(t, elements[0].IsSetRepetitionType(), + "root element should not have repetition_type set (was %v)", rep) + assert.True(t, elements[1].IsSetRepetitionType(), + "non-root element must have repetition_type set") + } +} + func TestInvalidConvertedTypeInDeserialize(t *testing.T) { n := MustPrimitive(NewPrimitiveNodeLogical("string" /* name */, parquet.Repetitions.Required, StringLogicalType{}, parquet.Types.ByteArray, -1 /* type len */, -1 /* fieldID */))