package jwt import ( "crypto/rsa" "crypto/x509" "encoding/pem" "fmt" "github.com/golang-jwt/jwt/v5" "io" "os" "path/filepath" "time" ) func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT { return JWT{ Type: "refresh", RegisteredClaims: jwt.RegisteredClaims{ ID: accessClaims.ID, ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: accessClaims.ExpiresAt, }, } } func ParseRefreshToken(token string, publicKey *rsa.PublicKey) (*JWT, error) { refreshDecodedClaims, err := Decode(token, &JWT{}, publicKey) if err != nil { return nil, err } refreshTokenClaims, ok := refreshDecodedClaims.(*JWT) if !ok { return nil, fmt.Errorf("invalid refresh token claims") } return refreshTokenClaims, nil } type JWT struct { Type string `json:"t"` SessionId string `json:"si,omitempty"` AuthorizationInfo string `json:"ai,omitempty"` jwt.RegisteredClaims } func (j *JWT) WithTtl(ttl time.Duration) *JWT { j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) j.IssuedAt = jwt.NewNumericDate(time.Now()) j.NotBefore = jwt.NewNumericDate(time.Now()) return j } func (j *JWT) WithId(id string) *JWT { j.ID = id return j } func (j *JWT) WithSessionId(sessionId string) *JWT { j.SessionId = sessionId return j } func (j *JWT) WithAuthorizationInfo(tai TokenAuthorizationInfo, secret string) *JWT { ciphertext, err := tai.WithSecret(secret).Encode() if err != nil { return j } j.AuthorizationInfo = ciphertext return j } func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo { tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret) if err != nil { return nil } return tai } func (j *JWT) IsRefreshToken() bool { return j.Type == "refresh" } func Encode(j interface{}, privateKey *rsa.PrivateKey) (string, error) { payload, ok := j.(jwt.Claims) if !ok { return "", fmt.Errorf("invalid jwt claims") } token := jwt.NewWithClaims(jwt.SigningMethodRS256, payload) ss, err := token.SignedString(privateKey) return ss, err } func Decode(token string, data jwt.Claims, publicKey *rsa.PublicKey, options ...jwt.ParserOption) (jwt.Claims, error) { t, err := jwt.ParseWithClaims(token, data, func(token *jwt.Token) (interface{}, error) { return publicKey, nil }, options...) if err != nil { return nil, err } if claims, ok := t.Claims.(jwt.Claims); ok { return claims, nil } return nil, fmt.Errorf("unknown claims type, cannot proceed") } func ReadPublicKey(path string) (*rsa.PublicKey, error) { b, err := readPublicKey(path) if err != nil { return nil, fmt.Errorf("reading file: %w", err) } decrypted, err := decodePEMBlock(b) if err != nil { return nil, fmt.Errorf("decoding PEM block failed: %w", err) } parsedKey, err := x509.ParsePKIXPublicKey(decrypted) if err != nil { return nil, fmt.Errorf("parsing decrypted public key failed: %w", err) } publicKey, ok := parsedKey.(*rsa.PublicKey) if !ok { return nil, fmt.Errorf("parsing decrypted public key failed: %w", err) } return publicKey, nil } func readFile(path string) ([]byte, error) { file, err := os.Open(filepath.Clean(path)) if err != nil { return nil, fmt.Errorf("opening file: %w", err) } defer file.Close() b, err := io.ReadAll(file) if err != nil { return nil, fmt.Errorf("reading file: %w", err) } return b, nil } func readPublicKey(path string) ([]byte, error) { b, err := readFile(path) if err != nil { return nil, fmt.Errorf("reading file: %w", err) } return b, nil } func ReadPrivateKey(path string) (*rsa.PrivateKey, error) { b, err := readFile(path) if err != nil { return nil, fmt.Errorf("reading file: %w", err) } decrypted, err := decodePEMBlock(b) if err != nil { return nil, fmt.Errorf("decode PEM block: %w", err) } parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted) if err != nil { return nil, fmt.Errorf("parsing decrypted private key: %w", err) } return parsedKey, nil } func decodePEMBlock(block []byte) ([]byte, error) { decodedKey, _ := pem.Decode(block) if decodedKey == nil { return nil, fmt.Errorf("decoding PEM block failed") } var ( decrypted = decodedKey.Bytes err error ) //nolint:staticcheck,nolintlint if x509.IsEncryptedPEMBlock(decodedKey) { //nolint:staticcheck,nolintlint if decrypted, err = x509.DecryptPEMBlock(decodedKey, []byte("")); err != nil { return nil, fmt.Errorf("decrypting PEM key failed: %s", err) } } return decrypted, nil }