feat(jwt): add helpers

This commit is contained in:
Sipachev Igor 2025-02-24 20:52:14 +07:00
parent d2f7826eb3
commit e21bf927b5

71
jwt.go
View File

@ -12,9 +12,14 @@ import (
"time" "time"
) )
const (
AccessType = "access"
RefreshType = "refresh"
)
func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT { func CreateRefreshTokenByAccess(accessClaims JWT, ttl time.Duration) JWT {
return JWT{ return JWT{
Type: "refresh", Type: RefreshType,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ID: accessClaims.ID, ID: accessClaims.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
@ -70,6 +75,11 @@ func (j *JWT) WithAuthorizationInfo(tai TokenAuthorizationInfo, secret string) *
return j return j
} }
func (j *JWT) WithTFAStep(tai TokenAuthorizationInfo, secret string) *JWT {
j.Step = "tfa"
return j
}
func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo { func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo {
tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret) tai, err := DecodeTokenAuthorizationInfo(j.AuthorizationInfo, secret)
if err != nil { if err != nil {
@ -80,7 +90,26 @@ func (j *JWT) GetAuthorizationInfo(secret string) *TokenAuthorizationInfo {
} }
func (j *JWT) IsRefreshToken() bool { func (j *JWT) IsRefreshToken() bool {
return j.Type == "refresh" return j.Type == RefreshType
}
func (j *JWT) IsTFAStep() bool {
return j.Step == "tfa"
}
func (j *JWT) CheckAccess() bool {
if j.Type != AccessType {
return false
}
if j.Step != "" {
return false
}
return true
}
func (j *JWT) CheckAuthorizationInfo(secret string, tai TokenAuthorizationInfo) bool {
return j.GetAuthorizationInfo(secret).Equal(tai)
} }
func Encode(j interface{}, privateKey *rsa.PrivateKey) (string, error) { func Encode(j interface{}, privateKey *rsa.PrivateKey) (string, error) {
@ -132,6 +161,25 @@ func ReadPublicKey(path string) (*rsa.PublicKey, error) {
return publicKey, nil return publicKey, nil
} }
func ReadPrivateKey(path string) (*rsa.PrivateKey, error) {
b, err := readFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
decrypted, err := decodePEMBlock(b)
if err != nil {
return nil, fmt.Errorf("decode PEM block: %w", err)
}
parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted)
if err != nil {
return nil, fmt.Errorf("parsing decrypted private key: %w", err)
}
return parsedKey, nil
}
func readFile(path string) ([]byte, error) { func readFile(path string) ([]byte, error) {
file, err := os.Open(filepath.Clean(path)) file, err := os.Open(filepath.Clean(path))
if err != nil { if err != nil {
@ -154,25 +202,6 @@ func readPublicKey(path string) ([]byte, error) {
return b, nil return b, nil
} }
func ReadPrivateKey(path string) (*rsa.PrivateKey, error) {
b, err := readFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
decrypted, err := decodePEMBlock(b)
if err != nil {
return nil, fmt.Errorf("decode PEM block: %w", err)
}
parsedKey, err := x509.ParsePKCS1PrivateKey(decrypted)
if err != nil {
return nil, fmt.Errorf("parsing decrypted private key: %w", err)
}
return parsedKey, nil
}
func decodePEMBlock(block []byte) ([]byte, error) { func decodePEMBlock(block []byte) ([]byte, error) {
decodedKey, _ := pem.Decode(block) decodedKey, _ := pem.Decode(block)
if decodedKey == nil { if decodedKey == nil {