/*
 *  Licensed to the Apache Software Foundation (ASF) under one
 *  or more contributor license agreements.  See the NOTICE file
 *  distributed with this work for additional information
 *  regarding copyright ownership.  The ASF licenses this file
 *  to you under the Apache License, Version 2.0 (the
 *  "License"); you may not use this file except in compliance
 *  with the License.  You may obtain a copy of the License at
 *
 *    https://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing,
 *  software distributed under the License is distributed on an
 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 *  KIND, either express or implied.  See the License for the
 *  specific language governing permissions and limitations
 *  under the License.
 */
package grails.plugin.springsecurity.rest.token.generation

import com.nimbusds.jose.EncryptionMethod
import com.nimbusds.jose.JWEAlgorithm
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.crypto.RSADecrypter
import com.nimbusds.jwt.EncryptedJWT
import com.nimbusds.jwt.JWT
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.JWTParser
import grails.plugin.springsecurity.rest.JwtService
import grails.plugin.springsecurity.rest.TokenGeneratorSupport
import grails.plugin.springsecurity.rest.token.AccessToken
import grails.plugin.springsecurity.rest.token.generation.jwt.AbstractJwtTokenGenerator
import grails.plugin.springsecurity.rest.token.generation.jwt.CustomClaimProvider
import grails.plugin.springsecurity.rest.token.generation.jwt.DefaultRSAKeyProvider
import grails.plugin.springsecurity.rest.token.generation.jwt.EncryptedJwtTokenGenerator
import grails.plugin.springsecurity.rest.token.generation.jwt.IssuerClaimProvider
import grails.plugin.springsecurity.rest.token.generation.jwt.SignedJwtTokenGenerator
import grails.plugin.springsecurity.rest.token.storage.jwt.JwtTokenStorageService
import grails.spring.BeanBuilder
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetails
import spock.lang.Issue
import spock.lang.Specification
import spock.lang.Unroll

class JwtTokenGeneratorSpec extends Specification implements TokenGeneratorSupport  {

    @Unroll
    void "#jwtTokenGenerator.class.simpleName generates access tokens with refresh tokens that can be rehydrated back"() {
        given:
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = jwtTokenGenerator.generateAccessToken(userDetails)

        then:
        accessToken.accessToken
        accessToken.refreshToken

        when:
        UserDetails parsedUserDetails = jwtTokenGenerator.jwtTokenStorageService.loadUserByToken(accessToken.accessToken)

        then:
        parsedUserDetails == userDetails

        where:
        jwtTokenGenerator << [setupSignedJwtTokenGenerator(), setupEncryptedJwtTokenGenerator()]

    }

    @Unroll
    void "refresh tokens generated by #jwtTokenGenerator.class.simpleName doesn't expire by default"() {
        given:
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = jwtTokenGenerator.generateAccessToken(userDetails)
        JWT accessTokenJwt = JWTParser.parse(accessToken.accessToken)
        JWT refreshTokenJwt = JWTParser.parse(accessToken.refreshToken)
        [accessTokenJwt, refreshTokenJwt].each { JWT jwt ->
            if (jwt instanceof EncryptedJWT) {
                EncryptedJWT encryptedJWT = jwt as EncryptedJWT
                RSADecrypter decrypter = new RSADecrypter((jwtTokenGenerator as EncryptedJwtTokenGenerator).keyProvider.privateKey)
                encryptedJWT.decrypt(decrypter)
            }
        }

        then:
        accessTokenJwt.JWTClaimsSet.expirationTime
        !refreshTokenJwt.JWTClaimsSet.expirationTime

        where:
        jwtTokenGenerator << [setupSignedJwtTokenGenerator(), setupEncryptedJwtTokenGenerator()]
    }

    @Unroll
    void "refresh tokens generated by #jwtTokenGenerator.class.simpleName can be configured to expire"() {
        given:
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = jwtTokenGenerator.generateAccessToken(userDetails, true, 3600, 7200)
        JWT accessTokenJwt = JWTParser.parse(accessToken.accessToken)
        JWT refreshTokenJwt = JWTParser.parse(accessToken.refreshToken)
        [accessTokenJwt, refreshTokenJwt].each { JWT jwt ->
            if (jwt instanceof EncryptedJWT) {
                EncryptedJWT encryptedJWT = jwt as EncryptedJWT
                RSADecrypter decrypter = new RSADecrypter((jwtTokenGenerator as EncryptedJwtTokenGenerator).keyProvider.privateKey)
                encryptedJWT.decrypt(decrypter)
            }
        }

        then:
        accessTokenJwt.JWTClaimsSet.expirationTime
        refreshTokenJwt.JWTClaimsSet.expirationTime
        accessTokenJwt.JWTClaimsSet.expirationTime != refreshTokenJwt.JWTClaimsSet.expirationTime

        where:
        jwtTokenGenerator << [setupSignedJwtTokenGenerator(), setupEncryptedJwtTokenGenerator()]
    }

