Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions config/sample_webconfig.conf
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
webconfig {
security {
// If encryption_mechanism is aws_kms, AWS KMS will be used to encrypt data in database, else
// encryption_key_env_name will be used.
encryption_mechanism = ""
encryption_key_env_name = "WEBCONFIG_KEY"
kms {
aws_region = ""
endpoint = ""
secret_key = ""
encryption_algorithm = ""

role_based_access_enabled = false

// If role_based_access_enabled is true, access_key_id, secret_access_key and session_token will be fetched using IAM temporary credentials
access_key_id = ""
secret_access_key = ""
// Token is only required for temporary security credentials retrieved via STS, otherwise an empty string can be passed for this parameter.
session_token = ""
}
}

panic_exit_enabled = false
Expand Down Expand Up @@ -136,6 +153,14 @@ webconfig {
public_key_file = /tmp/sat-themis-201701.pub
}

kid {
// Public key will be fetched using the url if provided.
url = ""

// If url is not provided, public_key_file will be used.
public_key_file = "/tmp/sat-themis-201701.pub"
}

sat-prod-k1-1024 {
public_key_file = /tmp/sat-prod-k1-1024.pub
}
Expand All @@ -156,6 +181,7 @@ webconfig {
unittest_db_file = "/tmp/test_webconfig.sqlite"
concurrent_queries = 5
}

cassandra {
encrypted_password = ""
hosts = [
Expand All @@ -171,6 +197,17 @@ webconfig {
user = "dbuser"
test_keyspace = "test_webconfig"
is_ssl_enabled = true
port = 9042

//Config to create database client to AWS Keyspace using IAM temporary credentials
aws_keyspace_enabled = false
role_based_access_enabled = false
aws_region = ""
aws_keyspace_ca_path = "path_to_file/sf-class2-root.crt"

//If role_based_access_enabled is true, access_key_id and secret_access_key will be fetched using IAM temporary credentials
access_key_id = ""
secret_access_key = ""
}

yugabyte {
Expand Down
145 changes: 145 additions & 0 deletions db/cassandra/aws_keyspace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package cassandra

import (
"fmt"
"os"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sigv4-auth-cassandra-gocql-driver-plugin/sigv4"
"github.com/go-akka/configuration"
"github.com/gocql/gocql"
"github.com/rdkcentral/webconfig/common"
"github.com/rdkcentral/webconfig/security"
)

func awsKeyspaceClient(conf *configuration.Config, testOnly bool) (*CassandraClient, error) {
var codec *security.AesCodec
var err error

// build codec
if testOnly {
codec = security.NewTestCodec(conf)
} else {
codec, err = security.NewAesCodec(conf)
if err != nil {
return nil, common.NewError(err)
}
}

dbconf := conf.GetConfig("webconfig.database.cassandra")

// init
hosts := dbconf.GetStringList("hosts")
cluster := gocql.NewCluster(hosts...)

cluster.Consistency = gocql.LocalQuorum
cluster.ProtoVersion = ProtocolVersion
cluster.DisableInitialHostLookup = DisableInitialHostLookup
cluster.Timeout = time.Duration(dbconf.GetInt32("timeout_in_sec", 1)) * time.Second
cluster.ConnectTimeout = time.Duration(dbconf.GetInt32("connect_timeout_in_sec", 1)) * time.Second
cluster.NumConns = int(dbconf.GetInt32("connections", DefaultConnections))
cluster.Port = int(dbconf.GetInt64("port", DefaultPort))

cluster.RetryPolicy = &gocql.DowngradingConsistencyRetryPolicy{
ConsistencyLevelsToTry: []gocql.Consistency{
gocql.LocalQuorum,
gocql.LocalOne,
gocql.One,
},
}

localDc := dbconf.GetString("local_dc")
if len(localDc) > 0 {
cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(localDc)
}

awsRegion, err := getAwsRegionForCassandra(dbconf)
if err != nil {
return nil, err
}

var auth sigv4.AwsAuthenticator = sigv4.NewAwsAuthenticator()
auth.Region = awsRegion

isRoleBasedAccessEnabled := dbconf.GetBoolean("role_based_access_enabled")
if isRoleBasedAccessEnabled {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(awsRegion)},
)
if err != nil {
return nil, err
}

// Set up the callback to refresh credentials
auth.CredentialsCallback = func() (sigv4.SigV4Credentials, error) {
creds, err := sess.Config.Credentials.Get()
if err != nil {
return sigv4.SigV4Credentials{}, err
}

return sigv4.SigV4Credentials{
AccessKeyId: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
}, nil
}
} else {
auth.AccessKeyId = dbconf.GetString("access_key_id")
auth.SecretAccessKey = dbconf.GetString("secret_access_key")
}
cluster.Authenticator = auth

awsKeySpaceCaPath := dbconf.GetString("aws_keyspace_ca_path")
cluster.SslOpts = &gocql.SslOptions{
CaPath: awsKeySpaceCaPath,
EnableHostVerification: false,
}

// check and create test_keyspace
if testOnly {
cluster.Keyspace = dbconf.GetString("test_keyspace", DefaultTestKeyspace)
} else {
cluster.Keyspace = dbconf.GetString("keyspace", DefaultKeyspace)
}

// now point to the real keyspace
session, err := cluster.CreateSession()
if err != nil {
return nil, common.NewError(err)
}
session.SetPageSize(int(dbconf.GetInt32("page_size", DefaultPageSize)))

blockedSubdocIds := conf.GetStringList("webconfig.blocked_subdoc_ids")
encryptedSubdocIds := conf.GetStringList("webconfig.encrypted_subdoc_ids")
stateCorrectionEnabled := conf.GetBoolean("webconfig.state_correction_enabled")
lockRootDocumentEnabled := conf.GetBoolean("webconfig.lock_root_document_enabled")

return &CassandraClient{
Session: session,
ClusterConfig: cluster,
AesCodec: codec,
concurrentQueries: make(chan bool, dbconf.GetInt32("concurrent_queries", 500)),
localDc: localDc,
blockedSubdocIds: blockedSubdocIds,
encryptedSubdocIds: encryptedSubdocIds,
stateCorrectionEnabled: stateCorrectionEnabled,
lockRootDocumentEnabled: lockRootDocumentEnabled,
awsKeyspaceEnabled: true,
}, nil
}

func getAwsRegionForCassandra(dbconf *configuration.Config) (string, error) {
awsRegion := dbconf.GetString("aws_region")

if len(awsRegion) == 0 {
awsRegion = os.Getenv("AWS_REGION")
}

if len(awsRegion) == 0 {
return "", fmt.Errorf("%s", "Aws region is not provided")
}

return awsRegion, nil
}
16 changes: 16 additions & 0 deletions db/cassandra/cassandra_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
DefaultSleepTimeInMillisecond = 10
DefaultConnections = 2
DefaultPageSize = 50
DefaultPort = 9042
)

// if 'wifi_schema_v2_enabled'=true, v1.3 is also supported
Expand All @@ -54,6 +55,7 @@ type CassandraClient struct {
encryptedSubdocIds []string
stateCorrectionEnabled bool
lockRootDocumentEnabled bool
awsKeyspaceEnabled bool
}

/*
Expand All @@ -68,6 +70,19 @@ current column types:
*/

func NewCassandraClient(conf *configuration.Config, testOnly bool) (*CassandraClient, error) {
if isAwsKeyspaceEnabled(conf) {
return awsKeyspaceClient(conf, testOnly)
} else {
return cassandraClient(conf, testOnly)
}
}

func isAwsKeyspaceEnabled(conf *configuration.Config) bool {
return (conf.GetString("webconfig.database.active_driver") == "cassandra") &&
conf.GetBoolean("webconfig.database.cassandra.aws_keyspace_enabled")
}

func cassandraClient(conf *configuration.Config, testOnly bool) (*CassandraClient, error) {
var codec *security.AesCodec
var err error

Expand Down Expand Up @@ -101,6 +116,7 @@ func NewCassandraClient(conf *configuration.Config, testOnly bool) (*CassandraCl
cluster.Timeout = time.Duration(dbconf.GetInt32("timeout_in_sec", 1)) * time.Second
cluster.ConnectTimeout = time.Duration(dbconf.GetInt32("connect_timeout_in_sec", 1)) * time.Second
cluster.NumConns = int(dbconf.GetInt32("connections", DefaultConnections))
cluster.Port = int(dbconf.GetInt64("port", DefaultPort))

cluster.RetryPolicy = &gocql.DowngradingConsistencyRetryPolicy{
ConsistencyLevelsToTry: []gocql.Consistency{
Expand Down
53 changes: 38 additions & 15 deletions db/cassandra/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
*/
package cassandra

import (
"fmt"
"time"

"github.com/gocql/gocql"
"github.com/prometheus/client_golang/prometheus"
"github.com/rdkcentral/webconfig/common"
"github.com/rdkcentral/webconfig/db"
"github.com/rdkcentral/webconfig/util"
"github.com/gocql/gocql"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
)

func (c *CassandraClient) GetSubDocument(cpeMac string, groupId string) (*common.SubDocument, error) {
var err error
var payload []byte
var payload, kmsRemoteDataKey []byte
var version, errorDetails string
var state, errorCode int
var updatedTime, expiry time.Time
Expand All @@ -40,8 +40,8 @@ func (c *CassandraClient) GetSubDocument(cpeMac string, groupId string) (*common
c.concurrentQueries <- true
defer func() { <-c.concurrentQueries }()

stmt := "SELECT payload,version,state,updated_time,error_code,error_details,expiry FROM xpc_group_config WHERE cpe_mac=? AND group_id=?"
if err := c.Query(stmt, cpeMac, groupId).Scan(&payload, &version, &state, &updatedTime, &errorCode, &errorDetails, &expiry); err != nil {
stmt := "SELECT payload,version,state,updated_time,error_code,error_details,expiry,kms_remote_data_key FROM xpc_group_config WHERE cpe_mac=? AND group_id=?"
if err := c.Query(stmt, cpeMac, groupId).Scan(&payload, &version, &state, &updatedTime, &errorCode, &errorDetails, &expiry, &kmsRemoteDataKey); err != nil {
return nil, common.NewError(err)
}

Expand All @@ -50,7 +50,7 @@ func (c *CassandraClient) GetSubDocument(cpeMac string, groupId string) (*common
}

if c.IsEncryptedGroup(groupId) {
payload, err = c.DecryptBytes(payload)
payload, err = c.DecryptBytes(payload, kmsRemoteDataKey)
if err != nil {
return nil, common.NewError(err)
}
Expand Down Expand Up @@ -113,12 +113,17 @@ func (c *CassandraClient) SetSubDocument(cpeMac string, groupId string, subdoc *
columns = append(columns, "payload")
// TODO evel if it is necessary use a list of groupIds that need encryption
if c.IsEncryptedGroup(groupId) {
encbytes, err := c.EncryptBytes(subdoc.Payload())
encbytes, kmsRemoteDataKey, err := c.EncryptBytes(subdoc.Payload())
if err != nil {
return common.NewError(err)
}
values = append(values, encbytes)
columnMap["payload_len"] = len(encbytes)

if kmsRemoteDataKey != nil {
columns = append(columns, "kms_remote_data_key")
values = append(values, kmsRemoteDataKey)
}
} else {
values = append(values, subdoc.Payload())
columnMap["payload_len"] = len(subdoc.Payload())
Expand Down Expand Up @@ -199,9 +204,24 @@ func (c *CassandraClient) DeleteDocument(cpeMac string) error {
c.concurrentQueries <- true
defer func() { <-c.concurrentQueries }()

stmt := "DELETE FROM xpc_group_config WHERE cpe_mac=?"
if err := c.Query(stmt, cpeMac).Exec(); err != nil {
return common.NewError(err)
if c.awsKeyspaceEnabled {
stmt := "SELECT group_id FROM xpc_group_config WHERE cpe_mac=? ALLOW FILTERING"
iter := c.Query(stmt, cpeMac).Iter()
for {
var groupId string
if !iter.Scan(&groupId) {
break
}
stmt := "DELETE FROM xpc_group_config WHERE cpe_mac=? AND group_id=?"
if err := c.Query(stmt, cpeMac, groupId).Exec(); err != nil {
return common.NewError(err)
}
}
} else {
stmt := "DELETE FROM xpc_group_config WHERE cpe_mac=?"
if err := c.Query(stmt, cpeMac).Exec(); err != nil {
return common.NewError(err)
}
}

return nil
Expand All @@ -226,7 +246,10 @@ func (c *CassandraClient) GetDocument(cpeMac string, xargs ...interface{}) (fndo
c.concurrentQueries <- true
defer func() { <-c.concurrentQueries }()

stmt := "SELECT group_id,payload,version,state,updated_time,error_code,error_details,expiry FROM xpc_group_config WHERE cpe_mac=?"
stmt := "SELECT group_id,payload,version,state,updated_time,error_code,error_details,expiry,kms_remote_data_key FROM xpc_group_config WHERE cpe_mac=?"
if c.awsKeyspaceEnabled {
stmt += " ALLOW FILTERING"
}
iter := c.Query(stmt, cpeMac).Iter()
rmap := make(util.Dict)
defer func() {
Expand All @@ -250,13 +273,13 @@ func (c *CassandraClient) GetDocument(cpeMac string, xargs ...interface{}) (fndo
now := time.Now()
for {
var err error
var payload []byte
var payload, kmsRemoteDataKey []byte
var groupId, version, errorDetails string
var state, errorCode int
var updatedTime, expiry time.Time
var updatedTimeTsPtr *int

if !iter.Scan(&groupId, &payload, &version, &state, &updatedTime, &errorCode, &errorDetails, &expiry) {
if !iter.Scan(&groupId, &payload, &version, &state, &updatedTime, &errorCode, &errorDetails, &expiry, &kmsRemoteDataKey) {
break
}

Expand All @@ -280,7 +303,7 @@ func (c *CassandraClient) GetDocument(cpeMac string, xargs ...interface{}) (fndo
}

if c.IsEncryptedGroup(groupId) {
payload, err = c.DecryptBytes(payload)
payload, err = c.DecryptBytes(payload, kmsRemoteDataKey)
if err != nil {
tfields := common.FilterLogFields(fields)
tfields["logger"] = "subdoc"
Expand Down
Loading