diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index 1452cd5bd9..bcf1efa45b 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -46,10 +46,17 @@ func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models } func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { - if err := a.handleSamlAcs(w, r); err != nil { - u, uerr := url.Parse(a.config.SiteURL) + redirectTo, err := a.handleSamlAcs(w, r) + if err != nil { + // Use redirectTo if valid, otherwise fall back to SiteURL + targetURL := a.config.SiteURL + if redirectTo != "" && utilities.IsRedirectURLValid(a.config, redirectTo) { + targetURL = redirectTo + } + + u, uerr := url.Parse(targetURL) if uerr != nil { - return apierrors.NewInternalServerError("site url is improperly formattted").WithInternalError(err) + return apierrors.NewInternalServerError("redirect url is improperly formatted").WithInternalError(err) } q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) @@ -60,7 +67,8 @@ func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { } // handleSamlAcs implements the main Assertion Consumer Service endpoint behavior. -func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { +// Returns the redirectTo URL (if determined) and any error that occurred. +func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) (string, error) { ctx := r.Context() db := a.db.WithContext(ctx) @@ -82,17 +90,17 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") + return "", apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { - return err + return "", err } if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod { if err := a.samlDestroyRelayState(ctx, relayState); err != nil { - return apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) + return relayState.RedirectTo, apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") + return relayState.RedirectTo, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -100,11 +108,11 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID) if err != nil { - return apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState") + return relayState.RedirectTo, apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState") } if !ssoProvider.IsEnabled() { - return apierrors.NewNotFoundError( + return relayState.RedirectTo, apierrors.NewNotFoundError( apierrors.ErrorCodeSSOProviderDisabled, "SSO Provider assigned for this domain is currently disabled") } @@ -118,7 +126,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } if err := a.samlDestroyRelayState(ctx, relayState); err != nil { - return err + return redirectTo, err } } else if relayStateValue == "" || relayStateURL != nil { // RelayState may be a URL in which case it's the URL where the @@ -128,23 +136,23 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -152,25 +160,25 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") + return redirectTo, apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { - return err + return redirectTo, err } if !ssoProvider.IsEnabled() { - return apierrors.NewNotFoundError( + return redirectTo, apierrors.NewNotFoundError( apierrors.ErrorCodeSSOProviderDisabled, "SSO Provider assigned for this domain is currently disabled") } idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor() if err != nil { - return err + return redirectTo, err } samlMetadataModified := false @@ -203,10 +211,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) } - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -215,7 +223,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -227,19 +235,19 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } jsonClaims, err := json.Marshal(claims) if err != nil { - return apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) + return redirectTo, apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) } providerClaims := &provider.Claims{} if err := json.Unmarshal(jsonClaims, providerClaims); err != nil { - return apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) + return redirectTo, apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) } providerClaims.Subject = userID @@ -290,14 +298,14 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { var token *AccessTokenResponse if samlMetadataModified { if err := db.UpdateColumns(&ssoProvider.SAMLProvider, "metadata_xml", "updated_at"); err != nil { - return err + return redirectTo, err } } providerType := "sso:" + ssoProvider.ID.String() if err := a.triggerBeforeUserCreatedExternal( r, db, &userProvidedData, providerType); err != nil { - return err + return redirectTo, err } var createdUser bool @@ -328,11 +336,11 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { return nil }); err != nil { - return err + return redirectTo, err } if createdUser { if err := a.triggerAfterUserCreated(r, db, user); err != nil { - return err + return redirectTo, err } } @@ -343,10 +351,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { // PKCE flow: redirect with auth code redirectTo, err = a.prepPKCERedirectURL(redirectTo, *flowState.AuthCode) if err != nil { - return err + return redirectTo, err } http.Redirect(w, r, redirectTo, http.StatusFound) - return nil + return "", nil } // Record login for analytics - only when token is issued (not during pkce authorize) @@ -358,5 +366,5 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound) - return nil + return "", nil }