go-jwt/jwt.go
2025-03-01 16:46:24 +07:00

262 lines
5.7 KiB
Go

package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/golang-jwt/jwt/v5"
"io"
"os"
"path/filepath"
"time"
)
const (
AccessType = "access"
RefreshType = "refresh"
TFAStep = "tfa"
)
func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT {
return JWT{
Type: RefreshType,
RegisteredClaims: jwt.RegisteredClaims{
ID: accessClaims.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: accessClaims.ExpiresAt,
},
}
}
func ParseRefreshToken(token string, publicKey *rsa.PublicKey) (*JWT, error) {
refreshDecodedClaims, err := Decode(token, &JWT{}, publicKey)
if err != nil {
return nil, err
}
refreshTokenClaims, ok := refreshDecodedClaims.(*JWT)
if !ok {
return nil, fmt.Errorf("invalid refresh token claims")
}
if refreshTokenClaims.Type != RefreshType {
return nil, fmt.Errorf("invalid token type")
}
return refreshTokenClaims, nil
}
type JWT struct {
Type string `json:"t"`
SessionId string `json:"si,omitempty"`
AuthorizationInfo string `json:"ai,omitempty"`
Step string `json:"step,omitempty"`
jwt.RegisteredClaims
}
func (j *JWT) GetJWT() *JWT {
return j
}
func (j *JWT) WithTtl(ttl time.Duration) *JWT {
j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
j.IssuedAt = jwt.NewNumericDate(time.Now())
j.NotBefore = jwt.NewNumericDate(time.Now())
return j
}
func (j *JWT) WithId(id string) *JWT {
j.ID = id
return j
}
func (j *JWT) WithSessionId(sessionId string) *JWT {
j.SessionId = sessionId
return j
}
func (j *JWT) WithAuthorizationInfo(tai TokenAuthorizationInfo, secret string) *JWT {
ciphertext, err := tai.WithSecret(secret).Encode()
if err != nil {
return j
}
j.AuthorizationInfo = ciphertext
return j
}
func (j *JWT) WithTFAStep() *JWT {
j.Step = TFAStep
return j
}
func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo {
tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret)
if err != nil {
return nil
}
return tai
}
func (j *JWT) IsRefreshToken() bool {
return j.Type == RefreshType
}
func (j *JWT) IsTFAStep() bool {
return j.Step == TFAStep
}
func (j *JWT) CheckAccess() bool {
if j.Type != AccessType {
return false
}
if j.Step != "" {
return false
}
return true
}
func (j *JWT) CheckAuthorizationInfo(secret string, tai TokenAuthorizationInfo) bool {
return j.GetAuthorizationInfo(secret).Equal(tai)
}
type JWTInterface interface {
GetJWT() *JWT
}
func ReNew[T interface{}](j T) (*T, error) {
newJwt, ok := interface{}(&j).(JWTInterface)
if !ok {
return nil, fmt.Errorf("invalid jwt claims")
}
internalJwt := newJwt.GetJWT()
exp, _ := internalJwt.GetExpirationTime()
iss, _ := internalJwt.GetIssuedAt()
ttl := exp.Sub(iss.Time)
internalJwt.WithTtl(ttl)
return &j, nil
}
func Encode[T interface{}](j T, privateKey *rsa.PrivateKey) (string, error) {
newJwt, ok := interface{}(&j).(JWTInterface)
if !ok {
return "", fmt.Errorf("invalid jwt claims")
}
internalJwt := newJwt.GetJWT()
if internalJwt.Type == "" {
internalJwt.Type = AccessType
}
payload, ok := interface{}(&j).(jwt.Claims)
if !ok {
return "", fmt.Errorf("invalid jwt claims")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, payload)
ss, err := token.SignedString(privateKey)
return ss, err
}
func Decode(token string, data jwt.Claims, publicKey *rsa.PublicKey, options ...jwt.ParserOption) (jwt.Claims, error) {
t, err := jwt.ParseWithClaims(token, data, func(token *jwt.Token) (interface{}, error) {
return publicKey, nil
}, options...)
if err != nil {
return nil, err
}
if claims, ok := t.Claims.(jwt.Claims); ok {
return claims, nil
}
return nil, fmt.Errorf("unknown claims type, cannot proceed")
}
func ReadPublicKey(path string) (*rsa.PublicKey, error) {
b, err := readPublicKey(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
decrypted, err := decodePEMBlock(b)
if err != nil {
return nil, fmt.Errorf("decoding PEM block failed: %w", err)
}
parsedKey, err := x509.ParsePKIXPublicKey(decrypted)
if err != nil {
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
}
publicKey, ok := parsedKey.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
}
return publicKey, nil
}
func ReadPrivateKey(path string) (*rsa.PrivateKey, error) {
b, err := readFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
decrypted, err := decodePEMBlock(b)
if err != nil {
return nil, fmt.Errorf("decode PEM block: %w", err)
}
parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted)
if err != nil {
return nil, fmt.Errorf("parsing decrypted private key: %w", err)
}
return parsedKey, nil
}
func readFile(path string) ([]byte, error) {
file, err := os.Open(filepath.Clean(path))
if err != nil {
return nil, fmt.Errorf("opening file: %w", err)
}
defer file.Close()
b, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
return b, nil
}
func readPublicKey(path string) ([]byte, error) {
b, err := readFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
return b, nil
}
func decodePEMBlock(block []byte) ([]byte, error) {
decodedKey, _ := pem.Decode(block)
if decodedKey == nil {
return nil, fmt.Errorf("decoding PEM block failed")
}
var (
decrypted = decodedKey.Bytes
err error
)
//nolint:staticcheck,nolintlint
if x509.IsEncryptedPEMBlock(decodedKey) {
//nolint:staticcheck,nolintlint
if decrypted, err = x509.DecryptPEMBlock(decodedKey, []byte("")); err != nil {
return nil, fmt.Errorf("decrypting PEM key failed: %s", err)
}
}
return decrypted, nil
}