Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 41 additions & 33 deletions internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,17 @@ func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models
}

func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error {
if err := a.handleSamlAcs(w, r); err != nil {
u, uerr := url.Parse(a.config.SiteURL)
redirectTo, err := a.handleSamlAcs(w, r)
if err != nil {
// Use redirectTo if valid, otherwise fall back to SiteURL
targetURL := a.config.SiteURL
if redirectTo != "" && utilities.IsRedirectURLValid(a.config, redirectTo) {
targetURL = redirectTo
}

u, uerr := url.Parse(targetURL)
if uerr != nil {
return apierrors.NewInternalServerError("site url is improperly formattted").WithInternalError(err)
return apierrors.NewInternalServerError("redirect url is improperly formatted").WithInternalError(err)
}

q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query())
Expand All @@ -60,7 +67,8 @@ func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error {
}

// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior.
func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
// Returns the redirectTo URL (if determined) and any error that occurred.
func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) (string, error) {
ctx := r.Context()

db := a.db.WithContext(ctx)
Expand All @@ -82,29 +90,29 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {

relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID)
if models.IsNotFoundError(err) {
return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?")
return "", apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?")
} else if err != nil {
return err
return "", err
}

if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod {
if err := a.samlDestroyRelayState(ctx, relayState); err != nil {
return apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err)
return relayState.RedirectTo, apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err)
}

return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?")
return relayState.RedirectTo, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?")
}

// TODO: add abuse detection to bind the RelayState UUID with a
// HTTP-Only cookie

ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID)
if err != nil {
return apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState")
return relayState.RedirectTo, apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState")
}

if !ssoProvider.IsEnabled() {
return apierrors.NewNotFoundError(
return relayState.RedirectTo, apierrors.NewNotFoundError(
apierrors.ErrorCodeSSOProviderDisabled,
"SSO Provider assigned for this domain is currently disabled")
}
Expand All @@ -118,7 +126,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
}

if err := a.samlDestroyRelayState(ctx, relayState); err != nil {
return err
return redirectTo, err
}
} else if relayStateValue == "" || relayStateURL != nil {
// RelayState may be a URL in which case it's the URL where the
Expand All @@ -128,49 +136,49 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
// SAML Artifact responses are possible only when
// RelayState can be used to identify the Identity
// Provider.
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow")
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow")
}

samlResponse := r.FormValue("SAMLResponse")
if samlResponse == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing")
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing")
}

responseXML, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string")
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string")
}

var peekResponse saml.Response
err = xml.Unmarshal(responseXML, &peekResponse)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err)
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err)
}

initiatedBy = "idp"
entityId = peekResponse.Issuer.Value
redirectTo = relayStateValue
} else {
// RelayState can't be identified, so SAML flow can't continue
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL")
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL")
}

ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId)
if models.IsNotFoundError(err) {
return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider")
return redirectTo, apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider")
} else if err != nil {
return err
return redirectTo, err
}

if !ssoProvider.IsEnabled() {
return apierrors.NewNotFoundError(
return redirectTo, apierrors.NewNotFoundError(
apierrors.ErrorCodeSSOProviderDisabled,
"SSO Provider assigned for this domain is currently disabled")
}

idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor()
if err != nil {
return err
return redirectTo, err
}

samlMetadataModified := false
Expand Down Expand Up @@ -203,10 +211,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
spAssertion, err := serviceProvider.ParseResponse(r, requestIds)
if err != nil {
if ire, ok := err.(*saml.InvalidResponseError); ok {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr)
return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr)
}

return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err)
return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err)
}

assertion := SAMLAssertion{
Expand All @@ -215,7 +223,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {

userID := assertion.UserID()
if userID == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user")
return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user")
}

claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping)
Expand All @@ -227,19 +235,19 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
}

if email == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address")
return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address")
} else {
claims["email"] = email
}

jsonClaims, err := json.Marshal(claims)
if err != nil {
return apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err)
return redirectTo, apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err)
}

providerClaims := &provider.Claims{}
if err := json.Unmarshal(jsonClaims, providerClaims); err != nil {
return apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err)
return redirectTo, apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err)
}

providerClaims.Subject = userID
Expand Down Expand Up @@ -290,14 +298,14 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
var token *AccessTokenResponse
if samlMetadataModified {
if err := db.UpdateColumns(&ssoProvider.SAMLProvider, "metadata_xml", "updated_at"); err != nil {
return err
return redirectTo, err
}
}

providerType := "sso:" + ssoProvider.ID.String()
if err := a.triggerBeforeUserCreatedExternal(
r, db, &userProvidedData, providerType); err != nil {
return err
return redirectTo, err
}

var createdUser bool
Expand Down Expand Up @@ -328,11 +336,11 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {

return nil
}); err != nil {
return err
return redirectTo, err
}
if createdUser {
if err := a.triggerAfterUserCreated(r, db, user); err != nil {
return err
return redirectTo, err
}
}

Expand All @@ -343,10 +351,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
// PKCE flow: redirect with auth code
redirectTo, err = a.prepPKCERedirectURL(redirectTo, *flowState.AuthCode)
if err != nil {
return err
return redirectTo, err
}
http.Redirect(w, r, redirectTo, http.StatusFound)
return nil
return "", nil
}

// Record login for analytics - only when token is issued (not during pkce authorize)
Expand All @@ -358,5 +366,5 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {

http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound)

return nil
return "", nil
}