From f8b3e7692a56dd9c294672a4c5a459aaa54980c4 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Mon, 24 Feb 2025 14:29:08 -0500 Subject: [PATCH 01/27] add oidc/clientassertion package for signing JWTs to send as client_assertions ref: https://oauth.net/private-key-jwt/ --- oidc/clientassertion/client_assertion.go | 200 ++++++++++++ oidc/clientassertion/client_assertion_test.go | 287 ++++++++++++++++++ oidc/clientassertion/doc.go | 13 + 3 files changed, 500 insertions(+) create mode 100644 oidc/clientassertion/client_assertion.go create mode 100644 oidc/clientassertion/client_assertion_test.go create mode 100644 oidc/clientassertion/doc.go diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go new file mode 100644 index 0000000..b204800 --- /dev/null +++ b/oidc/clientassertion/client_assertion.go @@ -0,0 +1,200 @@ +package clientassertion + +import ( + "crypto/rsa" + "errors" + "fmt" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/hashicorp/go-uuid" +) + +var ( + // these may happen due to user error + ErrMissingClientID = errors.New("missing client ID") + ErrMissingAudience = errors.New("missing audience") + ErrMissingAlgorithm = errors.New("missing signing algorithm") + ErrMissingKeyOrSecret = errors.New("missing private key or client secret") + ErrBothKeyAndSecret = errors.New("both private key and client secret provided") + // if these happen, either the user directly instantiated &ClientAssertion{} + // or there's a bug somewhere. + ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use New()") + ErrMissingFuncNow = errors.New("missing now func; please use New()") +) + +// New sets up a new ClientAssertion to sign private key JWTs +func New(clientID string, audience []string, opts ...Option) (*ClientAssertion, error) { + a := &ClientAssertion{ + clientID: clientID, + audience: audience, + headers: make(map[string]string), + genID: uuid.GenerateUUID, + now: time.Now, + } + for _, opt := range opts { + opt(a) + } + if err := a.Validate(); err != nil { + return nil, fmt.Errorf("new client assertion validation error: %w", err) + } + return a, nil +} + +// ClientAssertion signs a JWT with either a private key or a secret +type ClientAssertion struct { + // for JWT claims + clientID string + audience []string + headers map[string]string + + // for signer + alg jose.SignatureAlgorithm + // key may be any key type that jose.SigningKey accepts for its Key + key any + // secret may be used instead of key + secret string + + // these are overwritten for testing + genID func() (string, error) + now func() time.Time +} + +// Validate validates the expected fields +func (c *ClientAssertion) Validate() error { + var errs []error + if c.genID == nil { + errs = append(errs, ErrMissingFuncIDGenerator) + } + if c.now == nil { + errs = append(errs, ErrMissingFuncNow) + } + // bail early if any internal func errors + if len(errs) > 0 { + return errors.Join(errs...) + } + + if c.clientID == "" { + errs = append(errs, ErrMissingClientID) + } + if len(c.audience) == 0 { + errs = append(errs, ErrMissingAudience) + } + if c.alg == "" { + errs = append(errs, ErrMissingAlgorithm) + } + if c.key == nil && c.secret == "" { + errs = append(errs, ErrMissingKeyOrSecret) + } + if c.key != nil && c.secret != "" { + errs = append(errs, ErrBothKeyAndSecret) + } + return errors.Join(errs...) +} + +// SignedToken returns a signed JWT in the compact serialization format +func (c *ClientAssertion) SignedToken() (string, error) { + if err := c.Validate(); err != nil { + return "", err + } + builder, err := c.builder() + if err != nil { + return "", err + } + token, err := builder.CompactSerialize() + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + return token, nil +} + +func (c *ClientAssertion) builder() (jwt.Builder, error) { + signer, err := c.signer() + if err != nil { + return nil, err + } + id, err := c.genID() + if err != nil { + return nil, fmt.Errorf("failed to generate token id: %w", err) + } + claims := c.claims(id) + return jwt.Signed(signer).Claims(claims), nil +} + +func (c *ClientAssertion) signer() (jose.Signer, error) { + sKey := jose.SigningKey{ + Algorithm: c.alg, + } + + // Validate() ensures these are mutually exclusive + if c.secret != "" { + sKey.Key = []byte(c.secret) + } + if c.key != nil { + sKey.Key = c.key + } + + sOpts := &jose.SignerOptions{ + ExtraHeaders: make(map[jose.HeaderKey]interface{}, len(c.headers)), + } + // note: extra headers can override "kid" + for k, v := range c.headers { + sOpts.ExtraHeaders[jose.HeaderKey(k)] = v + } + + signer, err := jose.NewSigner(sKey, sOpts.WithType("JWT")) + if err != nil { + return nil, fmt.Errorf("failed to create jwt signer: %w", err) + } + return signer, nil +} + +func (c *ClientAssertion) claims(id string) *jwt.Claims { + now := c.now().UTC() + return &jwt.Claims{ + Issuer: c.clientID, + Subject: c.clientID, + Audience: c.audience, + Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), + NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)), + IssuedAt: jwt.NewNumericDate(now), + ID: id, + } +} + +// Options configure the ClientAssertion +type Option func(*ClientAssertion) + +// WithClientSecret sets a secret and algorithm to sign the JWT with +func WithClientSecret(secret string, alg string) Option { + return func(c *ClientAssertion) { + c.secret = secret + c.alg = jose.SignatureAlgorithm(alg) + } +} + +// WithRSAKey sets a private key to sign the JWT with +func WithRSAKey(key *rsa.PrivateKey, alg string) Option { + return func(c *ClientAssertion) { + c.key = key + c.alg = jose.SignatureAlgorithm(alg) + } +} + +// WithKeyID sets the "kid" header that OIDC providers use to look up the +// public key to check the signed JWT +func WithKeyID(keyID string) Option { + return func(c *ClientAssertion) { + c.headers["kid"] = keyID + } +} + +// WithHeaders sets extra JWT headers +func WithHeaders(h map[string]string) Option { + return func(c *ClientAssertion) { + for k, v := range h { + c.headers[k] = v + } + } +} diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go new file mode 100644 index 0000000..cc63166 --- /dev/null +++ b/oidc/clientassertion/client_assertion_test.go @@ -0,0 +1,287 @@ +package clientassertion + +import ( + "crypto/rand" + "crypto/rsa" + "errors" + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// any non-nil error from New()/Validate() will be errors.Join()ed. +// this is so we can assert each error within. +type joinedErrs interface { + Unwrap() []error +} + +func assertJoinedErrs(t *testing.T, expect []error, actual error) { + t.Helper() + joined, ok := actual.(joinedErrs) // Validate() error is errors.Join()ed + require.True(t, ok, "expected Join()ed errors from Validate()") + unwrapped := joined.Unwrap() + require.ElementsMatch(t, expect, unwrapped) +} + +// TestClientAssertionBare tests what errors we expect if &ClientAssertion{} +// is instantiated directly, rather than using the constructor New(). +func TestClientAssertionBare(t *testing.T) { + ca := &ClientAssertion{} + + // all public methods should return the same error(s) + expect := []error{ErrMissingFuncIDGenerator, ErrMissingFuncNow} + + actual := ca.Validate() + assertJoinedErrs(t, expect, actual) + + tokenStr, err := ca.SignedToken() + assertJoinedErrs(t, expect, err) + + assert.Equal(t, "", tokenStr) +} + +func TestNew(t *testing.T) { + t.Run("should run validate", func(t *testing.T) { + ca, err := New("", nil) + require.ErrorContains(t, err, "validation error:") + assert.Nil(t, ca) + }) + + tCid := "test-client-id" + tAud := []string{"test-audience"} + + cases := []struct { + name string + cid string + aud []string + opts []Option + check func(*testing.T, *ClientAssertion) + }{ + { + name: "with private key", + cid: tCid, aud: tAud, + opts: []Option{WithRSAKey(&rsa.PrivateKey{}, "test-alg")}, + check: func(t *testing.T, ca *ClientAssertion) { + require.NotNil(t, ca.key) + require.Equal(t, jose.SignatureAlgorithm("test-alg"), ca.alg) + }, + }, + { + name: "with client secret", + cid: tCid, aud: tAud, + opts: []Option{WithClientSecret("ssshhhh", "test-alg")}, + check: func(t *testing.T, ca *ClientAssertion) { + require.Equal(t, "ssshhhh", ca.secret) + require.Equal(t, jose.SignatureAlgorithm("test-alg"), ca.alg) + }, + }, + { + name: "with key id", + cid: tCid, aud: tAud, + opts: []Option{ + WithKeyID("kid"), + WithClientSecret("ssshhhh", "blah"), + }, + check: func(t *testing.T, ca *ClientAssertion) { + require.Equal(t, "kid", ca.headers["kid"]) + }, + }, + { + name: "with headers", + cid: tCid, aud: tAud, + opts: []Option{ + WithHeaders(map[string]string{"h1": "v1", "h2": "v2"}), + WithClientSecret("ssshhhh", "test-alg"), + }, + check: func(t *testing.T, ca *ClientAssertion) { + require.Equal(t, map[string]string{"h1": "v1", "h2": "v2"}, ca.headers) + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + + ca, err := New(tc.cid, tc.aud, tc.opts...) + + require.NoError(t, err) + require.NotNil(t, ca) + require.Equal(t, tc.cid, ca.clientID) + require.Equal(t, tc.aud, ca.audience) + + if tc.check != nil { + tc.check(t, ca) + } + + }) + } +} + +func TestValidate(t *testing.T) { + tCid := "test-client-id" + tAud := []string{"test-audience"} + cases := []struct { + name string + cid string + aud []string + opts []Option + errs []error + }{ + { + name: "missing everything", + errs: []error{ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingKeyOrSecret}, + }, + { + name: "missing client id", + aud: tAud, + errs: []error{ErrMissingClientID}, + opts: []Option{ + WithRSAKey(&rsa.PrivateKey{}, "algo"), + }, + }, + { + name: "missing audience", + cid: tCid, + errs: []error{ErrMissingAudience}, + opts: []Option{ + WithRSAKey(&rsa.PrivateKey{}, "algo"), + }, + }, + { + name: "missing client and secret", + cid: tCid, aud: tAud, + errs: []error{ErrMissingAlgorithm, ErrMissingKeyOrSecret}, + }, + { + name: "both client and secret", + cid: tCid, aud: tAud, + opts: []Option{ + WithRSAKey(&rsa.PrivateKey{}, "algo"), + WithClientSecret("ssshhhh", "algo"), + }, + errs: []error{ErrBothKeyAndSecret}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + + // New() runs Validate() + ca, err := New(tc.cid, tc.aud, tc.opts...) + + require.NotNil(t, err) + require.ErrorContains(t, err, "validation error:") + + err = errors.Unwrap(err) // New wraps the error from Validate() with fmt.Errorf("%w") + assertJoinedErrs(t, tc.errs, err) + + require.Nil(t, ca) + + }) + } +} + +func TestSignedToken(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pub, ok := key.Public().(*rsa.PublicKey) + require.True(t, ok, "couldn't get rsa.PublicKey from PrivateKey") + + cases := []struct { + name string + claimKey any // []byte or pubkey; we'll use this to check the signature + opts []Option + err error + }{ + { + name: "valid secret", + claimKey: []byte("ssshhhh"), + opts: []Option{ + WithClientSecret("ssshhhh", "HS256"), + WithKeyID("test-key-id"), + WithHeaders(map[string]string{"xtra": "headies"}), + }, + }, + { + name: "valid key", + claimKey: pub, + opts: []Option{ + WithRSAKey(key, "RS256"), + WithKeyID("test-key-id"), + WithHeaders(map[string]string{"xtra": "headies"}), + }, + }, + { + name: "invalid alg", + claimKey: pub, + opts: []Option{ + WithRSAKey(key, "ruh-roh"), + }, + err: jose.ErrUnsupportedAlgorithm, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ca, err := New("test-client-id", []string{"test-aud"}, tc.opts...) + require.NoError(t, err) + + now := time.Now() + ca.now = func() time.Time { return now } + ca.genID = func() (string, error) { return "test-claim-id", nil } + + // method under test + tokenString, err := ca.SignedToken() + + if tc.err != nil { + require.ErrorIs(t, err, tc.err) + require.Equal(t, "", tokenString) + return + } + require.NoError(t, err) + + // extract the token from the signed string + token, err := jwt.ParseSigned(tokenString) + require.NoError(t, err) + + // check headers + expectHeaders := jose.Header{ + Algorithm: string(ca.alg), + KeyID: "test-key-id", + ExtraHeaders: map[jose.HeaderKey]any{ + "typ": "JWT", + "xtra": "headies", + }, + } + require.Len(t, token.Headers, 1) + actualHeaders := token.Headers[0] + require.Equal(t, expectHeaders, actualHeaders) + + // check claims + expectClaims := jwt.Expected{ + Issuer: "test-client-id", + Subject: "test-client-id", + Audience: []string{"test-aud"}, + ID: "test-claim-id", + Time: now, + } + var actualClaims jwt.Claims + err = token.Claims(tc.claimKey, &actualClaims) + require.NoError(t, err) + err = actualClaims.Validate(expectClaims) + require.NoError(t, err) + }) + } + + t.Run("error generating token id", func(t *testing.T) { + genIDErr := errors.New("failed to generate test id") + ca, err := New("a", []string{"a"}, WithClientSecret("ssshhhh", "HS256")) + require.NoError(t, err) + ca.genID = func() (string, error) { return "", genIDErr } + tokenString, err := ca.SignedToken() + require.ErrorIs(t, err, genIDErr) + require.Equal(t, "", tokenString) + }) +} diff --git a/oidc/clientassertion/doc.go b/oidc/clientassertion/doc.go new file mode 100644 index 0000000..f01b0a5 --- /dev/null +++ b/oidc/clientassertion/doc.go @@ -0,0 +1,13 @@ +package clientassertion + +// clientassertion signs JWTs with a Private Key or Client Secret for use +// in OIDC client_assertion requests, A.K.A. private_key_jwt. reference: +// https://oauth.net/private-key-jwt/ +// +// Example usage: +// +// cass, err := clientassertion.New("client-id", []string{"audience"}, +// clientassertion.WithRSAKey(rsaPrivateKey, "RS256"), +// clientassertion.WithKeyID("jwks-key-id-or-x5t-etc"), +// ) +// jwtString, err := cass.SignedToken() From d64063310208f8b8d135d7a58a3df2f31c942288 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Mon, 24 Feb 2025 15:07:08 -0500 Subject: [PATCH 02/27] add WithClientAssertionJWT Option to send a signed JWT as request's client_assertion ref: https://oauth.net/private-key-jwt/ --- oidc/provider.go | 8 ++++++++ oidc/request.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/oidc/provider.go b/oidc/provider.go index f8b342e..4f68f4e 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -306,6 +306,14 @@ func (p *Provider) Exchange(ctx context.Context, oidcRequest Request, authorizat if oidcRequest.PKCEVerifier() != nil { authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("code_verifier", oidcRequest.PKCEVerifier().Verifier())) } + if oidcRequest.ClientAssertionJWT() != "" { + authCodeOpts = append(authCodeOpts, + // client_assertion_type is *always* this value. + // https://www.rfc-editor.org/rfc/rfc7523.html#section-2.2 + oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), + oauth2.SetAuthURLParam("client_assertion", oidcRequest.ClientAssertionJWT()), + ) + } oauth2Token, err := oauth2Config.Exchange(oidcCtx, authorizationCode, authCodeOpts...) if err != nil { return nil, fmt.Errorf("%s: unable to exchange auth code with provider: %w", op, p.convertError(err)) diff --git a/oidc/request.go b/oidc/request.go index 38a11f0..f5dc632 100644 --- a/oidc/request.go +++ b/oidc/request.go @@ -83,6 +83,12 @@ type Request interface { // See: https://tools.ietf.org/html/rfc7636 PKCEVerifier() CodeVerifier + // ClientAssertionJWT optionally specifies a signed JWT to be used in a + // client_assertion token request. + // + // See: https://oauth.net/private-key-jwt/ + ClientAssertionJWT() string + // MaxAge: when authAfter is not a zero value (authTime.IsZero()) then the // id_token's auth_time claim must be after the specified time. // @@ -175,6 +181,8 @@ type Req struct { // with PKCE. It suppies the required CodeVerifier for PKCE. withVerifier CodeVerifier + withClientJWT string + // withMaxAge: when withMaxAge.authAfter is not a zero value // (authTime.IsZero()) then the id_token's auth_time claim must be after the // specified time. @@ -216,6 +224,7 @@ var _ Request = (*Req)(nil) // * WithScopes // * WithImplicit // * WithPKCE +// * WithClientJWT // * WithMaxAge // * WithPrompts // * WithDisplay @@ -267,6 +276,7 @@ func NewRequest(expireIn time.Duration, redirectURL string, opt ...Option) (*Req scopes: opts.withScopes, withImplicit: opts.withImplicitFlow, withVerifier: opts.withVerifier, + withClientJWT: opts.withClientJWT, withPrompts: opts.withPrompts, withDisplay: opts.withDisplay, withUILocales: opts.withUILocales, @@ -321,6 +331,12 @@ func (r *Req) PKCEVerifier() CodeVerifier { return r.withVerifier.Copy() } +// ClientAssertionJWT implements the Request.ClientAssertionJWT() interface +// function and returns the JWT string. +func (r *Req) ClientAssertionJWT() string { + return r.withClientJWT +} + // Prompts() implements the Request.Prompts() interface function and returns a // copy of the prompts. func (r *Req) Prompts() []Prompt { @@ -442,6 +458,7 @@ type reqOptions struct { withACRValues []string withState string withNonce string + withClientJWT string } // reqDefaults is a handy way to get the defaults at runtime and during unit @@ -502,6 +519,21 @@ func WithPKCE(v CodeVerifier) Option { } } +// WithClientAssertionJWT provides an option to send a signed JWT as a +// client_assertion. +// +// Option is valid for: Request +// +// See: https://oauth.net/private-key-jwt/ +func WithClientAssertionJWT(jwt string) Option { + return func(o interface{}) { + switch v := o.(type) { + case *reqOptions: + v.withClientJWT = jwt + } + } +} + // WithMaxAge provides an optional maximum authentication age, which is the // allowable elapsed time in seconds since the last time the user was actively // authenticated by the provider. When a max age is specified, the provider From ca292a6c9a1708ac443a9a813716703464f30ef1 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Mon, 24 Feb 2025 17:07:19 -0500 Subject: [PATCH 03/27] test WithClientAssertionJWT --- oidc/clientassertion/client_assertion.go | 6 +++ oidc/provider.go | 5 +-- oidc/provider_test.go | 20 +++++++++ oidc/request_test.go | 54 +++++++++++++++++------- oidc/testing_provider.go | 33 +++++++++++++++ 5 files changed, 99 insertions(+), 19 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index b204800..8bcab61 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -11,6 +11,12 @@ import ( "github.com/hashicorp/go-uuid" ) +const ( + // ClientAssertionJWTType is the proper value for client_assertion_type. + // https://www.rfc-editor.org/rfc/rfc7523.html#section-2.2 + ClientAssertionJWTType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +) + var ( // these may happen due to user error ErrMissingClientID = errors.New("missing client ID") diff --git a/oidc/provider.go b/oidc/provider.go index 4f68f4e..df78d14 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -20,6 +20,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + cass "github.com/hashicorp/cap/oidc/clientassertion" "github.com/hashicorp/cap/oidc/internal/strutils" "github.com/hashicorp/go-cleanhttp" "golang.org/x/oauth2" @@ -308,9 +309,7 @@ func (p *Provider) Exchange(ctx context.Context, oidcRequest Request, authorizat } if oidcRequest.ClientAssertionJWT() != "" { authCodeOpts = append(authCodeOpts, - // client_assertion_type is *always* this value. - // https://www.rfc-editor.org/rfc/rfc7523.html#section-2.2 - oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), + oauth2.SetAuthURLParam("client_assertion_type", cass.ClientAssertionJWTType), oauth2.SetAuthURLParam("client_assertion", oidcRequest.ClientAssertionJWT()), ) } diff --git a/oidc/provider_test.go b/oidc/provider_test.go index b07b128..2894738 100644 --- a/oidc/provider_test.go +++ b/oidc/provider_test.go @@ -501,6 +501,14 @@ func TestProvider_Exchange(t *testing.T) { ) require.NoError(t, err) + cassJWT := "test-client-assertion-jwt" + reqWithClientAssertion, err := NewRequest( + 1*time.Minute, + redirect, + WithClientAssertionJWT(cassJWT), + ) + require.NoError(t, err) + type args struct { ctx context.Context r Request @@ -548,6 +556,17 @@ func TestProvider_Exchange(t *testing.T) { expectedAudiences: []string{"state-override"}, }, }, + { + name: "client-assertion-jwt", + p: p, + args: args{ + ctx: ctx, + r: reqWithClientAssertion, + authRequest: reqWithClientAssertion.State(), + authCode: "test-code", + expectedAudiences: []string{"state-override"}, + }, + }, { name: "nil-config", p: &Provider{}, @@ -590,6 +609,7 @@ func TestProvider_Exchange(t *testing.T) { if tt.args.r.PKCEVerifier() != nil { tp.SetPKCEVerifier(tt.args.r.PKCEVerifier()) } + tp.SetClientAssertionJWT(tt.args.r.ClientAssertionJWT()) } if tt.args.expectedNonce != "" { tp.SetExpectedAuthNonce(tt.args.expectedNonce) diff --git a/oidc/request_test.go b/oidc/request_test.go index 1119b87..83ebb99 100644 --- a/oidc/request_test.go +++ b/oidc/request_test.go @@ -28,17 +28,18 @@ func TestNewRequest(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - expireIn time.Duration - redirectURL string - opts []Option - wantNowFunc func() time.Time - wantRedirectURL string - wantAudiences []string - wantScopes []string - wantVerifier CodeVerifier - wantErr bool - wantIsErr error + name string + expireIn time.Duration + redirectURL string + opts []Option + wantNowFunc func() time.Time + wantRedirectURL string + wantAudiences []string + wantScopes []string + wantVerifier CodeVerifier + wantClientAssertion string + wantErr bool + wantIsErr error }{ { name: "valid-with-all-options", @@ -49,12 +50,14 @@ func TestNewRequest(t *testing.T) { WithAudiences("bob", "alice"), WithScopes("email", "profile"), WithPKCE(testVerifier), + WithClientAssertionJWT("test-client-assertion-jwt"), }, - wantNowFunc: testNow, - wantRedirectURL: "https://bob.com", - wantAudiences: []string{"bob", "alice"}, - wantScopes: []string{oidc.ScopeOpenID, "email", "profile"}, - wantVerifier: testVerifier, + wantNowFunc: testNow, + wantRedirectURL: "https://bob.com", + wantAudiences: []string{"bob", "alice"}, + wantScopes: []string{oidc.ScopeOpenID, "email", "profile"}, + wantVerifier: testVerifier, + wantClientAssertion: "test-client-assertion-jwt", }, { name: "valid-no-opt", @@ -90,6 +93,7 @@ func TestNewRequest(t *testing.T) { assert.Equalf(got.Audiences(), tt.wantAudiences, "wanted \"%s\" but got \"%s\"", tt.wantAudiences, got.Audiences()) assert.Equalf(got.Scopes(), tt.wantScopes, "wanted \"%s\" but got \"%s\"", tt.wantScopes, got.Scopes()) assert.Equalf(got.PKCEVerifier(), tt.wantVerifier, "wanted \"%s\" but got \"%s\"", tt.wantVerifier, got.PKCEVerifier()) + assert.Equal(got.ClientAssertionJWT(), tt.wantClientAssertion) }) } } @@ -150,6 +154,24 @@ func Test_WithPKCE(t *testing.T) { }) } +func Test_WithClientAssertionJWT(t *testing.T) { + t.Parallel() + t.Run("reqOptions", func(t *testing.T) { + t.Parallel() + assert := assert.New(t) + opts := getReqOpts() + testOpts := reqDefaults() + assert.Equal(opts, testOpts) + assert.Empty(testOpts.withClientJWT) + + j := "test-jwt" + opts = getReqOpts(WithClientAssertionJWT(j)) + testOpts = reqDefaults() + testOpts.withClientJWT = j + assert.Equal(opts, testOpts) + }) +} + func Test_WithMaxAge(t *testing.T) { t.Parallel() t.Run("reqOptions", func(t *testing.T) { diff --git a/oidc/testing_provider.go b/oidc/testing_provider.go index a8c7e5c..48fe2be 100644 --- a/oidc/testing_provider.go +++ b/oidc/testing_provider.go @@ -33,6 +33,7 @@ import ( "time" "github.com/go-jose/go-jose/v3" + cass "github.com/hashicorp/cap/oidc/clientassertion" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/require" @@ -176,6 +177,8 @@ type TestProvider struct { nowFunc func() time.Time pkceVerifier CodeVerifier + clientAssertionJWT string + // privKey *ecdsa.PrivateKey privKey crypto.PrivateKey pubKey crypto.PublicKey @@ -234,6 +237,7 @@ func StartTestProvider(t TestingT, opt ...Option) *TestProvider { replyExpiry: *opts.withDefaults.Expiry, nowFunc: opts.withDefaults.NowFunc, pkceVerifier: opts.withDefaults.PKCEVerifier, + clientAssertionJWT: opts.withDefaults.ClientAssertionJWT, replySubject: *opts.withDefaults.ExpectedSubject, subjectInfo: opts.withDefaults.SubjectInfo, // default is not to use a login form, so no passwords required for subjects codes: map[string]*codeState{}, @@ -440,6 +444,9 @@ type TestProviderDefaults struct { // PKCEVerifier(oidc.CodeVerifier) configures the PKCE code_verifier PKCEVerifier CodeVerifier + // ClientAssertionJWT includes a client_assertion JWT in token requests + ClientAssertionJWT string + // OmitAuthTime turn on/off the omitting of an auth_time claim from // id_tokens from the /token endpoint. If set to true, the test provider will // not include the auth_time claim in issued id_tokens from the /token @@ -534,6 +541,9 @@ func WithTestDefaults(defaults *TestProviderDefaults) Option { if defaults.PKCEVerifier != nil { o.withDefaults.PKCEVerifier = defaults.PKCEVerifier } + if defaults.ClientAssertionJWT != "" { + o.withDefaults.ClientAssertionJWT = defaults.ClientAssertionJWT + } if defaults.ExpectedSubject != nil { o.withDefaults.ExpectedSubject = defaults.ExpectedSubject } @@ -894,6 +904,17 @@ func (p *TestProvider) PKCEVerifier() CodeVerifier { return p.pkceVerifier } +// SetClientAssertionJWT sets the client assertion JWT +func (p *TestProvider) SetClientAssertionJWT(jwt string) { + p.mu.Lock() + defer p.mu.Unlock() + if v, ok := interface{}(p.t).(HelperT); ok { + v.Helper() + } + require.NotEqual(p.t, "", jwt) + p.clientAssertionJWT = jwt +} + // SetUserInfoReply sets the UserInfo endpoint response. func (p *TestProvider) SetUserInfoReply(resp map[string]interface{}) { p.mu.Lock() @@ -1359,6 +1380,18 @@ func (p *TestProvider) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } + if t := req.FormValue("client_assertion_type"); t != "" { + // assume client_assertion_type, if set, is always the magic jwt string + if t != cass.ClientAssertionJWTType { + _ = p.writeTokenErrorResponse(w, http.StatusBadRequest, "invalid_client", "unknown client assertion type") + return + } + if req.FormValue("client_assertion") != p.clientAssertionJWT { + _ = p.writeTokenErrorResponse(w, http.StatusUnauthorized, "invalid_client", "bad client assertion value") + return + } + } + var sub string var nonce string switch { From 8fca05fd7b494f416ef73b1d8a96712c10279992 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 25 Feb 2025 12:33:01 -0500 Subject: [PATCH 04/27] add changelog, fix test --- CHANGELOG.md | 1 + oidc/provider_test.go | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7bceb..4b042f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Canonical reference for changes, improvements, and bugfixes for cap. ## Next +* feat (oidc): add WithClientAssertionJWT to enable "private key JWT" ([PR #155](https://github.com/hashicorp/cap/pull/155)) * feat (oidc): add WithVerifier ([PR #141](https://github.com/hashicorp/cap/pull/141)) * feat (ldap): add an option to enable sAMAccountname logins when upndomain is set ([PR #146](https://github.com/hashicorp/cap/pull/146)) * feat (saml): enhancing signature validation in SAML Response ([PR #144](https://github.com/hashicorp/cap/pull/144)) diff --git a/oidc/provider_test.go b/oidc/provider_test.go index 2894738..7420500 100644 --- a/oidc/provider_test.go +++ b/oidc/provider_test.go @@ -609,7 +609,9 @@ func TestProvider_Exchange(t *testing.T) { if tt.args.r.PKCEVerifier() != nil { tp.SetPKCEVerifier(tt.args.r.PKCEVerifier()) } - tp.SetClientAssertionJWT(tt.args.r.ClientAssertionJWT()) + if jot := tt.args.r.ClientAssertionJWT(); jot != "" { + tp.SetClientAssertionJWT(jot) + } } if tt.args.expectedNonce != "" { tp.SetExpectedAuthNonce(tt.args.expectedNonce) From 2fe4caee7355a8228ac0e3df1ead6820941bf200 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 25 Feb 2025 16:42:58 -0500 Subject: [PATCH 05/27] s/ClientAssertion/JWT/g in clientassertion pkg --- oidc/clientassertion/client_assertion.go | 102 +++++++++--------- oidc/clientassertion/client_assertion_test.go | 64 +++++------ oidc/clientassertion/doc.go | 2 +- 3 files changed, 84 insertions(+), 84 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 8bcab61..269d8e8 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -24,15 +24,15 @@ var ( ErrMissingAlgorithm = errors.New("missing signing algorithm") ErrMissingKeyOrSecret = errors.New("missing private key or client secret") ErrBothKeyAndSecret = errors.New("both private key and client secret provided") - // if these happen, either the user directly instantiated &ClientAssertion{} + // if these happen, either the user directly instantiated &JWT{} // or there's a bug somewhere. - ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use New()") - ErrMissingFuncNow = errors.New("missing now func; please use New()") + ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use NewJWT()") + ErrMissingFuncNow = errors.New("missing now func; please use NewJWT()") ) -// New sets up a new ClientAssertion to sign private key JWTs -func New(clientID string, audience []string, opts ...Option) (*ClientAssertion, error) { - a := &ClientAssertion{ +// NewJWT sets up a new JWT to sign with a private key or client secret +func NewJWT(clientID string, audience []string, opts ...Option) (*JWT, error) { + j := &JWT{ clientID: clientID, audience: audience, headers: make(map[string]string), @@ -40,16 +40,16 @@ func New(clientID string, audience []string, opts ...Option) (*ClientAssertion, now: time.Now, } for _, opt := range opts { - opt(a) + opt(j) } - if err := a.Validate(); err != nil { + if err := j.Validate(); err != nil { return nil, fmt.Errorf("new client assertion validation error: %w", err) } - return a, nil + return j, nil } -// ClientAssertion signs a JWT with either a private key or a secret -type ClientAssertion struct { +// JWT signs a JWT with either a private key or a secret +type JWT struct { // for JWT claims clientID string audience []string @@ -68,12 +68,12 @@ type ClientAssertion struct { } // Validate validates the expected fields -func (c *ClientAssertion) Validate() error { +func (j *JWT) Validate() error { var errs []error - if c.genID == nil { + if j.genID == nil { errs = append(errs, ErrMissingFuncIDGenerator) } - if c.now == nil { + if j.now == nil { errs = append(errs, ErrMissingFuncNow) } // bail early if any internal func errors @@ -81,30 +81,30 @@ func (c *ClientAssertion) Validate() error { return errors.Join(errs...) } - if c.clientID == "" { + if j.clientID == "" { errs = append(errs, ErrMissingClientID) } - if len(c.audience) == 0 { + if len(j.audience) == 0 { errs = append(errs, ErrMissingAudience) } - if c.alg == "" { + if j.alg == "" { errs = append(errs, ErrMissingAlgorithm) } - if c.key == nil && c.secret == "" { + if j.key == nil && j.secret == "" { errs = append(errs, ErrMissingKeyOrSecret) } - if c.key != nil && c.secret != "" { + if j.key != nil && j.secret != "" { errs = append(errs, ErrBothKeyAndSecret) } return errors.Join(errs...) } // SignedToken returns a signed JWT in the compact serialization format -func (c *ClientAssertion) SignedToken() (string, error) { - if err := c.Validate(); err != nil { +func (j *JWT) SignedToken() (string, error) { + if err := j.Validate(); err != nil { return "", err } - builder, err := c.builder() + builder, err := j.builder() if err != nil { return "", err } @@ -115,37 +115,37 @@ func (c *ClientAssertion) SignedToken() (string, error) { return token, nil } -func (c *ClientAssertion) builder() (jwt.Builder, error) { - signer, err := c.signer() +func (j *JWT) builder() (jwt.Builder, error) { + signer, err := j.signer() if err != nil { return nil, err } - id, err := c.genID() + id, err := j.genID() if err != nil { return nil, fmt.Errorf("failed to generate token id: %w", err) } - claims := c.claims(id) + claims := j.claims(id) return jwt.Signed(signer).Claims(claims), nil } -func (c *ClientAssertion) signer() (jose.Signer, error) { +func (j *JWT) signer() (jose.Signer, error) { sKey := jose.SigningKey{ - Algorithm: c.alg, + Algorithm: j.alg, } // Validate() ensures these are mutually exclusive - if c.secret != "" { - sKey.Key = []byte(c.secret) + if j.secret != "" { + sKey.Key = []byte(j.secret) } - if c.key != nil { - sKey.Key = c.key + if j.key != nil { + sKey.Key = j.key } sOpts := &jose.SignerOptions{ - ExtraHeaders: make(map[jose.HeaderKey]interface{}, len(c.headers)), + ExtraHeaders: make(map[jose.HeaderKey]interface{}, len(j.headers)), } // note: extra headers can override "kid" - for k, v := range c.headers { + for k, v := range j.headers { sOpts.ExtraHeaders[jose.HeaderKey(k)] = v } @@ -156,12 +156,12 @@ func (c *ClientAssertion) signer() (jose.Signer, error) { return signer, nil } -func (c *ClientAssertion) claims(id string) *jwt.Claims { - now := c.now().UTC() +func (j *JWT) claims(id string) *jwt.Claims { + now := j.now().UTC() return &jwt.Claims{ - Issuer: c.clientID, - Subject: c.clientID, - Audience: c.audience, + Issuer: j.clientID, + Subject: j.clientID, + Audience: j.audience, Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)), IssuedAt: jwt.NewNumericDate(now), @@ -169,38 +169,38 @@ func (c *ClientAssertion) claims(id string) *jwt.Claims { } } -// Options configure the ClientAssertion -type Option func(*ClientAssertion) +// Options configure the JWT +type Option func(*JWT) // WithClientSecret sets a secret and algorithm to sign the JWT with func WithClientSecret(secret string, alg string) Option { - return func(c *ClientAssertion) { - c.secret = secret - c.alg = jose.SignatureAlgorithm(alg) + return func(j *JWT) { + j.secret = secret + j.alg = jose.SignatureAlgorithm(alg) } } // WithRSAKey sets a private key to sign the JWT with func WithRSAKey(key *rsa.PrivateKey, alg string) Option { - return func(c *ClientAssertion) { - c.key = key - c.alg = jose.SignatureAlgorithm(alg) + return func(j *JWT) { + j.key = key + j.alg = jose.SignatureAlgorithm(alg) } } // WithKeyID sets the "kid" header that OIDC providers use to look up the // public key to check the signed JWT func WithKeyID(keyID string) Option { - return func(c *ClientAssertion) { - c.headers["kid"] = keyID + return func(j *JWT) { + j.headers["kid"] = keyID } } // WithHeaders sets extra JWT headers func WithHeaders(h map[string]string) Option { - return func(c *ClientAssertion) { + return func(j *JWT) { for k, v := range h { - c.headers[k] = v + j.headers[k] = v } } } diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index cc63166..bad9eae 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -// any non-nil error from New()/Validate() will be errors.Join()ed. +// any non-nil error from NewJWT()/Validate() will be errors.Join()ed. // this is so we can assert each error within. type joinedErrs interface { Unwrap() []error @@ -27,28 +27,28 @@ func assertJoinedErrs(t *testing.T, expect []error, actual error) { require.ElementsMatch(t, expect, unwrapped) } -// TestClientAssertionBare tests what errors we expect if &ClientAssertion{} -// is instantiated directly, rather than using the constructor New(). -func TestClientAssertionBare(t *testing.T) { - ca := &ClientAssertion{} +// TestJWTBare tests what errors we expect if &JWT{} +// is instantiated directly, rather than using the constructor NewJWT(). +func TestJWTBare(t *testing.T) { + j := &JWT{} // all public methods should return the same error(s) expect := []error{ErrMissingFuncIDGenerator, ErrMissingFuncNow} - actual := ca.Validate() + actual := j.Validate() assertJoinedErrs(t, expect, actual) - tokenStr, err := ca.SignedToken() + tokenStr, err := j.SignedToken() assertJoinedErrs(t, expect, err) assert.Equal(t, "", tokenStr) } -func TestNew(t *testing.T) { +func TestNewJWT(t *testing.T) { t.Run("should run validate", func(t *testing.T) { - ca, err := New("", nil) + j, err := NewJWT("", nil) require.ErrorContains(t, err, "validation error:") - assert.Nil(t, ca) + assert.Nil(t, j) }) tCid := "test-client-id" @@ -59,13 +59,13 @@ func TestNew(t *testing.T) { cid string aud []string opts []Option - check func(*testing.T, *ClientAssertion) + check func(*testing.T, *JWT) }{ { name: "with private key", cid: tCid, aud: tAud, opts: []Option{WithRSAKey(&rsa.PrivateKey{}, "test-alg")}, - check: func(t *testing.T, ca *ClientAssertion) { + check: func(t *testing.T, ca *JWT) { require.NotNil(t, ca.key) require.Equal(t, jose.SignatureAlgorithm("test-alg"), ca.alg) }, @@ -74,7 +74,7 @@ func TestNew(t *testing.T) { name: "with client secret", cid: tCid, aud: tAud, opts: []Option{WithClientSecret("ssshhhh", "test-alg")}, - check: func(t *testing.T, ca *ClientAssertion) { + check: func(t *testing.T, ca *JWT) { require.Equal(t, "ssshhhh", ca.secret) require.Equal(t, jose.SignatureAlgorithm("test-alg"), ca.alg) }, @@ -86,7 +86,7 @@ func TestNew(t *testing.T) { WithKeyID("kid"), WithClientSecret("ssshhhh", "blah"), }, - check: func(t *testing.T, ca *ClientAssertion) { + check: func(t *testing.T, ca *JWT) { require.Equal(t, "kid", ca.headers["kid"]) }, }, @@ -97,7 +97,7 @@ func TestNew(t *testing.T) { WithHeaders(map[string]string{"h1": "v1", "h2": "v2"}), WithClientSecret("ssshhhh", "test-alg"), }, - check: func(t *testing.T, ca *ClientAssertion) { + check: func(t *testing.T, ca *JWT) { require.Equal(t, map[string]string{"h1": "v1", "h2": "v2"}, ca.headers) }, }, @@ -105,15 +105,15 @@ func TestNew(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ca, err := New(tc.cid, tc.aud, tc.opts...) + j, err := NewJWT(tc.cid, tc.aud, tc.opts...) require.NoError(t, err) - require.NotNil(t, ca) - require.Equal(t, tc.cid, ca.clientID) - require.Equal(t, tc.aud, ca.audience) + require.NotNil(t, j) + require.Equal(t, tc.cid, j.clientID) + require.Equal(t, tc.aud, j.audience) if tc.check != nil { - tc.check(t, ca) + tc.check(t, j) } }) @@ -168,16 +168,16 @@ func TestValidate(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - // New() runs Validate() - ca, err := New(tc.cid, tc.aud, tc.opts...) + // NewJWT() runs Validate() + j, err := NewJWT(tc.cid, tc.aud, tc.opts...) require.NotNil(t, err) require.ErrorContains(t, err, "validation error:") - err = errors.Unwrap(err) // New wraps the error from Validate() with fmt.Errorf("%w") + err = errors.Unwrap(err) // NewJWT wraps the error from Validate() with fmt.Errorf("%w") assertJoinedErrs(t, tc.errs, err) - require.Nil(t, ca) + require.Nil(t, j) }) } @@ -225,15 +225,15 @@ func TestSignedToken(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ca, err := New("test-client-id", []string{"test-aud"}, tc.opts...) + j, err := NewJWT("test-client-id", []string{"test-aud"}, tc.opts...) require.NoError(t, err) now := time.Now() - ca.now = func() time.Time { return now } - ca.genID = func() (string, error) { return "test-claim-id", nil } + j.now = func() time.Time { return now } + j.genID = func() (string, error) { return "test-claim-id", nil } // method under test - tokenString, err := ca.SignedToken() + tokenString, err := j.SignedToken() if tc.err != nil { require.ErrorIs(t, err, tc.err) @@ -248,7 +248,7 @@ func TestSignedToken(t *testing.T) { // check headers expectHeaders := jose.Header{ - Algorithm: string(ca.alg), + Algorithm: string(j.alg), KeyID: "test-key-id", ExtraHeaders: map[jose.HeaderKey]any{ "typ": "JWT", @@ -277,10 +277,10 @@ func TestSignedToken(t *testing.T) { t.Run("error generating token id", func(t *testing.T) { genIDErr := errors.New("failed to generate test id") - ca, err := New("a", []string{"a"}, WithClientSecret("ssshhhh", "HS256")) + j, err := NewJWT("a", []string{"a"}, WithClientSecret("ssshhhh", "HS256")) require.NoError(t, err) - ca.genID = func() (string, error) { return "", genIDErr } - tokenString, err := ca.SignedToken() + j.genID = func() (string, error) { return "", genIDErr } + tokenString, err := j.SignedToken() require.ErrorIs(t, err, genIDErr) require.Equal(t, "", tokenString) }) diff --git a/oidc/clientassertion/doc.go b/oidc/clientassertion/doc.go index f01b0a5..a28fc65 100644 --- a/oidc/clientassertion/doc.go +++ b/oidc/clientassertion/doc.go @@ -6,7 +6,7 @@ package clientassertion // // Example usage: // -// cass, err := clientassertion.New("client-id", []string{"audience"}, +// cass, err := clientassertion.NewJWT("client-id", []string{"audience"}, // clientassertion.WithRSAKey(rsaPrivateKey, "RS256"), // clientassertion.WithKeyID("jwks-key-id-or-x5t-etc"), // ) From b40fd48f05bbe83bf1c529bfa674d8ae9b0ffd3c Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 25 Feb 2025 16:46:46 -0500 Subject: [PATCH 06/27] move options to options.go --- oidc/clientassertion/client_assertion.go | 37 -------------------- oidc/clientassertion/options.go | 43 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 37 deletions(-) create mode 100644 oidc/clientassertion/options.go diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 269d8e8..c1fc524 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -1,7 +1,6 @@ package clientassertion import ( - "crypto/rsa" "errors" "fmt" "time" @@ -168,39 +167,3 @@ func (j *JWT) claims(id string) *jwt.Claims { ID: id, } } - -// Options configure the JWT -type Option func(*JWT) - -// WithClientSecret sets a secret and algorithm to sign the JWT with -func WithClientSecret(secret string, alg string) Option { - return func(j *JWT) { - j.secret = secret - j.alg = jose.SignatureAlgorithm(alg) - } -} - -// WithRSAKey sets a private key to sign the JWT with -func WithRSAKey(key *rsa.PrivateKey, alg string) Option { - return func(j *JWT) { - j.key = key - j.alg = jose.SignatureAlgorithm(alg) - } -} - -// WithKeyID sets the "kid" header that OIDC providers use to look up the -// public key to check the signed JWT -func WithKeyID(keyID string) Option { - return func(j *JWT) { - j.headers["kid"] = keyID - } -} - -// WithHeaders sets extra JWT headers -func WithHeaders(h map[string]string) Option { - return func(j *JWT) { - for k, v := range h { - j.headers[k] = v - } - } -} diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go new file mode 100644 index 0000000..ce2b34e --- /dev/null +++ b/oidc/clientassertion/options.go @@ -0,0 +1,43 @@ +package clientassertion + +import ( + "crypto/rsa" + + "github.com/go-jose/go-jose/v3" +) + +// Option configures the JWT +type Option func(*JWT) + +// WithClientSecret sets a secret and algorithm to sign the JWT with +func WithClientSecret(secret string, alg string) Option { + return func(j *JWT) { + j.secret = secret + j.alg = jose.SignatureAlgorithm(alg) + } +} + +// WithRSAKey sets a private key to sign the JWT with +func WithRSAKey(key *rsa.PrivateKey, alg string) Option { + return func(j *JWT) { + j.key = key + j.alg = jose.SignatureAlgorithm(alg) + } +} + +// WithKeyID sets the "kid" header that OIDC providers use to look up the +// public key to check the signed JWT +func WithKeyID(keyID string) Option { + return func(j *JWT) { + j.headers["kid"] = keyID + } +} + +// WithHeaders sets extra JWT headers +func WithHeaders(h map[string]string) Option { + return func(j *JWT) { + for k, v := range h { + j.headers[k] = v + } + } +} From 4dfbc63078dccea033f91111302732892dd681f8 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 25 Feb 2025 16:55:14 -0500 Subject: [PATCH 07/27] go-jose/v3->v4 v4 is more sensitive to HMAC length --- go.mod | 2 +- oidc/clientassertion/client_assertion.go | 6 +++--- oidc/clientassertion/client_assertion_test.go | 21 ++++++++++--------- oidc/clientassertion/options.go | 2 +- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 2c30530..a3e2a3e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/coreos/go-oidc/v3 v3.11.0 github.com/go-jose/go-jose/v3 v3.0.3 + github.com/go-jose/go-jose/v4 v4.0.4 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-hclog v1.6.3 github.com/hashicorp/go-multierror v1.1.1 @@ -20,7 +21,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.17.0 // indirect - github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index c1fc524..affdef2 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -5,8 +5,8 @@ import ( "fmt" "time" - "github.com/go-jose/go-jose/v3" - "github.com/go-jose/go-jose/v3/jwt" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "github.com/hashicorp/go-uuid" ) @@ -107,7 +107,7 @@ func (j *JWT) SignedToken() (string, error) { if err != nil { return "", err } - token, err := builder.CompactSerialize() + token, err := builder.Serialize() if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index bad9eae..7b3013a 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/go-jose/go-jose/v3" - "github.com/go-jose/go-jose/v3/jwt" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -188,6 +188,7 @@ func TestSignedToken(t *testing.T) { require.NoError(t, err) pub, ok := key.Public().(*rsa.PublicKey) require.True(t, ok, "couldn't get rsa.PublicKey from PrivateKey") + validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 cases := []struct { name string @@ -197,9 +198,9 @@ func TestSignedToken(t *testing.T) { }{ { name: "valid secret", - claimKey: []byte("ssshhhh"), + claimKey: []byte(validSecret), opts: []Option{ - WithClientSecret("ssshhhh", "HS256"), + WithClientSecret(validSecret, "HS256"), WithKeyID("test-key-id"), WithHeaders(map[string]string{"xtra": "headies"}), }, @@ -243,7 +244,7 @@ func TestSignedToken(t *testing.T) { require.NoError(t, err) // extract the token from the signed string - token, err := jwt.ParseSigned(tokenString) + token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{j.alg}) require.NoError(t, err) // check headers @@ -261,11 +262,11 @@ func TestSignedToken(t *testing.T) { // check claims expectClaims := jwt.Expected{ - Issuer: "test-client-id", - Subject: "test-client-id", - Audience: []string{"test-aud"}, - ID: "test-claim-id", - Time: now, + Issuer: "test-client-id", + Subject: "test-client-id", + AnyAudience: []string{"test-aud"}, + ID: "test-claim-id", + Time: now, } var actualClaims jwt.Claims err = token.Claims(tc.claimKey, &actualClaims) diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index ce2b34e..6276ca3 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -3,7 +3,7 @@ package clientassertion import ( "crypto/rsa" - "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v4" ) // Option configures the JWT From d8b89d1fb7f40e6b9a65486b15303fef4add4a16 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 25 Feb 2025 19:38:31 -0500 Subject: [PATCH 08/27] more and different option validation --- oidc/clientassertion/algorithms.go | 55 ++++++++++++ oidc/clientassertion/client_assertion.go | 10 ++- oidc/clientassertion/client_assertion_test.go | 84 +++++++++++++------ oidc/clientassertion/options.go | 36 ++++++-- 4 files changed, 151 insertions(+), 34 deletions(-) create mode 100644 oidc/clientassertion/algorithms.go diff --git a/oidc/clientassertion/algorithms.go b/oidc/clientassertion/algorithms.go new file mode 100644 index 0000000..c70c29e --- /dev/null +++ b/oidc/clientassertion/algorithms.go @@ -0,0 +1,55 @@ +package clientassertion + +import ( + "crypto/rsa" + "errors" + "fmt" +) + +type ( + SignatureAlgorithm string + HSAlgorithm SignatureAlgorithm + RSAlgorithm SignatureAlgorithm +) + +const ( + HS256 HSAlgorithm = "HS256" + HS384 HSAlgorithm = "HS384" + HS512 HSAlgorithm = "HS512" + RS256 RSAlgorithm = "RS256" + RS384 RSAlgorithm = "RS384" + RS512 RSAlgorithm = "RS512" +) + +var ( + ErrUnsupportedAlgorithm = errors.New("unsupported algorithm") + ErrInvalidSecretLength = errors.New("invalid secret length for algorithm") +) + +func (a HSAlgorithm) Validate(secret string) error { + // verify secret length based on alg + var expectLen int + switch a { + case HS256: + expectLen = 32 + case HS384: + expectLen = 48 + case HS512: + expectLen = 64 + default: + return fmt.Errorf("%w %q for client secret", ErrUnsupportedAlgorithm, a) + } + if len(secret) < expectLen { + return fmt.Errorf("%w: %q must be %d bytes long", ErrInvalidSecretLength, a, expectLen) + } + return nil +} + +func (a RSAlgorithm) Validate(key *rsa.PrivateKey) error { + switch a { + case RS256, RS384, RS512: + return key.Validate() + default: + return fmt.Errorf("%w %q for for RSA key", ErrUnsupportedAlgorithm, a) + } +} diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index affdef2..462ada1 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -38,9 +38,17 @@ func NewJWT(clientID string, audience []string, opts ...Option) (*JWT, error) { genID: uuid.GenerateUUID, now: time.Now, } + + var errs []error for _, opt := range opts { - opt(j) + if err := opt(j); err != nil { + errs = append(errs, err) + } } + if len(errs) > 0 { + return nil, errors.Join(errs...) + } + if err := j.Validate(); err != nil { return nil, fmt.Errorf("new client assertion validation error: %w", err) } diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index 7b3013a..2e281f4 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -53,6 +53,9 @@ func TestNewJWT(t *testing.T) { tCid := "test-client-id" tAud := []string{"test-audience"} + validKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 cases := []struct { name string @@ -60,23 +63,24 @@ func TestNewJWT(t *testing.T) { aud []string opts []Option check func(*testing.T, *JWT) + err string }{ { name: "with private key", cid: tCid, aud: tAud, - opts: []Option{WithRSAKey(&rsa.PrivateKey{}, "test-alg")}, + opts: []Option{WithRSAKey(validKey, RS256)}, check: func(t *testing.T, ca *JWT) { require.NotNil(t, ca.key) - require.Equal(t, jose.SignatureAlgorithm("test-alg"), ca.alg) + require.Equal(t, jose.SignatureAlgorithm("RS256"), ca.alg) }, }, { name: "with client secret", cid: tCid, aud: tAud, - opts: []Option{WithClientSecret("ssshhhh", "test-alg")}, + opts: []Option{WithClientSecret(validSecret, HS256)}, check: func(t *testing.T, ca *JWT) { - require.Equal(t, "ssshhhh", ca.secret) - require.Equal(t, jose.SignatureAlgorithm("test-alg"), ca.alg) + require.Equal(t, validSecret, ca.secret) + require.Equal(t, jose.SignatureAlgorithm(HS256), ca.alg) }, }, { @@ -84,7 +88,7 @@ func TestNewJWT(t *testing.T) { cid: tCid, aud: tAud, opts: []Option{ WithKeyID("kid"), - WithClientSecret("ssshhhh", "blah"), + WithClientSecret(validSecret, HS256), }, check: func(t *testing.T, ca *JWT) { require.Equal(t, "kid", ca.headers["kid"]) @@ -95,22 +99,59 @@ func TestNewJWT(t *testing.T) { cid: tCid, aud: tAud, opts: []Option{ WithHeaders(map[string]string{"h1": "v1", "h2": "v2"}), - WithClientSecret("ssshhhh", "test-alg"), + WithClientSecret(validSecret, HS256), }, check: func(t *testing.T, ca *JWT) { require.Equal(t, map[string]string{"h1": "v1", "h2": "v2"}, ca.headers) }, }, + { + name: "invalid alg for secret", + cid: tCid, aud: tAud, + opts: []Option{ + WithClientSecret(validSecret, "ruh-roh"), + }, + err: ErrUnsupportedAlgorithm.Error(), + }, + { + name: "invalid alg for key", + cid: tCid, aud: tAud, + opts: []Option{ + WithRSAKey(validKey, "ruh-roh"), + }, + err: ErrUnsupportedAlgorithm.Error(), + }, + { + name: "invalid client secret", + cid: tCid, aud: tAud, + opts: []Option{ + WithClientSecret("invalid secret", HS256), + }, + err: ErrInvalidSecretLength.Error(), + }, + { + name: "invalid key", + cid: tCid, aud: tAud, + opts: []Option{ + WithRSAKey(&rsa.PrivateKey{}, RS256), + }, + err: "crypto/rsa: missing public modulus", + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { j, err := NewJWT(tc.cid, tc.aud, tc.opts...) - require.NoError(t, err) - require.NotNil(t, j) - require.Equal(t, tc.cid, j.clientID) - require.Equal(t, tc.aud, j.audience) + if tc.err == "" { + require.NoError(t, err) + require.NotNil(t, j) + require.Equal(t, tc.cid, j.clientID) + require.Equal(t, tc.aud, j.audience) + } else { + require.Error(t, err) + require.ErrorContains(t, err, tc.err) + } if tc.check != nil { tc.check(t, j) @@ -123,6 +164,9 @@ func TestNewJWT(t *testing.T) { func TestValidate(t *testing.T) { tCid := "test-client-id" tAud := []string{"test-audience"} + validKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 cases := []struct { name string cid string @@ -139,7 +183,7 @@ func TestValidate(t *testing.T) { aud: tAud, errs: []error{ErrMissingClientID}, opts: []Option{ - WithRSAKey(&rsa.PrivateKey{}, "algo"), + WithRSAKey(validKey, RS256), }, }, { @@ -147,7 +191,7 @@ func TestValidate(t *testing.T) { cid: tCid, errs: []error{ErrMissingAudience}, opts: []Option{ - WithRSAKey(&rsa.PrivateKey{}, "algo"), + WithRSAKey(validKey, RS256), }, }, { @@ -159,8 +203,8 @@ func TestValidate(t *testing.T) { name: "both client and secret", cid: tCid, aud: tAud, opts: []Option{ - WithRSAKey(&rsa.PrivateKey{}, "algo"), - WithClientSecret("ssshhhh", "algo"), + WithRSAKey(validKey, RS256), + WithClientSecret(validSecret, HS256), }, errs: []error{ErrBothKeyAndSecret}, }, @@ -214,14 +258,6 @@ func TestSignedToken(t *testing.T) { WithHeaders(map[string]string{"xtra": "headies"}), }, }, - { - name: "invalid alg", - claimKey: pub, - opts: []Option{ - WithRSAKey(key, "ruh-roh"), - }, - err: jose.ErrUnsupportedAlgorithm, - }, } for _, tc := range cases { @@ -278,7 +314,7 @@ func TestSignedToken(t *testing.T) { t.Run("error generating token id", func(t *testing.T) { genIDErr := errors.New("failed to generate test id") - j, err := NewJWT("a", []string{"a"}, WithClientSecret("ssshhhh", "HS256")) + j, err := NewJWT("a", []string{"a"}, WithClientSecret(validSecret, HS256)) require.NoError(t, err) j.genID = func() (string, error) { return "", genIDErr } tokenString, err := j.SignedToken() diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 6276ca3..e57c674 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -7,37 +7,55 @@ import ( ) // Option configures the JWT -type Option func(*JWT) +type Option func(*JWT) error -// WithClientSecret sets a secret and algorithm to sign the JWT with -func WithClientSecret(secret string, alg string) Option { - return func(j *JWT) { +// WithClientSecret sets a secret and algorithm to sign the JWT with. +// alg must be one of: +// * HS256 with a >= 32 byte secret +// * HS384 with a >= 48 byte secret +// * HS512 with a >= 64 byte secret +func WithClientSecret(secret string, alg HSAlgorithm) Option { + return func(j *JWT) error { + if err := alg.Validate(secret); err != nil { + return err + } j.secret = secret j.alg = jose.SignatureAlgorithm(alg) + return nil } } -// WithRSAKey sets a private key to sign the JWT with -func WithRSAKey(key *rsa.PrivateKey, alg string) Option { - return func(j *JWT) { +// WithRSAKey sets a private key to sign the JWT with. +// alg must be one of: +// * RS256 +// * RS384 +// * RS512 +func WithRSAKey(key *rsa.PrivateKey, alg RSAlgorithm) Option { + return func(j *JWT) error { + if err := alg.Validate(key); err != nil { + return err + } j.key = key j.alg = jose.SignatureAlgorithm(alg) + return nil } } // WithKeyID sets the "kid" header that OIDC providers use to look up the // public key to check the signed JWT func WithKeyID(keyID string) Option { - return func(j *JWT) { + return func(j *JWT) error { j.headers["kid"] = keyID + return nil } } // WithHeaders sets extra JWT headers func WithHeaders(h map[string]string) Option { - return func(j *JWT) { + return func(j *JWT) error { for k, v := range h { j.headers[k] = v } + return nil } } From a71134d9908af2656a95b75e18b68a10b657451b Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Tue, 25 Feb 2025 20:28:10 -0500 Subject: [PATCH 09/27] WithClientAssertionJWT accepts a Serializer instead of a string, so it's harder to use incorrectly. JWT implements Serializer, but folks may provider their own. --- oidc/clientassertion/client_assertion.go | 32 +++++++++++++++---- oidc/clientassertion/client_assertion_test.go | 12 +++---- oidc/provider.go | 10 ++++-- oidc/provider_test.go | 19 +++++++++-- oidc/request.go | 24 ++++++++------ oidc/request_test.go | 9 +++--- 6 files changed, 74 insertions(+), 32 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 462ada1..e780293 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -27,6 +27,7 @@ var ( // or there's a bug somewhere. ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use NewJWT()") ErrMissingFuncNow = errors.New("missing now func; please use NewJWT()") + ErrCreatingSigner = errors.New("error creating jwt signer") ) // NewJWT sets up a new JWT to sign with a private key or client secret @@ -103,14 +104,22 @@ func (j *JWT) Validate() error { if j.key != nil && j.secret != "" { errs = append(errs, ErrBothKeyAndSecret) } - return errors.Join(errs...) -} + // if any of those fail, we have no hope. + if len(errs) > 0 { + return errors.Join(errs...) + } -// SignedToken returns a signed JWT in the compact serialization format -func (j *JWT) SignedToken() (string, error) { - if err := j.Validate(); err != nil { - return "", err + // finally, make sure Serialize() works; we can't pre-validate everything, + // and this whole thing is useless if it can't Serialize() + if _, err := j.Serialize(); err != nil { + return fmt.Errorf("serialization error during validate: %w", err) } + + return nil +} + +// Serialize returns a signed JWT string +func (j *JWT) Serialize() (string, error) { builder, err := j.builder() if err != nil { return "", err @@ -158,7 +167,7 @@ func (j *JWT) signer() (jose.Signer, error) { signer, err := jose.NewSigner(sKey, sOpts.WithType("JWT")) if err != nil { - return nil, fmt.Errorf("failed to create jwt signer: %w", err) + return nil, fmt.Errorf("%w: %w", ErrCreatingSigner, err) } return signer, nil } @@ -175,3 +184,12 @@ func (j *JWT) claims(id string) *jwt.Claims { ID: id, } } + +// Serializer is the primary interface impelmented by JWT. +type Serializer interface { + Serialize() (string, error) +} + +// ensure JWT implements Serializer, which is accepted by the oidc option +// oidc.WithClientAssertionJWT. +var _ Serializer = &JWT{} diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index 2e281f4..5af21e5 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -22,7 +22,7 @@ type joinedErrs interface { func assertJoinedErrs(t *testing.T, expect []error, actual error) { t.Helper() joined, ok := actual.(joinedErrs) // Validate() error is errors.Join()ed - require.True(t, ok, "expected Join()ed errors from Validate()") + require.True(t, ok, "expected Join()ed errors from Validate(); got: %v", actual) unwrapped := joined.Unwrap() require.ElementsMatch(t, expect, unwrapped) } @@ -32,14 +32,12 @@ func assertJoinedErrs(t *testing.T, expect []error, actual error) { func TestJWTBare(t *testing.T) { j := &JWT{} - // all public methods should return the same error(s) expect := []error{ErrMissingFuncIDGenerator, ErrMissingFuncNow} - actual := j.Validate() assertJoinedErrs(t, expect, actual) - tokenStr, err := j.SignedToken() - assertJoinedErrs(t, expect, err) + tokenStr, err := j.Serialize() + require.ErrorIs(t, err, ErrCreatingSigner) assert.Equal(t, "", tokenStr) } @@ -270,7 +268,7 @@ func TestSignedToken(t *testing.T) { j.genID = func() (string, error) { return "test-claim-id", nil } // method under test - tokenString, err := j.SignedToken() + tokenString, err := j.Serialize() if tc.err != nil { require.ErrorIs(t, err, tc.err) @@ -317,7 +315,7 @@ func TestSignedToken(t *testing.T) { j, err := NewJWT("a", []string{"a"}, WithClientSecret(validSecret, HS256)) require.NoError(t, err) j.genID = func() (string, error) { return "", genIDErr } - tokenString, err := j.SignedToken() + tokenString, err := j.Serialize() require.ErrorIs(t, err, genIDErr) require.Equal(t, "", tokenString) }) diff --git a/oidc/provider.go b/oidc/provider.go index df78d14..3b999f3 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -307,10 +307,16 @@ func (p *Provider) Exchange(ctx context.Context, oidcRequest Request, authorizat if oidcRequest.PKCEVerifier() != nil { authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("code_verifier", oidcRequest.PKCEVerifier().Verifier())) } - if oidcRequest.ClientAssertionJWT() != "" { + if oidcRequest.ClientAssertionJWT() != nil { + // by now, sufficient validation should have already occurred to prevent + // errors here, but err check again just in case. + token, err := oidcRequest.ClientAssertionJWT().Serialize() + if err != nil { + return nil, err + } authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("client_assertion_type", cass.ClientAssertionJWTType), - oauth2.SetAuthURLParam("client_assertion", oidcRequest.ClientAssertionJWT()), + oauth2.SetAuthURLParam("client_assertion", token), ) } oauth2Token, err := oauth2Config.Exchange(oidcCtx, authorizationCode, authCodeOpts...) diff --git a/oidc/provider_test.go b/oidc/provider_test.go index 7420500..546f2a2 100644 --- a/oidc/provider_test.go +++ b/oidc/provider_test.go @@ -501,11 +501,11 @@ func TestProvider_Exchange(t *testing.T) { ) require.NoError(t, err) - cassJWT := "test-client-assertion-jwt" + testJWT := &mockSerializer{s: "test-client-assertion-jwt"} reqWithClientAssertion, err := NewRequest( 1*time.Minute, redirect, - WithClientAssertionJWT(cassJWT), + WithClientAssertionJWT(testJWT), ) require.NoError(t, err) @@ -609,7 +609,9 @@ func TestProvider_Exchange(t *testing.T) { if tt.args.r.PKCEVerifier() != nil { tp.SetPKCEVerifier(tt.args.r.PKCEVerifier()) } - if jot := tt.args.r.ClientAssertionJWT(); jot != "" { + if j := tt.args.r.ClientAssertionJWT(); j != nil { + jot, err := j.Serialize() + require.NoError(err) tp.SetClientAssertionJWT(jot) } } @@ -1753,3 +1755,14 @@ func TestProvider_DiscoveryInfo(t *testing.T) { }) } } + +var _ JWTSerializer = &mockSerializer{} + +type mockSerializer struct { + s string + err error +} + +func (ms *mockSerializer) Serialize() (string, error) { + return ms.s, ms.err +} diff --git a/oidc/request.go b/oidc/request.go index f5dc632..f191997 100644 --- a/oidc/request.go +++ b/oidc/request.go @@ -83,11 +83,11 @@ type Request interface { // See: https://tools.ietf.org/html/rfc7636 PKCEVerifier() CodeVerifier - // ClientAssertionJWT optionally specifies a signed JWT to be used in a + // ClientAssertionJWT optionally specifies a JWT Serializer to be used in a // client_assertion token request. // // See: https://oauth.net/private-key-jwt/ - ClientAssertionJWT() string + ClientAssertionJWT() JWTSerializer // MaxAge: when authAfter is not a zero value (authTime.IsZero()) then the // id_token's auth_time claim must be after the specified time. @@ -181,7 +181,8 @@ type Req struct { // with PKCE. It suppies the required CodeVerifier for PKCE. withVerifier CodeVerifier - withClientJWT string + // withClientJWT optionally sends a JWT as a client_assertion. + withClientJWT JWTSerializer // withMaxAge: when withMaxAge.authAfter is not a zero value // (authTime.IsZero()) then the id_token's auth_time claim must be after the @@ -332,8 +333,8 @@ func (r *Req) PKCEVerifier() CodeVerifier { } // ClientAssertionJWT implements the Request.ClientAssertionJWT() interface -// function and returns the JWT string. -func (r *Req) ClientAssertionJWT() string { +// function and returns a JWTSerializer. +func (r *Req) ClientAssertionJWT() JWTSerializer { return r.withClientJWT } @@ -458,7 +459,7 @@ type reqOptions struct { withACRValues []string withState string withNonce string - withClientJWT string + withClientJWT JWTSerializer } // reqDefaults is a handy way to get the defaults at runtime and during unit @@ -519,13 +520,18 @@ func WithPKCE(v CodeVerifier) Option { } } -// WithClientAssertionJWT provides an option to send a signed JWT as a -// client_assertion. +// JWTSerializer's Serialize method returns a signed JWT or an error. +// clientassertion.JWT is a useful implementation of this interface. +type JWTSerializer interface { + Serialize() (string, error) +} + +// WithClientAssertionJWT will send a JWT as a client_assertion. // // Option is valid for: Request // // See: https://oauth.net/private-key-jwt/ -func WithClientAssertionJWT(jwt string) Option { +func WithClientAssertionJWT(jwt JWTSerializer) Option { return func(o interface{}) { switch v := o.(type) { case *reqOptions: diff --git a/oidc/request_test.go b/oidc/request_test.go index 83ebb99..0d3efa3 100644 --- a/oidc/request_test.go +++ b/oidc/request_test.go @@ -26,6 +26,7 @@ func TestNewRequest(t *testing.T) { testVerifier, err := NewCodeVerifier() require.NoError(t, err) + testJWT := &mockSerializer{s: "test-client-assertion-jwt"} tests := []struct { name string @@ -37,7 +38,7 @@ func TestNewRequest(t *testing.T) { wantAudiences []string wantScopes []string wantVerifier CodeVerifier - wantClientAssertion string + wantClientAssertion JWTSerializer wantErr bool wantIsErr error }{ @@ -50,14 +51,14 @@ func TestNewRequest(t *testing.T) { WithAudiences("bob", "alice"), WithScopes("email", "profile"), WithPKCE(testVerifier), - WithClientAssertionJWT("test-client-assertion-jwt"), + WithClientAssertionJWT(testJWT), }, wantNowFunc: testNow, wantRedirectURL: "https://bob.com", wantAudiences: []string{"bob", "alice"}, wantScopes: []string{oidc.ScopeOpenID, "email", "profile"}, wantVerifier: testVerifier, - wantClientAssertion: "test-client-assertion-jwt", + wantClientAssertion: testJWT, }, { name: "valid-no-opt", @@ -164,7 +165,7 @@ func Test_WithClientAssertionJWT(t *testing.T) { assert.Equal(opts, testOpts) assert.Empty(testOpts.withClientJWT) - j := "test-jwt" + j := &mockSerializer{s: "test-jwt"} opts = getReqOpts(WithClientAssertionJWT(j)) testOpts = reqDefaults() testOpts.withClientJWT = j From 25a0e5c33e7ed90b5bfc493113465805e81e4a99 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 00:39:22 -0500 Subject: [PATCH 10/27] protect kid from headers --- oidc/clientassertion/options.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index e57c674..210382f 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -2,6 +2,7 @@ package clientassertion import ( "crypto/rsa" + "errors" "github.com/go-jose/go-jose/v4" ) @@ -54,6 +55,11 @@ func WithKeyID(keyID string) Option { func WithHeaders(h map[string]string) Option { return func(j *JWT) error { for k, v := range h { + // disallow potential confusion arising from the "kid" header + // being set both by this and WithKeyID() + if k == "kid" { + return errors.New(`"kid" header not allowed in WithHeaders; use WithKeyID instead`) + } j.headers[k] = v } return nil From 8f9aa91bb42e49f4f298cbf80a61201e771c98f0 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 13:18:23 -0500 Subject: [PATCH 11/27] add ExampleJWT test --- oidc/clientassertion/example_test.go | 99 ++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 oidc/clientassertion/example_test.go diff --git a/oidc/clientassertion/example_test.go b/oidc/clientassertion/example_test.go new file mode 100644 index 0000000..fbfd8cc --- /dev/null +++ b/oidc/clientassertion/example_test.go @@ -0,0 +1,99 @@ +package clientassertion + +import ( + "crypto/rand" + "crypto/rsa" + "fmt" + "log" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" +) + +func ExampleJWT() { + // With an HMAC client secret + secret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 + j, err := NewJWT("client-id", []string{"audience"}, + WithClientSecret(secret, HS256), + ) + if err != nil { + log.Fatal(err) + } + signed, err := j.Serialize() + if err != nil { + log.Fatal(err) + } + + { + // decode and inspect the JWT -- this is the IDP's job, + // but it illustrates the example. + token, err := jwt.ParseSigned(signed, []jose.SignatureAlgorithm{"HS256"}) + if err != nil { + log.Fatal(err) + } + headers := token.Headers[0] + fmt.Printf("ClientSecret\n Headers - Algorithm: %s; typ: %s\n", + headers.Algorithm, headers.ExtraHeaders["typ"]) + var claim jwt.Claims + err = token.Claims([]byte(secret), &claim) + if err != nil { + log.Fatal(err) + } + fmt.Printf(" Claims - Issuer: %s; Subject: %s; Audience: %v\n", + claim.Issuer, claim.Subject, claim.Audience) + } + + // With an RSA key + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatal(err) + } + pubKey, ok := privKey.Public().(*rsa.PublicKey) + if !ok { + log.Fatal("couldn't get rsa.PublicKey from PrivateKey") + } + j, err = NewJWT("client-id", []string{"audience"}, + WithRSAKey(privKey, RS256), + // note: for some providers, they key ID may be an x5t derivation + // of a cert generated from the private key. + // if your key has an associated JWKS endpoint, it will be the "kid" + // for the public key at /.well-known/jwks.json + WithKeyID("some-key-id"), + // extra headers, like x5t, are optional + WithHeaders(map[string]string{ + "x5t": "should-be-derived-from-a-cert", + }), + ) + if err != nil { + log.Fatal(err) + } + signed, err = j.Serialize() + if err != nil { + log.Fatal(err) + } + + { // decode and inspect the JWT -- this is the IDP's job + token, err := jwt.ParseSigned(signed, []jose.SignatureAlgorithm{"RS256"}) + if err != nil { + log.Fatal(err) + } + h := token.Headers[0] + fmt.Printf("PrivateKey\n Headers - KeyID: %s; Algorithm: %s; typ: %s; x5t: %s\n", + h.KeyID, h.Algorithm, h.ExtraHeaders["typ"], h.ExtraHeaders["x5t"]) + var claim jwt.Claims + err = token.Claims(pubKey, &claim) + if err != nil { + log.Fatal(err) + } + fmt.Printf(" Claims - Issuer: %s; Subject: %s; Audience: %v\n", + claim.Issuer, claim.Subject, claim.Audience) + } + + // Output: + // ClientSecret + // Headers - Algorithm: HS256; typ: JWT + // Claims - Issuer: client-id; Subject: client-id; Audience: [audience] + // PrivateKey + // Headers - KeyID: some-key-id; Algorithm: RS256; typ: JWT; x5t: should-be-derived-from-a-cert + // Claims - Issuer: client-id; Subject: client-id; Audience: [audience] +} From 1d1813c4d366ec36591d48394374f2a4f6d94460 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 13:33:00 -0500 Subject: [PATCH 12/27] brief package docstring, rename a const --- oidc/clientassertion/client_assertion.go | 7 +++++-- oidc/clientassertion/doc.go | 12 ------------ oidc/provider.go | 2 +- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index e780293..01f4fa7 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -1,3 +1,6 @@ +// Package clientassertion signs JWTs with a Private Key or Client Secret +// for use in OIDC client_assertion requests, A.K.A. private_key_jwt. +// reference: https://oauth.net/private-key-jwt/ package clientassertion import ( @@ -11,9 +14,9 @@ import ( ) const ( - // ClientAssertionJWTType is the proper value for client_assertion_type. + // JWTTypeParam is the proper value for client_assertion_type. // https://www.rfc-editor.org/rfc/rfc7523.html#section-2.2 - ClientAssertionJWTType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + JWTTypeParam = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ) var ( diff --git a/oidc/clientassertion/doc.go b/oidc/clientassertion/doc.go index a28fc65..0217da9 100644 --- a/oidc/clientassertion/doc.go +++ b/oidc/clientassertion/doc.go @@ -1,13 +1 @@ package clientassertion - -// clientassertion signs JWTs with a Private Key or Client Secret for use -// in OIDC client_assertion requests, A.K.A. private_key_jwt. reference: -// https://oauth.net/private-key-jwt/ -// -// Example usage: -// -// cass, err := clientassertion.NewJWT("client-id", []string{"audience"}, -// clientassertion.WithRSAKey(rsaPrivateKey, "RS256"), -// clientassertion.WithKeyID("jwks-key-id-or-x5t-etc"), -// ) -// jwtString, err := cass.SignedToken() diff --git a/oidc/provider.go b/oidc/provider.go index 3b999f3..81bbac5 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -315,7 +315,7 @@ func (p *Provider) Exchange(ctx context.Context, oidcRequest Request, authorizat return nil, err } authCodeOpts = append(authCodeOpts, - oauth2.SetAuthURLParam("client_assertion_type", cass.ClientAssertionJWTType), + oauth2.SetAuthURLParam("client_assertion_type", cass.JWTTypeParam), oauth2.SetAuthURLParam("client_assertion", token), ) } From cf24de8c52dd6015e0c110fdaf6c6f44656f2bba Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 13:34:34 -0500 Subject: [PATCH 13/27] add copyright headers --- oidc/clientassertion/algorithms.go | 3 +++ oidc/clientassertion/client_assertion.go | 3 +++ oidc/clientassertion/client_assertion_test.go | 3 +++ oidc/clientassertion/doc.go | 1 - oidc/clientassertion/example_test.go | 3 +++ oidc/clientassertion/options.go | 3 +++ 6 files changed, 15 insertions(+), 1 deletion(-) delete mode 100644 oidc/clientassertion/doc.go diff --git a/oidc/clientassertion/algorithms.go b/oidc/clientassertion/algorithms.go index c70c29e..a172f38 100644 --- a/oidc/clientassertion/algorithms.go +++ b/oidc/clientassertion/algorithms.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package clientassertion import ( diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 01f4fa7..cd2c94e 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + // Package clientassertion signs JWTs with a Private Key or Client Secret // for use in OIDC client_assertion requests, A.K.A. private_key_jwt. // reference: https://oauth.net/private-key-jwt/ diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index 5af21e5..bd5e3fb 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package clientassertion import ( diff --git a/oidc/clientassertion/doc.go b/oidc/clientassertion/doc.go deleted file mode 100644 index 0217da9..0000000 --- a/oidc/clientassertion/doc.go +++ /dev/null @@ -1 +0,0 @@ -package clientassertion diff --git a/oidc/clientassertion/example_test.go b/oidc/clientassertion/example_test.go index fbfd8cc..df9db1c 100644 --- a/oidc/clientassertion/example_test.go +++ b/oidc/clientassertion/example_test.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package clientassertion import ( diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 210382f..4a52198 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package clientassertion import ( From 9f0b3938c25ee561d8920b1567b1b293cdbd9a82 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 13:36:40 -0500 Subject: [PATCH 14/27] missed a rename --- oidc/testing_provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oidc/testing_provider.go b/oidc/testing_provider.go index 48fe2be..3125eb8 100644 --- a/oidc/testing_provider.go +++ b/oidc/testing_provider.go @@ -1382,7 +1382,7 @@ func (p *TestProvider) ServeHTTP(w http.ResponseWriter, req *http.Request) { if t := req.FormValue("client_assertion_type"); t != "" { // assume client_assertion_type, if set, is always the magic jwt string - if t != cass.ClientAssertionJWTType { + if t != cass.JWTTypeParam { _ = p.writeTokenErrorResponse(w, http.StatusBadRequest, "invalid_client", "unknown client assertion type") return } From 95b5c034c61cdf1bb372fc04bc17ea6e4114ffde Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 15:54:05 -0500 Subject: [PATCH 15/27] test the new method on test provider --- oidc/testing_provider_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/oidc/testing_provider_test.go b/oidc/testing_provider_test.go index 40bab95..aa06204 100644 --- a/oidc/testing_provider_test.go +++ b/oidc/testing_provider_test.go @@ -419,6 +419,14 @@ func TestTestProvider_SetPKCEVerifier(t *testing.T) { }) } +func TestTestProvider_SetClientAssertionJWT(t *testing.T) { + t.Run("simple", func(t *testing.T) { + tp := StartTestProvider(t) + tp.SetClientAssertionJWT("expected") + assert.Equal(t, "expected", tp.clientAssertionJWT) + }) +} + func TestTestProvider_SetUserInfoReply(t *testing.T) { t.Run("simple", func(t *testing.T) { assert := assert.New(t) From c06d55938e75ec03c7291c668f358c5e34c3e54c Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 16:45:14 -0500 Subject: [PATCH 16/27] godocs and ops everywhere! --- oidc/clientassertion/algorithms.go | 52 +++++++++++++++-------- oidc/clientassertion/client_assertion.go | 54 ++++++++++++------------ oidc/clientassertion/options.go | 13 +++--- 3 files changed, 69 insertions(+), 50 deletions(-) diff --git a/oidc/clientassertion/algorithms.go b/oidc/clientassertion/algorithms.go index a172f38..dc256ac 100644 --- a/oidc/clientassertion/algorithms.go +++ b/oidc/clientassertion/algorithms.go @@ -5,31 +5,38 @@ package clientassertion import ( "crypto/rsa" - "errors" "fmt" ) type ( - SignatureAlgorithm string - HSAlgorithm SignatureAlgorithm - RSAlgorithm SignatureAlgorithm + // HSAlgorithm is an HMAC signature algorithm + HSAlgorithm string + // RSAlgorithm is an RSA signature algorithm + RSAlgorithm string ) const ( - HS256 HSAlgorithm = "HS256" - HS384 HSAlgorithm = "HS384" - HS512 HSAlgorithm = "HS512" - RS256 RSAlgorithm = "RS256" - RS384 RSAlgorithm = "RS384" - RS512 RSAlgorithm = "RS512" -) + // JOSE asymmetric signing algorithm values as defined by RFC 7518. + // See: https://tools.ietf.org/html/rfc7518#section-3.1 -var ( - ErrUnsupportedAlgorithm = errors.New("unsupported algorithm") - ErrInvalidSecretLength = errors.New("invalid secret length for algorithm") + HS256 HSAlgorithm = "HS256" // HMAC using SHA-256 + HS384 HSAlgorithm = "HS384" // HMAC using SHA-384 + HS512 HSAlgorithm = "HS512" // HMAC using SHA-512 + RS256 RSAlgorithm = "RS256" // RSASSA-PKCS-v1.5 using SHA-256 + RS384 RSAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384 + RS512 RSAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512 ) +// Validate checks that the secret is a supported algorithm and that it's +// the proper length for the HSAlgorithm: +// - HS256: >= 32 bytes +// - HS384: >= 48 bytes +// - HS512: >= 64 bytes func (a HSAlgorithm) Validate(secret string) error { + const op = "HSAlgorithm.Validate" + if secret == "" { + return fmt.Errorf("%w: empty", ErrInvalidSecretLength) + } // verify secret length based on alg var expectLen int switch a { @@ -40,19 +47,28 @@ func (a HSAlgorithm) Validate(secret string) error { case HS512: expectLen = 64 default: - return fmt.Errorf("%w %q for client secret", ErrUnsupportedAlgorithm, a) + return fmt.Errorf("%s: %w %q for client secret", op, ErrUnsupportedAlgorithm, a) } if len(secret) < expectLen { - return fmt.Errorf("%w: %q must be %d bytes long", ErrInvalidSecretLength, a, expectLen) + return fmt.Errorf("%s: %w: %q must be %d bytes long", op, ErrInvalidSecretLength, a, expectLen) } return nil } +// Validate checks that the key is a supported algorithm and is valid per +// rsa.PrivateKey's Validate() method. func (a RSAlgorithm) Validate(key *rsa.PrivateKey) error { + const op = "RSAlgorithm.Validate" + if key == nil { + return fmt.Errorf("%s: %w", op, ErrNilPrivateKey) + } switch a { case RS256, RS384, RS512: - return key.Validate() + if err := key.Validate(); err != nil { + return fmt.Errorf("%s: %w", op, err) + } + return nil default: - return fmt.Errorf("%w %q for for RSA key", ErrUnsupportedAlgorithm, a) + return fmt.Errorf("%s: %w %q for for RSA key", op, ErrUnsupportedAlgorithm, a) } } diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index cd2c94e..b0d58d4 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -22,22 +22,18 @@ const ( JWTTypeParam = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ) -var ( - // these may happen due to user error - ErrMissingClientID = errors.New("missing client ID") - ErrMissingAudience = errors.New("missing audience") - ErrMissingAlgorithm = errors.New("missing signing algorithm") - ErrMissingKeyOrSecret = errors.New("missing private key or client secret") - ErrBothKeyAndSecret = errors.New("both private key and client secret provided") - // if these happen, either the user directly instantiated &JWT{} - // or there's a bug somewhere. - ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use NewJWT()") - ErrMissingFuncNow = errors.New("missing now func; please use NewJWT()") - ErrCreatingSigner = errors.New("error creating jwt signer") -) - -// NewJWT sets up a new JWT to sign with a private key or client secret +// NewJWT creates a new JWT which will be signed with either a private key or +// client secret. +// +// Supported Options: +// * WithClientSecret +// * WithRSAKey +// * WithKeyID +// * WithHeaders +// +// Either WithRSAKey or WithClientSecret must be used, but not both. func NewJWT(clientID string, audience []string, opts ...Option) (*JWT, error) { + const op = "NewJWT" j := &JWT{ clientID: clientID, audience: audience, @@ -57,12 +53,13 @@ func NewJWT(clientID string, audience []string, opts ...Option) (*JWT, error) { } if err := j.Validate(); err != nil { - return nil, fmt.Errorf("new client assertion validation error: %w", err) + return nil, fmt.Errorf("%s: %w", op, err) } return j, nil } -// JWT signs a JWT with either a private key or a secret +// JWT is used to create a client assertion JWT, a special JWT used by an OAuth +// 2.0 or OIDC client to authenticate themselves to an authorization server type JWT struct { // for JWT claims clientID string @@ -83,6 +80,7 @@ type JWT struct { // Validate validates the expected fields func (j *JWT) Validate() error { + const op = "JWT.Validate" var errs []error if j.genID == nil { errs = append(errs, ErrMissingFuncIDGenerator) @@ -92,7 +90,7 @@ func (j *JWT) Validate() error { } // bail early if any internal func errors if len(errs) > 0 { - return errors.Join(errs...) + return fmt.Errorf("%s: %w", op, errors.Join(errs...)) } if j.clientID == "" { @@ -112,45 +110,49 @@ func (j *JWT) Validate() error { } // if any of those fail, we have no hope. if len(errs) > 0 { - return errors.Join(errs...) + return fmt.Errorf("%s: %w", op, errors.Join(errs...)) } // finally, make sure Serialize() works; we can't pre-validate everything, // and this whole thing is useless if it can't Serialize() if _, err := j.Serialize(); err != nil { - return fmt.Errorf("serialization error during validate: %w", err) + return fmt.Errorf("%s: serialization error during validate: %w", op, err) } return nil } -// Serialize returns a signed JWT string +// Serialize returns client assertion JWT which can be used by an OAuth 2.0 or +// OIDC client to authenticate themselves to an authorization server func (j *JWT) Serialize() (string, error) { + const op = "JWT.Serialize" builder, err := j.builder() if err != nil { - return "", err + return "", fmt.Errorf("%s: %w", op, err) } token, err := builder.Serialize() if err != nil { - return "", fmt.Errorf("failed to sign token: %w", err) + return "", fmt.Errorf("%s: failed to sign token: %w", op, err) } return token, nil } func (j *JWT) builder() (jwt.Builder, error) { + const op = "builder" signer, err := j.signer() if err != nil { - return nil, err + return nil, fmt.Errorf("%s: %w", op, err) } id, err := j.genID() if err != nil { - return nil, fmt.Errorf("failed to generate token id: %w", err) + return nil, fmt.Errorf("%s: failed to generate token id: %w", op, err) } claims := j.claims(id) return jwt.Signed(signer).Claims(claims), nil } func (j *JWT) signer() (jose.Signer, error) { + const op = "signer" sKey := jose.SigningKey{ Algorithm: j.alg, } @@ -173,7 +175,7 @@ func (j *JWT) signer() (jose.Signer, error) { signer, err := jose.NewSigner(sKey, sOpts.WithType("JWT")) if err != nil { - return nil, fmt.Errorf("%w: %w", ErrCreatingSigner, err) + return nil, fmt.Errorf("%s: %w: %w", op, ErrCreatingSigner, err) } return signer, nil } diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 4a52198..0cb04bd 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -5,7 +5,7 @@ package clientassertion import ( "crypto/rsa" - "errors" + "fmt" "github.com/go-jose/go-jose/v4" ) @@ -35,9 +35,10 @@ func WithClientSecret(secret string, alg HSAlgorithm) Option { // * RS384 // * RS512 func WithRSAKey(key *rsa.PrivateKey, alg RSAlgorithm) Option { + const op = "WithRSAKey" return func(j *JWT) error { if err := alg.Validate(key); err != nil { - return err + return fmt.Errorf("%s: %w", op, err) } j.key = key j.alg = jose.SignatureAlgorithm(alg) @@ -54,14 +55,14 @@ func WithKeyID(keyID string) Option { } } -// WithHeaders sets extra JWT headers +// WithHeaders sets extra JWT headers. +// Do not set a "kid" header here; instead use WithKeyID. func WithHeaders(h map[string]string) Option { + const op = "WithHeaders" return func(j *JWT) error { for k, v := range h { - // disallow potential confusion arising from the "kid" header - // being set both by this and WithKeyID() if k == "kid" { - return errors.New(`"kid" header not allowed in WithHeaders; use WithKeyID instead`) + return fmt.Errorf(`%s: "kid" header not allowed; use WithKeyID instead`, op) } j.headers[k] = v } From 70ea140f8c1cb611aed8ab783f713b079fe21d17 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 16:50:28 -0500 Subject: [PATCH 17/27] privatize a couple things, const KeyIDHeader --- oidc/clientassertion/client_assertion.go | 57 ++++++++++++------------ oidc/clientassertion/options.go | 6 ++- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index b0d58d4..3865c68 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -52,9 +52,16 @@ func NewJWT(clientID string, audience []string, opts ...Option) (*JWT, error) { return nil, errors.Join(errs...) } - if err := j.Validate(); err != nil { + if err := j.validate(); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } + + // finally, make sure Serialize() works; we can't pre-validate everything, + // and this whole thing is useless if it can't Serialize() + if _, err := j.Serialize(); err != nil { + return nil, fmt.Errorf("%s: %w", op, err) + } + return j, nil } @@ -78,9 +85,23 @@ type JWT struct { now func() time.Time } -// Validate validates the expected fields -func (j *JWT) Validate() error { - const op = "JWT.Validate" +// Serialize returns client assertion JWT which can be used by an OAuth 2.0 or +// OIDC client to authenticate themselves to an authorization server +func (j *JWT) Serialize() (string, error) { + const op = "JWT.Serialize" + builder, err := j.builder() + if err != nil { + return "", fmt.Errorf("%s: %w", op, err) + } + token, err := builder.Serialize() + if err != nil { + return "", fmt.Errorf("%s: failed to sign token: %w", op, err) + } + return token, nil +} + +func (j *JWT) validate() error { + const op = "JWT.validate" var errs []error if j.genID == nil { errs = append(errs, ErrMissingFuncIDGenerator) @@ -113,30 +134,9 @@ func (j *JWT) Validate() error { return fmt.Errorf("%s: %w", op, errors.Join(errs...)) } - // finally, make sure Serialize() works; we can't pre-validate everything, - // and this whole thing is useless if it can't Serialize() - if _, err := j.Serialize(); err != nil { - return fmt.Errorf("%s: serialization error during validate: %w", op, err) - } - return nil } -// Serialize returns client assertion JWT which can be used by an OAuth 2.0 or -// OIDC client to authenticate themselves to an authorization server -func (j *JWT) Serialize() (string, error) { - const op = "JWT.Serialize" - builder, err := j.builder() - if err != nil { - return "", fmt.Errorf("%s: %w", op, err) - } - token, err := builder.Serialize() - if err != nil { - return "", fmt.Errorf("%s: failed to sign token: %w", op, err) - } - return token, nil -} - func (j *JWT) builder() (jwt.Builder, error) { const op = "builder" signer, err := j.signer() @@ -168,7 +168,6 @@ func (j *JWT) signer() (jose.Signer, error) { sOpts := &jose.SignerOptions{ ExtraHeaders: make(map[jose.HeaderKey]interface{}, len(j.headers)), } - // note: extra headers can override "kid" for k, v := range j.headers { sOpts.ExtraHeaders[jose.HeaderKey(k)] = v } @@ -193,11 +192,11 @@ func (j *JWT) claims(id string) *jwt.Claims { } } -// Serializer is the primary interface impelmented by JWT. -type Serializer interface { +// serializer is the primary interface impelmented by JWT. +type serializer interface { Serialize() (string, error) } // ensure JWT implements Serializer, which is accepted by the oidc option // oidc.WithClientAssertionJWT. -var _ Serializer = &JWT{} +var _ serializer = &JWT{} diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 0cb04bd..71c20cd 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -10,6 +10,10 @@ import ( "github.com/go-jose/go-jose/v4" ) +// KeyIDHeader is the "kid" header on a JWT, which providers use to look up +// the right public key to verify the JWT. +const KeyIDHeader = "kid" + // Option configures the JWT type Option func(*JWT) error @@ -61,7 +65,7 @@ func WithHeaders(h map[string]string) Option { const op = "WithHeaders" return func(j *JWT) error { for k, v := range h { - if k == "kid" { + if k == KeyIDHeader { return fmt.Errorf(`%s: "kid" header not allowed; use WithKeyID instead`, op) } j.headers[k] = v From c871e25a552bb97952b08a0ffd18168aa10204b8 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 16:51:49 -0500 Subject: [PATCH 18/27] errors go in error.go --- oidc/clientassertion/error.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 oidc/clientassertion/error.go diff --git a/oidc/clientassertion/error.go b/oidc/clientassertion/error.go new file mode 100644 index 0000000..0fb5546 --- /dev/null +++ b/oidc/clientassertion/error.go @@ -0,0 +1,29 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package clientassertion + +import "errors" + +var ( + // these may happen due to user error + + ErrMissingClientID = errors.New("missing client ID") + ErrMissingAudience = errors.New("missing audience") + ErrMissingAlgorithm = errors.New("missing signing algorithm") + ErrMissingKeyOrSecret = errors.New("missing private key or client secret") + ErrBothKeyAndSecret = errors.New("both private key and client secret provided") + + // if these happen, either the user directly instantiated &JWT{} + // or there's a bug somewhere. + + ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use NewJWT()") + ErrMissingFuncNow = errors.New("missing now func; please use NewJWT()") + ErrCreatingSigner = errors.New("error creating jwt signer") + + // algorithm errors + + ErrUnsupportedAlgorithm = errors.New("unsupported algorithm") + ErrInvalidSecretLength = errors.New("invalid secret length for algorithm") + ErrNilPrivateKey = errors.New("nil private key") +) From e6d416e9c1fd229ea6d3c3e0f21c62e6c601a83c Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 16:58:13 -0500 Subject: [PATCH 19/27] error if missing kid --- oidc/clientassertion/error.go | 1 + oidc/clientassertion/options.go | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/oidc/clientassertion/error.go b/oidc/clientassertion/error.go index 0fb5546..f7ba929 100644 --- a/oidc/clientassertion/error.go +++ b/oidc/clientassertion/error.go @@ -11,6 +11,7 @@ var ( ErrMissingClientID = errors.New("missing client ID") ErrMissingAudience = errors.New("missing audience") ErrMissingAlgorithm = errors.New("missing signing algorithm") + ErrMissingKeyID = errors.New("missing key ID") ErrMissingKeyOrSecret = errors.New("missing private key or client secret") ErrBothKeyAndSecret = errors.New("both private key and client secret provided") diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 71c20cd..7d9e91f 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -53,7 +53,11 @@ func WithRSAKey(key *rsa.PrivateKey, alg RSAlgorithm) Option { // WithKeyID sets the "kid" header that OIDC providers use to look up the // public key to check the signed JWT func WithKeyID(keyID string) Option { + const op = "WithKeyID" return func(j *JWT) error { + if keyID == "" { + return fmt.Errorf("%s: %w", op, ErrMissingKeyID) + } j.headers["kid"] = keyID return nil } From 62aaf9787b13f572b0b23f6128826108d513e5ca Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 17:14:07 -0500 Subject: [PATCH 20/27] fix obvious mistake it is getting late in the day! --- oidc/clientassertion/client_assertion_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index bd5e3fb..5b742a4 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -35,10 +35,6 @@ func assertJoinedErrs(t *testing.T, expect []error, actual error) { func TestJWTBare(t *testing.T) { j := &JWT{} - expect := []error{ErrMissingFuncIDGenerator, ErrMissingFuncNow} - actual := j.Validate() - assertJoinedErrs(t, expect, actual) - tokenStr, err := j.Serialize() require.ErrorIs(t, err, ErrCreatingSigner) From 9981bcee40e2e17b905852e7c645a5dc6a899f2d Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Wed, 26 Feb 2025 17:21:30 -0500 Subject: [PATCH 21/27] fix less obvious but still pretty obvious error if i'd just run my own tests (: --- oidc/clientassertion/client_assertion_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index 5b742a4..b695568 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -24,7 +24,9 @@ type joinedErrs interface { func assertJoinedErrs(t *testing.T, expect []error, actual error) { t.Helper() - joined, ok := actual.(joinedErrs) // Validate() error is errors.Join()ed + // validate() error is wrapped, joined, wrapped errors + err := errors.Unwrap(actual) + joined, ok := err.(joinedErrs) require.True(t, ok, "expected Join()ed errors from Validate(); got: %v", actual) unwrapped := joined.Unwrap() require.ElementsMatch(t, expect, unwrapped) @@ -44,7 +46,7 @@ func TestJWTBare(t *testing.T) { func TestNewJWT(t *testing.T) { t.Run("should run validate", func(t *testing.T) { j, err := NewJWT("", nil) - require.ErrorContains(t, err, "validation error:") + require.ErrorContains(t, err, "validate:") assert.Nil(t, j) }) @@ -213,7 +215,7 @@ func TestValidate(t *testing.T) { j, err := NewJWT(tc.cid, tc.aud, tc.opts...) require.NotNil(t, err) - require.ErrorContains(t, err, "validation error:") + require.ErrorContains(t, err, "validate:") err = errors.Unwrap(err) // NewJWT wraps the error from Validate() with fmt.Errorf("%w") assertJoinedErrs(t, tc.errs, err) From d6932e084fe881c1fae4209a254c2957887dfe51 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Thu, 27 Feb 2025 12:12:53 -0500 Subject: [PATCH 22/27] add a couple missing err ops --- oidc/clientassertion/algorithms.go | 2 +- oidc/clientassertion/client_assertion.go | 2 +- oidc/provider.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/oidc/clientassertion/algorithms.go b/oidc/clientassertion/algorithms.go index dc256ac..c6fc38b 100644 --- a/oidc/clientassertion/algorithms.go +++ b/oidc/clientassertion/algorithms.go @@ -35,7 +35,7 @@ const ( func (a HSAlgorithm) Validate(secret string) error { const op = "HSAlgorithm.Validate" if secret == "" { - return fmt.Errorf("%w: empty", ErrInvalidSecretLength) + return fmt.Errorf("%s: %w: empty", op, ErrInvalidSecretLength) } // verify secret length based on alg var expectLen int diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 3865c68..a774a67 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -95,7 +95,7 @@ func (j *JWT) Serialize() (string, error) { } token, err := builder.Serialize() if err != nil { - return "", fmt.Errorf("%s: failed to sign token: %w", op, err) + return "", fmt.Errorf("%s: failed to serialize token: %w", op, err) } return token, nil } diff --git a/oidc/provider.go b/oidc/provider.go index 81bbac5..7116a46 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -312,7 +312,7 @@ func (p *Provider) Exchange(ctx context.Context, oidcRequest Request, authorizat // errors here, but err check again just in case. token, err := oidcRequest.ClientAssertionJWT().Serialize() if err != nil { - return nil, err + return nil, fmt.Errorf("%s: %w", op, err) } authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("client_assertion_type", cass.JWTTypeParam), From cd236a574be118ff391aa5e596d8c3c2c0c1e4f7 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Thu, 27 Feb 2025 16:19:44 -0500 Subject: [PATCH 23/27] pr feedback bits --- oidc/clientassertion/algorithms.go | 5 ++--- oidc/clientassertion/client_assertion.go | 2 +- oidc/clientassertion/options.go | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/oidc/clientassertion/algorithms.go b/oidc/clientassertion/algorithms.go index c6fc38b..390f6a6 100644 --- a/oidc/clientassertion/algorithms.go +++ b/oidc/clientassertion/algorithms.go @@ -15,10 +15,9 @@ type ( RSAlgorithm string ) +// JOSE asymmetric signing algorithm values as defined by RFC 7518. +// See: https://tools.ietf.org/html/rfc7518#section-3.1 const ( - // JOSE asymmetric signing algorithm values as defined by RFC 7518. - // See: https://tools.ietf.org/html/rfc7518#section-3.1 - HS256 HSAlgorithm = "HS256" // HMAC using SHA-256 HS384 HSAlgorithm = "HS384" // HMAC using SHA-384 HS512 HSAlgorithm = "HS512" // HMAC using SHA-512 diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index a774a67..0c1e53e 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -192,7 +192,7 @@ func (j *JWT) claims(id string) *jwt.Claims { } } -// serializer is the primary interface impelmented by JWT. +// serializer is the primary interface implemented by JWT. type serializer interface { Serialize() (string, error) } diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 7d9e91f..21744e5 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -58,7 +58,7 @@ func WithKeyID(keyID string) Option { if keyID == "" { return fmt.Errorf("%s: %w", op, ErrMissingKeyID) } - j.headers["kid"] = keyID + j.headers[KeyIDHeader] = keyID return nil } } From 97b358809d4306a7ea2c1ae305cbf025a2aa7bf1 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Thu, 27 Feb 2025 17:43:02 -0500 Subject: [PATCH 24/27] clarify hmac min len --- oidc/clientassertion/algorithms.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/oidc/clientassertion/algorithms.go b/oidc/clientassertion/algorithms.go index 390f6a6..19d46a1 100644 --- a/oidc/clientassertion/algorithms.go +++ b/oidc/clientassertion/algorithms.go @@ -36,20 +36,24 @@ func (a HSAlgorithm) Validate(secret string) error { if secret == "" { return fmt.Errorf("%s: %w: empty", op, ErrInvalidSecretLength) } - // verify secret length based on alg - var expectLen int + // rfc7518 https://datatracker.ietf.org/doc/html/rfc7518#section-3.2 + // states: + // A key of the same size as the hash output (for instance, 256 bits + // for "HS256") or larger MUST be used + // e.g. 256 / 8 = 32 bytes + var minLen int switch a { case HS256: - expectLen = 32 + minLen = 32 case HS384: - expectLen = 48 + minLen = 48 case HS512: - expectLen = 64 + minLen = 64 default: return fmt.Errorf("%s: %w %q for client secret", op, ErrUnsupportedAlgorithm, a) } - if len(secret) < expectLen { - return fmt.Errorf("%s: %w: %q must be %d bytes long", op, ErrInvalidSecretLength, a, expectLen) + if len(secret) < minLen { + return fmt.Errorf("%s: %w: %q must be at least %d bytes long", op, ErrInvalidSecretLength, a, minLen) } return nil } From e8dfe1983ebefb36957b082f01a62950e1d9e4c5 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Thu, 27 Feb 2025 21:16:23 -0500 Subject: [PATCH 25/27] refactor to NewJWTWithRSAKey and NewJWTWithHMAC --- oidc/clientassertion/client_assertion.go | 143 ++++--- oidc/clientassertion/client_assertion_test.go | 393 +++++++----------- oidc/clientassertion/error.go | 21 +- oidc/clientassertion/example_test.go | 10 +- oidc/clientassertion/options.go | 38 +- 5 files changed, 241 insertions(+), 364 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 0c1e53e..0d948e6 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -7,6 +7,7 @@ package clientassertion import ( + "crypto/rsa" "errors" "fmt" "time" @@ -22,44 +23,112 @@ const ( JWTTypeParam = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ) -// NewJWT creates a new JWT which will be signed with either a private key or -// client secret. +// NewJWTWithRSAKey creates a new JWT which will be signed with a private key. +// +// alg must be one of: +// * RS256 +// * RS384 +// * RS512 // // Supported Options: -// * WithClientSecret -// * WithRSAKey // * WithKeyID // * WithHeaders -// -// Either WithRSAKey or WithClientSecret must be used, but not both. -func NewJWT(clientID string, audience []string, opts ...Option) (*JWT, error) { - const op = "NewJWT" +func NewJWTWithRSAKey(clientID string, audience []string, + alg RSAlgorithm, key *rsa.PrivateKey, opts ...Option) (*JWT, error) { + const op = "NewJWTWithRSAKey" + j := &JWT{ clientID: clientID, audience: audience, + alg: jose.SignatureAlgorithm(alg), + key: key, headers: make(map[string]string), genID: uuid.GenerateUUID, now: time.Now, } var errs []error + if clientID == "" { + errs = append(errs, ErrMissingClientID) + } + if len(audience) == 0 { + errs = append(errs, ErrMissingAudience) + } + if alg == "" { + errs = append(errs, ErrMissingAlgorithm) + } + + // rsa-specific + if key == nil { + errs = append(errs, ErrMissingKey) + } else { + if err := alg.Validate(key); err != nil { + errs = append(errs, err) + } + } + for _, opt := range opts { if err := opt(j); err != nil { errs = append(errs, err) } } if len(errs) > 0 { - return nil, errors.Join(errs...) + return nil, fmt.Errorf("%s: %w", op, errors.Join(errs...)) } - if err := j.validate(); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return j, nil +} + +// NewJWTWithHMAC creates a new JWT which will be signed with an HMAC secret. +// +// alg must be one of: +// * HS256 with a >= 32 byte secret +// * HS384 with a >= 48 byte secret +// * HS512 with a >= 64 byte secret +// +// Supported Options: +// * WithKeyID +// * WithHeaders +func NewJWTWithHMAC(clientID string, audience []string, + alg HSAlgorithm, secret string, opts ...Option) (*JWT, error) { + const op = "NewJWTWithHMAC" + j := &JWT{ + clientID: clientID, + audience: audience, + alg: jose.SignatureAlgorithm(alg), + secret: secret, + headers: make(map[string]string), + genID: uuid.GenerateUUID, + now: time.Now, } - // finally, make sure Serialize() works; we can't pre-validate everything, - // and this whole thing is useless if it can't Serialize() - if _, err := j.Serialize(); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + var errs []error + if clientID == "" { + errs = append(errs, ErrMissingClientID) + } + if len(audience) == 0 { + errs = append(errs, ErrMissingAudience) + } + if alg == "" { + errs = append(errs, ErrMissingAlgorithm) + } + + // hmac-specific + if secret == "" { + errs = append(errs, ErrMissingSecret) + } else { + if err := alg.Validate(secret); err != nil { + errs = append(errs, err) + } + } + + for _, opt := range opts { + if err := opt(j); err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return nil, fmt.Errorf("%s: %w", op, errors.Join(errs...)) } return j, nil @@ -75,8 +144,9 @@ type JWT struct { // for signer alg jose.SignatureAlgorithm - // key may be any key type that jose.SigningKey accepts for its Key - key any + // key may be any type that jose.SigningKey accepts for its Key, + // but today we only support RSA keys. + key *rsa.PrivateKey // secret may be used instead of key secret string @@ -100,43 +170,6 @@ func (j *JWT) Serialize() (string, error) { return token, nil } -func (j *JWT) validate() error { - const op = "JWT.validate" - var errs []error - if j.genID == nil { - errs = append(errs, ErrMissingFuncIDGenerator) - } - if j.now == nil { - errs = append(errs, ErrMissingFuncNow) - } - // bail early if any internal func errors - if len(errs) > 0 { - return fmt.Errorf("%s: %w", op, errors.Join(errs...)) - } - - if j.clientID == "" { - errs = append(errs, ErrMissingClientID) - } - if len(j.audience) == 0 { - errs = append(errs, ErrMissingAudience) - } - if j.alg == "" { - errs = append(errs, ErrMissingAlgorithm) - } - if j.key == nil && j.secret == "" { - errs = append(errs, ErrMissingKeyOrSecret) - } - if j.key != nil && j.secret != "" { - errs = append(errs, ErrBothKeyAndSecret) - } - // if any of those fail, we have no hope. - if len(errs) > 0 { - return fmt.Errorf("%s: %w", op, errors.Join(errs...)) - } - - return nil -} - func (j *JWT) builder() (jwt.Builder, error) { const op = "builder" signer, err := j.signer() @@ -157,7 +190,7 @@ func (j *JWT) signer() (jose.Signer, error) { Algorithm: j.alg, } - // Validate() ensures these are mutually exclusive + // the different New* constructors ensure these are mutually exclusive. if j.secret != "" { sKey.Key = []byte(j.secret) } diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index b695568..9ccfea2 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -16,24 +16,24 @@ import ( "github.com/stretchr/testify/require" ) -// any non-nil error from NewJWT()/Validate() will be errors.Join()ed. +// any non-nil error from NewJWT*() will be errors.Join()ed. // this is so we can assert each error within. type joinedErrs interface { Unwrap() []error } -func assertJoinedErrs(t *testing.T, expect []error, actual error) { +func assertJoinedErrs(t *testing.T, actual error, expect []error) { t.Helper() - // validate() error is wrapped, joined, wrapped errors + // New* error is wrapped, joined, wrapped errors err := errors.Unwrap(actual) joined, ok := err.(joinedErrs) - require.True(t, ok, "expected Join()ed errors from Validate(); got: %v", actual) + require.True(t, ok, "expected Join()ed errors; got: %v", actual) unwrapped := joined.Unwrap() require.ElementsMatch(t, expect, unwrapped) } // TestJWTBare tests what errors we expect if &JWT{} -// is instantiated directly, rather than using the constructor NewJWT(). +// is instantiated directly, rather than using a constructor. func TestJWTBare(t *testing.T) { j := &JWT{} @@ -43,277 +43,162 @@ func TestJWTBare(t *testing.T) { assert.Equal(t, "", tokenStr) } -func TestNewJWT(t *testing.T) { - t.Run("should run validate", func(t *testing.T) { - j, err := NewJWT("", nil) - require.ErrorContains(t, err, "validate:") - assert.Nil(t, j) - }) - - tCid := "test-client-id" - tAud := []string{"test-audience"} +func TestNewJWTWithRSAKey(t *testing.T) { + cid := "test-client-id" + aud := []string{"test-audience"} validKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 - - cases := []struct { - name string - cid string - aud []string - opts []Option - check func(*testing.T, *JWT) - err string - }{ - { - name: "with private key", - cid: tCid, aud: tAud, - opts: []Option{WithRSAKey(validKey, RS256)}, - check: func(t *testing.T, ca *JWT) { - require.NotNil(t, ca.key) - require.Equal(t, jose.SignatureAlgorithm("RS256"), ca.alg) - }, - }, - { - name: "with client secret", - cid: tCid, aud: tAud, - opts: []Option{WithClientSecret(validSecret, HS256)}, - check: func(t *testing.T, ca *JWT) { - require.Equal(t, validSecret, ca.secret) - require.Equal(t, jose.SignatureAlgorithm(HS256), ca.alg) - }, - }, - { - name: "with key id", - cid: tCid, aud: tAud, - opts: []Option{ - WithKeyID("kid"), - WithClientSecret(validSecret, HS256), - }, - check: func(t *testing.T, ca *JWT) { - require.Equal(t, "kid", ca.headers["kid"]) - }, - }, - { - name: "with headers", - cid: tCid, aud: tAud, - opts: []Option{ - WithHeaders(map[string]string{"h1": "v1", "h2": "v2"}), - WithClientSecret(validSecret, HS256), - }, - check: func(t *testing.T, ca *JWT) { - require.Equal(t, map[string]string{"h1": "v1", "h2": "v2"}, ca.headers) - }, - }, - { - name: "invalid alg for secret", - cid: tCid, aud: tAud, - opts: []Option{ - WithClientSecret(validSecret, "ruh-roh"), - }, - err: ErrUnsupportedAlgorithm.Error(), - }, - { - name: "invalid alg for key", - cid: tCid, aud: tAud, - opts: []Option{ - WithRSAKey(validKey, "ruh-roh"), - }, - err: ErrUnsupportedAlgorithm.Error(), - }, - { - name: "invalid client secret", - cid: tCid, aud: tAud, - opts: []Option{ - WithClientSecret("invalid secret", HS256), - }, - err: ErrInvalidSecretLength.Error(), - }, - { - name: "invalid key", - cid: tCid, aud: tAud, - opts: []Option{ - WithRSAKey(&rsa.PrivateKey{}, RS256), - }, - err: "crypto/rsa: missing public modulus", - }, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - - j, err := NewJWT(tc.cid, tc.aud, tc.opts...) - - if tc.err == "" { - require.NoError(t, err) - require.NotNil(t, j) - require.Equal(t, tc.cid, j.clientID) - require.Equal(t, tc.aud, j.audience) - } else { - require.Error(t, err) - require.ErrorContains(t, err, tc.err) - } - if tc.check != nil { - tc.check(t, j) - } + // happy path + j, err := NewJWTWithRSAKey(cid, aud, RS256, validKey, + WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) + assert.NoError(t, err) + assert.NotNil(t, j) - }) - } + // errors + j, err = NewJWTWithRSAKey("", []string{}, "", nil) + assertJoinedErrs(t, err, []error{ + ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingKey, + }) + assert.Nil(t, j) + // bad algorithm + j, err = NewJWTWithRSAKey(cid, aud, "bad-alg", &rsa.PrivateKey{}) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) + // bad key; only checked if good alg + j, err = NewJWTWithRSAKey(cid, aud, RS256, &rsa.PrivateKey{}) + assert.ErrorContains(t, err, "RSAlgorithm.Validate: crypto/rsa") + assert.Nil(t, j) + // bad With*s + j, err = NewJWTWithRSAKey(cid, aud, RS256, validKey, + WithKeyID(""), WithHeaders(map[string]string{"kid": "baz"})) + assert.ErrorIs(t, err, ErrMissingKeyID) + assert.ErrorIs(t, err, ErrKidHeader) } -func TestValidate(t *testing.T) { - tCid := "test-client-id" - tAud := []string{"test-audience"} - validKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) +func TestNewJWTWithHMAC(t *testing.T) { + cid := "test-client-id" + aud := []string{"test-audience"} validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 - cases := []struct { - name string - cid string - aud []string - opts []Option - errs []error - }{ - { - name: "missing everything", - errs: []error{ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingKeyOrSecret}, - }, - { - name: "missing client id", - aud: tAud, - errs: []error{ErrMissingClientID}, - opts: []Option{ - WithRSAKey(validKey, RS256), - }, - }, - { - name: "missing audience", - cid: tCid, - errs: []error{ErrMissingAudience}, - opts: []Option{ - WithRSAKey(validKey, RS256), - }, - }, - { - name: "missing client and secret", - cid: tCid, aud: tAud, - errs: []error{ErrMissingAlgorithm, ErrMissingKeyOrSecret}, - }, - { - name: "both client and secret", - cid: tCid, aud: tAud, - opts: []Option{ - WithRSAKey(validKey, RS256), - WithClientSecret(validSecret, HS256), - }, - errs: []error{ErrBothKeyAndSecret}, - }, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - - // NewJWT() runs Validate() - j, err := NewJWT(tc.cid, tc.aud, tc.opts...) - require.NotNil(t, err) - require.ErrorContains(t, err, "validate:") + // happy path + j, err := NewJWTWithHMAC(cid, aud, HS256, validSecret, + WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) + assert.NoError(t, err) + assert.NotNil(t, j) - err = errors.Unwrap(err) // NewJWT wraps the error from Validate() with fmt.Errorf("%w") - assertJoinedErrs(t, tc.errs, err) + // errors + j, err = NewJWTWithHMAC("", []string{}, "", "") + assertJoinedErrs(t, err, []error{ + ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingSecret, + }) + assert.Nil(t, j) + // bad algorithm + j, err = NewJWTWithHMAC(cid, aud, "bad-alg", validSecret) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) + // bad secret; only checked if good alg + j, err = NewJWTWithHMAC(cid, aud, HS256, "not-very-good") + assert.ErrorIs(t, err, ErrInvalidSecretLength) + assert.Nil(t, j) + // bad With*s + j, err = NewJWTWithHMAC(cid, aud, HS256, validSecret, + WithKeyID(""), WithHeaders(map[string]string{"kid": "baz"})) + assert.ErrorIs(t, err, ErrMissingKeyID) + assert.ErrorIs(t, err, ErrKidHeader) +} - require.Nil(t, j) +func TestJWT_Serialize(t *testing.T) { + cid := "test-client-id" + aud := []string{"test-audience"} - }) - } -} + // make the world more predictable + now := time.Now() + nowF := func() time.Time { return now } + genIDF := func() (string, error) { return "test-claim-id", nil } -func TestSignedToken(t *testing.T) { + // for rsa key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) pub, ok := key.Public().(*rsa.PublicKey) require.True(t, ok, "couldn't get rsa.PublicKey from PrivateKey") - validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 - - cases := []struct { - name string - claimKey any // []byte or pubkey; we'll use this to check the signature - opts []Option - err error - }{ - { - name: "valid secret", - claimKey: []byte(validSecret), - opts: []Option{ - WithClientSecret(validSecret, "HS256"), - WithKeyID("test-key-id"), - WithHeaders(map[string]string{"xtra": "headies"}), - }, - }, - { - name: "valid key", - claimKey: pub, - opts: []Option{ - WithRSAKey(key, "RS256"), - WithKeyID("test-key-id"), - WithHeaders(map[string]string{"xtra": "headies"}), - }, - }, + // for hmac + secret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 + + // this specific shape is what this whole library is oriented around + assertClaims := func(t *testing.T, token *jwt.JSONWebToken, key any) { + t.Helper() + expectClaims := jwt.Expected{ + Issuer: "test-client-id", // = cid + Subject: "test-client-id", // = cid + AnyAudience: []string{"test-audience"}, // = aud + ID: "test-claim-id", // = genIDf() + Time: now, // = nowF() + } + var actualClaims jwt.Claims + err := token.Claims(key, &actualClaims) + require.NoError(t, err) + err = actualClaims.Validate(expectClaims) + require.NoError(t, err) } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - j, err := NewJWT("test-client-id", []string{"test-aud"}, tc.opts...) - require.NoError(t, err) - - now := time.Now() - j.now = func() time.Time { return now } - j.genID = func() (string, error) { return "test-claim-id", nil } - - // method under test - tokenString, err := j.Serialize() - - if tc.err != nil { - require.ErrorIs(t, err, tc.err) - require.Equal(t, "", tokenString) - return - } - require.NoError(t, err) - - // extract the token from the signed string - token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{j.alg}) - require.NoError(t, err) - - // check headers - expectHeaders := jose.Header{ - Algorithm: string(j.alg), - KeyID: "test-key-id", - ExtraHeaders: map[jose.HeaderKey]any{ - "typ": "JWT", - "xtra": "headies", - }, - } - require.Len(t, token.Headers, 1) - actualHeaders := token.Headers[0] - require.Equal(t, expectHeaders, actualHeaders) + t.Run("rsa", func(t *testing.T) { + j, err := NewJWTWithRSAKey(cid, aud, RS256, key, + WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) + require.NoError(t, err) + require.NotNil(t, j) + j.now = nowF + j.genID = genIDF + token, err := j.Serialize() // method under test + require.NoError(t, err) + require.NotEmpty(t, token) + // make sure we made what we intended to + parsed, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.RS256}) + require.NoError(t, err) + require.NotNil(t, parsed) + expectHeaders := jose.Header{ + Algorithm: string(RS256), + KeyID: "key-id", + ExtraHeaders: map[jose.HeaderKey]any{ + "typ": "JWT", + "foo": "bar", + }, + } + require.Len(t, parsed.Headers, 1) + actualHeaders := parsed.Headers[0] + require.Equal(t, expectHeaders, actualHeaders) + assertClaims(t, parsed, pub) + }) - // check claims - expectClaims := jwt.Expected{ - Issuer: "test-client-id", - Subject: "test-client-id", - AnyAudience: []string{"test-aud"}, - ID: "test-claim-id", - Time: now, - } - var actualClaims jwt.Claims - err = token.Claims(tc.claimKey, &actualClaims) - require.NoError(t, err) - err = actualClaims.Validate(expectClaims) - require.NoError(t, err) - }) - } + t.Run("hmac", func(t *testing.T) { + j, err := NewJWTWithHMAC(cid, aud, HS256, secret, + WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) + require.NoError(t, err) + require.NotNil(t, j) + j.now = nowF + j.genID = genIDF + token, err := j.Serialize() // method under test + require.NoError(t, err) + require.NotEmpty(t, token) + // make sure we made what we intended to + parsed, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.HS256}) + require.NoError(t, err) + require.NotNil(t, parsed) + expectHeaders := jose.Header{ + Algorithm: string(HS256), + KeyID: "key-id", + ExtraHeaders: map[jose.HeaderKey]any{ + "typ": "JWT", + "foo": "bar", + }, + } + require.Len(t, parsed.Headers, 1) + actualHeaders := parsed.Headers[0] + require.Equal(t, expectHeaders, actualHeaders) + assertClaims(t, parsed, []byte(secret)) + }) t.Run("error generating token id", func(t *testing.T) { genIDErr := errors.New("failed to generate test id") - j, err := NewJWT("a", []string{"a"}, WithClientSecret(validSecret, HS256)) + j, err := NewJWTWithHMAC("a", []string{"a"}, HS256, secret) require.NoError(t, err) j.genID = func() (string, error) { return "", genIDErr } tokenString, err := j.Serialize() diff --git a/oidc/clientassertion/error.go b/oidc/clientassertion/error.go index f7ba929..9f9c240 100644 --- a/oidc/clientassertion/error.go +++ b/oidc/clientassertion/error.go @@ -8,19 +8,14 @@ import "errors" var ( // these may happen due to user error - ErrMissingClientID = errors.New("missing client ID") - ErrMissingAudience = errors.New("missing audience") - ErrMissingAlgorithm = errors.New("missing signing algorithm") - ErrMissingKeyID = errors.New("missing key ID") - ErrMissingKeyOrSecret = errors.New("missing private key or client secret") - ErrBothKeyAndSecret = errors.New("both private key and client secret provided") - - // if these happen, either the user directly instantiated &JWT{} - // or there's a bug somewhere. - - ErrMissingFuncIDGenerator = errors.New("missing IDgen func; please use NewJWT()") - ErrMissingFuncNow = errors.New("missing now func; please use NewJWT()") - ErrCreatingSigner = errors.New("error creating jwt signer") + ErrMissingClientID = errors.New("missing client ID") + ErrMissingAudience = errors.New("missing audience") + ErrMissingAlgorithm = errors.New("missing signing algorithm") + ErrMissingKeyID = errors.New("missing key ID") + ErrMissingKey = errors.New("missing private key") + ErrMissingSecret = errors.New("missing client secret") + ErrKidHeader = errors.New(`"kid" not allowed in WithHeaders; use WithKeyID instead`) + ErrCreatingSigner = errors.New("error creating jwt signer") // algorithm errors diff --git a/oidc/clientassertion/example_test.go b/oidc/clientassertion/example_test.go index df9db1c..fe62d49 100644 --- a/oidc/clientassertion/example_test.go +++ b/oidc/clientassertion/example_test.go @@ -14,11 +14,12 @@ import ( ) func ExampleJWT() { + cid := "client-id" + aud := []string{"audience"} + // With an HMAC client secret secret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 - j, err := NewJWT("client-id", []string{"audience"}, - WithClientSecret(secret, HS256), - ) + j, err := NewJWTWithHMAC(cid, aud, HS256, secret) if err != nil { log.Fatal(err) } @@ -55,8 +56,7 @@ func ExampleJWT() { if !ok { log.Fatal("couldn't get rsa.PublicKey from PrivateKey") } - j, err = NewJWT("client-id", []string{"audience"}, - WithRSAKey(privKey, RS256), + j, err = NewJWTWithRSAKey(cid, aud, RS256, privKey, // note: for some providers, they key ID may be an x5t derivation // of a cert generated from the private key. // if your key has an associated JWKS endpoint, it will be the "kid" diff --git a/oidc/clientassertion/options.go b/oidc/clientassertion/options.go index 21744e5..30b2529 100644 --- a/oidc/clientassertion/options.go +++ b/oidc/clientassertion/options.go @@ -4,10 +4,7 @@ package clientassertion import ( - "crypto/rsa" "fmt" - - "github.com/go-jose/go-jose/v4" ) // KeyIDHeader is the "kid" header on a JWT, which providers use to look up @@ -17,39 +14,6 @@ const KeyIDHeader = "kid" // Option configures the JWT type Option func(*JWT) error -// WithClientSecret sets a secret and algorithm to sign the JWT with. -// alg must be one of: -// * HS256 with a >= 32 byte secret -// * HS384 with a >= 48 byte secret -// * HS512 with a >= 64 byte secret -func WithClientSecret(secret string, alg HSAlgorithm) Option { - return func(j *JWT) error { - if err := alg.Validate(secret); err != nil { - return err - } - j.secret = secret - j.alg = jose.SignatureAlgorithm(alg) - return nil - } -} - -// WithRSAKey sets a private key to sign the JWT with. -// alg must be one of: -// * RS256 -// * RS384 -// * RS512 -func WithRSAKey(key *rsa.PrivateKey, alg RSAlgorithm) Option { - const op = "WithRSAKey" - return func(j *JWT) error { - if err := alg.Validate(key); err != nil { - return fmt.Errorf("%s: %w", op, err) - } - j.key = key - j.alg = jose.SignatureAlgorithm(alg) - return nil - } -} - // WithKeyID sets the "kid" header that OIDC providers use to look up the // public key to check the signed JWT func WithKeyID(keyID string) Option { @@ -70,7 +34,7 @@ func WithHeaders(h map[string]string) Option { return func(j *JWT) error { for k, v := range h { if k == KeyIDHeader { - return fmt.Errorf(`%s: "kid" header not allowed; use WithKeyID instead`, op) + return fmt.Errorf("%s: %w", op, ErrKidHeader) } j.headers[k] = v } From a94da547016a7a5b029880a7408f12390e3e48c7 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Thu, 27 Feb 2025 21:22:30 -0500 Subject: [PATCH 26/27] remove most single-use private methods i was a bit overzealous with my tiny methods. i couldn't quite bring myself to delete signer() :P --- oidc/clientassertion/client_assertion.go | 44 +++++++++--------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 0d948e6..7001cad 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -159,29 +159,30 @@ type JWT struct { // OIDC client to authenticate themselves to an authorization server func (j *JWT) Serialize() (string, error) { const op = "JWT.Serialize" - builder, err := j.builder() + signer, err := j.signer() if err != nil { return "", fmt.Errorf("%s: %w", op, err) } - token, err := builder.Serialize() + id, err := j.genID() if err != nil { - return "", fmt.Errorf("%s: failed to serialize token: %w", op, err) + return "", fmt.Errorf("%s: failed to generate token id: %w", op, err) } - return token, nil -} - -func (j *JWT) builder() (jwt.Builder, error) { - const op = "builder" - signer, err := j.signer() - if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + now := j.now().UTC() + claims := &jwt.Claims{ + Issuer: j.clientID, + Subject: j.clientID, + Audience: j.audience, + Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), + NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)), + IssuedAt: jwt.NewNumericDate(now), + ID: id, } - id, err := j.genID() + builder := jwt.Signed(signer).Claims(claims) + token, err := builder.Serialize() if err != nil { - return nil, fmt.Errorf("%s: failed to generate token id: %w", op, err) + return "", fmt.Errorf("%s: failed to serialize token: %w", op, err) } - claims := j.claims(id) - return jwt.Signed(signer).Claims(claims), nil + return token, nil } func (j *JWT) signer() (jose.Signer, error) { @@ -212,19 +213,6 @@ func (j *JWT) signer() (jose.Signer, error) { return signer, nil } -func (j *JWT) claims(id string) *jwt.Claims { - now := j.now().UTC() - return &jwt.Claims{ - Issuer: j.clientID, - Subject: j.clientID, - Audience: j.audience, - Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), - NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)), - IssuedAt: jwt.NewNumericDate(now), - ID: id, - } -} - // serializer is the primary interface implemented by JWT. type serializer interface { Serialize() (string, error) From 80514102ada6c41df3e50bd02e12212c9e496c14 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Fri, 28 Feb 2025 12:53:19 -0500 Subject: [PATCH 27/27] more lovely PR feedback --- oidc/clientassertion/client_assertion.go | 6 +- oidc/clientassertion/client_assertion_test.go | 104 ++++++++++-------- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/oidc/clientassertion/client_assertion.go b/oidc/clientassertion/client_assertion.go index 7001cad..920b7af 100644 --- a/oidc/clientassertion/client_assertion.go +++ b/oidc/clientassertion/client_assertion.go @@ -35,7 +35,7 @@ const ( // * WithHeaders func NewJWTWithRSAKey(clientID string, audience []string, alg RSAlgorithm, key *rsa.PrivateKey, opts ...Option) (*JWT, error) { - const op = "NewJWTWithRSAKey" + const op = "clientassertion.NewJWTWithRSAKey" j := &JWT{ clientID: clientID, @@ -91,7 +91,7 @@ func NewJWTWithRSAKey(clientID string, audience []string, // * WithHeaders func NewJWTWithHMAC(clientID string, audience []string, alg HSAlgorithm, secret string, opts ...Option) (*JWT, error) { - const op = "NewJWTWithHMAC" + const op = "clientassertion.NewJWTWithHMAC" j := &JWT{ clientID: clientID, audience: audience, @@ -200,7 +200,7 @@ func (j *JWT) signer() (jose.Signer, error) { } sOpts := &jose.SignerOptions{ - ExtraHeaders: make(map[jose.HeaderKey]interface{}, len(j.headers)), + ExtraHeaders: make(map[jose.HeaderKey]any, len(j.headers)), } for k, v := range j.headers { sOpts.ExtraHeaders[jose.HeaderKey(k)] = v diff --git a/oidc/clientassertion/client_assertion_test.go b/oidc/clientassertion/client_assertion_test.go index 9ccfea2..4ba65ea 100644 --- a/oidc/clientassertion/client_assertion_test.go +++ b/oidc/clientassertion/client_assertion_test.go @@ -35,6 +35,8 @@ func assertJoinedErrs(t *testing.T, actual error, expect []error) { // TestJWTBare tests what errors we expect if &JWT{} // is instantiated directly, rather than using a constructor. func TestJWTBare(t *testing.T) { + t.Parallel() + j := &JWT{} tokenStr, err := j.Serialize() @@ -44,69 +46,83 @@ func TestJWTBare(t *testing.T) { } func TestNewJWTWithRSAKey(t *testing.T) { + t.Parallel() + cid := "test-client-id" aud := []string{"test-audience"} validKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - // happy path - j, err := NewJWTWithRSAKey(cid, aud, RS256, validKey, - WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) - assert.NoError(t, err) - assert.NotNil(t, j) + t.Run("happy path", func(t *testing.T) { + j, err := NewJWTWithRSAKey(cid, aud, RS256, validKey, + WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) + assert.NoError(t, err) + assert.NotNil(t, j) + }) - // errors - j, err = NewJWTWithRSAKey("", []string{}, "", nil) - assertJoinedErrs(t, err, []error{ - ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingKey, + t.Run("multiple errors", func(t *testing.T) { + j, err := NewJWTWithRSAKey("", []string{}, "", nil) + assertJoinedErrs(t, err, []error{ + ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingKey, + }) + assert.Nil(t, j) + }) + t.Run("bad algorithm", func(t *testing.T) { + _, err := NewJWTWithRSAKey(cid, aud, "bad-alg", &rsa.PrivateKey{}) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) + }) + t.Run("bad key", func(t *testing.T) { + _, err = NewJWTWithRSAKey(cid, aud, RS256, &rsa.PrivateKey{}) + assert.ErrorContains(t, err, "RSAlgorithm.Validate: crypto/rsa") + }) + t.Run("bad Options", func(t *testing.T) { + _, err = NewJWTWithRSAKey(cid, aud, RS256, validKey, + WithKeyID(""), WithHeaders(map[string]string{"kid": "baz"})) + assert.ErrorIs(t, err, ErrMissingKeyID) + assert.ErrorIs(t, err, ErrKidHeader) }) - assert.Nil(t, j) - // bad algorithm - j, err = NewJWTWithRSAKey(cid, aud, "bad-alg", &rsa.PrivateKey{}) - assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) - // bad key; only checked if good alg - j, err = NewJWTWithRSAKey(cid, aud, RS256, &rsa.PrivateKey{}) - assert.ErrorContains(t, err, "RSAlgorithm.Validate: crypto/rsa") - assert.Nil(t, j) - // bad With*s - j, err = NewJWTWithRSAKey(cid, aud, RS256, validKey, - WithKeyID(""), WithHeaders(map[string]string{"kid": "baz"})) - assert.ErrorIs(t, err, ErrMissingKeyID) - assert.ErrorIs(t, err, ErrKidHeader) } func TestNewJWTWithHMAC(t *testing.T) { + t.Parallel() + cid := "test-client-id" aud := []string{"test-audience"} validSecret := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 32 bytes for HS256 - // happy path - j, err := NewJWTWithHMAC(cid, aud, HS256, validSecret, - WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) - assert.NoError(t, err) - assert.NotNil(t, j) + t.Run("happy path", func(t *testing.T) { + j, err := NewJWTWithHMAC(cid, aud, HS256, validSecret, + WithKeyID("key-id"), WithHeaders(map[string]string{"foo": "bar"})) + assert.NoError(t, err) + assert.NotNil(t, j) + }) - // errors - j, err = NewJWTWithHMAC("", []string{}, "", "") - assertJoinedErrs(t, err, []error{ - ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingSecret, + t.Run("errors", func(t *testing.T) { + j, err := NewJWTWithHMAC("", []string{}, "", "") + assertJoinedErrs(t, err, []error{ + ErrMissingClientID, ErrMissingAudience, ErrMissingAlgorithm, ErrMissingSecret, + }) + assert.Nil(t, j) + }) + t.Run("bad algorithm", func(t *testing.T) { + _, err := NewJWTWithHMAC(cid, aud, "bad-alg", validSecret) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) + }) + t.Run("bad secret", func(t *testing.T) { + _, err := NewJWTWithHMAC(cid, aud, HS256, "not-very-good") + assert.ErrorIs(t, err, ErrInvalidSecretLength) + }) + t.Run("bad Options", func(t *testing.T) { + _, err := NewJWTWithHMAC(cid, aud, HS256, validSecret, + WithKeyID(""), WithHeaders(map[string]string{"kid": "baz"})) + assert.ErrorIs(t, err, ErrMissingKeyID) + assert.ErrorIs(t, err, ErrKidHeader) }) - assert.Nil(t, j) - // bad algorithm - j, err = NewJWTWithHMAC(cid, aud, "bad-alg", validSecret) - assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) - // bad secret; only checked if good alg - j, err = NewJWTWithHMAC(cid, aud, HS256, "not-very-good") - assert.ErrorIs(t, err, ErrInvalidSecretLength) - assert.Nil(t, j) - // bad With*s - j, err = NewJWTWithHMAC(cid, aud, HS256, validSecret, - WithKeyID(""), WithHeaders(map[string]string{"kid": "baz"})) - assert.ErrorIs(t, err, ErrMissingKeyID) - assert.ErrorIs(t, err, ErrKidHeader) } func TestJWT_Serialize(t *testing.T) { + t.Parallel() + cid := "test-client-id" aud := []string{"test-audience"}