/*
 * Decompiled with CFR 0.152.
 */
package org.apache.dubbo.rpc.protocol.tri.rest.cors;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.dubbo.common.extension.Activate;
import org.apache.dubbo.common.utils.ArrayUtils;
import org.apache.dubbo.common.utils.StringUtils;
import org.apache.dubbo.remoting.http12.HttpMethods;
import org.apache.dubbo.remoting.http12.HttpRequest;
import org.apache.dubbo.remoting.http12.HttpResponse;
import org.apache.dubbo.remoting.http12.HttpResult;
import org.apache.dubbo.remoting.http12.HttpStatus;
import org.apache.dubbo.remoting.http12.exception.HttpResultPayloadException;
import org.apache.dubbo.rpc.Invoker;
import org.apache.dubbo.rpc.RpcException;
import org.apache.dubbo.rpc.RpcInvocation;
import org.apache.dubbo.rpc.protocol.tri.rest.RestConstants;
import org.apache.dubbo.rpc.protocol.tri.rest.cors.CorsUtils;
import org.apache.dubbo.rpc.protocol.tri.rest.filter.RestHeaderFilterAdapter;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.RequestMapping;
import org.apache.dubbo.rpc.protocol.tri.rest.mapping.meta.CorsMeta;

@Activate(group={"provider"}, order=1000)
public class CorsHeaderFilter
extends RestHeaderFilterAdapter {
    public static final String VARY = "vary";
    public static final String ORIGIN = "origin";
    public static final String ACCESS_CONTROL_REQUEST_METHOD = "access-control-request-method";
    public static final String ACCESS_CONTROL_REQUEST_HEADERS = "access-control-request-headers";
    public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials";
    public static final String ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers";
    public static final String ACCESS_CONTROL_MAX_AGE = "access-control-max-age";
    public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin";
    public static final String ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods";
    public static final String ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers";
    public static final String SEP = ", ";

    @Override
    protected void invoke(Invoker<?> invoker, RpcInvocation invocation, HttpRequest request, HttpResponse response) throws RpcException {
        RequestMapping mapping = (RequestMapping)request.attribute(RestConstants.MAPPING_ATTRIBUTE);
        CorsMeta cors = mapping.getCors();
        String origin = request.header(ORIGIN);
        if (cors == null) {
            if (CorsHeaderFilter.isPreFlightRequest(request, origin)) {
                throw new HttpResultPayloadException(HttpResult.builder().status(HttpStatus.FORBIDDEN).body("Invalid CORS request").build());
            }
            return;
        }
        if (this.process(cors, request, response)) {
            return;
        }
        throw new HttpResultPayloadException(HttpResult.builder().status(HttpStatus.FORBIDDEN).body("Invalid CORS request").headers(response.headers()).build());
    }

    private boolean process(CorsMeta cors, HttpRequest request, HttpResponse response) {
        CorsHeaderFilter.setVaryHeader(response);
        String origin = request.header(ORIGIN);
        if (CorsHeaderFilter.isNotCorsRequest(request, origin)) {
            return true;
        }
        if (response.header(ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
            return true;
        }
        String allowOrigin = CorsHeaderFilter.checkOrigin(cors, origin);
        if (allowOrigin == null) {
            return false;
        }
        boolean preFlight = CorsHeaderFilter.isPreFlightRequest(request, origin);
        List<String> allowMethods = CorsHeaderFilter.checkMethods(cors, preFlight ? request.header(ACCESS_CONTROL_REQUEST_METHOD) : request.method());
        if (allowMethods == null) {
            return false;
        }
        List<String> allowHeaders = null;
        if (preFlight && (allowHeaders = CorsHeaderFilter.checkHeaders(cors, request.headerValues(ACCESS_CONTROL_REQUEST_HEADERS))) == null) {
            return false;
        }
        response.setHeader((CharSequence)ACCESS_CONTROL_ALLOW_ORIGIN, allowOrigin);
        if (ArrayUtils.isNotEmpty(cors.getExposedHeaders())) {
            response.setHeader((CharSequence)ACCESS_CONTROL_EXPOSE_HEADERS, StringUtils.join(cors.getExposedHeaders(), SEP));
        }
        if (Boolean.TRUE.equals(cors.getAllowCredentials())) {
            response.setHeader((CharSequence)ACCESS_CONTROL_ALLOW_CREDENTIALS, Boolean.TRUE.toString());
        }
        if (preFlight) {
            response.setHeader((CharSequence)ACCESS_CONTROL_ALLOW_METHODS, StringUtils.join(allowMethods, SEP));
            if (!allowHeaders.isEmpty()) {
                response.setHeader((CharSequence)ACCESS_CONTROL_ALLOW_HEADERS, StringUtils.join(allowHeaders, SEP));
            }
            if (cors.getMaxAge() != null) {
                response.setHeader((CharSequence)ACCESS_CONTROL_MAX_AGE, cors.getMaxAge().toString());
            }
            throw new HttpResultPayloadException(HttpResult.builder().status(HttpStatus.NO_CONTENT).headers(response.headers()).build());
        }
        return true;
    }

    private static void setVaryHeader(HttpResponse response) {
        String varyValue;
        List<String> varyHeaders = response.headerValues(VARY);
        if (varyHeaders == null) {
            varyValue = "origin, access-control-request-method, access-control-request-headers";
        } else {
            LinkedHashSet<String> varHeadersSet = new LinkedHashSet<String>(varyHeaders);
            varHeadersSet.add(ORIGIN);
            varHeadersSet.add(ACCESS_CONTROL_REQUEST_METHOD);
            varHeadersSet.add(ACCESS_CONTROL_REQUEST_HEADERS);
            varyValue = StringUtils.join(varHeadersSet, SEP);
        }
        response.setHeader((CharSequence)VARY, varyValue);
    }

    private static String checkOrigin(CorsMeta cors, String origin) {
        if (StringUtils.isBlank(origin)) {
            return null;
        }
        origin = CorsUtils.formatOrigin(origin);
        Object[] allowedOrigins = cors.getAllowedOrigins();
        if (ArrayUtils.isNotEmpty(allowedOrigins)) {
            if (ArrayUtils.contains((String[])allowedOrigins, "*")) {
                if (Boolean.TRUE.equals(cors.getAllowCredentials())) {
                    throw new IllegalArgumentException("When allowCredentials is true, allowedOrigins cannot contain the special value \"*\"");
                }
                return "*";
            }
            for (Object allowedOrigin : allowedOrigins) {
                if (!origin.equalsIgnoreCase((String)allowedOrigin)) continue;
                return origin;
            }
        }
        if (ArrayUtils.isNotEmpty(cors.getAllowedOriginsPatterns())) {
            for (Pattern pattern : cors.getAllowedOriginsPatterns()) {
                if (!pattern.matcher(origin).matches()) continue;
                return origin;
            }
        }
        return null;
    }

    private static List<String> checkMethods(CorsMeta cors, String method) {
        if (method == null) {
            return null;
        }
        String[] allowedMethods = cors.getAllowedMethods();
        if (ArrayUtils.contains(allowedMethods, "*")) {
            return Collections.singletonList(method);
        }
        for (String allowedMethod : allowedMethods) {
            if (!method.equalsIgnoreCase(allowedMethod)) continue;
            return Arrays.asList(allowedMethods);
        }
        return null;
    }

    private static List<String> checkHeaders(CorsMeta cors, Collection<String> headers) {
        if (headers == null || headers.isEmpty()) {
            return Collections.emptyList();
        }
        Object[] allowedHeaders = cors.getAllowedHeaders();
        if (ArrayUtils.isEmpty(allowedHeaders)) {
            return null;
        }
        boolean allowAny = ArrayUtils.contains((String[])allowedHeaders, "*");
        ArrayList<String> result = new ArrayList<String>(headers.size());
        block0: for (String header : headers) {
            if (allowAny) {
                result.add(header);
                continue;
            }
            for (Object allowedHeader : allowedHeaders) {
                if (!header.equalsIgnoreCase((String)allowedHeader)) continue;
                result.add(header);
                continue block0;
            }
        }
        return result.isEmpty() ? null : result;
    }

    private static boolean isNotCorsRequest(HttpRequest request, String origin) {
        if (origin == null) {
            return true;
        }
        try {
            URI uri = new URI(origin);
            return request.scheme().equals(uri.getScheme()) && request.serverName().equals(uri.getHost()) && CorsHeaderFilter.getPort(request.scheme(), request.serverPort()) == CorsHeaderFilter.getPort(uri.getScheme(), uri.getPort());
        }
        catch (URISyntaxException e) {
            return false;
        }
    }

    private static boolean isPreFlightRequest(HttpRequest request, String origin) {
        return HttpMethods.OPTIONS.is(request.method()) && origin != null && request.hasHeader(ACCESS_CONTROL_REQUEST_METHOD);
    }

    private static int getPort(String scheme, int port) {
        if (port == -1) {
            if ("http".equals(scheme)) {
                return 80;
            }
            if ("https".equals(scheme)) {
                return 443;
            }
        }
        return port;
    }
}

