Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cypher/models/pgsql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package pgsql
import (
"bytes"
"encoding/json"
"strconv"
"strings"

"reflect"

Expand Down Expand Up @@ -55,6 +57,21 @@ func PropertiesToJSONB(properties *graph.Properties) (pgtype.JSONB, error) {
return MapStringAnyToJSONB(properties.MapOrEmpty())
}

func DeletedPropertiesToString(properties *graph.Properties) string {
if properties == nil {
return "{}"
}

deleted := properties.DeletedProperties()
quoted := make([]string, 0, len(deleted))

for _, prop := range deleted {
quoted = append(quoted, strconv.Quote(prop))
}

return "{" + strings.Join(quoted, ",") + "}"
}
Comment thread
bsheth711 marked this conversation as resolved.

func JSONBToProperties(jsonb pgtype.JSONB) (*graph.Properties, error) {
propertiesMap := make(map[string]any)

Expand Down
80 changes: 80 additions & 0 deletions cypher/models/pgsql/type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package pgsql_test

import (
"testing"

"github.com/specterops/dawgs/cypher/models/pgsql"
"github.com/specterops/dawgs/graph"
"github.com/stretchr/testify/assert"
)

func TestDeletedPropertiesToString(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setup func() *graph.Properties
check func(t *testing.T, result string)
}{
{
name: "no deleted properties returns empty braces",
setup: func() *graph.Properties {
return graph.NewProperties()
},
check: func(t *testing.T, result string) {
assert.Equal(t, "{}", result)
},
},
{
name: "single deleted property is wrapped in braces",
setup: func() *graph.Properties {
props := graph.NewProperties()
props.Set("mykey", "myvalue")
props.Delete("mykey")
return props
},
check: func(t *testing.T, result string) {
assert.Equal(t, "{\"mykey\"}", result)
},
},
{
name: "multiple deleted properties are all present in output",
setup: func() *graph.Properties {
props := graph.NewProperties()
props.Set("alpha", 1)
props.Set("beta", 2)
props.Delete("alpha")
props.Delete("beta")
return props
},
check: func(t *testing.T, result string) {
// Map iteration order is non-deterministic; accept either ordering.
assert.True(t, result == "{\"alpha\",\"beta\"}" || result == "{\"beta\",\"alpha\"}",
"unexpected result: %s", result)
},
},
{
name: "non-deleted properties are not included",
setup: func() *graph.Properties {
props := graph.NewProperties()
props.Set("active", "yes")
props.Set("removed", "no")
props.Delete("removed")
return props
},
check: func(t *testing.T, result string) {
assert.Equal(t, "{\"removed\"}", result)
assert.NotContains(t, result, "active")
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
props := tc.setup()
result := pgsql.DeletedPropertiesToString(props)
tc.check(t, result)
})
}
}
36 changes: 31 additions & 5 deletions drivers/neo4j/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type batchTransaction struct {
innerTx *neo4jTransaction
nodeDeletionBuffer []graph.ID
relationshipDeletionBuffer []graph.ID
nodeUpdateBuffer []*graph.Node
nodeUpdateByBuffer []graph.NodeUpdate
relationshipCreateBuffer []createRelationshipByIDs
relationshipUpdateByBuffer []graph.RelationshipUpdate
Expand Down Expand Up @@ -48,8 +49,20 @@ func (s *batchTransaction) Relationships() graph.RelationshipQuery {
}

func (s *batchTransaction) UpdateNodeBy(update graph.NodeUpdate) error {
if s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update); len(s.nodeUpdateByBuffer) >= s.batchWriteSize {
return s.flushNodeUpdates()
s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update)

if len(s.nodeUpdateByBuffer) >= s.batchWriteSize {
return s.flushNodeUpdateByBuffer()
}

return nil
}

func (s *batchTransaction) UpdateNodes(nodes []*graph.Node) error {
s.nodeUpdateBuffer = append(s.nodeUpdateBuffer, nodes...)

if len(s.nodeUpdateBuffer) > s.batchWriteSize {
return s.flushNodeUpdateBuffer()
}

return nil
Comment thread
zinic marked this conversation as resolved.
Expand All @@ -73,7 +86,13 @@ func (s *batchTransaction) DeleteRelationships(ids []graph.ID) error {

func (s *batchTransaction) Commit() error {
if len(s.nodeUpdateByBuffer) > 0 {
if err := s.flushNodeUpdates(); err != nil {
if err := s.flushNodeUpdateByBuffer(); err != nil {
return err
}
}

if len(s.nodeUpdateBuffer) > 0 {
if err := s.flushNodeUpdateBuffer(); err != nil {
return err
}
}
Expand Down Expand Up @@ -110,7 +129,7 @@ func (s *batchTransaction) Close() error {
}

func (s *batchTransaction) UpdateNode(target *graph.Node) error {
return s.innerTx.UpdateNode(target)
return s.UpdateNodes([]*graph.Node{target})
}

func (s *batchTransaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error {
Expand Down Expand Up @@ -221,13 +240,20 @@ func (s *batchTransaction) flushRelationshipDeletions() error {
return s.DeleteRelationships(buffer)
}

func (s *batchTransaction) flushNodeUpdates() error {
func (s *batchTransaction) flushNodeUpdateByBuffer() error {
buffer := s.nodeUpdateByBuffer
s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0]

return s.innerTx.updateNodesBy(buffer...)
}

func (s *batchTransaction) flushNodeUpdateBuffer() error {
buffer := s.nodeUpdateBuffer
s.nodeUpdateBuffer = s.nodeUpdateBuffer[:0]

return s.innerTx.updateNodeBatch(buffer)
}

func (s *batchTransaction) flushRelationshipUpdates() error {
buffer := s.relationshipUpdateByBuffer
s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0]
Expand Down
130 changes: 130 additions & 0 deletions drivers/neo4j/batch_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
//go:build integration

package neo4j_test

import (
"context"
"log/slog"
"os"
"strconv"
"testing"
"time"

"github.com/specterops/dawgs"
"github.com/specterops/dawgs/drivers/neo4j"
"github.com/specterops/dawgs/graph"
"github.com/specterops/dawgs/ops"
"github.com/specterops/dawgs/query"
"github.com/stretchr/testify/assert"
)

const Neo4jConnectionStringEnv = "NEO4J_CONNECTION"

var (
NodeKind1 = graph.StringKind("NodeKind1")
NodeKind2 = graph.StringKind("NodeKind2")

NameProperty = "name"
HeatProperty = "heat"
NewProperty = "new"
)

func prepareNode(index int) *graph.Node {
return graph.PrepareNode(
graph.AsProperties(map[string]any{
NameProperty: "Node " + strconv.Itoa(index),
HeatProperty: 10 + index,
}),
NodeKind1,
)
}

func TestBatchTransaction_NodeUpdate(t *testing.T) {
const (
numNodes = 1_000
)

var (
neo4jConnectionStr = os.Getenv(Neo4jConnectionStringEnv)
ctx, done = context.WithCancel(context.Background())
)

defer done()

if neo4jConnectionStr == "" {
t.Fatalf("No Neo4j connection string specified. Test requires a valid Neo4j connection string present in the %s environment variable.", Neo4jConnectionStringEnv)
}

graphDB, err := dawgs.Open(ctx, neo4j.DriverName, dawgs.Config{
ConnectionString: neo4jConnectionStr,
GraphQueryMemoryLimit: 0,
})
assert.NoError(t, err)

// Regsiter a cleanup step to wipe the database
Comment thread
zinic marked this conversation as resolved.
t.Cleanup(func() {
cleanupCtx, done := context.WithTimeout(context.Background(), time.Minute)
defer done()

err := graphDB.WriteTransaction(cleanupCtx, func(tx graph.Transaction) error {
return tx.Nodes().Filter(query.Kind(query.Node(), NodeKind1)).Delete()
})

if err != nil {
slog.Error("Failed to cleanup after test.", slog.String("err", err.Error()))
}
})

// Insert nodes to batch update afterward
assert.NoError(t,
graphDB.BatchOperation(ctx, func(batch graph.Batch) error {
for idx := range numNodes {
if err := batch.CreateNode(prepareNode(idx)); err != nil {
return err
}
}

return nil
}),
)

// Update the nodes in batch
assert.NoError(t,
graphDB.BatchOperation(ctx, func(batch graph.Batch) error {
// Fetch all of the nodes and ensure that they the number of nodes expected matches
if nodes, err := ops.FetchNodes(batch.Nodes().Filter(query.Kind(query.Node(), NodeKind1))); err != nil {
return err
} else {
assert.Equal(t, numNodes, len(nodes))

for _, node := range nodes {
node.AddKinds(NodeKind2)

node.Properties.Set(NewProperty, true)
node.Properties.Delete(HeatProperty)
}

return batch.UpdateNodes(nodes)
}
}),
)

// Assert the node update
assert.NoError(t,
graphDB.ReadTransaction(ctx, func(tx graph.Transaction) error {
// Fetch all of the nodes and ensure that they have been updated
if nodes, err := ops.FetchNodes(tx.Nodes().Filter(query.Kind(query.Node(), NodeKind1))); err != nil {
return err
} else {
assert.Equal(t, numNodes, len(nodes))

for _, node := range nodes {
assert.True(t, node.Properties.Exists(HeatProperty))
assert.Equal(t, true, node.Properties.Get(NewProperty).Any())
}
Comment thread
zinic marked this conversation as resolved.

return nil
}
}),
)
}
Loading
Loading