diff --git a/internal/integration_tests/applications_test.go b/internal/integration_tests/applications_test.go new file mode 100644 index 00000000..f36b74ad --- /dev/null +++ b/internal/integration_tests/applications_test.go @@ -0,0 +1,183 @@ +package integration_tests + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// TestApplications tests the M2M Application CRUD operations at the storage layer +func TestApplications(t *testing.T) { + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + + // Use background context for storage calls + storageCtx := context.Background() + + t.Run("should create application", func(t *testing.T) { + application := &schemas.Application{ + Name: "test-m2m-app-" + uuid.New().String(), + Description: "Test M2M application", + ClientID: uuid.New().String(), + ClientSecret: "test-secret", + Scopes: "read write", + Roles: "admin", + IsActive: true, + CreatedBy: uuid.New().String(), + } + + err := ts.StorageProvider.CreateApplication(storageCtx, application) + require.NoError(t, err) + assert.NotEmpty(t, application.ID) + assert.NotZero(t, application.CreatedAt) + assert.NotZero(t, application.UpdatedAt) + }) + + t.Run("should get application by ID", func(t *testing.T) { + application := &schemas.Application{ + Name: "test-m2m-app-by-id-" + uuid.New().String(), + Description: "Test M2M application get by ID", + ClientID: uuid.New().String(), + ClientSecret: "test-secret", + Scopes: "read", + Roles: "user", + IsActive: true, + CreatedBy: uuid.New().String(), + } + + err := ts.StorageProvider.CreateApplication(storageCtx, application) + require.NoError(t, err) + require.NotEmpty(t, application.ID) + + retrieved, err := ts.StorageProvider.GetApplicationByID(storageCtx, application.ID) + require.NoError(t, err) + assert.NotNil(t, retrieved) + assert.Equal(t, application.Name, retrieved.Name) + assert.Equal(t, application.ClientID, retrieved.ClientID) + assert.Equal(t, application.Description, retrieved.Description) + assert.Equal(t, application.IsActive, retrieved.IsActive) + }) + + t.Run("should fail to get application with non-existent ID", func(t *testing.T) { + retrieved, err := ts.StorageProvider.GetApplicationByID(storageCtx, uuid.New().String()) + assert.Error(t, err) + assert.Nil(t, retrieved) + }) + + t.Run("should get application by client ID", func(t *testing.T) { + clientID := uuid.New().String() + application := &schemas.Application{ + Name: "test-m2m-app-by-clientid-" + uuid.New().String(), + Description: "Test M2M application get by client ID", + ClientID: clientID, + ClientSecret: "test-secret", + Scopes: "read write", + Roles: "user", + IsActive: true, + CreatedBy: uuid.New().String(), + } + + err := ts.StorageProvider.CreateApplication(storageCtx, application) + require.NoError(t, err) + + retrieved, err := ts.StorageProvider.GetApplicationByClientID(storageCtx, clientID) + require.NoError(t, err) + assert.NotNil(t, retrieved) + assert.Equal(t, clientID, retrieved.ClientID) + assert.Equal(t, application.Name, retrieved.Name) + }) + + t.Run("should fail to get application with non-existent client ID", func(t *testing.T) { + retrieved, err := ts.StorageProvider.GetApplicationByClientID(storageCtx, uuid.New().String()) + assert.Error(t, err) + assert.Nil(t, retrieved) + }) + + t.Run("should list applications with pagination", func(t *testing.T) { + // Create two applications to ensure list returns results + for i := 0; i < 2; i++ { + app := &schemas.Application{ + Name: "test-m2m-app-list-" + uuid.New().String(), + Description: "Test M2M application for list", + ClientID: uuid.New().String(), + ClientSecret: "test-secret", + Scopes: "read", + Roles: "user", + IsActive: true, + CreatedBy: uuid.New().String(), + } + err := ts.StorageProvider.CreateApplication(storageCtx, app) + require.NoError(t, err) + } + + pagination := &model.Pagination{ + Limit: 10, + Offset: 0, + } + applications, paginationResult, err := ts.StorageProvider.ListApplications(storageCtx, pagination) + require.NoError(t, err) + assert.NotNil(t, paginationResult) + assert.GreaterOrEqual(t, len(applications), 2) + assert.GreaterOrEqual(t, paginationResult.Total, int64(2)) + }) + + t.Run("should update application", func(t *testing.T) { + application := &schemas.Application{ + Name: "test-m2m-app-update-" + uuid.New().String(), + Description: "Test M2M application before update", + ClientID: uuid.New().String(), + ClientSecret: "test-secret", + Scopes: "read", + Roles: "user", + IsActive: true, + CreatedBy: uuid.New().String(), + } + + err := ts.StorageProvider.CreateApplication(storageCtx, application) + require.NoError(t, err) + require.NotEmpty(t, application.ID) + + application.Description = "Test M2M application after update" + application.Scopes = "read write" + application.IsActive = false + + err = ts.StorageProvider.UpdateApplication(storageCtx, application) + require.NoError(t, err) + + retrieved, err := ts.StorageProvider.GetApplicationByID(storageCtx, application.ID) + require.NoError(t, err) + assert.Equal(t, "Test M2M application after update", retrieved.Description) + assert.Equal(t, "read write", retrieved.Scopes) + assert.False(t, retrieved.IsActive) + }) + + t.Run("should delete application", func(t *testing.T) { + application := &schemas.Application{ + Name: "test-m2m-app-delete-" + uuid.New().String(), + Description: "Test M2M application for deletion", + ClientID: uuid.New().String(), + ClientSecret: "test-secret", + Scopes: "read", + Roles: "user", + IsActive: true, + CreatedBy: uuid.New().String(), + } + + err := ts.StorageProvider.CreateApplication(storageCtx, application) + require.NoError(t, err) + require.NotEmpty(t, application.ID) + + err = ts.StorageProvider.DeleteApplication(storageCtx, application.ID) + require.NoError(t, err) + + retrieved, err := ts.StorageProvider.GetApplicationByID(storageCtx, application.ID) + assert.Error(t, err) + assert.Nil(t, retrieved) + }) +} diff --git a/internal/storage/db/arangodb/application.go b/internal/storage/db/arangodb/application.go new file mode 100644 index 00000000..76be87a8 --- /dev/null +++ b/internal/storage/db/arangodb/application.go @@ -0,0 +1,139 @@ +package arangodb + +import ( + "context" + "fmt" + "time" + + arangoDriver "github.com/arangodb/go-driver" + "github.com/google/uuid" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// CreateApplication creates a new M2M application +func (p *provider) CreateApplication(ctx context.Context, application *schemas.Application) error { + if application.ID == "" { + application.ID = uuid.New().String() + } + application.Key = application.ID + application.CreatedAt = time.Now().Unix() + application.UpdatedAt = time.Now().Unix() + applicationCollection, _ := p.db.Collection(ctx, schemas.Collections.Application) + meta, err := applicationCollection.CreateDocument(ctx, application) + if err != nil { + return err + } + application.Key = meta.Key + application.ID = meta.ID.String() + return nil +} + +// GetApplicationByID retrieves an application by ID +func (p *provider) GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) { + var application schemas.Application + query := fmt.Sprintf("FOR d in %s FILTER d._id == @id RETURN d", schemas.Collections.Application) + bindVars := map[string]interface{}{ + "id": id, + } + cursor, err := p.db.Query(ctx, query, bindVars) + if err != nil { + return nil, err + } + defer cursor.Close() + for { + if !cursor.HasMore() { + if application.ID == "" { + return nil, fmt.Errorf("application not found") + } + break + } + _, err := cursor.ReadDocument(ctx, &application) + if err != nil { + return nil, err + } + } + return &application, nil +} + +// GetApplicationByClientID retrieves an application by client ID +func (p *provider) GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) { + var application schemas.Application + query := fmt.Sprintf("FOR d in %s FILTER d.client_id == @client_id RETURN d", schemas.Collections.Application) + bindVars := map[string]interface{}{ + "client_id": clientID, + } + cursor, err := p.db.Query(ctx, query, bindVars) + if err != nil { + return nil, err + } + defer cursor.Close() + for { + if !cursor.HasMore() { + if application.ID == "" { + return nil, fmt.Errorf("application not found") + } + break + } + _, err := cursor.ReadDocument(ctx, &application) + if err != nil { + return nil, err + } + } + return &application, nil +} + +// ListApplications lists all applications with pagination +func (p *provider) ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) { + applications := []*schemas.Application{} + query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", schemas.Collections.Application, pagination.Offset, pagination.Limit) + sctx := arangoDriver.WithQueryFullCount(ctx) + cursor, err := p.db.Query(sctx, query, nil) + if err != nil { + return nil, nil, err + } + defer cursor.Close() + paginationClone := pagination + paginationClone.Total = cursor.Statistics().FullCount() + for { + var application schemas.Application + meta, err := cursor.ReadDocument(ctx, &application) + if arangoDriver.IsNoMoreDocuments(err) { + break + } else if err != nil { + return nil, nil, err + } + if meta.Key != "" { + applications = append(applications, &application) + } + } + return applications, paginationClone, nil +} + +// UpdateApplication updates an application +func (p *provider) UpdateApplication(ctx context.Context, application *schemas.Application) error { + application.UpdatedAt = time.Now().Unix() + applicationCollection, _ := p.db.Collection(ctx, schemas.Collections.Application) + meta, err := applicationCollection.UpdateDocument(ctx, application.Key, application) + if err != nil { + return err + } + application.Key = meta.Key + application.ID = meta.ID.String() + return nil +} + +// DeleteApplication deletes an application by ID +func (p *provider) DeleteApplication(ctx context.Context, id string) error { + application, err := p.GetApplicationByID(ctx, id) + if err != nil { + return err + } + applicationCollection, _ := p.db.Collection(ctx, schemas.Collections.Application) + _, err = applicationCollection.RemoveDocument(ctx, application.Key) + if err != nil { + return err + } + return nil +} diff --git a/internal/storage/db/arangodb/provider.go b/internal/storage/db/arangodb/provider.go index ea3db9be..575ec5f0 100644 --- a/internal/storage/db/arangodb/provider.go +++ b/internal/storage/db/arangodb/provider.go @@ -323,6 +323,30 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { Sparse: true, }) + // Application collection and indexes + applicationCollectionExists, err := arangodb.CollectionExists(ctx, schemas.Collections.Application) + if err != nil { + return nil, err + } + if !applicationCollectionExists { + _, err = arangodb.CreateCollection(ctx, schemas.Collections.Application, nil) + if err != nil { + return nil, err + } + } + applicationCollection, err := arangodb.Collection(ctx, schemas.Collections.Application) + if err != nil { + return nil, err + } + applicationCollection.EnsureHashIndex(ctx, []string{"name"}, &arangoDriver.EnsureHashIndexOptions{ + Unique: true, + Sparse: true, + }) + applicationCollection.EnsureHashIndex(ctx, []string{"client_id"}, &arangoDriver.EnsureHashIndexOptions{ + Unique: true, + Sparse: true, + }) + return &provider{ config: cfg, dependencies: deps, diff --git a/internal/storage/db/cassandradb/application.go b/internal/storage/db/cassandradb/application.go new file mode 100644 index 00000000..801045ec --- /dev/null +++ b/internal/storage/db/cassandradb/application.go @@ -0,0 +1,135 @@ +package cassandradb + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/gocql/gocql" + "github.com/google/uuid" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// CreateApplication creates a new M2M application +func (p *provider) CreateApplication(ctx context.Context, application *schemas.Application) error { + if application.ID == "" { + application.ID = uuid.New().String() + } + application.Key = application.ID + application.CreatedAt = time.Now().Unix() + application.UpdatedAt = time.Now().Unix() + insertQuery := fmt.Sprintf("INSERT INTO %s (id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", KeySpace+"."+schemas.Collections.Application) + err := p.db.Query(insertQuery, application.ID, application.Name, application.Description, application.ClientID, application.ClientSecret, application.Scopes, application.Roles, application.IsActive, application.CreatedBy, application.CreatedAt, application.UpdatedAt).Exec() + if err != nil { + return err + } + return nil +} + +// GetApplicationByID retrieves an application by ID +func (p *provider) GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) { + var application schemas.Application + query := fmt.Sprintf(`SELECT id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at FROM %s WHERE id = ? LIMIT 1`, KeySpace+"."+schemas.Collections.Application) + err := p.db.Query(query, id).Consistency(gocql.One).Scan(&application.ID, &application.Name, &application.Description, &application.ClientID, &application.ClientSecret, &application.Scopes, &application.Roles, &application.IsActive, &application.CreatedBy, &application.CreatedAt, &application.UpdatedAt) + if err != nil { + return nil, err + } + return &application, nil +} + +// GetApplicationByClientID retrieves an application by client ID +func (p *provider) GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) { + var application schemas.Application + query := fmt.Sprintf(`SELECT id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at FROM %s WHERE client_id = ? LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.Application) + err := p.db.Query(query, clientID).Consistency(gocql.One).Scan(&application.ID, &application.Name, &application.Description, &application.ClientID, &application.ClientSecret, &application.Scopes, &application.Roles, &application.IsActive, &application.CreatedBy, &application.CreatedAt, &application.UpdatedAt) + if err != nil { + return nil, err + } + return &application, nil +} + +// ListApplications lists all applications with pagination +func (p *provider) ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) { + applications := []*schemas.Application{} + paginationClone := pagination + totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+schemas.Collections.Application) + err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) + if err != nil { + return nil, nil, err + } + // there is no offset in cassandra + // so we fetch till limit + offset + // and return the results from offset to limit + query := fmt.Sprintf("SELECT id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+schemas.Collections.Application, pagination.Limit+pagination.Offset) + scanner := p.db.Query(query).Iter().Scanner() + counter := int64(0) + for scanner.Next() { + if counter >= pagination.Offset { + var application schemas.Application + err := scanner.Scan(&application.ID, &application.Name, &application.Description, &application.ClientID, &application.ClientSecret, &application.Scopes, &application.Roles, &application.IsActive, &application.CreatedBy, &application.CreatedAt, &application.UpdatedAt) + if err != nil { + return nil, nil, err + } + applications = append(applications, &application) + } + counter++ + } + return applications, paginationClone, nil +} + +// UpdateApplication updates an application +func (p *provider) UpdateApplication(ctx context.Context, application *schemas.Application) error { + application.UpdatedAt = time.Now().Unix() + bytes, err := json.Marshal(application) + if err != nil { + return err + } + // use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling + decoder := json.NewDecoder(strings.NewReader(string(bytes))) + decoder.UseNumber() + applicationMap := map[string]interface{}{} + err = decoder.Decode(&applicationMap) + if err != nil { + return err + } + convertMapValues(applicationMap) + updateFields := "" + var updateValues []interface{} + for key, value := range applicationMap { + if key == "_id" { + continue + } + if key == "_key" { + continue + } + if value == nil { + updateFields += fmt.Sprintf("%s = null,", key) + continue + } + updateFields += fmt.Sprintf("%s = ?, ", key) + updateValues = append(updateValues, value) + } + updateFields = strings.Trim(updateFields, " ") + updateFields = strings.TrimSuffix(updateFields, ",") + updateValues = append(updateValues, application.ID) + query := fmt.Sprintf("UPDATE %s SET %s WHERE id = ?", KeySpace+"."+schemas.Collections.Application, updateFields) + err = p.db.Query(query, updateValues...).Exec() + if err != nil { + return err + } + return nil +} + +// DeleteApplication deletes an application by ID +func (p *provider) DeleteApplication(ctx context.Context, id string) error { + query := fmt.Sprintf("DELETE FROM %s WHERE id = ?", KeySpace+"."+schemas.Collections.Application) + err := p.db.Query(query, id).Exec() + if err != nil { + return err + } + return nil +} diff --git a/internal/storage/db/cassandradb/provider.go b/internal/storage/db/cassandradb/provider.go index d7198122..541c4ec8 100644 --- a/internal/storage/db/cassandradb/provider.go +++ b/internal/storage/db/cassandradb/provider.go @@ -345,6 +345,23 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { return nil, err } + // Application table and indexes + applicationCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, name text, description text, client_id text, client_secret text, scopes text, roles text, is_active boolean, created_by text, created_at bigint, updated_at bigint, PRIMARY KEY (id))", KeySpace, schemas.Collections.Application) + err = session.Query(applicationCollectionQuery).Exec() + if err != nil { + return nil, err + } + applicationNameIndex := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_application_name ON %s.%s (name)", KeySpace, schemas.Collections.Application) + err = session.Query(applicationNameIndex).Exec() + if err != nil { + return nil, err + } + applicationClientIDIndex := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_application_client_id ON %s.%s (client_id)", KeySpace, schemas.Collections.Application) + err = session.Query(applicationClientIDIndex).Exec() + if err != nil { + return nil, err + } + return &provider{ config: cfg, dependencies: deps, diff --git a/internal/storage/db/couchbase/application.go b/internal/storage/db/couchbase/application.go new file mode 100644 index 00000000..f986c106 --- /dev/null +++ b/internal/storage/db/couchbase/application.go @@ -0,0 +1,152 @@ +package couchbase + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + "time" + + "github.com/couchbase/gocb/v2" + "github.com/google/uuid" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// CreateApplication creates a new M2M application +func (p *provider) CreateApplication(ctx context.Context, application *schemas.Application) error { + if application.ID == "" { + application.ID = uuid.New().String() + } + application.Key = application.ID + application.CreatedAt = time.Now().Unix() + application.UpdatedAt = time.Now().Unix() + insertOpt := gocb.InsertOptions{ + Context: ctx, + } + _, err := p.db.Collection(schemas.Collections.Application).Insert(application.ID, application, &insertOpt) + if err != nil { + return err + } + return nil +} + +// GetApplicationByID retrieves an application by ID +func (p *provider) GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) { + var application schemas.Application + params := make(map[string]interface{}, 1) + params["_id"] = id + query := fmt.Sprintf(`SELECT _id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, schemas.Collections.Application) + q, err := p.db.Query(query, &gocb.QueryOptions{ + Context: ctx, + ScanConsistency: gocb.QueryScanConsistencyRequestPlus, + NamedParameters: params, + }) + if err != nil { + return nil, err + } + err = q.One(&application) + if err != nil { + return nil, err + } + return &application, nil +} + +// GetApplicationByClientID retrieves an application by client ID +func (p *provider) GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) { + var application schemas.Application + params := make(map[string]interface{}, 1) + params["client_id"] = clientID + query := fmt.Sprintf(`SELECT _id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at FROM %s.%s WHERE client_id=$client_id LIMIT 1`, p.scopeName, schemas.Collections.Application) + q, err := p.db.Query(query, &gocb.QueryOptions{ + Context: ctx, + ScanConsistency: gocb.QueryScanConsistencyRequestPlus, + NamedParameters: params, + }) + if err != nil { + return nil, err + } + err = q.One(&application) + if err != nil { + return nil, err + } + return &application, nil +} + +// ListApplications lists all applications with pagination +func (p *provider) ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) { + applications := []*schemas.Application{} + paginationClone := pagination + params := make(map[string]interface{}, 1) + params["offset"] = paginationClone.Offset + params["limit"] = paginationClone.Limit + total, err := p.GetTotalDocs(ctx, schemas.Collections.Application) + if err != nil { + return nil, nil, err + } + paginationClone.Total = total + query := fmt.Sprintf("SELECT _id, name, description, client_id, client_secret, scopes, roles, is_active, created_by, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, schemas.Collections.Application) + queryResult, err := p.db.Query(query, &gocb.QueryOptions{ + Context: ctx, + ScanConsistency: gocb.QueryScanConsistencyRequestPlus, + NamedParameters: params, + }) + if err != nil { + return nil, nil, err + } + for queryResult.Next() { + var application schemas.Application + err := queryResult.Row(&application) + if err != nil { + log.Fatal(err) + } + applications = append(applications, &application) + } + if err := queryResult.Err(); err != nil { + return nil, nil, err + } + return applications, paginationClone, nil +} + +// UpdateApplication updates an application +func (p *provider) UpdateApplication(ctx context.Context, application *schemas.Application) error { + application.UpdatedAt = time.Now().Unix() + bytes, err := json.Marshal(application) + if err != nil { + return err + } + // use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling + decoder := json.NewDecoder(strings.NewReader(string(bytes))) + decoder.UseNumber() + applicationMap := map[string]interface{}{} + err = decoder.Decode(&applicationMap) + if err != nil { + return err + } + updateFields, params := GetSetFields(applicationMap) + params["_id"] = application.ID + query := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE _id=$_id`, p.scopeName, schemas.Collections.Application, updateFields) + _, err = p.db.Query(query, &gocb.QueryOptions{ + Context: ctx, + ScanConsistency: gocb.QueryScanConsistencyRequestPlus, + NamedParameters: params, + }) + if err != nil { + return err + } + return nil +} + +// DeleteApplication deletes an application by ID +func (p *provider) DeleteApplication(ctx context.Context, id string) error { + removeOpt := gocb.RemoveOptions{ + Context: ctx, + } + _, err := p.db.Collection(schemas.Collections.Application).Remove(id, &removeOpt) + if err != nil { + return err + } + return nil +} diff --git a/internal/storage/db/couchbase/provider.go b/internal/storage/db/couchbase/provider.go index 85aee384..156fba44 100644 --- a/internal/storage/db/couchbase/provider.go +++ b/internal/storage/db/couchbase/provider.go @@ -220,5 +220,10 @@ func getIndex(scopeName string) map[string][]string { oauthStateIndex1 := fmt.Sprintf("CREATE INDEX OAuthStateKeyIndex ON %s.%s(state_key)", scopeName, schemas.Collections.OAuthState) indices[schemas.Collections.OAuthState] = []string{oauthStateIndex1} + // Application indexes + applicationIndex1 := fmt.Sprintf("CREATE INDEX ApplicationNameIndex ON %s.%s(name)", scopeName, schemas.Collections.Application) + applicationIndex2 := fmt.Sprintf("CREATE INDEX ApplicationClientIDIndex ON %s.%s(client_id)", scopeName, schemas.Collections.Application) + indices[schemas.Collections.Application] = []string{applicationIndex1, applicationIndex2} + return indices } diff --git a/internal/storage/db/dynamodb/application.go b/internal/storage/db/dynamodb/application.go new file mode 100644 index 00000000..60ba02df --- /dev/null +++ b/internal/storage/db/dynamodb/application.go @@ -0,0 +1,111 @@ +package dynamodb + +import ( + "context" + "errors" + "time" + + "github.com/google/uuid" + "github.com/guregu/dynamo" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// CreateApplication creates a new M2M application +func (p *provider) CreateApplication(ctx context.Context, application *schemas.Application) error { + collection := p.db.Table(schemas.Collections.Application) + if application.ID == "" { + application.ID = uuid.New().String() + } + application.Key = application.ID + application.CreatedAt = time.Now().Unix() + application.UpdatedAt = time.Now().Unix() + err := collection.Put(application).RunWithContext(ctx) + if err != nil { + return err + } + return nil +} + +// GetApplicationByID retrieves an application by ID +func (p *provider) GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) { + collection := p.db.Table(schemas.Collections.Application) + var application schemas.Application + err := collection.Get("id", id).OneWithContext(ctx, &application) + if err != nil { + return nil, err + } + if application.ID == "" { + return nil, errors.New("no document found") + } + return &application, nil +} + +// GetApplicationByClientID retrieves an application by client ID +func (p *provider) GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) { + collection := p.db.Table(schemas.Collections.Application) + var applications []schemas.Application + err := collection.Scan().Filter("client_id = ?", clientID).AllWithContext(ctx, &applications) + if err != nil { + return nil, err + } + if len(applications) == 0 { + return nil, errors.New("no document found") + } + return &applications[0], nil +} + +// ListApplications lists all applications with pagination +func (p *provider) ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) { + applications := []*schemas.Application{} + var application schemas.Application + var lastEval dynamo.PagingKey + var iter dynamo.PagingIter + var iteration int64 = 0 + collection := p.db.Table(schemas.Collections.Application) + paginationClone := pagination + scanner := collection.Scan() + count, err := scanner.Count() + if err != nil { + return nil, nil, err + } + for (paginationClone.Offset + paginationClone.Limit) > iteration { + iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() + for iter.NextWithContext(ctx, &application) { + if paginationClone.Offset == iteration { + a := application + applications = append(applications, &a) + } + } + err = iter.Err() + if err != nil { + return nil, nil, err + } + lastEval = iter.LastEvaluatedKey() + iteration += paginationClone.Limit + } + paginationClone.Total = count + return applications, paginationClone, nil +} + +// UpdateApplication updates an application +func (p *provider) UpdateApplication(ctx context.Context, application *schemas.Application) error { + application.UpdatedAt = time.Now().Unix() + collection := p.db.Table(schemas.Collections.Application) + err := UpdateByHashKey(collection, "id", application.ID, application) + if err != nil { + return err + } + return nil +} + +// DeleteApplication deletes an application by ID +func (p *provider) DeleteApplication(ctx context.Context, id string) error { + collection := p.db.Table(schemas.Collections.Application) + err := collection.Delete("id", id).RunWithContext(ctx) + if err != nil { + return err + } + return nil +} diff --git a/internal/storage/db/dynamodb/provider.go b/internal/storage/db/dynamodb/provider.go index 88fae327..e227ee6b 100644 --- a/internal/storage/db/dynamodb/provider.go +++ b/internal/storage/db/dynamodb/provider.go @@ -62,6 +62,7 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { db.CreateTable(schemas.Collections.SessionToken, schemas.SessionToken{}).Wait() db.CreateTable(schemas.Collections.MFASession, schemas.MFASession{}).Wait() db.CreateTable(schemas.Collections.OAuthState, schemas.OAuthState{}).Wait() + db.CreateTable(schemas.Collections.Application, schemas.Application{}).Wait() return &provider{ db: db, config: cfg, diff --git a/internal/storage/db/mongodb/application.go b/internal/storage/db/mongodb/application.go new file mode 100644 index 00000000..cae03a10 --- /dev/null +++ b/internal/storage/db/mongodb/application.go @@ -0,0 +1,102 @@ +package mongodb + +import ( + "context" + "time" + + "github.com/google/uuid" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/options" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// CreateApplication creates a new M2M application +func (p *provider) CreateApplication(ctx context.Context, application *schemas.Application) error { + if application.ID == "" { + application.ID = uuid.New().String() + } + application.Key = application.ID + application.CreatedAt = time.Now().Unix() + application.UpdatedAt = time.Now().Unix() + applicationCollection := p.db.Collection(schemas.Collections.Application, options.Collection()) + _, err := applicationCollection.InsertOne(ctx, application) + if err != nil { + return err + } + return nil +} + +// GetApplicationByID retrieves an application by ID +func (p *provider) GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) { + var application schemas.Application + applicationCollection := p.db.Collection(schemas.Collections.Application, options.Collection()) + err := applicationCollection.FindOne(ctx, bson.M{"_id": id}).Decode(&application) + if err != nil { + return nil, err + } + return &application, nil +} + +// GetApplicationByClientID retrieves an application by client ID +func (p *provider) GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) { + var application schemas.Application + applicationCollection := p.db.Collection(schemas.Collections.Application, options.Collection()) + err := applicationCollection.FindOne(ctx, bson.M{"client_id": clientID}).Decode(&application) + if err != nil { + return nil, err + } + return &application, nil +} + +// ListApplications lists all applications with pagination +func (p *provider) ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) { + applications := []*schemas.Application{} + opts := options.Find() + opts.SetLimit(pagination.Limit) + opts.SetSkip(pagination.Offset) + opts.SetSort(bson.M{"created_at": -1}) + paginationClone := pagination + applicationCollection := p.db.Collection(schemas.Collections.Application, options.Collection()) + count, err := applicationCollection.CountDocuments(ctx, bson.M{}, options.Count()) + if err != nil { + return nil, nil, err + } + paginationClone.Total = count + cursor, err := applicationCollection.Find(ctx, bson.M{}, opts) + if err != nil { + return nil, nil, err + } + defer cursor.Close(ctx) + for cursor.Next(ctx) { + var application schemas.Application + err := cursor.Decode(&application) + if err != nil { + return nil, nil, err + } + applications = append(applications, &application) + } + return applications, paginationClone, nil +} + +// UpdateApplication updates an application +func (p *provider) UpdateApplication(ctx context.Context, application *schemas.Application) error { + application.UpdatedAt = time.Now().Unix() + applicationCollection := p.db.Collection(schemas.Collections.Application, options.Collection()) + _, err := applicationCollection.ReplaceOne(ctx, bson.M{"_id": bson.M{"$eq": application.ID}}, application, options.Replace()) + if err != nil { + return err + } + return nil +} + +// DeleteApplication deletes an application by ID +func (p *provider) DeleteApplication(ctx context.Context, id string) error { + applicationCollection := p.db.Collection(schemas.Collections.Application, options.Collection()) + _, err := applicationCollection.DeleteOne(ctx, bson.M{"_id": id}, options.Delete()) + if err != nil { + return err + } + return nil +} diff --git a/internal/storage/db/mongodb/provider.go b/internal/storage/db/mongodb/provider.go index e8ecfdd2..c5345c62 100644 --- a/internal/storage/db/mongodb/provider.go +++ b/internal/storage/db/mongodb/provider.go @@ -178,6 +178,20 @@ func NewProvider(config *config.Config, deps *Dependencies) (*provider, error) { }, }, options.CreateIndexes()) + // Application collection and indexes + mongodb.CreateCollection(ctx, schemas.Collections.Application, options.CreateCollection()) + applicationCollection := mongodb.Collection(schemas.Collections.Application, options.Collection()) + applicationCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ + { + Keys: bson.M{"name": 1}, + Options: options.Index().SetUnique(true).SetSparse(true), + }, + { + Keys: bson.M{"client_id": 1}, + Options: options.Index().SetUnique(true).SetSparse(true), + }, + }, options.CreateIndexes()) + return &provider{ config: config, dependencies: deps, diff --git a/internal/storage/db/sql/application.go b/internal/storage/db/sql/application.go new file mode 100644 index 00000000..37c1fc18 --- /dev/null +++ b/internal/storage/db/sql/application.go @@ -0,0 +1,84 @@ +package sql + +import ( + "context" + "time" + + "github.com/google/uuid" + + "github.com/authorizerdev/authorizer/internal/graph/model" + "github.com/authorizerdev/authorizer/internal/storage/schemas" +) + +// CreateApplication creates a new M2M application +func (p *provider) CreateApplication(ctx context.Context, application *schemas.Application) error { + if application.ID == "" { + application.ID = uuid.New().String() + } + application.Key = application.ID + application.CreatedAt = time.Now().Unix() + application.UpdatedAt = time.Now().Unix() + result := p.db.Create(&application) + if result.Error != nil { + return result.Error + } + return nil +} + +// GetApplicationByID retrieves an application by ID +func (p *provider) GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) { + var application schemas.Application + result := p.db.Where("id = ?", id).First(&application) + if result.Error != nil { + return nil, result.Error + } + return &application, nil +} + +// GetApplicationByClientID retrieves an application by client ID +func (p *provider) GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) { + var application schemas.Application + result := p.db.Where("client_id = ?", clientID).First(&application) + if result.Error != nil { + return nil, result.Error + } + return &application, nil +} + +// ListApplications lists all applications with pagination +func (p *provider) ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) { + var applications []*schemas.Application + result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&applications) + if result.Error != nil { + return nil, nil, result.Error + } + var total int64 + totalRes := p.db.Model(&schemas.Application{}).Count(&total) + if totalRes.Error != nil { + return nil, nil, totalRes.Error + } + paginationClone := pagination + paginationClone.Total = total + return applications, paginationClone, nil +} + +// UpdateApplication updates an application +func (p *provider) UpdateApplication(ctx context.Context, application *schemas.Application) error { + application.UpdatedAt = time.Now().Unix() + result := p.db.Save(&application) + if result.Error != nil { + return result.Error + } + return nil +} + +// DeleteApplication deletes an application by ID +func (p *provider) DeleteApplication(ctx context.Context, id string) error { + result := p.db.Delete(&schemas.Application{ + ID: id, + }) + if result.Error != nil { + return result.Error + } + return nil +} diff --git a/internal/storage/db/sql/provider.go b/internal/storage/db/sql/provider.go index 2a7dbe1f..b2a21f1d 100644 --- a/internal/storage/db/sql/provider.go +++ b/internal/storage/db/sql/provider.go @@ -83,7 +83,7 @@ func NewProvider( } } - err = sqlDB.AutoMigrate(&schemas.User{}, &schemas.VerificationRequest{}, &schemas.Session{}, &schemas.Env{}, &schemas.Webhook{}, &schemas.WebhookLog{}, &schemas.EmailTemplate{}, &schemas.OTP{}, &schemas.Authenticator{}, &schemas.SessionToken{}, &schemas.MFASession{}, &schemas.OAuthState{}) + err = sqlDB.AutoMigrate(&schemas.User{}, &schemas.VerificationRequest{}, &schemas.Session{}, &schemas.Env{}, &schemas.Webhook{}, &schemas.WebhookLog{}, &schemas.EmailTemplate{}, &schemas.OTP{}, &schemas.Authenticator{}, &schemas.SessionToken{}, &schemas.MFASession{}, &schemas.OAuthState{}, &schemas.Application{}) if err != nil { return nil, err } diff --git a/internal/storage/provider.go b/internal/storage/provider.go index afb7643c..3d9d5270 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -160,6 +160,19 @@ type Provider interface { DeleteOAuthStateByKey(ctx context.Context, key string) error // GetAllOAuthStates retrieves all OAuth states (for testing) GetAllOAuthStates(ctx context.Context) ([]*schemas.OAuthState, error) + + // CreateApplication creates a new M2M application + CreateApplication(ctx context.Context, application *schemas.Application) error + // GetApplicationByID retrieves an application by ID + GetApplicationByID(ctx context.Context, id string) (*schemas.Application, error) + // GetApplicationByClientID retrieves an application by client ID + GetApplicationByClientID(ctx context.Context, clientID string) (*schemas.Application, error) + // ListApplications lists all applications with pagination + ListApplications(ctx context.Context, pagination *model.Pagination) ([]*schemas.Application, *model.Pagination, error) + // UpdateApplication updates an application + UpdateApplication(ctx context.Context, application *schemas.Application) error + // DeleteApplication deletes an application by ID + DeleteApplication(ctx context.Context, id string) error } // New creates a new database provider based on the configuration diff --git a/internal/storage/schemas/application.go b/internal/storage/schemas/application.go new file mode 100644 index 00000000..86b87a62 --- /dev/null +++ b/internal/storage/schemas/application.go @@ -0,0 +1,19 @@ +package schemas + +// Note: any change here should be reflected in providers/cassandra/provider.go as it does not have model support in collection creation + +// Application represents a machine-to-machine (M2M) application / service account +type Application struct { + Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb + ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` + Name string `gorm:"type:varchar(256);uniqueIndex" json:"name" bson:"name" cql:"name" dynamo:"name"` + Description string `gorm:"type:text" json:"description" bson:"description" cql:"description" dynamo:"description"` + ClientID string `gorm:"type:char(36);uniqueIndex" json:"client_id" bson:"client_id" cql:"client_id" dynamo:"client_id"` + ClientSecret string `gorm:"type:text" json:"client_secret" bson:"client_secret" cql:"client_secret" dynamo:"client_secret"` + Scopes string `gorm:"type:text" json:"scopes" bson:"scopes" cql:"scopes" dynamo:"scopes"` + Roles string `gorm:"type:text" json:"roles" bson:"roles" cql:"roles" dynamo:"roles"` + IsActive bool `json:"is_active" bson:"is_active" cql:"is_active" dynamo:"is_active"` + CreatedBy string `gorm:"type:char(36)" json:"created_by" bson:"created_by" cql:"created_by" dynamo:"created_by"` + CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"` + UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"` +} diff --git a/internal/storage/schemas/model.go b/internal/storage/schemas/model.go index 70d55123..9970b395 100644 --- a/internal/storage/schemas/model.go +++ b/internal/storage/schemas/model.go @@ -15,6 +15,7 @@ type CollectionList struct { SessionToken string MFASession string OAuthState string + Application string } var ( @@ -35,5 +36,6 @@ var ( SessionToken: Prefix + "session_tokens", MFASession: Prefix + "mfa_sessions", OAuthState: Prefix + "oauth_states", + Application: Prefix + "applications", } )