/*
 * Copyright 2019 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 *
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.drools.modelcompiler.builder.generator.declaredtype.generator;

import java.util.List;
import java.util.stream.Collectors;

import com.github.javaparser.ast.Modifier;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.Statement;
import com.github.javaparser.ast.type.ArrayType;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.PrimitiveType;
import com.github.javaparser.ast.type.Type;

import static com.github.javaparser.StaticJavaParser.parseStatement;
import static com.github.javaparser.StaticJavaParser.parseType;
import static com.github.javaparser.ast.NodeList.nodeList;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toClassOrInterfaceType;
import static org.drools.modelcompiler.builder.generator.declaredtype.generator.GeneratedClassDeclaration.OVERRIDE;
import static org.drools.modelcompiler.builder.generator.declaredtype.generator.GeneratedClassDeclaration.replaceFieldName;

class GeneratedEqualsMethod {

    private static final Statement referenceEquals = parseStatement("if (this == o) { return true; }");
    private static final Statement classCheckEquals = parseStatement("if (o == null || getClass() != o.getClass()) { return false; }");

    private static final String EQUALS = "equals";

    private static Statement classCastStatement(String className) {
        Statement statement = parseStatement("__className that = (__className) o;");
        statement.findAll(ClassOrInterfaceType.class)
                .stream()
                .filter(n1 -> n1.getName().toString().equals("__className"))
                .forEach(n -> n.replace(toClassOrInterfaceType(className)));
        return statement;
    }

    private static Statement generateEqualsForField(GeneratedMethods.PojoField field) {
        Statement statement;
        if (field.type instanceof ClassOrInterfaceType) {
            statement = parseStatement(" if( __fieldName != null ? !__fieldName.equals(that.__fieldName) : that.__fieldName != null) { return false; }");
        } else if (field.type instanceof ArrayType) {
            Type componentType = ((ArrayType) field.type).getComponentType();
            if (componentType instanceof PrimitiveType) {
                statement = parseStatement(" if( !java.util.Arrays.equals((" + componentType + "[])__fieldName, (" + componentType + "[])that.__fieldName)) { return false; }");
            } else {
                statement = parseStatement(" if( !java.util.Arrays.equals((Object[])__fieldName, (Object[])that.__fieldName)) { return false; }");
            }
        } else if (field.type instanceof PrimitiveType) {
            statement = parseStatement(" if( __fieldName != that.__fieldName) { return false; }");
        } else {
            throw new RuntimeException("Unknown type");
        }
        return replaceFieldName(statement, field.name);
    }

    static MethodDeclaration method(List<GeneratedMethods.PojoField> fields, String generatedClassName, boolean hasSuper) {
        List<Statement> equalsFieldStatement = fields.stream()
                .map( GeneratedEqualsMethod::generateEqualsForField )
                .collect( Collectors.toList());

        NodeList<Statement> equalsStatements = nodeList(referenceEquals, classCheckEquals);
        equalsStatements.add(classCastStatement(generatedClassName));
        if (hasSuper) {
            equalsStatements.add(parseStatement("if ( !super.equals( o ) ) return false;"));
        }
        equalsStatements.addAll(equalsFieldStatement);
        equalsStatements.add(parseStatement("return true;"));

        final Type returnType = parseType(boolean.class.getSimpleName());
        final MethodDeclaration equals = new MethodDeclaration(nodeList(Modifier.publicModifier()), returnType, EQUALS);
        equals.addParameter(Object.class, "o");
        equals.addAnnotation(OVERRIDE);
        equals.setBody(new BlockStmt(equalsStatements));
        return equals;
    }
}
