diff --git a/internal/merkle/builder.go b/internal/merkle/builder.go index 9f903a9e8..e517f0fe0 100644 --- a/internal/merkle/builder.go +++ b/internal/merkle/builder.go @@ -144,7 +144,10 @@ func (tree *Tree) GetRootHash() common.Hash { return tree.RootHash } -func (tree *Tree) FindChildByHash(hash common.Hash) *InnerNode { +func (tree *Tree) FindChildByHash(hash common.Hash) *Tree { + if tree.RootHash == hash { + return tree + } if inner := tree.Subtrees; inner != nil { if !inner.Valid() { panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner)) @@ -161,7 +164,7 @@ func (tree *Tree) FindChildByHash(hash common.Hash) *InnerNode { return lhs } - rhs := inner.LHS.FindChildByHash(hash) + rhs := inner.RHS.FindChildByHash(hash) if rhs != nil { return rhs } diff --git a/internal/merkle/builder_test.go b/internal/merkle/builder_test.go index 18a5d1d33..d533ba2b7 100644 --- a/internal/merkle/builder_test.go +++ b/internal/merkle/builder_test.go @@ -240,6 +240,64 @@ func TestBuildRootChildrenAgainstBuilder(t *testing.T) { assert.Equal(t, rootHashProof, crypto.Keccak256Hash(lhsProof[:], rhsProof[:])) } +func TestFindChildByHash(t *testing.T) { + t.Run("simple", func(t *testing.T) { + builder := Builder{} + leaf1 := TreeLeaf(common.HexToHash("0x1")) + leaf2 := TreeLeaf(common.HexToHash("0x2")) + builder.Append(leaf1) + builder.Append(leaf2) + tree := builder.Build() + + child1 := tree.FindChildByHash(leaf1.RootHash) + assert.NotNil(t, child1) + assert.Equal(t, leaf1.RootHash, child1.RootHash) + }) + + t.Run("repetitions", func(t *testing.T) { + builder := Builder{} + leaf1 := TreeLeaf(common.HexToHash("0x1")) + builder.AppendRepeatedUint64(leaf1, 1024) + tree := builder.Build() + + child1 := tree.FindChildByHash(leaf1.RootHash) + assert.NotNil(t, child1) + assert.Equal(t, leaf1.RootHash, child1.RootHash) + }) + + t.Run("deeper", func(t *testing.T) { + builder := Builder{} + leaf1 := TreeLeaf(common.HexToHash("0x1")) + leaf2 := TreeLeaf(common.HexToHash("0x2")) + leaf3 := TreeLeaf(common.HexToHash("0x3")) + leaf4 := TreeLeaf(common.HexToHash("0x4")) + builder.Append(leaf1) + builder.Append(leaf2) + builder.Append(leaf3) + builder.Append(leaf4) + tree := builder.Build() + + child1 := tree.FindChildByHash(leaf1.RootHash) + assert.NotNil(t, child1) + assert.Equal(t, leaf1.RootHash, child1.RootHash) + + child2 := tree.FindChildByHash(leaf2.RootHash) + assert.NotNil(t, child2) + assert.Equal(t, child2.RootHash, leaf2.RootHash) + }) + + t.Run("notfound", func(t *testing.T) { + builder := Builder{} + builder.Append(TreeLeaf(zeroDigest)) + builder.Append(TreeLeaf(oneDigest)) + tree := builder.Build() + + missing := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + child := tree.FindChildByHash(missing) + assert.Nil(t, child) + }) +} + // repanicked //func TestBuildNotPow2(t *testing.T) { // defer recover()