From 540443d7ffd27c512be519a79658a72f87dee01f Mon Sep 17 00:00:00 2001 From: Yuval Karmi Date: Sun, 25 Jan 2026 21:16:01 +0200 Subject: [PATCH] fix(saml): respect redirectTo URL on ACS error redirects Previously, SAML ACS errors always redirected to SiteURL, ignoring the redirectTo parameter stored in RelayState. This made it difficult to test SSO in development environments with different domains. Now handleSamlAcs returns the redirectTo URL along with errors, and SamlAcs uses it for error redirects when valid (falling back to SiteURL otherwise). --- internal/api/samlacs.go | 74 +++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index 1452cd5bd9..bcf1efa45b 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -46,10 +46,17 @@ func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models } func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { - if err := a.handleSamlAcs(w, r); err != nil { - u, uerr := url.Parse(a.config.SiteURL) + redirectTo, err := a.handleSamlAcs(w, r) + if err != nil { + // Use redirectTo if valid, otherwise fall back to SiteURL + targetURL := a.config.SiteURL + if redirectTo != "" && utilities.IsRedirectURLValid(a.config, redirectTo) { + targetURL = redirectTo + } + + u, uerr := url.Parse(targetURL) if uerr != nil { - return apierrors.NewInternalServerError("site url is improperly formattted").WithInternalError(err) + return apierrors.NewInternalServerError("redirect url is improperly formatted").WithInternalError(err) } q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) @@ -60,7 +67,8 @@ func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { } // handleSamlAcs implements the main Assertion Consumer Service endpoint behavior. -func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { +// Returns the redirectTo URL (if determined) and any error that occurred. +func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) (string, error) { ctx := r.Context() db := a.db.WithContext(ctx) @@ -82,17 +90,17 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") + return "", apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { - return err + return "", err } if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod { if err := a.samlDestroyRelayState(ctx, relayState); err != nil { - return apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) + return relayState.RedirectTo, apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") + return relayState.RedirectTo, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -100,11 +108,11 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID) if err != nil { - return apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState") + return relayState.RedirectTo, apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState") } if !ssoProvider.IsEnabled() { - return apierrors.NewNotFoundError( + return relayState.RedirectTo, apierrors.NewNotFoundError( apierrors.ErrorCodeSSOProviderDisabled, "SSO Provider assigned for this domain is currently disabled") } @@ -118,7 +126,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } if err := a.samlDestroyRelayState(ctx, relayState); err != nil { - return err + return redirectTo, err } } else if relayStateValue == "" || relayStateURL != nil { // RelayState may be a URL in which case it's the URL where the @@ -128,23 +136,23 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -152,25 +160,25 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") + return redirectTo, apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { - return err + return redirectTo, err } if !ssoProvider.IsEnabled() { - return apierrors.NewNotFoundError( + return redirectTo, apierrors.NewNotFoundError( apierrors.ErrorCodeSSOProviderDisabled, "SSO Provider assigned for this domain is currently disabled") } idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor() if err != nil { - return err + return redirectTo, err } samlMetadataModified := false @@ -203,10 +211,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) } - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -215,7 +223,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -227,19 +235,19 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") + return redirectTo, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } jsonClaims, err := json.Marshal(claims) if err != nil { - return apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) + return redirectTo, apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) } providerClaims := &provider.Claims{} if err := json.Unmarshal(jsonClaims, providerClaims); err != nil { - return apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) + return redirectTo, apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) } providerClaims.Subject = userID @@ -290,14 +298,14 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { var token *AccessTokenResponse if samlMetadataModified { if err := db.UpdateColumns(&ssoProvider.SAMLProvider, "metadata_xml", "updated_at"); err != nil { - return err + return redirectTo, err } } providerType := "sso:" + ssoProvider.ID.String() if err := a.triggerBeforeUserCreatedExternal( r, db, &userProvidedData, providerType); err != nil { - return err + return redirectTo, err } var createdUser bool @@ -328,11 +336,11 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { return nil }); err != nil { - return err + return redirectTo, err } if createdUser { if err := a.triggerAfterUserCreated(r, db, user); err != nil { - return err + return redirectTo, err } } @@ -343,10 +351,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { // PKCE flow: redirect with auth code redirectTo, err = a.prepPKCERedirectURL(redirectTo, *flowState.AuthCode) if err != nil { - return err + return redirectTo, err } http.Redirect(w, r, redirectTo, http.StatusFound) - return nil + return "", nil } // Record login for analytics - only when token is issued (not during pkce authorize) @@ -358,5 +366,5 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound) - return nil + return "", nil }