194 lines
4.4 KiB
Go
194 lines
4.4 KiB
Go
package jwt
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
)
|
|
|
|
func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT {
|
|
return JWT{
|
|
Type: "refresh",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: accessClaims.ID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
NotBefore: accessClaims.ExpiresAt,
|
|
},
|
|
}
|
|
}
|
|
|
|
func ParseRefreshToken(token string, publicKey *rsa.PublicKey) (*JWT, error) {
|
|
refreshDecodedClaims, err := Decode(token, &JWT{}, publicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
refreshTokenClaims, ok := refreshDecodedClaims.(*JWT)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid refresh token claims")
|
|
}
|
|
return refreshTokenClaims, nil
|
|
}
|
|
|
|
type JWT struct {
|
|
Type string `json:"t"`
|
|
SessionId string `json:"si,omitempty"`
|
|
AuthorizationInfo string `json:"ai,omitempty"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
func (j *JWT) WithTtl(ttl time.Duration) *JWT {
|
|
j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
|
|
j.IssuedAt = jwt.NewNumericDate(time.Now())
|
|
j.NotBefore = jwt.NewNumericDate(time.Now())
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithId(id string) *JWT {
|
|
j.ID = id
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithSessionId(sessionId string) *JWT {
|
|
j.SessionId = sessionId
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithAuthorizationInfo(tai TokenAuthorizationInfo, secret string) *JWT {
|
|
ciphertext, err := tai.WithSecret(secret).Encode()
|
|
if err != nil {
|
|
return j
|
|
}
|
|
j.AuthorizationInfo = ciphertext
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo {
|
|
tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return tai
|
|
}
|
|
|
|
func (j *JWT) IsRefreshToken() bool {
|
|
return j.Type == "refresh"
|
|
}
|
|
|
|
func Encode(j interface{}, privateKey *rsa.PrivateKey) (string, error) {
|
|
payload, ok := j.(jwt.Claims)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid jwt claims")
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, payload)
|
|
ss, err := token.SignedString(privateKey)
|
|
return ss, err
|
|
}
|
|
|
|
func Decode(token string, data jwt.Claims, publicKey *rsa.PublicKey, options ...jwt.ParserOption) (jwt.Claims, error) {
|
|
t, err := jwt.ParseWithClaims(token, data, func(token *jwt.Token) (interface{}, error) {
|
|
return publicKey, nil
|
|
}, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims, ok := t.Claims.(jwt.Claims); ok {
|
|
return claims, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("unknown claims type, cannot proceed")
|
|
}
|
|
|
|
func ReadPublicKey(path string) (*rsa.PublicKey, error) {
|
|
b, err := readPublicKey(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
decrypted, err := decodePEMBlock(b)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decoding PEM block failed: %w", err)
|
|
}
|
|
|
|
parsedKey, err := x509.ParsePKIXPublicKey(decrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
|
|
}
|
|
|
|
publicKey, ok := parsedKey.(*rsa.PublicKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
func readFile(path string) ([]byte, error) {
|
|
file, err := os.Open(filepath.Clean(path))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("opening file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
b, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func readPublicKey(path string) ([]byte, error) {
|
|
b, err := readFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func ReadPrivateKey(path string) (*rsa.PrivateKey, error) {
|
|
b, err := readFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
decrypted, err := decodePEMBlock(b)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode PEM block: %w", err)
|
|
}
|
|
|
|
parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing decrypted private key: %w", err)
|
|
}
|
|
|
|
return parsedKey, nil
|
|
}
|
|
|
|
func decodePEMBlock(block []byte) ([]byte, error) {
|
|
decodedKey, _ := pem.Decode(block)
|
|
if decodedKey == nil {
|
|
return nil, fmt.Errorf("decoding PEM block failed")
|
|
}
|
|
var (
|
|
decrypted = decodedKey.Bytes
|
|
err error
|
|
)
|
|
//nolint:staticcheck,nolintlint
|
|
if x509.IsEncryptedPEMBlock(decodedKey) {
|
|
//nolint:staticcheck,nolintlint
|
|
if decrypted, err = x509.DecryptPEMBlock(decodedKey, []byte("")); err != nil {
|
|
return nil, fmt.Errorf("decrypting PEM key failed: %s", err)
|
|
}
|
|
}
|
|
|
|
return decrypted, nil
|
|
}
|