diff --git a/agentcontrolSubAgent.go b/agentcontrolSubAgent.go index 33974ad..79f66df 100644 --- a/agentcontrolSubAgent.go +++ b/agentcontrolSubAgent.go @@ -19,6 +19,9 @@ type SubAgent struct { // OIDs for Read/Write actions OIDs []*PDUValueControlItem + OIDsGenerator func() ([]*PDUValueControlItem, error) + OIDsSyncOnWalkStart bool + oidLock sync.Mutex // UserErrorMarkPacket decides if shll treat user returned error as generr @@ -34,6 +37,16 @@ func (t *SubAgent) SyncConfig() error { t.oidLock.Lock() defer t.oidLock.Unlock() + if t.OIDsGenerator != nil { + newoids, oiderr := t.OIDsGenerator() + if oiderr != nil { + t.Logger.Warnf("error generating oids, will not update main table: %s", oiderr.Error()) + } else { + t.Logger.Info("updating oid table") + t.OIDs = newoids + } + } + t.Logger.Debugf("Total OIDs of %v: %v", t.CommunityIDs, len(t.OIDs)) for _, oid := range t.OIDs { @@ -379,6 +392,9 @@ func (t *SubAgent) serveGetNextRequest(i *gosnmp.SnmpPacket) (*gosnmp.SnmpPacket break } } else { + if t.OIDsSyncOnWalkStart && t.OIDsGenerator != nil { + t.SyncConfig() + } before = false } diff --git a/snmpserver_test.go b/snmpserver_test.go index 7baf6f9..566a0be 100644 --- a/snmpserver_test.go +++ b/snmpserver_test.go @@ -235,6 +235,7 @@ func (suite *ServerTests) TestErrors() { } func (suite *ServerTests) TestGetSetOids() { + oids, _ := suite.getTestGetSetOIDS() master := MasterAgent{ Logger: suite.Logger, SecurityConfig: SecurityConfig{ @@ -252,7 +253,7 @@ func (suite *ServerTests) TestGetSetOids() { SubAgents: []*SubAgent{ { CommunityIDs: []string{"public"}, - OIDs: suite.getTestGetSetOIDS(), + OIDs: oids, }, }, } @@ -430,7 +431,66 @@ func (suite *ServerTests) TestGetSetOids() { <-stopWaitChain } -func (suite *ServerTests) getTestGetSetOIDS() []*PDUValueControlItem { +func (suite *ServerTests) TestGetWalkGeneratorOids() { + master := MasterAgent{ + Logger: suite.Logger, + SecurityConfig: SecurityConfig{ + AuthoritativeEngineBoots: 1, + Users: []gosnmp.UsmSecurityParameters{ + { + UserName: "testUser", + AuthenticationProtocol: gosnmp.MD5, + PrivacyProtocol: gosnmp.DES, + AuthenticationPassphrase: "testAuth", + PrivacyPassphrase: "testPriv", + }, + }, + }, + SubAgents: []*SubAgent{ + { + CommunityIDs: []string{"public"}, + OIDsGenerator: suite.getTestGetSetOIDS, + OIDsSyncOnWalkStart: true, + }, + }, + } + shandle := NewSNMPServer(master) + shandle.ListenUDP("udp4", ":0") + var stopWaitChain = make(chan int) + go func() { + err := shandle.ServeForever() + if err != nil { + suite.Logger.Errorf("error in ServeForever: %v", err) + } else { + suite.Logger.Info("ServeForever Stoped.") + } + stopWaitChain <- 1 + + }() + + serverAddress := shandle.Address().(*net.UDPAddr) + suite.Run("SNMPGetNext", func() { + result, err := getCmdOutput("snmpgetnext", "-v2c", "-c", "public", + serverAddress.String(), "1") + if err != nil { + suite.T().Errorf("cmd meet error: %+v", err) + } + lines := bytes.Split(bytes.TrimSpace(result), []byte("\n")) + assert.NotEqual(suite.T(), []byte{}, result, "data SNMPGetNext gets: \n%v", string(result)) + assert.Equalf(suite.T(), 1, len(lines), "data SNMPGetNext gets: \n%v", string(result)) + }) + suite.Run("SNMPWalk", func() { + result, err := getCmdOutput("snmpwalk", "-v2c", "-c", "public", + serverAddress.String(), "1") + if err != nil { + suite.T().Errorf("cmd meet error: %+v", err) + } + lines := bytes.Split(bytes.TrimSpace(result), []byte("\n")) + assert.Equalf(suite.T(), len(master.SubAgents[0].OIDs)+1, len(lines), "data snmpwalk gets: \n%v", string(result)) + }) +} + +func (suite *ServerTests) getTestGetSetOIDS() ([]*PDUValueControlItem, error) { baseTestSuite := suite return []*PDUValueControlItem{ { @@ -601,7 +661,7 @@ func (suite *ServerTests) getTestGetSetOIDS() []*PDUValueControlItem { }, Document: "TestTypeIPAddress", }, - } + }, nil } func (suite *ServerTests) TearDownSuite() {