    @Unroll
    void "refresh tokens generated by #jwtTokenGenerator.class.simpleName have an identifying claim"() {
        given:
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = jwtTokenGenerator.generateAccessToken(userDetails)
        JWT accessTokenJwt = JWTParser.parse(accessToken.accessToken)
        JWT refreshTokenJwt = JWTParser.parse(accessToken.refreshToken)
        [accessTokenJwt, refreshTokenJwt].each { JWT jwt ->
            if (jwt instanceof EncryptedJWT) {
                EncryptedJWT encryptedJWT = jwt as EncryptedJWT
                RSADecrypter decrypter = new RSADecrypter((jwtTokenGenerator as EncryptedJwtTokenGenerator).keyProvider.privateKey)
                encryptedJWT.decrypt(decrypter)
            }
        }

        then: "refresh token has custom claim"
        refreshTokenJwt.JWTClaimsSet.getBooleanClaim(AbstractJwtTokenGenerator.REFRESH_ONLY_CLAIM)

        and: "access token does not"
        !accessTokenJwt.JWTClaimsSet.getBooleanClaim(AbstractJwtTokenGenerator.REFRESH_ONLY_CLAIM)

        where:
        jwtTokenGenerator << [setupSignedJwtTokenGenerator(), setupEncryptedJwtTokenGenerator()]
    }

    @Issue("https://github.com/grails/grails-spring-security-rest/issues/295")
    void "custom claims can be added"() {
        given:
        SignedJwtTokenGenerator tokenGenerator = setupSignedJwtTokenGenerator()
        CustomClaimProvider claimProvider = [
            provideCustomClaims: { JWTClaimsSet.Builder builder, UserDetails details, String principal, Integer expiration ->
                builder.claim("favouriteTeam", "Real Madrid")
            }
        ] as CustomClaimProvider
        tokenGenerator.customClaimProviders << claimProvider
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = tokenGenerator.generateAccessToken(userDetails)
        JWT jwt = tokenGenerator.jwtTokenStorageService.jwtService.parse(accessToken.accessToken)

        then:
        jwt.JWTClaimsSet.getClaim('favouriteTeam') == 'Real Madrid'
    }

    void "generated access tokens contain the JWT object"() {
        given:
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = jwtTokenGenerator.generateAccessToken(userDetails)

        then:
        accessToken.accessTokenJwt
        accessToken.accessTokenJwt.serialize() == accessToken.accessToken
        accessToken.refreshTokenJwt
        accessToken.refreshTokenJwt.serialize() == accessToken.refreshToken

        where:
        jwtTokenGenerator << [setupSignedJwtTokenGenerator(), setupEncryptedJwtTokenGenerator()]
    }

    void "generated tokens contains the issuer claim"() {
        given:
        SignedJwtTokenGenerator tokenGenerator = getTokenGenerator(false) as SignedJwtTokenGenerator
        UserDetails userDetails = new User('username', 'password', [new SimpleGrantedAuthority('ROLE_USER')])

        when:
        AccessToken accessToken = tokenGenerator.generateAccessToken(userDetails)
        JWT jwt = tokenGenerator.jwtTokenStorageService.jwtService.parse(accessToken.accessToken)

        then:
        jwt.JWTClaimsSet.issuer == 'Spring Security REST test'
    }

    private AbstractJwtTokenGenerator getTokenGenerator(boolean useEncryptedJwt) {
        BeanBuilder beanBuilder = new BeanBuilder()
        beanBuilder.beans {
            keyProvider(DefaultRSAKeyProvider)

            issuerClaimProvider(IssuerClaimProvider) {
                issuerName = 'Spring Security REST test'
            }

            def customClaimProviderList = [ref('issuerClaimProvider')]

            jwtService(JwtService) {
                keyProvider = ref('keyProvider')
                jwtSecret = 'foo123'*8
            }
            tokenStorageService(JwtTokenStorageService) {
                jwtService = ref('jwtService')
            }

            if (useEncryptedJwt) {
                tokenGenerator(EncryptedJwtTokenGenerator) {
                    jwtTokenStorageService = ref('tokenStorageService')
                    keyProvider = ref('keyProvider')
                    defaultExpiration = 3600
                    customClaimProviders = customClaimProviderList
                    jweAlgorithm = JWEAlgorithm.RSA_OAEP
                    encryptionMethod = EncryptionMethod.A128GCM
                }
            } else {
                tokenGenerator(SignedJwtTokenGenerator) {
                    jwtTokenStorageService = ref('tokenStorageService')
                    jwtSecret = 'foo123'*8
                    defaultExpiration = 3600
                    customClaimProviders = customClaimProviderList
                    jwsAlgorithm = JWSAlgorithm.HS256
                }
            }
        }

        return beanBuilder.createApplicationContext().getBean(AbstractJwtTokenGenerator, 'tokenGenerator')
    }


}
