go-jwt/jwt.go

194 lines
4.4 KiB
Go
Raw Normal View History

2025-02-14 19:46:57 +03:00
package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/golang-jwt/jwt/v5"
"io"
"os"
"path/filepath"
"time"
)
func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT {
return JWT{
Type: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
ID: accessClaims.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: accessClaims.ExpiresAt,
2025-02-14 19:46:57 +03:00
},
}
}
func ParseRefreshToken(token string, publicKey *rsa.PublicKey) (*JWT, error) {
refreshDecodedClaims, err := Decode(token, &JWT{}, publicKey)
2025-02-14 19:46:57 +03:00
if err != nil {
return nil, err
}
refreshTokenClaims, ok := refreshDecodedClaims.(*JWT)
2025-02-14 19:46:57 +03:00
if !ok {
return nil, fmt.Errorf("invalid refresh token claims")
}
return refreshTokenClaims, nil
}
type JWT struct {
Type string `json:"t"`
SessionId string `json:"si,omitempty"`
AuthorizationInfo string `json:"ai,omitempty"`
2025-02-14 19:46:57 +03:00
jwt.RegisteredClaims
}
func (j *JWT) WithTtl(ttl time.Duration) *JWT {
j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
j.IssuedAt = jwt.NewNumericDate(time.Now())
j.NotBefore = jwt.NewNumericDate(time.Now())
return j
}
func (j *JWT) WithId(id string) *JWT {
j.ID = id
return j
}
func (j *JWT) WithSessionId(sessionId string) *JWT {
j.SessionId = sessionId
return j
}
func (j *JWT) WithAuthorizationInfo(tai TokenAuthorizationInfo, secret string) *JWT {
ciphertext, err := tai.WithSecret(secret).Encode()
if err != nil {
return j
}
j.AuthorizationInfo = ciphertext
return j
}
func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo {
tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret)
if err != nil {
return nil
}
return tai
}
func (j *JWT) IsRefreshToken() bool {
return j.Type == "refresh"
}
func Encode(j interface{}, privateKey *rsa.PrivateKey) (string, error) {
payload, ok := j.(jwt.Claims)
if !ok {
return "", fmt.Errorf("invalid jwt claims")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, payload)
ss, err := token.SignedString(privateKey)
return ss, err
}
func Decode(token string, data jwt.Claims, publicKey *rsa.PublicKey, options ...jwt.ParserOption) (jwt.Claims, error) {
t, err := jwt.ParseWithClaims(token, data, func(token *jwt.Token) (interface{}, error) {
return publicKey, nil
}, options...)
if err != nil {
return nil, err
}
if claims, ok := t.Claims.(jwt.Claims); ok {
return claims, nil
}
return nil, fmt.Errorf("unknown claims type, cannot proceed")
}
func ReadPublicKey(path string) (*rsa.PublicKey, error) {
b, err := readPublicKey(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
decrypted, err := decodePEMBlock(b)
if err != nil {
return nil, fmt.Errorf("decoding PEM block failed: %w", err)
}
parsedKey, err := x509.ParsePKIXPublicKey(decrypted)
if err != nil {
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
}
publicKey, ok := parsedKey.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
}
return publicKey, nil
}
func readFile(path string) ([]byte, error) {
file, err := os.Open(filepath.Clean(path))
if err != nil {
return nil, fmt.Errorf("opening file: %w", err)
}
defer file.Close()
b, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
return b, nil
}
func readPublicKey(path string) ([]byte, error) {
b, err := readFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
return b, nil
}
func ReadPrivateKey(path string) (*rsa.PrivateKey, error) {
b, err := readFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
decrypted, err := decodePEMBlock(b)
if err != nil {
return nil, fmt.Errorf("decode PEM block: %w", err)
}
parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted)
if err != nil {
return nil, fmt.Errorf("parsing decrypted private key: %w", err)
}
return parsedKey, nil
}
func decodePEMBlock(block []byte) ([]byte, error) {
decodedKey, _ := pem.Decode(block)
if decodedKey == nil {
return nil, fmt.Errorf("decoding PEM block failed")
}
var (
decrypted = decodedKey.Bytes
err error
)
//nolint:staticcheck,nolintlint
if x509.IsEncryptedPEMBlock(decodedKey) {
//nolint:staticcheck,nolintlint
if decrypted, err = x509.DecryptPEMBlock(decodedKey, []byte("")); err != nil {
return nil, fmt.Errorf("decrypting PEM key failed: %s", err)
}
}
return decrypted, nil
}