Skip to content
76 changes: 76 additions & 0 deletions drivers/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package drivers

import (
"context"
"fmt"
"log/slog"
"net"
"net/url"
"strings"

awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
)

type DatabaseConfiguration struct {
Connection string `json:"connection"`
Address string `json:"addr"`
Database string `json:"database"`
Username string `json:"username"`
Secret string `json:"secret"`
MaxConcurrentSessions int `json:"max_concurrent_sessions"`
EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"`
}

func (s DatabaseConfiguration) defaultPostgreSQLConnectionString() string {
if s.Connection != "" {
return s.Connection
}

return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(s.Secret), s.Address, s.Database)
}

func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string {
slog.Info("Loading RDS Configuration With IAM Auth")

if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil {
slog.Error("AWS Config Loading Error", slog.String("err", err.Error()))
} else {
host := s.Address

if hostCName, err := net.LookupCNAME(s.Address); err != nil {
slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error()))
} else {
host = hostCName
}

endpoint := strings.TrimSuffix(host, ".") + ":5432"

slog.Info("Requesting RDS IAM Auth Token")

if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil {
slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error()))
} else {
slog.Info("RDS IAM Auth Token Created")
return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database)
}
}

return s.defaultPostgreSQLConnectionString()
}

func (s DatabaseConfiguration) PostgreSQLConnectionString() string {
if s.EnableRDSIAMAuth {
return s.RDSIAMAuthConnectionString()
}

return s.defaultPostgreSQLConnectionString()
}

func (s DatabaseConfiguration) Neo4jConnectionString() string {
if s.Connection == "" {
return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database)
}

return s.Connection
}
23 changes: 18 additions & 5 deletions drivers/pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/dawgs"
"github.com/specterops/dawgs/cypher/models/pgsql"
"github.com/specterops/dawgs/drivers"
"github.com/specterops/dawgs/graph"
)

Expand Down Expand Up @@ -50,15 +51,12 @@ func afterPooledConnectionRelease(conn *pgx.Conn) bool {
return true
}

func NewPool(connectionString string) (*pgxpool.Pool, error) {
if connectionString == "" {
return nil, fmt.Errorf("graph connection requires a connection url to be set")
}
func NewPool(cfg drivers.DatabaseConfiguration) (*pgxpool.Pool, error) {

poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout)
defer done()

poolCfg, err := pgxpool.ParseConfig(connectionString)
poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString())
if err != nil {
return nil, err
}
Expand All @@ -73,6 +71,21 @@ func NewPool(connectionString string) (*pgxpool.Pool, error) {
poolCfg.AfterConnect = afterPooledConnectionEstablished
poolCfg.AfterRelease = afterPooledConnectionRelease

if cfg.EnableRDSIAMAuth {
// Only enable the BeforeConnect handler if RDS IAM Auth is enabled
poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error {
slog.Debug("New Connection RDS IAM Auth")

if newPoolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()); err != nil {
return err
} else {
connCfg.Password = newPoolCfg.ConnConfig.Password
}

return nil
}
}

pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg)
if err != nil {
return nil, err
Expand Down
14 changes: 14 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ require (
cuelang.org/go v0.15.3
github.com/RoaringBitmap/roaring/v2 v2.14.4
github.com/antlr4-go/antlr/v4 v4.13.1
github.com/aws/aws-sdk-go-v2/config v1.31.13
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10
github.com/axiomhq/hyperloglog v0.2.6
github.com/bits-and-blooms/bitset v1.24.4
github.com/cespare/xxhash/v2 v2.3.0
Expand All @@ -17,6 +19,18 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2 v1.39.3 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect
github.com/aws/smithy-go v1.23.1 // indirect
github.com/cockroachdb/apd/v3 v3.2.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect
Expand Down
28 changes: 28 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@ github.com/RoaringBitmap/roaring/v2 v2.14.4 h1:4aKySrrg9G/5oRtJ3TrZLObVqxgQ9f1zn
github.com/RoaringBitmap/roaring/v2 v2.14.4/go.mod h1:oMvV6omPWr+2ifRdeZvVJyaz+aoEUopyv5iH0u/+wbY=
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw=
github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM=
github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM=
github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k=
github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk=
github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs=
github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10 h1:xfgjONWMae6+y//dlhVukwt9N+I++FPuiwcQt7DI7Qg=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10/go.mod h1:FO6aarJTHA2N3S8F2A4wKfnX9Jr6MPerJFaqoLgTctU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o=
github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M=
github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/axiomhq/hyperloglog v0.2.6 h1:sRhvvF3RIXWQgAXaTphLp4yJiX4S0IN3MWTaAgZoRJw=
github.com/axiomhq/hyperloglog v0.2.6/go.mod h1:YjX/dQqCR/7QYX0g8mu8UZAjpIenz1FKM71UEsjFoTo=
github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE=
Expand Down
Loading