diff --git a/jwt.go b/jwt.go index 89eafb6..7a6ec5a 100644 --- a/jwt.go +++ b/jwt.go @@ -40,6 +40,9 @@ func ParseRefreshToken(token string, publicKey *rsa.PublicKey) (*JWT, error) { if !ok { return nil, fmt.Errorf("invalid refresh token claims") } + if refreshTokenClaims.Type != RefreshType { + return nil, fmt.Errorf("invalid token type") + } return refreshTokenClaims, nil } @@ -51,6 +54,10 @@ type JWT struct { jwt.RegisteredClaims } +func (j *JWT) GetJWT() *JWT { + return j +} + func (j *JWT) WithTtl(ttl time.Duration) *JWT { j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) j.IssuedAt = jwt.NewNumericDate(time.Now()) @@ -114,8 +121,37 @@ func (j *JWT) CheckAuthorizationInfo(secret string, tai TokenAuthorizationInfo) return j.GetAuthorizationInfo(secret).Equal(tai) } -func Encode(j interface{}, privateKey *rsa.PrivateKey) (string, error) { - payload, ok := j.(jwt.Claims) +type JWTInterface interface { + GetJWT() *JWT +} + +func ReNew[T interface{}](j T) (*T, error) { + newJwt, ok := interface{}(&j).(JWTInterface) + if !ok { + return nil, fmt.Errorf("invalid jwt claims") + } + + internalJwt := newJwt.GetJWT() + exp, _ := internalJwt.GetExpirationTime() + iss, _ := internalJwt.GetIssuedAt() + ttl := exp.Sub(iss.Time) + internalJwt.WithTtl(ttl) + + return &j, nil +} + +func Encode[T interface{}](j T, privateKey *rsa.PrivateKey) (string, error) { + newJwt, ok := interface{}(&j).(JWTInterface) + if !ok { + return "", fmt.Errorf("invalid jwt claims") + } + + internalJwt := newJwt.GetJWT() + if internalJwt.Type == "" { + internalJwt.Type = AccessType + } + + payload, ok := interface{}(&j).(jwt.Claims) if !ok { return "", fmt.Errorf("invalid jwt claims") } diff --git a/jwt_test.go b/jwt_test.go index 9d583c4..1b97319 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -1,6 +1,7 @@ package jwt import ( + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "path/filepath" "testing" @@ -38,3 +39,21 @@ func TestJWT(t *testing.T) { assert.Equal(t, f.FirstName, "Igor") assert.Equal(t, f.LastName, "Sypachev") } + +func TestReNewJWT(t *testing.T) { + accessClaims := AccessTokenClaims{ + UserId: "123", + FirstName: "Igor", + LastName: "Sypachev", + } + accessClaims.ExpiresAt = jwt.NewNumericDate(time.Now()) + accessClaims.IssuedAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) + accessClaims.NotBefore = jwt.NewNumericDate(time.Now()) + + assert.Equal(t, true, accessClaims.ExpiresAt.Time.Unix() <= time.Now().Unix()) + + newAccessClaims, err := ReNew(accessClaims) + assert.Equal(t, true, err == nil) + + assert.Equal(t, true, newAccessClaims.ExpiresAt.Time.Unix() > accessClaims.ExpiresAt.Time.Unix()) +}