diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0ab081f7..032eb767 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,16 +13,15 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.18.x", "1.19.x", "1.20.x"] + go: ["1.21", "1.22", "1.23"] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: "${{ matrix.go }}" check-latest: true - cache: true - name: Check Go code formatting run: | if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then @@ -34,16 +33,16 @@ jobs: run: | go install github.com/mfridman/tparse@latest go vet ./... - go test -v -race -count=1 -json -coverpkg=$(go list ./...) ./... | tee output.json | tparse -follow -notests || true + go test -v -race -count=1 -json -cover ./... | tee output.json | tparse -follow -notests || true tparse -format markdown -file output.json -all > $GITHUB_STEP_SUMMARY go build ./... coverage: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 - name: Coverage run: | go test -v -covermode=count -coverprofile=coverage.cov ./... diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 6210b814..37081a2c 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -37,11 +37,11 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -52,7 +52,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 # â„šī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -66,4 +66,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e6ef00ba..cec3b92b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -10,15 +10,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.20.x" + go-version: "1.23" check-latest: true - cache: true - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version version: latest diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index 6ad1c22b..ff9c57e1 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -17,7 +17,7 @@ and corresponding updates for existing programs. ## Parsing and Validation Options -Under the hood, a new `validator` struct takes care of validating the claims. A +Under the hood, a new `Validator` struct takes care of validating the claims. A long awaited feature has been the option to fine-tune the validation of tokens. This is now possible with several `ParserOption` functions that can be appended to most `Parse` functions, such as `ParseWithClaims`. The most important options @@ -68,6 +68,16 @@ type Claims interface { } ``` +Users that previously directly called the `Valid` function on their claims, +e.g., to perform validation independently of parsing/verifying a token, can now +use the `jwt.NewValidator` function to create a `Validator` independently of the +`Parser`. + +```go +var v = jwt.NewValidator(jwt.WithLeeway(5*time.Second)) +v.Validate(myClaims) +``` + ### Supported Claim Types and Removal of `StandardClaims` The two standard claim types supported by this library, `MapClaims` and @@ -169,7 +179,7 @@ be a drop-in replacement, if you're having troubles migrating, please open an issue. You can replace all occurrences of `github.com/dgrijalva/jwt-go` or -`github.com/golang-jwt/jwt` with `github.com/golang-jwt/jwt/v5`, either manually +`github.com/golang-jwt/jwt` with `github.com/golang-jwt/jwt/v4`, either manually or by using tools such as `sed` or `gofmt`. And then you'd typically run: diff --git a/README.md b/README.md index 964598a3..0bb636f2 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ implementation of [JSON Web Tokens](https://datatracker.ietf.org/doc/html/rfc7519). Starting with [v4.0.0](https://github.com/golang-jwt/jwt/releases/tag/v4.0.0) -this project adds Go module support, but maintains backwards compatibility with +this project adds Go module support, but maintains backward compatibility with older `v3.x.y` tags and upstream `github.com/dgrijalva/jwt-go`. See the [`MIGRATION_GUIDE.md`](./MIGRATION_GUIDE.md) for more information. Version v5.0.0 introduces major improvements to the validation of tokens, but is not -entirely backwards compatible. +entirely backward compatible. > After the original author of the library suggested migrating the maintenance > of `jwt-go`, a dedicated team of open source maintainers decided to clone the @@ -24,7 +24,7 @@ entirely backwards compatible. **SECURITY NOTICE:** Some older versions of Go have a security issue in the -crypto/elliptic. Recommendation is to upgrade to at least 1.15 See issue +crypto/elliptic. The recommendation is to upgrade to at least 1.15 See issue [dgrijalva/jwt-go#216](https://github.com/dgrijalva/jwt-go/issues/216) for more detail. @@ -32,7 +32,7 @@ detail. what you expect](https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key -types match the expected alg, but you should take the extra step to verify it in +types to match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided. ### Supported Go versions @@ -41,7 +41,7 @@ Our support of Go versions is aligned with Go's [version release policy](https://golang.org/doc/devel/release#policy). So we will support a major version of Go until there are two newer major releases. We no longer support building jwt-go with unsupported Go versions, as these contain security -vulnerabilities which will not be fixed. +vulnerabilities that will not be fixed. ## What the heck is a JWT? @@ -117,7 +117,7 @@ notable differences: This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few -backwards-incompatible changes outside of major version updates (and only with +backward-incompatible changes outside of major version updates (and only with good reason). This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull @@ -125,8 +125,8 @@ requests will land on `main`. Periodically, versions will be tagged from `main`. You can find all the releases on [the project releases page](https://github.com/golang-jwt/jwt/releases). -**BREAKING CHANGES:*** A full list of breaking changes is available in -`VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating +**BREAKING CHANGES:** A full list of breaking changes is available in +`VERSION_HISTORY.md`. See [`MIGRATION_GUIDE.md`](./MIGRATION_GUIDE.md) for more information on updating your code. ## Extensions diff --git a/SECURITY.md b/SECURITY.md index b08402c3..2740597f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -2,11 +2,11 @@ ## Supported Versions -As of February 2022 (and until this document is updated), the latest version `v4` is supported. +As of November 2024 (and until this document is updated), the latest version `v5` is supported. In critical cases, we might supply back-ported patches for `v4`. ## Reporting a Vulnerability -If you think you found a vulnerability, and even if you are not sure, please report it to jwt-go-security@googlegroups.com or one of the other [golang-jwt maintainers](https://github.com/orgs/golang-jwt/people). Please try be explicit, describe steps to reproduce the security issue with code example(s). +If you think you found a vulnerability, and even if you are not sure, please report it a [GitHub Security Advisory](https://github.com/golang-jwt/jwt/security/advisories/new). Please try be explicit, describe steps to reproduce the security issue with code example(s). You will receive a response within a timely manner. If the issue is confirmed, we will do our best to release a patch as soon as possible given the complexity of the problem. diff --git a/cmd/jwt/main.go b/cmd/jwt/main.go index f1e49a90..22031ca2 100644 --- a/cmd/jwt/main.go +++ b/cmd/jwt/main.go @@ -30,9 +30,9 @@ var ( flagHead = make(ArgList) // Modes - exactly one of these is required - flagSign = flag.String("sign", "", "path to claims object to sign, '-' to read from stdin, or '+' to use only -claim args") - flagVerify = flag.String("verify", "", "path to JWT token to verify or '-' to read from stdin") - flagShow = flag.String("show", "", "path to JWT file or '-' to read from stdin") + flagSign = flag.String("sign", "", "path to claims file to sign, '-' to read from stdin, or '+' to use only -claim args") + flagVerify = flag.String("verify", "", "path to JWT token file to verify or '-' to read from stdin") + flagShow = flag.String("show", "", "path to JWT token file to show without verification or '-' to read from stdin") ) func main() { @@ -43,7 +43,7 @@ func main() { // Usage message if you ask for -help or if you mess up inputs. flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) - fmt.Fprintf(os.Stderr, " One of the following flags is required: sign, verify\n") + fmt.Fprintf(os.Stderr, " One of the following flags is required: sign, verify or show\n") flag.PrintDefaults() } @@ -60,15 +60,16 @@ func main() { // Figure out which thing to do and then do that func start() error { - if *flagSign != "" { + switch { + case *flagSign != "": return signToken() - } else if *flagVerify != "" { + case *flagVerify != "": return verifyToken() - } else if *flagShow != "" { + case *flagShow != "": return showToken() - } else { + default: flag.Usage() - return fmt.Errorf("none of the required flags are present. What do you want me to do?") + return fmt.Errorf("none of the required flags are present. What do you want me to do?") } } @@ -79,17 +80,18 @@ func loadData(p string) ([]byte, error) { } var rdr io.Reader - if p == "-" { + switch p { + case "-": rdr = os.Stdin - } else if p == "+" { + case "+": return []byte("{}"), nil - } else { - if f, err := os.Open(p); err == nil { - rdr = f - defer f.Close() - } else { + default: + f, err := os.Open(p) + if err != nil { return nil, err } + rdr = f + defer f.Close() } return io.ReadAll(rdr) } @@ -136,30 +138,27 @@ func verifyToken() error { if err != nil { return nil, err } - if isEs() { + switch { + case isEs(): return jwt.ParseECPublicKeyFromPEM(data) - } else if isRs() { + case isRs(): return jwt.ParseRSAPublicKeyFromPEM(data) - } else if isEd() { + case isEd(): return jwt.ParseEdPublicKeyFromPEM(data) + default: + return data, nil } - return data, nil }) - // Print some debug data - if *flagDebug && token != nil { - fmt.Fprintf(os.Stderr, "Header:\n%v\n", token.Header) - fmt.Fprintf(os.Stderr, "Claims:\n%v\n", token.Claims) - } - // Print an error if we can't parse for some reason if err != nil { return fmt.Errorf("couldn't parse token: %w", err) } - // Is token invalid? - if !token.Valid { - return fmt.Errorf("token is invalid") + // Print some debug data + if *flagDebug { + fmt.Fprintf(os.Stderr, "Header:\n%v\n", token.Header) + fmt.Fprintf(os.Stderr, "Claims:\n%v\n", token.Claims) } // Print the token details @@ -221,40 +220,41 @@ func signToken() error { } } - if isEs() { - if k, ok := key.([]byte); !ok { + switch { + case isEs(): + k, ok := key.([]byte) + if !ok { return fmt.Errorf("couldn't convert key data to key") - } else { - key, err = jwt.ParseECPrivateKeyFromPEM(k) - if err != nil { - return err - } } - } else if isRs() { - if k, ok := key.([]byte); !ok { + key, err = jwt.ParseECPrivateKeyFromPEM(k) + if err != nil { + return err + } + case isRs(): + k, ok := key.([]byte) + if !ok { return fmt.Errorf("couldn't convert key data to key") - } else { - key, err = jwt.ParseRSAPrivateKeyFromPEM(k) - if err != nil { - return err - } } - } else if isEd() { - if k, ok := key.([]byte); !ok { + key, err = jwt.ParseRSAPrivateKeyFromPEM(k) + if err != nil { + return err + } + case isEd(): + k, ok := key.([]byte) + if !ok { return fmt.Errorf("couldn't convert key data to key") - } else { - key, err = jwt.ParseEdPrivateKeyFromPEM(k) - if err != nil { - return err - } + } + key, err = jwt.ParseEdPrivateKeyFromPEM(k) + if err != nil { + return err } } - if out, err := token.SignedString(key); err == nil { - fmt.Println(out) - } else { + out, err := token.SignedString(key) + if err != nil { return fmt.Errorf("error signing token: %w", err) } + fmt.Println(out) return nil } @@ -273,8 +273,8 @@ func showToken() error { fmt.Fprintf(os.Stderr, "Token len: %v bytes\n", len(tokData)) } - token, err := jwt.Parse(string(tokData), nil) - if token == nil { + token, _, err := jwt.NewParser().ParseUnverified(string(tokData), make(jwt.MapClaims)) + if err != nil { return fmt.Errorf("malformed token: %w", err) } diff --git a/ecdsa.go b/ecdsa.go index 4ccae2a8..c929e4a0 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -62,7 +62,7 @@ func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interf case *ecdsa.PublicKey: ecdsaKey = k default: - return ErrInvalidKeyType + return newError("ECDSA verify expects *ecdsa.PublicKey", ErrInvalidKeyType) } if len(sig) != 2*m.KeySize { @@ -96,7 +96,7 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte case *ecdsa.PrivateKey: ecdsaKey = k default: - return nil, ErrInvalidKeyType + return nil, newError("ECDSA sign expects *ecdsa.PrivateKey", ErrInvalidKeyType) } // Create the hasher diff --git a/ed25519.go b/ed25519.go index 3db00e4a..c2138119 100644 --- a/ed25519.go +++ b/ed25519.go @@ -1,11 +1,10 @@ package jwt import ( - "errors" - "crypto" "crypto/ed25519" "crypto/rand" + "errors" ) var ( @@ -39,7 +38,7 @@ func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key inte var ok bool if ed25519Key, ok = key.(ed25519.PublicKey); !ok { - return ErrInvalidKeyType + return newError("Ed25519 verify expects ed25519.PublicKey", ErrInvalidKeyType) } if len(ed25519Key) != ed25519.PublicKeySize { @@ -61,7 +60,7 @@ func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]by var ok bool if ed25519Key, ok = key.(crypto.Signer); !ok { - return nil, ErrInvalidKeyType + return nil, newError("Ed25519 sign expects crypto.Signer", ErrInvalidKeyType) } if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok { diff --git a/errors_go_other.go b/errors_go_other.go index 3afb04e6..2ad542f0 100644 --- a/errors_go_other.go +++ b/errors_go_other.go @@ -22,7 +22,7 @@ func (je joinedError) Is(err error) bool { // wrappedErrors is a workaround for wrapping multiple errors in environments // where Go 1.20 is not available. It basically uses the already implemented -// functionatlity of joinedError to handle multiple errors with supplies a +// functionality of joinedError to handle multiple errors with supplies a // custom error message that is identical to the one we produce in Go 1.20 using // multiple %w directives. type wrappedErrors struct { diff --git a/example_test.go b/example_test.go index f677d7c0..651841de 100644 --- a/example_test.go +++ b/example_test.go @@ -3,6 +3,7 @@ package jwt_test import ( "errors" "fmt" + "log" "time" "github.com/golang-jwt/jwt/v5" @@ -24,8 +25,8 @@ func ExampleNewWithClaims_registeredClaims() { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) ss, err := token.SignedString(mySigningKey) - fmt.Printf("%v %v", ss, err) - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 + fmt.Println(ss, err) + // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 } // Example creating a token using a custom claims type. The RegisteredClaims is embedded @@ -67,10 +68,10 @@ func ExampleNewWithClaims_customClaimsType() { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) ss, err := token.SignedString(mySigningKey) - fmt.Printf("%v %v", ss, err) + fmt.Println(ss, err) - //Output: foo: bar - //eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM + // Output: foo: bar + // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM } // Example creating a token using a custom claims type. The RegisteredClaims is embedded @@ -86,11 +87,12 @@ func ExampleParseWithClaims_customClaimsType() { token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte("AllYourBase"), nil }) - - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + if err != nil { + log.Fatal(err) + } else if claims, ok := token.Claims.(*MyCustomClaims); ok { + fmt.Println(claims.Foo, claims.RegisteredClaims.Issuer) } else { - fmt.Println(err) + log.Fatal("unknown claims type, cannot proceed") } // Output: bar test @@ -109,11 +111,12 @@ func ExampleParseWithClaims_validationOptions() { token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte("AllYourBase"), nil }, jwt.WithLeeway(5*time.Second)) - - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + if err != nil { + log.Fatal(err) + } else if claims, ok := token.Claims.(*MyCustomClaims); ok { + fmt.Println(claims.Foo, claims.RegisteredClaims.Issuer) } else { - fmt.Println(err) + log.Fatal("unknown claims type, cannot proceed") } // Output: bar test @@ -124,6 +127,9 @@ type MyCustomClaims struct { jwt.RegisteredClaims } +// Ensure we implement [jwt.ClaimsValidator] at compile time so we know our custom Validate method is used. +var _ jwt.ClaimsValidator = (*MyCustomClaims)(nil) + // Validate can be used to execute additional application-specific claims // validation. func (m MyCustomClaims) Validate() error { @@ -144,11 +150,12 @@ func ExampleParseWithClaims_customValidation() { token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte("AllYourBase"), nil }, jwt.WithLeeway(5*time.Second)) - - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + if err != nil { + log.Fatal(err) + } else if claims, ok := token.Claims.(*MyCustomClaims); ok { + fmt.Println(claims.Foo, claims.RegisteredClaims.Issuer) } else { - fmt.Println(err) + log.Fatal("unknown claims type, cannot proceed") } // Output: bar test @@ -163,17 +170,18 @@ func ExampleParse_errorChecking() { return []byte("AllYourBase"), nil }) - if token.Valid { + switch { + case token.Valid: fmt.Println("You look nice today") - } else if errors.Is(err, jwt.ErrTokenMalformed) { + case errors.Is(err, jwt.ErrTokenMalformed): fmt.Println("That's not even a token") - } else if errors.Is(err, jwt.ErrTokenSignatureInvalid) { + case errors.Is(err, jwt.ErrTokenSignatureInvalid): // Invalid signature fmt.Println("Invalid signature") - } else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) { + case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): // Token is either expired or not active yet fmt.Println("Timing is everything") - } else { + default: fmt.Println("Couldn't handle this token:", err) } diff --git a/hmac.go b/hmac.go index 91b688ba..aca600ce 100644 --- a/hmac.go +++ b/hmac.go @@ -59,7 +59,7 @@ func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interfa // Verify the key is the right type keyBytes, ok := key.([]byte) if !ok { - return ErrInvalidKeyType + return newError("HMAC verify expects []byte", ErrInvalidKeyType) } // Can we use the specified hashing method? @@ -100,5 +100,5 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, return hasher.Sum(nil), nil } - return nil, ErrInvalidKeyType + return nil, newError("HMAC sign expects []byte", ErrInvalidKeyType) } diff --git a/hmac_example_test.go b/hmac_example_test.go index 4b2ff08a..f8f8c26b 100644 --- a/hmac_example_test.go +++ b/hmac_example_test.go @@ -2,6 +2,7 @@ package jwt_test import ( "fmt" + "log" "os" "time" @@ -48,16 +49,14 @@ func ExampleParse_hmac() { // head of the token to identify which key to use, but the parsed token (head and claims) is provided // to the callback, providing flexibility. token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // Don't forget to validate the alg is what you expect: - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) - } - // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") return hmacSampleSecret, nil - }) + }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) + if err != nil { + log.Fatal(err) + } - if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + if claims, ok := token.Claims.(jwt.MapClaims); ok { fmt.Println(claims["foo"], claims["nbf"]) } else { fmt.Println(err) diff --git a/hmac_test.go b/hmac_test.go index 264a2a42..3eb03804 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -66,16 +66,17 @@ func TestHMACVerify(t *testing.T) { func TestHMACSign(t *testing.T) { for _, data := range hmacTestData { - if data.valid { - parts := strings.Split(data.tokenString, ".") - method := jwt.GetSigningMethod(data.alg) - sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey) - if err != nil { - t.Errorf("[%v] Error signing token: %v", data.name, err) - } - if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { - t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) - } + if !data.valid { + continue + } + parts := strings.Split(data.tokenString, ".") + method := jwt.GetSigningMethod(data.alg) + sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } } diff --git a/http_example_test.go b/http_example_test.go index c09cc367..0b22af93 100644 --- a/http_example_test.go +++ b/http_example_test.go @@ -4,7 +4,6 @@ package jwt_test // This is based on a (now outdated) example at https://gist.github.com/cryptix/45c33ecf0ae54828e63b import ( - "bytes" "crypto/rsa" "fmt" "io" @@ -84,20 +83,17 @@ func Example_getTokenViaHTTP() { "user": {"test"}, "pass": {"known"}, }) - if err != nil { - fatal(err) - } + fatal(err) if res.StatusCode != 200 { fmt.Println("Unexpected status code", res.StatusCode) } // Read the token out of the response body - buf := new(bytes.Buffer) - _, err = io.Copy(buf, res.Body) + buf, err := io.ReadAll(res.Body) fatal(err) res.Body.Close() - tokenString := strings.TrimSpace(buf.String()) + tokenString := strings.TrimSpace(string(buf)) // Parse the token token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) { @@ -110,11 +106,10 @@ func Example_getTokenViaHTTP() { claims := token.Claims.(*CustomClaimsExample) fmt.Println(claims.CustomerInfo.Name) - //Output: test + // Output: test } func Example_useTokenViaHTTP() { - // Make a sample token // In a real world situation, this token will have been acquired from // some other API call (see Example_getTokenViaHTTP) @@ -129,11 +124,10 @@ func Example_useTokenViaHTTP() { fatal(err) // Read the response body - buf := new(bytes.Buffer) - _, err = io.Copy(buf, res.Body) + buf, err := io.ReadAll(res.Body) fatal(err) res.Body.Close() - fmt.Println(buf.String()) + fmt.Printf("%s", buf) // Output: Welcome, foo } diff --git a/jwt_test.go b/jwt_test.go new file mode 100644 index 00000000..b01e899d --- /dev/null +++ b/jwt_test.go @@ -0,0 +1,89 @@ +package jwt + +import ( + "testing" +) + +func TestSplitToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected []string + isValid bool + }{ + { + name: "valid token with three parts", + input: "header.claims.signature", + expected: []string{"header", "claims", "signature"}, + isValid: true, + }, + { + name: "invalid token with two parts only", + input: "header.claims", + expected: nil, + isValid: false, + }, + { + name: "invalid token with one part only", + input: "header", + expected: nil, + isValid: false, + }, + { + name: "invalid token with extra delimiter", + input: "header.claims.signature.extra", + expected: nil, + isValid: false, + }, + { + name: "invalid empty token", + input: "", + expected: nil, + isValid: false, + }, + { + name: "valid token with empty parts", + input: "..signature", + expected: []string{"", "", "signature"}, + isValid: true, + }, + { + // We are just splitting the token into parts, so we don't care about the actual values. + // It is up to the caller to validate the parts. + name: "valid token with all parts empty", + input: "..", + expected: []string{"", "", ""}, + isValid: true, + }, + { + name: "invalid token with just delimiters and extra part", + input: "...", + expected: nil, + isValid: false, + }, + { + name: "invalid token with many delimiters", + input: "header.claims.signature..................", + expected: nil, + isValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parts, ok := splitToken(tt.input) + if ok != tt.isValid { + t.Errorf("expected %t, got %t", tt.isValid, ok) + } + if ok { + for i, part := range tt.expected { + if parts[i] != part { + t.Errorf("expected %s, got %s", part, parts[i]) + } + } + } + }) + } +} diff --git a/map_claims_test.go b/map_claims_test.go index 83065d5b..034173d2 100644 --- a/map_claims_test.go +++ b/map_claims_test.go @@ -46,9 +46,9 @@ func TestVerifyAud(t *testing.T) { // []interface{} {Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"}, - {Name: "[]interface{} Aud wit match required", MapClaims: MapClaims{"aud": []interface{}{"a", "foo", "example.com"}}, Expected: true, Required: true, Comparison: "example.com"}, - {Name: "[]interface{} Aud wit match but invalid types", MapClaims: MapClaims{"aud": []interface{}{"a", 5, "example.com"}}, Expected: false, Required: true, Comparison: "example.com"}, - {Name: "[]interface{} Aud int wit match required", MapClaims: MapClaims{"aud": intListInterface}, Expected: false, Required: true, Comparison: "example.com"}, + {Name: "[]interface{} Aud with match required", MapClaims: MapClaims{"aud": []interface{}{"a", "foo", "example.com"}}, Expected: true, Required: true, Comparison: "example.com"}, + {Name: "[]interface{} Aud with match but invalid types", MapClaims: MapClaims{"aud": []interface{}{"a", 5, "example.com"}}, Expected: false, Required: true, Comparison: "example.com"}, + {Name: "[]interface{} Aud int with match required", MapClaims: MapClaims{"aud": intListInterface}, Expected: false, Required: true, Comparison: "example.com"}, // interface{} {Name: "Empty interface{} Aud without match not required", MapClaims: MapClaims{"aud": nilInterface}, Expected: true, Required: false, Comparison: "example.com"}, @@ -62,7 +62,7 @@ func TestVerifyAud(t *testing.T) { opts = append(opts, WithAudience(test.Comparison)) } - validator := newValidator(opts...) + validator := NewValidator(opts...) got := validator.Validate(test.MapClaims) if (got == nil) != test.Expected { @@ -77,7 +77,7 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) { "iat": "foo", } want := false - got := newValidator(WithIssuedAt()).Validate(mapClaims) + got := NewValidator(WithIssuedAt()).Validate(mapClaims) if want != (got == nil) { t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } @@ -88,7 +88,7 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) { "nbf": "foo", } want := false - got := newValidator().Validate(mapClaims) + got := NewValidator().Validate(mapClaims) if want != (got == nil) { t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } @@ -99,7 +99,7 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) { "exp": "foo", } want := false - got := newValidator().Validate(mapClaims) + got := NewValidator().Validate(mapClaims) if want != (got == nil) { t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) @@ -112,14 +112,14 @@ func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) { "exp": float64(exp.Unix()), } want := false - got := newValidator(WithTimeFunc(func() time.Time { + got := NewValidator(WithTimeFunc(func() time.Time { return exp })).Validate(mapClaims) if want != (got == nil) { t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } - got = newValidator(WithTimeFunc(func() time.Time { + got = NewValidator(WithTimeFunc(func() time.Time { return exp.Add(1 * time.Second) })).Validate(mapClaims) if want != (got == nil) { @@ -127,7 +127,7 @@ func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) { } want = true - got = newValidator(WithTimeFunc(func() time.Time { + got = NewValidator(WithTimeFunc(func() time.Time { return exp.Add(-1 * time.Second) })).Validate(mapClaims) if want != (got == nil) { diff --git a/none.go b/none.go index c93daa58..685c2ea3 100644 --- a/none.go +++ b/none.go @@ -32,7 +32,7 @@ func (m *signingMethodNone) Verify(signingString string, sig []byte, key interfa return NoneSignatureTypeDisallowedError } // If signing method is none, signature must be an empty string - if string(sig) != "" { + if len(sig) != 0 { return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable) } diff --git a/none_test.go b/none_test.go index d370cf8c..f126b14a 100644 --- a/none_test.go +++ b/none_test.go @@ -59,16 +59,17 @@ func TestNoneVerify(t *testing.T) { func TestNoneSign(t *testing.T) { for _, data := range noneTestData { - if data.valid { - parts := strings.Split(data.tokenString, ".") - method := jwt.GetSigningMethod(data.alg) - sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key) - if err != nil { - t.Errorf("[%v] Error signing token: %v", data.name, err) - } - if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { - t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) - } + if !data.valid { + continue + } + parts := strings.Split(data.tokenString, ".") + method := jwt.GetSigningMethod(data.alg) + sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } } diff --git a/parser.go b/parser.go index f4386fba..054c7eb6 100644 --- a/parser.go +++ b/parser.go @@ -8,6 +8,8 @@ import ( "strings" ) +const tokenDelimiter = "." + type Parser struct { // If populated, only these methods will be considered valid. validMethods []string @@ -18,7 +20,7 @@ type Parser struct { // Skip claims validation during token parsing. skipClaimsValidation bool - validator *validator + validator *Validator decodeStrict bool @@ -28,7 +30,7 @@ type Parser struct { // NewParser creates a new Parser with the specified options func NewParser(options ...ParserOption) *Parser { p := &Parser{ - validator: &validator{}, + validator: &Validator{}, } // Loop through our parsing options and apply them @@ -74,24 +76,40 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } } - // Lookup key - var key interface{} + // Decode signature + token.Signature, err = p.DecodeSegment(parts[2]) + if err != nil { + return token, newError("could not base64 decode signature", ErrTokenMalformed, err) + } + text := strings.Join(parts[0:2], ".") + + // Lookup key(s) if keyFunc == nil { // keyFunc was not provided. short circuiting validation return token, newError("no keyfunc was provided", ErrTokenUnverifiable) } - if key, err = keyFunc(token); err != nil { - return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err) - } - // Decode signature - token.Signature, err = p.DecodeSegment(parts[2]) + got, err := keyFunc(token) if err != nil { - return token, newError("could not base64 decode signature", ErrTokenMalformed, err) + return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err) } - // Perform signature validation - if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { + switch have := got.(type) { + case VerificationKeySet: + if len(have.Keys) == 0 { + return token, newError("keyfunc returned empty verification key set", ErrTokenUnverifiable) + } + // Iterate through keys and verify signature, skipping the rest when a match is found. + // Return the last error if no match is found. + for _, key := range have.Keys { + if err = token.Method.Verify(text, token.Signature, key); err == nil { + break + } + } + default: + err = token.Method.Verify(text, token.Signature, have) + } + if err != nil { return token, newError("", ErrTokenSignatureInvalid, err) } @@ -99,7 +117,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf if !p.skipClaimsValidation { // Make sure we have at least a default validator if p.validator == nil { - p.validator = newValidator() + p.validator = NewValidator() } if err := p.validator.Validate(claims); err != nil { @@ -117,12 +135,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // // WARNING: Don't use this method unless you know what you're doing. // -// It's only ever useful in cases where you know the signature is valid (because it has -// been checked previously in the stack) and you want to extract values from it. +// It's only ever useful in cases where you know the signature is valid (since it has already +// been or will be checked elsewhere in the stack) and you want to extract values from it. func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { - parts = strings.Split(tokenString, ".") - if len(parts) != 3 { - return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed) + var ok bool + parts, ok = splitToken(tokenString) + if !ok { + return nil, nil, newError("token contains an invalid number of segments", ErrTokenMalformed) } token = &Token{Raw: tokenString} @@ -130,9 +149,6 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // parse Header var headerBytes []byte if headerBytes, err = p.DecodeSegment(parts[0]); err != nil { - if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { - return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed) - } return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } if err = json.Unmarshal(headerBytes, &token.Header); err != nil { @@ -140,23 +156,33 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // parse Claims - var claimBytes []byte token.Claims = claims - if claimBytes, err = p.DecodeSegment(parts[1]); err != nil { + claimBytes, err := p.DecodeSegment(parts[1]) + if err != nil { return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } - dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - if p.useJSONNumber { - dec.UseNumber() - } - // JSON Decode. Special case for map type to avoid weird pointer behavior - if c, ok := token.Claims.(MapClaims); ok { - err = dec.Decode(&c) + + // If `useJSONNumber` is enabled then we must use *json.Decoder to decode + // the claims. However, this comes with a performance penalty so only use + // it if we must and, otherwise, simple use json.Unmarshal. + if !p.useJSONNumber { + // JSON Unmarshal. Special case for map type to avoid weird pointer behavior. + if c, ok := token.Claims.(MapClaims); ok { + err = json.Unmarshal(claimBytes, &c) + } else { + err = json.Unmarshal(claimBytes, &claims) + } } else { - err = dec.Decode(&claims) + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) + dec.UseNumber() + // JSON Decode. Special case for map type to avoid weird pointer behavior. + if c, ok := token.Claims.(MapClaims); ok { + err = dec.Decode(&c) + } else { + err = dec.Decode(&claims) + } } - // Handle decode error if err != nil { return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err) } @@ -173,6 +199,33 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, nil } +// splitToken splits a token string into three parts: header, claims, and signature. It will only +// return true if the token contains exactly two delimiters and three parts. In all other cases, it +// will return nil parts and false. +func splitToken(token string) ([]string, bool) { + parts := make([]string, 3) + header, remain, ok := strings.Cut(token, tokenDelimiter) + if !ok { + return nil, false + } + parts[0] = header + claims, remain, ok := strings.Cut(remain, tokenDelimiter) + if !ok { + return nil, false + } + parts[1] = claims + // One more cut to ensure the signature is the last part of the token and there are no more + // delimiters. This avoids an issue where malicious input could contain additional delimiters + // causing unecessary overhead parsing tokens. + signature, _, unexpected := strings.Cut(remain, tokenDelimiter) + if unexpected { + return nil, false + } + parts[2] = signature + + return parts, true +} + // DecodeSegment decodes a JWT specific base64url encoding. This function will // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. diff --git a/parser_option.go b/parser_option.go index 1b5af970..88a780fb 100644 --- a/parser_option.go +++ b/parser_option.go @@ -58,6 +58,14 @@ func WithIssuedAt() ParserOption { } } +// WithExpirationRequired returns the ParserOption to make exp claim required. +// By default exp claim is optional. +func WithExpirationRequired() ParserOption { + return func(p *Parser) { + p.validator.requireExp = true + } +} + // WithAudience configures the validator to require the specified audience in // the `aud` claim. Validation will fail if the audience is not listed in the // token or the `aud` claim is missing. diff --git a/parser_test.go b/parser_test.go index 5b912b15..c0f81711 100644 --- a/parser_test.go +++ b/parser_test.go @@ -28,6 +28,25 @@ var ( emptyKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, nil } errorKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, errKeyFuncError } nilKeyFunc jwt.Keyfunc = nil + multipleZeroKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return []interface{}{}, nil } + multipleEmptyKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { + return jwt.VerificationKeySet{Keys: []jwt.VerificationKey{nil, nil}}, nil + } + multipleVerificationKeysFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { + return []jwt.VerificationKey{jwtTestDefaultKey, jwtTestEC256PublicKey}, nil + } + multipleLastKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { + return jwt.VerificationKeySet{Keys: []jwt.VerificationKey{jwtTestEC256PublicKey, jwtTestDefaultKey}}, nil + } + multipleFirstKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { + return jwt.VerificationKeySet{Keys: []jwt.VerificationKey{jwtTestDefaultKey, jwtTestEC256PublicKey}}, nil + } + multipleAltTypedKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { + return jwt.VerificationKeySet{Keys: []jwt.VerificationKey{jwtTestDefaultKey, jwtTestDefaultKey}}, nil + } + emptyVerificationKeySetFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { + return jwt.VerificationKeySet{}, nil + } ) func init() { @@ -42,7 +61,6 @@ func init() { // Load private keys jwtTestRSAPrivateKey = test.LoadRSAPrivateKeyFromDisk("test/sample_key") jwtTestEC256PrivateKey = test.LoadECPrivateKeyFromDisk("test/ec256-private.pem") - } var jwtTestData = []struct { @@ -95,6 +113,46 @@ var jwtTestData = []struct { nil, jwt.SigningMethodRS256, }, + { + "multiple keys, last matches", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + multipleLastKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + }, + { + "multiple keys not []interface{} type, all match", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + multipleAltTypedKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + }, + { + "multiple keys, first matches", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + multipleFirstKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + }, + { + "public keys slice, not allowed", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + multipleVerificationKeysFunc, + jwt.MapClaims{"foo": "bar"}, + false, + []error{jwt.ErrTokenSignatureInvalid}, + nil, + jwt.SigningMethodRS256, + }, { "basic expired", "", // autogen @@ -155,6 +213,36 @@ var jwtTestData = []struct { nil, jwt.SigningMethodRS256, }, + { + "multiple nokey", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + multipleEmptyKeyFunc, + jwt.MapClaims{"foo": "bar"}, + false, + []error{jwt.ErrTokenSignatureInvalid}, + nil, + jwt.SigningMethodRS256, + }, + { + "empty verification key set", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + emptyVerificationKeySetFunc, + jwt.MapClaims{"foo": "bar"}, + false, + []error{jwt.ErrTokenUnverifiable}, + nil, + jwt.SigningMethodRS256, + }, + { + "zero length key list", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + multipleZeroKeyFunc, + jwt.MapClaims{"foo": "bar"}, + false, + []error{jwt.ErrTokenSignatureInvalid}, + nil, + jwt.SigningMethodRS256, + }, { "basic errorkey", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", @@ -335,6 +423,16 @@ var jwtTestData = []struct { jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, }, + { + "rejects if exp is required but missing", + "", // autogen + defaultKeyFunc, + &jwt.RegisteredClaims{}, + false, + []error{jwt.ErrTokenInvalidClaims}, + jwt.NewParser(jwt.WithExpirationRequired()), + jwt.SigningMethodRS256, + }, } // signToken creates and returns a signed JWT token using signingMethod. @@ -352,11 +450,9 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string { } func TestParser_Parse(t *testing.T) { - // Iterate over test data set and run tests for _, data := range jwtTestData { t.Run(data.name, func(t *testing.T) { - // If the token string is blank, use helper function to generate string if data.tokenString == "" { data.tokenString = signToken(data.claims, data.signingMethod) @@ -428,7 +524,6 @@ func TestParser_Parse(t *testing.T) { } func TestParser_ParseUnverified(t *testing.T) { - // Iterate over test data set and run tests for _, data := range jwtTestData { // Skip test data, that intentionally contains malformed tokens, as they would lead to an error @@ -670,13 +765,11 @@ func TestSetPadding(t *testing.T) { err, ) } - }) } } func BenchmarkParseUnverified(b *testing.B) { - // Iterate over test data set and run tests for _, data := range jwtTestData { // If the token string is blank, use helper function to generate string diff --git a/request/extractor.go b/request/extractor.go index 780721b6..9ef5f066 100644 --- a/request/extractor.go +++ b/request/extractor.go @@ -90,7 +90,7 @@ func (e BearerExtractor) ExtractToken(req *http.Request) (string, error) { tokenHeader := req.Header.Get("Authorization") // The usual convention is for "Bearer" to be title-cased. However, there's no // strict rule around this, and it's best to follow the robustness principle here. - if tokenHeader == "" || !strings.HasPrefix(strings.ToLower(tokenHeader), "bearer ") { + if len(tokenHeader) < 7 || !strings.EqualFold(tokenHeader[:7], "bearer ") { return "", ErrNoTokenInRequest } return tokenHeader[7:], nil diff --git a/request/extractor_example_test.go b/request/extractor_example_test.go index a994ffe5..2d058cab 100644 --- a/request/extractor_example_test.go +++ b/request/extractor_example_test.go @@ -17,7 +17,7 @@ func ExampleHeaderExtractor() { } else { fmt.Println(err) } - //Output: A + // Output: A } func ExampleArgumentExtractor() { @@ -28,5 +28,5 @@ func ExampleArgumentExtractor() { } else { fmt.Println(err) } - //Output: A + // Output: A } diff --git a/request/oauth2.go b/request/oauth2.go index 5860a53f..9f88c3e9 100644 --- a/request/oauth2.go +++ b/request/oauth2.go @@ -7,7 +7,7 @@ import ( // Strips 'Bearer ' prefix from bearer token string func stripBearerPrefixFromTokenString(tok string) (string, error) { // Should be a bearer token - if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " { + if len(tok) > 6 && strings.EqualFold(tok[:7], "bearer ") { return tok[7:], nil } return tok, nil diff --git a/rsa.go b/rsa.go index daff0943..83cbee6a 100644 --- a/rsa.go +++ b/rsa.go @@ -51,7 +51,7 @@ func (m *SigningMethodRSA) Verify(signingString string, sig []byte, key interfac var ok bool if rsaKey, ok = key.(*rsa.PublicKey); !ok { - return ErrInvalidKeyType + return newError("RSA verify expects *rsa.PublicKey", ErrInvalidKeyType) } // Create hasher @@ -73,7 +73,7 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte, // Validate type of key if rsaKey, ok = key.(*rsa.PrivateKey); !ok { - return nil, ErrInvalidKey + return nil, newError("RSA sign expects *rsa.PrivateKey", ErrInvalidKeyType) } // Create the hasher diff --git a/rsa_pss.go b/rsa_pss.go index 9599f0a4..28c386ec 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -88,7 +88,7 @@ func (m *SigningMethodRSAPSS) Verify(signingString string, sig []byte, key inter case *rsa.PublicKey: rsaKey = k default: - return ErrInvalidKey + return newError("RSA-PSS verify expects *rsa.PublicKey", ErrInvalidKeyType) } // Create hasher @@ -115,7 +115,7 @@ func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) ([]byt case *rsa.PrivateKey: rsaKey = k default: - return nil, ErrInvalidKeyType + return nil, newError("RSA-PSS sign expects *rsa.PrivateKey", ErrInvalidKeyType) } // Create the hasher diff --git a/rsa_pss_test.go b/rsa_pss_test.go index 9707a755..536cde61 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -84,18 +84,19 @@ func TestRSAPSSSign(t *testing.T) { } for _, data := range rsaPSSTestData { - if data.valid { - parts := strings.Split(data.tokenString, ".") - method := jwt.GetSigningMethod(data.alg) - sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey) - if err != nil { - t.Errorf("[%v] Error signing token: %v", data.name, err) - } - - ssig := encodeSegment(sig) - if ssig == parts[2] { - t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2]) - } + if !data.valid { + continue + } + parts := strings.Split(data.tokenString, ".") + method := jwt.GetSigningMethod(data.alg) + sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + + ssig := encodeSegment(sig) + if ssig == parts[2] { + t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2]) } } } diff --git a/token.go b/token.go index c8ad7c78..9c7f4ab0 100644 --- a/token.go +++ b/token.go @@ -1,6 +1,7 @@ package jwt import ( + "crypto" "encoding/base64" "encoding/json" ) @@ -9,8 +10,21 @@ import ( // the key for verification. The function receives the parsed, but unverified // Token. This allows you to use properties in the Header of the token (such as // `kid`) to identify which key to use. +// +// The returned interface{} may be a single key or a VerificationKeySet containing +// multiple keys. type Keyfunc func(*Token) (interface{}, error) +// VerificationKey represents a public or secret key for verifying a token's signature. +type VerificationKey interface { + crypto.PublicKey | []uint8 +} + +// VerificationKeySet is a set of public or secret keys. It is used by the parser to verify a token. +type VerificationKeySet struct { + Keys []VerificationKey +} + // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. type Token struct { @@ -61,7 +75,7 @@ func (t *Token) SignedString(key interface{}) (string, error) { } // SigningString generates the signing string. This is the most expensive part -// of the whole deal. Unless you need this for something special, just go +// of the whole deal. Unless you need this for something special, just go // straight for the SignedString. func (t *Token) SigningString() (string, error) { h, err := json.Marshal(t.Header) diff --git a/types.go b/types.go index b82b3886..b2655a9e 100644 --- a/types.go +++ b/types.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "math" - "reflect" "strconv" "time" ) @@ -121,14 +120,14 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { for _, vv := range v { vs, ok := vv.(string) if !ok { - return &json.UnsupportedTypeError{Type: reflect.TypeOf(vv)} + return ErrInvalidType } aud = append(aud, vs) } case nil: return nil default: - return &json.UnsupportedTypeError{Type: reflect.TypeOf(v)} + return ErrInvalidType } *s = aud diff --git a/types_test.go b/types_test.go index d07f5586..bd7b139f 100644 --- a/types_test.go +++ b/types_test.go @@ -28,7 +28,7 @@ func TestNumericDate(t *testing.T) { b, _ := json.Marshal(s) if raw != string(b) { - t.Errorf("Serialized format of numeric date mismatch. Expecting: %s Got: %s", string(raw), string(b)) + t.Errorf("Serialized format of numeric date mismatch. Expecting: %s Got: %s", raw, string(b)) } jwt.TimePrecision = oldPrecision @@ -41,13 +41,12 @@ func TestSingleArrayMarshal(t *testing.T) { expected := `"test"` b, err := json.Marshal(s) - if err != nil { t.Errorf("Unexpected error: %s", err) } if expected != string(b) { - t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", string(expected), string(b)) + t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b)) } jwt.MarshalSingleStringAsArray = true @@ -61,7 +60,7 @@ func TestSingleArrayMarshal(t *testing.T) { } if expected != string(b) { - t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", string(expected), string(b)) + t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b)) } } diff --git a/validator.go b/validator.go index 38504389..008ecd87 100644 --- a/validator.go +++ b/validator.go @@ -28,13 +28,12 @@ type ClaimsValidator interface { Validate() error } -// validator is the core of the new Validation API. It is automatically used by +// Validator is the core of the new Validation API. It is automatically used by // a [Parser] during parsing and can be modified with various parser options. // -// Note: This struct is intentionally not exported (yet) as we want to -// internally finalize its API. In the future, we might make it publicly -// available. -type validator struct { +// The [NewValidator] function should be used to create an instance of this +// struct. +type Validator struct { // leeway is an optional leeway that can be provided to account for clock skew. leeway time.Duration @@ -42,6 +41,9 @@ type validator struct { // validation. If unspecified, this defaults to time.Now. timeFunc func() time.Time + // requireExp specifies whether the exp claim is required + requireExp bool + // verifyIat specifies whether the iat (Issued At) claim will be verified. // According to https://www.rfc-editor.org/rfc/rfc7519#section-4.1.6 this // only specifies the age of the token, but no validation check is @@ -62,16 +64,28 @@ type validator struct { expectedSub string } -// newValidator can be used to create a stand-alone validator with the supplied +// NewValidator can be used to create a stand-alone validator with the supplied // options. This validator can then be used to validate already parsed claims. -func newValidator(opts ...ParserOption) *validator { +// +// Note: Under normal circumstances, explicitly creating a validator is not +// needed and can potentially be dangerous; instead functions of the [Parser] +// class should be used. +// +// The [Validator] is only checking the *validity* of the claims, such as its +// expiration time, but it does NOT perform *signature verification* of the +// token. +func NewValidator(opts ...ParserOption) *Validator { p := NewParser(opts...) return p.validator } // Validate validates the given claims. It will also perform any custom // validation if claims implements the [ClaimsValidator] interface. -func (v *validator) Validate(claims Claims) error { +// +// Note: It will NOT perform any *signature verification* on the token that +// contains the claims and expects that the [Claim] was already successfully +// verified. +func (v *Validator) Validate(claims Claims) error { var ( now time.Time errs []error = make([]error, 0, 6) @@ -86,8 +100,9 @@ func (v *validator) Validate(claims Claims) error { } // We always need to check the expiration time, but usage of the claim - // itself is OPTIONAL. - if err = v.verifyExpiresAt(claims, now, false); err != nil { + // itself is OPTIONAL by default. requireExp overrides this behavior + // and makes the exp claim mandatory. + if err = v.verifyExpiresAt(claims, now, v.requireExp); err != nil { errs = append(errs, err) } @@ -149,7 +164,7 @@ func (v *validator) Validate(claims Claims) error { // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error { +func (v *Validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error { exp, err := claims.GetExpirationTime() if err != nil { return err @@ -170,7 +185,7 @@ func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error { +func (v *Validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error { iat, err := claims.GetIssuedAt() if err != nil { return err @@ -191,7 +206,7 @@ func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error { +func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error { nbf, err := claims.GetNotBefore() if err != nil { return err @@ -211,7 +226,7 @@ func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error { +func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error { aud, err := claims.GetAudience() if err != nil { return err @@ -247,7 +262,7 @@ func (v *validator) verifyAudience(claims Claims, cmp string, required bool) err // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error { +func (v *Validator) verifyIssuer(claims Claims, cmp string, required bool) error { iss, err := claims.GetIssuer() if err != nil { return err @@ -267,7 +282,7 @@ func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *validator) verifySubject(claims Claims, cmp string, required bool) error { +func (v *Validator) verifySubject(claims Claims, cmp string, required bool) error { sub, err := claims.GetSubject() if err != nil { return err diff --git a/validator_test.go b/validator_test.go index 869b0507..08a6bd71 100644 --- a/validator_test.go +++ b/validator_test.go @@ -20,7 +20,7 @@ func (m MyCustomClaims) Validate() error { return nil } -func Test_validator_Validate(t *testing.T) { +func Test_Validator_Validate(t *testing.T) { type fields struct { leeway time.Duration timeFunc func() time.Time @@ -71,7 +71,7 @@ func Test_validator_Validate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ + v := &Validator{ leeway: tt.fields.leeway, timeFunc: tt.fields.timeFunc, verifyIat: tt.fields.verifyIat, @@ -86,7 +86,7 @@ func Test_validator_Validate(t *testing.T) { } } -func Test_validator_verifyExpiresAt(t *testing.T) { +func Test_Validator_verifyExpiresAt(t *testing.T) { type fields struct { leeway time.Duration timeFunc func() time.Time @@ -117,7 +117,7 @@ func Test_validator_verifyExpiresAt(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ + v := &Validator{ leeway: tt.fields.leeway, timeFunc: tt.fields.timeFunc, } @@ -130,7 +130,7 @@ func Test_validator_verifyExpiresAt(t *testing.T) { } } -func Test_validator_verifyIssuer(t *testing.T) { +func Test_Validator_verifyIssuer(t *testing.T) { type fields struct { expectedIss string } @@ -160,7 +160,7 @@ func Test_validator_verifyIssuer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ + v := &Validator{ expectedIss: tt.fields.expectedIss, } err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required) @@ -171,7 +171,7 @@ func Test_validator_verifyIssuer(t *testing.T) { } } -func Test_validator_verifySubject(t *testing.T) { +func Test_Validator_verifySubject(t *testing.T) { type fields struct { expectedSub string } @@ -201,7 +201,7 @@ func Test_validator_verifySubject(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ + v := &Validator{ expectedSub: tt.fields.expectedSub, } err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required) @@ -212,7 +212,7 @@ func Test_validator_verifySubject(t *testing.T) { } } -func Test_validator_verifyIssuedAt(t *testing.T) { +func Test_Validator_verifyIssuedAt(t *testing.T) { type fields struct { leeway time.Duration timeFunc func() time.Time @@ -248,7 +248,7 @@ func Test_validator_verifyIssuedAt(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ + v := &Validator{ leeway: tt.fields.leeway, timeFunc: tt.fields.timeFunc, verifyIat: tt.fields.verifyIat,