/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.shiro.web.filter

import org.apache.shiro.web.RestoreSystemProperties
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Isolated

import javax.servlet.http.HttpServletRequest

import static org.easymock.EasyMock.expect
import static org.easymock.EasyMock.mock
import static org.easymock.EasyMock.replay
import static org.hamcrest.MatcherAssert.assertThat

@Isolated("Uses System Properties")
class InvalidRequestFilterTest {

    @Test
    void defaultConfig() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        assertThat "filter.blockBackslash expected to be true", filter.isBlockBackslash()
        assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii()
        assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon()
        assertThat "filter.blockTraversal expected to be NORMAL",
                filter.getPathTraversalBlockMode() == InvalidRequestFilter.PathTraversalBlockMode.NORMAL
        assertThat "filter.blockRewriteTraversal expected to be true", filter.isBlockRewriteTraversal()
    }

    @Test
    void systemPropertyAllowBackslash() {
        RestoreSystemProperties.withProperties(["org.apache.shiro.web.ALLOW_BACKSLASH": "true"]) {
            InvalidRequestFilter filter = new InvalidRequestFilter()
            assertThat "filter.blockBackslash expected to be false", !filter.isBlockBackslash()
        }

        RestoreSystemProperties.withProperties(["org.apache.shiro.web.ALLOW_BACKSLASH": ""]) {
            InvalidRequestFilter filter = new InvalidRequestFilter()
            assertThat "filter.blockBackslash expected to be false", filter.isBlockBackslash()
        }

        RestoreSystemProperties.withProperties(["org.apache.shiro.web.ALLOW_BACKSLASH": "false"]) {
            InvalidRequestFilter filter = new InvalidRequestFilter()
            assertThat "filter.blockBackslash expected to be false", filter.isBlockBackslash()
        }
    }

    @Test
    void testFilterBlocks() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        assertPathBlocked(filter, "/\\something")
        assertPathBlocked(filter, "/%5csomething")
        assertPathBlocked(filter, "/%5Csomething")
        assertPathBlocked(filter, "/;something")
        assertPathBlocked(filter, "/%3bsomething")
        assertPathBlocked(filter, "/%3Bsomething")
        assertPathBlocked(filter, "/\u0019something")

        assertPathBlocked(filter, "/something", "/;something")
        assertPathBlocked(filter, "/something", "/something", "/;")
    }

    @Test
    void testBlocksTraversalNormal() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        assertPathBlocked(filter, "/something/../")
        assertPathBlocked(filter, "/something/../bar")
        assertPathBlocked(filter, "/something/../bar/")
        assertPathBlocked(filter, "/something/..")
        assertPathBlocked(filter, "/..")
        assertPathBlocked(filter, "..")
        assertPathBlocked(filter, "../")
        assertPathBlocked(filter, "/something/./")
        assertPathBlocked(filter, "/something/./bar")
        assertPathBlocked(filter, "/something/\u002e/bar")
        assertPathBlocked(filter, "/something/./bar/")
        assertPathBlocked(filter, "/something/.")
        assertPathBlocked(filter, "/.")
        assertPathBlocked(filter, "/something/../something/.")

        assertPathAllowed(filter, "%2E./")
        assertPathAllowed(filter, "%2F./")
        assertPathAllowed(filter, "/something/%2e/bar/")
        assertPathAllowed(filter, "/something/%2f/bar/")
        assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
        assertPathAllowed(filter, "/something/%2e%2E/bar/")
        assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
    }

    @Test
    void testBlocksTraversalStrict() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.STRICT)
        assertThat "filter.blockEncodedPeriod expected to be true", filter.isBlockEncodedPeriod()
        assertThat "filter.blockEncodedForwardSlash expected to be true", filter.isBlockEncodedForwardSlash()

        assertPathBlocked(filter, "/something/../")
        assertPathBlocked(filter, "/something/../bar")
        assertPathBlocked(filter, "/something/../bar/")
        assertPathBlocked(filter, "/something/..")
        assertPathBlocked(filter, "/..")
        assertPathBlocked(filter, "..")
        assertPathBlocked(filter, "../")
        assertPathBlocked(filter, "/something/./")
        assertPathBlocked(filter, "/something/./bar")
        assertPathBlocked(filter, "/something/\u002e/bar")
        assertPathBlocked(filter, "/something/./bar/")
        assertPathBlocked(filter, "/something/.")
        assertPathBlocked(filter, "/.")
        assertPathBlocked(filter, "/something/../something/.")

        assertPathBlocked(filter, "%2E./")
        assertPathBlocked(filter, "%2F./")
        assertPathBlocked(filter, "/something/%2e/bar/")
        assertPathBlocked(filter, "/something/%2f/bar/")
        assertPathBlocked(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
        assertPathBlocked(filter, "/something/%2e%2E/bar/")
        assertPathBlocked(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
    }

    @Test
    void testFilterAllowsBackslash() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        filter.setBlockBackslash(false)
        assertPathAllowed(filter, "/\\something")
        assertPathAllowed(filter, "/%5csomething")
        assertPathAllowed(filter, "/%5Csomething")
        assertPathBlocked(filter, "/;something")
        assertPathBlocked(filter, "/%3bsomething")
        assertPathBlocked(filter, "/%3Bsomething")
        assertPathBlocked(filter, "/\u0019something")

        assertPathAllowed(filter, "/something", "/\\something")
        assertPathAllowed(filter, "/something", "/something", "/\\")
    }

    @Test
    void testFilterAllowsNonAscii() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        filter.setBlockNonAscii(false)
        assertPathBlocked(filter, "/\\something")
        assertPathBlocked(filter, "/%5csomething")
        assertPathBlocked(filter, "/%5Csomething")
        assertPathBlocked(filter, "/;something")
        assertPathBlocked(filter, "/%3bsomething")
        assertPathBlocked(filter, "/%3Bsomething")
        assertPathAllowed(filter, "/\u0019something")

        assertPathAllowed(filter, "/something", "/\u0019something")
        assertPathAllowed(filter, "/something", "/something", "/\u0019")
    }

    @Test
    void testFilterAllowsSemicolon() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        filter.setBlockSemicolon(false)
        assertPathBlocked(filter, "/\\something")
        assertPathBlocked(filter, "/%5csomething")
        assertPathBlocked(filter, "/%5Csomething")
        assertPathAllowed(filter, "/;something")
        assertPathAllowed(filter, "/%3bsomething")
        assertPathAllowed(filter, "/%3Bsomething")
        assertPathBlocked(filter, "/\u0019something")

        assertPathAllowed(filter, "/something", "/;something")
        assertPathAllowed(filter, "/something", "/something", "/;")
    }

    @Test
    void testAllowTraversal() {
        InvalidRequestFilter filter = new InvalidRequestFilter()
        filter.setBlockPathTraversal(InvalidRequestFilter.PathTraversalBlockMode.NO_BLOCK);

        assertPathAllowed(filter, "/something/../")
        assertPathAllowed(filter, "/something/../bar")
        assertPathAllowed(filter, "/something/../bar/")
        assertPathAllowed(filter, "/something/..")
        assertPathAllowed(filter, "/..")
        assertPathAllowed(filter, "..")
        assertPathAllowed(filter, "../")
        assertPathAllowed(filter, "/something/./")
        assertPathAllowed(filter, "/something/./bar")
        assertPathAllowed(filter, "/something/\u002e/bar")
        assertPathAllowed(filter, "/something\u002fbar")
        assertPathAllowed(filter, "/something/./bar/")
        assertPathAllowed(filter, "/something/.")
        assertPathAllowed(filter, "/.")
        assertPathAllowed(filter, "/something/../something/.")

        assertPathAllowed(filter, "%2E./")
        assertPathAllowed(filter, "%2F./")
        assertPathAllowed(filter, "/something/%2e/bar/")
        assertPathAllowed(filter, "/something/%2f/bar/")
        assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/")
        assertPathAllowed(filter, "/something/%2e%2E/bar/")
        assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/")
    }

    static void assertPathBlocked(InvalidRequestFilter filter, String requestUri, String servletPath = requestUri, String pathInfo = null) {
        assertThat "Expected path '${requestUri}', to be blocked", !filter.isAccessAllowed(mockRequest(requestUri, servletPath, pathInfo), null, null)
    }

    static void assertPathAllowed(InvalidRequestFilter filter, String requestUri, String servletPath = requestUri, String pathInfo = null) {
        assertThat "Expected requestUri '${requestUri}', to be allowed", filter.isAccessAllowed(mockRequest(requestUri, servletPath, pathInfo), null, null)
    }

    static HttpServletRequest mockRequest(String requestUri, String servletPath, String pathInfo) {
        HttpServletRequest request = mock(HttpServletRequest)
        expect(request.getRequestURI()).andReturn(requestUri)
        expect(request.getServletPath()).andReturn(servletPath).anyTimes()
        expect(request.getPathInfo()).andReturn(pathInfo).anyTimes()
        expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn(servletPath)
        expect(request.getAttribute("javax.servlet.include.path_info")).andReturn(pathInfo)
        replay(request)
        return request
    }
}
