262 lines
5.7 KiB
Go
262 lines
5.7 KiB
Go
package jwt
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
AccessType = "access"
|
|
RefreshType = "refresh"
|
|
|
|
TFAStep = "tfa"
|
|
)
|
|
|
|
func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT {
|
|
return JWT{
|
|
Type: RefreshType,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: accessClaims.ID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
NotBefore: accessClaims.ExpiresAt,
|
|
},
|
|
}
|
|
}
|
|
|
|
func ParseRefreshToken(token string, publicKey *rsa.PublicKey) (*JWT, error) {
|
|
refreshDecodedClaims, err := Decode(token, &JWT{}, publicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
refreshTokenClaims, ok := refreshDecodedClaims.(*JWT)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid refresh token claims")
|
|
}
|
|
if refreshTokenClaims.Type != RefreshType {
|
|
return nil, fmt.Errorf("invalid token type")
|
|
}
|
|
return refreshTokenClaims, nil
|
|
}
|
|
|
|
type JWT struct {
|
|
Type string `json:"t"`
|
|
SessionId string `json:"si,omitempty"`
|
|
AuthorizationInfo string `json:"ai,omitempty"`
|
|
Step string `json:"step,omitempty"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
func (j *JWT) GetJWT() *JWT {
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithTtl(ttl time.Duration) *JWT {
|
|
j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
|
|
j.IssuedAt = jwt.NewNumericDate(time.Now())
|
|
j.NotBefore = jwt.NewNumericDate(time.Now())
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithId(id string) *JWT {
|
|
j.ID = id
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithSessionId(sessionId string) *JWT {
|
|
j.SessionId = sessionId
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithAuthorizationInfo(tai TokenAuthorizationInfo, secret string) *JWT {
|
|
ciphertext, err := tai.WithSecret(secret).Encode()
|
|
if err != nil {
|
|
return j
|
|
}
|
|
j.AuthorizationInfo = ciphertext
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) WithTFAStep() *JWT {
|
|
j.Step = TFAStep
|
|
return j
|
|
}
|
|
|
|
func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo {
|
|
tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return tai
|
|
}
|
|
|
|
func (j *JWT) IsRefreshToken() bool {
|
|
return j.Type == RefreshType
|
|
}
|
|
|
|
func (j *JWT) IsTFAStep() bool {
|
|
return j.Step == TFAStep
|
|
}
|
|
|
|
func (j *JWT) CheckAccess() bool {
|
|
if j.Type != AccessType {
|
|
return false
|
|
}
|
|
if j.Step != "" {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (j *JWT) CheckAuthorizationInfo(secret string, tai TokenAuthorizationInfo) bool {
|
|
return j.GetAuthorizationInfo(secret).Equal(tai)
|
|
}
|
|
|
|
type JWTInterface interface {
|
|
GetJWT() *JWT
|
|
}
|
|
|
|
func ReNew[T interface{}](j T) (*T, error) {
|
|
newJwt, ok := interface{}(&j).(JWTInterface)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid jwt claims")
|
|
}
|
|
|
|
internalJwt := newJwt.GetJWT()
|
|
exp, _ := internalJwt.GetExpirationTime()
|
|
iss, _ := internalJwt.GetIssuedAt()
|
|
ttl := exp.Sub(iss.Time)
|
|
internalJwt.WithTtl(ttl)
|
|
|
|
return &j, nil
|
|
}
|
|
|
|
func Encode[T interface{}](j T, privateKey *rsa.PrivateKey) (string, error) {
|
|
newJwt, ok := interface{}(&j).(JWTInterface)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid jwt claims")
|
|
}
|
|
|
|
internalJwt := newJwt.GetJWT()
|
|
if internalJwt.Type == "" {
|
|
internalJwt.Type = AccessType
|
|
}
|
|
|
|
payload, ok := interface{}(&j).(jwt.Claims)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid jwt claims")
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, payload)
|
|
ss, err := token.SignedString(privateKey)
|
|
return ss, err
|
|
}
|
|
|
|
func Decode(token string, data jwt.Claims, publicKey *rsa.PublicKey, options ...jwt.ParserOption) (jwt.Claims, error) {
|
|
t, err := jwt.ParseWithClaims(token, data, func(token *jwt.Token) (interface{}, error) {
|
|
return publicKey, nil
|
|
}, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims, ok := t.Claims.(jwt.Claims); ok {
|
|
return claims, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("unknown claims type, cannot proceed")
|
|
}
|
|
|
|
func ReadPublicKey(path string) (*rsa.PublicKey, error) {
|
|
b, err := readPublicKey(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
decrypted, err := decodePEMBlock(b)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decoding PEM block failed: %w", err)
|
|
}
|
|
|
|
parsedKey, err := x509.ParsePKIXPublicKey(decrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
|
|
}
|
|
|
|
publicKey, ok := parsedKey.(*rsa.PublicKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("parsing decrypted public key failed: %w", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
func ReadPrivateKey(path string) (*rsa.PrivateKey, error) {
|
|
b, err := readFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
decrypted, err := decodePEMBlock(b)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode PEM block: %w", err)
|
|
}
|
|
|
|
parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing decrypted private key: %w", err)
|
|
}
|
|
|
|
return parsedKey, nil
|
|
}
|
|
|
|
func readFile(path string) ([]byte, error) {
|
|
file, err := os.Open(filepath.Clean(path))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("opening file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
b, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func readPublicKey(path string) ([]byte, error) {
|
|
b, err := readFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func decodePEMBlock(block []byte) ([]byte, error) {
|
|
decodedKey, _ := pem.Decode(block)
|
|
if decodedKey == nil {
|
|
return nil, fmt.Errorf("decoding PEM block failed")
|
|
}
|
|
var (
|
|
decrypted = decodedKey.Bytes
|
|
err error
|
|
)
|
|
//nolint:staticcheck,nolintlint
|
|
if x509.IsEncryptedPEMBlock(decodedKey) {
|
|
//nolint:staticcheck,nolintlint
|
|
if decrypted, err = x509.DecryptPEMBlock(decodedKey, []byte("")); err != nil {
|
|
return nil, fmt.Errorf("decrypting PEM key failed: %s", err)
|
|
}
|
|
}
|
|
|
|
return decrypted, nil
|
|
}
|