Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 96 additions & 28 deletions consumer/session/session_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,48 @@ type Storage struct {

mu sync.RWMutex
sessionsActive map[session_node.ID]History

writeQueue chan func()
stopQueue chan struct{}
}

// NewSessionStorage creates session repository with given dependencies.
func NewSessionStorage(storage *boltdb.Bolt) *Storage {
return &Storage{
s := &Storage{
storage: storage,
timeGetter: time.Now,

sessionsActive: make(map[session_node.ID]History),
writeQueue: make(chan func(), 100),
stopQueue: make(chan struct{}),
}
go s.writeWorker()
return s
}

func (repo *Storage) writeWorker() {
for {
select {
case fn := <-repo.writeQueue:
fn()
case <-repo.stopQueue:
return
}
}
}

// Close stops the background write worker.
func (repo *Storage) Close() {
close(repo.stopQueue)
}

// waitForQueue blocks until all queued writes have been processed.
func (repo *Storage) waitForQueue() {
done := make(chan struct{})
repo.writeQueue <- func() { close(done) }
<-done
}

// Subscribe subscribes to relevant events of event bus.
func (repo *Storage) Subscribe(bus eventbus.Subscriber) error {
if err := bus.Subscribe(session_event.AppTopicSession, repo.consumeServiceSessionEvent); err != nil {
Expand All @@ -69,13 +99,13 @@ func (repo *Storage) Subscribe(bus eventbus.Subscriber) error {
if err := bus.SubscribeAsync(session_event.AppTopicTokensEarned, repo.consumeServiceSessionEarningsEvent); err != nil {
return err
}
if err := bus.Subscribe(connectionstate.AppTopicConnectionSession, repo.consumeConnectionSessionEvent); err != nil {
if err := bus.SubscribeAsync(connectionstate.AppTopicConnectionSession, repo.consumeConnectionSessionEvent); err != nil {
return err
}
if err := bus.Subscribe(connectionstate.AppTopicConnectionStatistics, repo.consumeConnectionStatisticsEvent); err != nil {
return err
}
return bus.Subscribe(pingpong_event.AppTopicInvoicePaid, repo.consumeConnectionSpendingEvent)
return bus.SubscribeAsync(pingpong_event.AppTopicInvoicePaid, repo.consumeConnectionSpendingEvent)
}

// GetAll returns array of all sessions.
Expand Down Expand Up @@ -274,65 +304,103 @@ func (repo *Storage) consumeConnectionStatisticsEvent(e connectionstate.AppEvent

func (repo *Storage) consumeConnectionSpendingEvent(e pingpong_event.AppEventInvoicePaid) {
repo.mu.Lock()
defer repo.mu.Unlock()

sessionID := session_node.ID(e.SessionID)

row, ok := repo.activeSession(sessionID)
if !ok {
repo.mu.Unlock()
return
}
row.Updated = repo.timeGetter().UTC()
row.Tokens = e.Invoice.AgreementTotal
repo.mu.Unlock()

err := repo.storage.Update(sessionStorageBucketName, &row)
if err != nil {
log.Error().Err(err).Msgf("Session %v update failed", sessionID)
return
}
fn := func() {
err := repo.storage.Update(sessionStorageBucketName, &row)
if err != nil {
log.Error().Err(err).Msgf("Session %v update failed", sessionID)
return
}

repo.sessionsActive[sessionID] = row
log.Debug().Msgf("Session %v updated", sessionID)
repo.mu.Lock()
repo.sessionsActive[sessionID] = row
repo.mu.Unlock()
log.Debug().Msgf("Session %v updated", sessionID)
}
select {
case repo.writeQueue <- fn:
case <-repo.stopQueue:
log.Warn().Msgf("Session storage closed, dropping spending update for %v", sessionID)
default:
log.Warn().Msgf("Session write queue full, dropping spending update for %v", sessionID)
}
}

func (repo *Storage) handleEndedEvent(sessionID session_node.ID) {
repo.mu.Lock()
defer repo.mu.Unlock()

row, ok := repo.sessionsActive[sessionID]
if !ok {
repo.mu.Unlock()
log.Warn().Msgf("Can't find session %v to update", sessionID)
return
}

row.Updated = repo.timeGetter().UTC()
row.Status = StatusCompleted
repo.mu.Unlock()

err := repo.storage.Update(sessionStorageBucketName, &row)
if err != nil {
log.Error().Err(err).Msgf("Session %v update failed", sessionID)
return
}
fn := func() {
err := repo.storage.Update(sessionStorageBucketName, &row)
if err != nil {
log.Error().Err(err).Msgf("Session %v update failed", sessionID)
return
}

delete(repo.sessionsActive, sessionID)
log.Debug().Msgf("Session %v updated with final data", sessionID)
repo.mu.Lock()
delete(repo.sessionsActive, sessionID)
repo.mu.Unlock()
log.Debug().Msgf("Session %v updated with final data", sessionID)
}
select {
case repo.writeQueue <- fn:
case <-repo.stopQueue:
log.Warn().Msgf("Session storage closed, dropping end-event update for %v", sessionID)
default:
log.Warn().Msgf("Session write queue full, dropping end-event update for %v", sessionID)
}
}

func (repo *Storage) handleCreatedEvent(sessionID session_node.ID) {
repo.mu.Lock()
defer repo.mu.Unlock()

row, ok := repo.sessionsActive[sessionID]
if !ok {
repo.mu.Unlock()
log.Warn().Msgf("Can't find session %v to store", sessionID)
return
}

row.Status = StatusNew
repo.mu.Unlock()

err := repo.storage.Store(sessionStorageBucketName, &row)
if err != nil {
log.Error().Err(err).Msgf("Session %v insert failed", row.SessionID)
return
}
fn := func() {
err := repo.storage.Store(sessionStorageBucketName, &row)
if err != nil {
log.Error().Err(err).Msgf("Session %v insert failed", row.SessionID)
return
}

repo.sessionsActive[sessionID] = row
log.Debug().Msgf("Session %v saved", row.SessionID)
repo.mu.Lock()
repo.sessionsActive[sessionID] = row
repo.mu.Unlock()
log.Debug().Msgf("Session %v saved", row.SessionID)
}
select {
case repo.writeQueue <- fn:
case <-repo.stopQueue:
log.Warn().Msgf("Session storage closed, dropping create-event for %v", sessionID)
default:
log.Warn().Msgf("Session write queue full, dropping create-event for %v", sessionID)
}
}
9 changes: 8 additions & 1 deletion consumer/session/session_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"

"github.com/mysteriumnetwork/node/core/connection/connectionstate"
"github.com/mysteriumnetwork/node/core/discovery/proposal"
"github.com/mysteriumnetwork/node/core/storage/boltdb"
Expand All @@ -33,7 +35,6 @@ import (
session_event "github.com/mysteriumnetwork/node/session/event"
"github.com/mysteriumnetwork/node/session/pingpong/event"
"github.com/mysteriumnetwork/payments/crypto"
"github.com/stretchr/testify/assert"
)

var (
Expand Down Expand Up @@ -265,6 +266,7 @@ func TestSessionStorage_consumeServiceSessionsEvent(t *testing.T) {
Status: session_event.CreatedStatus,
Session: serviceSessionMock,
})
storage.waitForQueue()
// then
sessions, err := storage.GetAll()
assert.Nil(t, err)
Expand Down Expand Up @@ -301,6 +303,7 @@ func TestSessionStorage_consumeServiceSessionsEvent(t *testing.T) {
Status: session_event.RemovedStatus,
Session: serviceSessionMock,
})
storage.waitForQueue()
// then
sessions, err = storage.GetAll()
assert.Nil(t, err)
Expand Down Expand Up @@ -348,6 +351,7 @@ func TestSessionStorage_consumeEventEndedOK(t *testing.T) {
Status: connectionstate.SessionEndedStatus,
SessionInfo: connectionSessionMock,
})
storage.waitForQueue()

// then
sessions, err := storage.GetAll()
Expand Down Expand Up @@ -385,6 +389,7 @@ func TestSessionStorage_consumeEventConnectedOK(t *testing.T) {
Status: connectionstate.SessionCreatedStatus,
SessionInfo: connectionSessionMock,
})
storage.waitForQueue()

// then
sessions, err := storage.GetAll()
Expand Down Expand Up @@ -428,6 +433,7 @@ func TestSessionStorage_consumeSessionSpendingEvent(t *testing.T) {
SessionID: "unknown",
Invoice: connectionInvoiceMock,
})
storage.waitForQueue()
// then
sessions, err := storage.GetAll()
assert.Nil(t, err)
Expand Down Expand Up @@ -457,6 +463,7 @@ func TestSessionStorage_consumeSessionSpendingEvent(t *testing.T) {
SessionID: "sessionID",
Invoice: connectionInvoiceMock,
})
storage.waitForQueue()
// then
sessions, err = storage.GetAll()
assert.Nil(t, err)
Expand Down
25 changes: 20 additions & 5 deletions core/connection/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import (
)

const (
p2pDialTimeout = 60 * time.Second
p2pDialTimeout = 20 * time.Second
)

var (
Expand Down Expand Up @@ -809,11 +809,22 @@ func (m *connectionManager) Cancel() {
}

func (m *connectionManager) Disconnect() error {
if m.Status().State == connectionstate.NotConnected {
m.statusLock.Lock()
stateWas := m.status.State
if stateWas == connectionstate.NotConnected {
m.statusLock.Unlock()
return ErrNoConnection
}
if stateWas == connectionstate.Disconnecting {
m.statusLock.Unlock()
return nil
}
m.status.State = connectionstate.Disconnecting
m.statusLock.Unlock()

log.Info().Msgf("Connection state: %v -> %v", stateWas, connectionstate.Disconnecting)
m.publishStateEvent(connectionstate.Disconnecting)

m.statusDisconnecting()
m.disconnect()

return nil
Expand Down Expand Up @@ -1025,9 +1036,13 @@ func (m *connectionManager) Reconnect() {
}
log.Info().Msg("Waiting for previous session to cleanup")

// Capture the channel reference while holding the lock, then release the lock
// before waiting. This prevents a deadlock where Connect() fails and its deferred
// disconnect() tries to re-acquire cleanupFinishedLock.
m.cleanupFinishedLock.Lock()
defer m.cleanupFinishedLock.Unlock()
<-m.cleanupFinished
ch := m.cleanupFinished
m.cleanupFinishedLock.Unlock()
<-ch
err = m.Connect(m.connectOptions.ConsumerID, m.connectOptions.HermesID, m.connectOptions.ProposalLookup, m.connectOptions.Params)
if err != nil {
log.Error().Err(err).Msgf("Failed to reconnect")
Expand Down
23 changes: 16 additions & 7 deletions core/connection/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,19 @@ func NewMultiConnectionManager(newConnectionManager func() Manager) *multiConnec

// Connect creates new connection from given consumer to provider, reports error if connection already exists.
func (mcm *multiConnectionManager) Connect(consumerID identity.Identity, hermesID common.Address, proposalLookup ProposalLookup, params ConnectParams) error {
mcm.mu.Lock()

mcm.mu.RLock()
m, ok := mcm.cms[params.ProxyPort]
mcm.mu.RUnlock()

if !ok {
m = mcm.newConnectionManager()
mcm.cms[params.ProxyPort] = m
mcm.mu.Lock()
m, ok = mcm.cms[params.ProxyPort]
if !ok {
m = mcm.newConnectionManager()
mcm.cms[params.ProxyPort] = m
}
mcm.mu.Unlock()
}
mcm.mu.Unlock()

return m.Connect(consumerID, hermesID, proposalLookup, params)
}
Expand Down Expand Up @@ -97,9 +102,13 @@ func (mcm *multiConnectionManager) Disconnect(id int) error {

if id < 0 {
mcm.mu.RLock()
defer mcm.mu.RUnlock()

managers := make([]Manager, 0, len(mcm.cms))
for _, m := range mcm.cms {
managers = append(managers, m)
}
mcm.mu.RUnlock()

for _, m := range managers {
if err := m.Disconnect(); err != nil {
log.Error().Err(err).Msg("Failed to disconnect active connection")
}
Expand Down
10 changes: 5 additions & 5 deletions identity/registry/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ var ErrNotFound = errors.New("no info for provided identity available in storage

// RegistrationStatusStorage allows for storing of registration statuses.
type RegistrationStatusStorage struct {
lock sync.Mutex
lock sync.RWMutex
bolt persistentStorage
}

Expand Down Expand Up @@ -105,15 +105,15 @@ func (rss *RegistrationStatusStorage) get(chainID int64, identity identity.Ident

// Get fetches the promise for the given hermes.
func (rss *RegistrationStatusStorage) Get(chainID int64, identity identity.Identity) (StoredRegistrationStatus, error) {
rss.lock.Lock()
defer rss.lock.Unlock()
rss.lock.RLock()
defer rss.lock.RUnlock()
return rss.get(chainID, identity)
}

// GetAll fetches all the registration statuses
func (rss *RegistrationStatusStorage) GetAll() ([]StoredRegistrationStatus, error) {
rss.lock.Lock()
defer rss.lock.Unlock()
rss.lock.RLock()
defer rss.lock.RUnlock()

list := []storedRegistrationStatus{}
err := rss.bolt.GetAllFrom(registrationStatusBucket, &list)
Expand Down
3 changes: 2 additions & 1 deletion p2p/stun.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ func stunPorts(identity identity.Identity, eventBus eventbus.Publisher, localPor
resp := multiServerSTUN(serverList, p, 2)

mu.Lock()
defer mu.Unlock()

natType := "unknown"

Expand All @@ -85,6 +84,8 @@ func stunPorts(identity identity.Identity, eventBus eventbus.Publisher, localPor
delete(m, p)
}

mu.Unlock()

if eventBus != nil {
eventBus.Publish(AppTopicSTUN, STUNDetectionStatus{
Identity: identity.Address,
Expand Down
Loading
Loading