/*
 * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.VarHandle;
import java.lang.invoke.WrongMethodTypeException;
import java.lang.reflect.Method;
import java.nio.ReadOnlyBufferException;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;
import static org.testng.Assert.*;

abstract class VarHandleBaseTest {
    static final int ITERS = Integer.getInteger("iters", 1);

    interface ThrowingRunnable {
        void run() throws Throwable;
    }

    static void checkUOE(ThrowingRunnable r) {
        checkWithThrowable(UnsupportedOperationException.class, null, r);
    }

    static void checkUOE(Object message, ThrowingRunnable r) {
        checkWithThrowable(UnsupportedOperationException.class, message, r);
    }

    static void checkROBE(ThrowingRunnable r) {
        checkWithThrowable(ReadOnlyBufferException.class, null, r);
    }

    static void checkROBE(Object message, ThrowingRunnable r) {
        checkWithThrowable(ReadOnlyBufferException.class, message, r);
    }

    static void checkIOOBE(ThrowingRunnable r) {
        checkWithThrowable(IndexOutOfBoundsException.class, null, r);
    }

    static void checkIOOBE(Object message, ThrowingRunnable r) {
        checkWithThrowable(IndexOutOfBoundsException.class, message, r);
    }

    static void checkISE(ThrowingRunnable r) {
        checkWithThrowable(IllegalStateException.class, null, r);
    }

    static void checkISE(Object message, ThrowingRunnable r) {
        checkWithThrowable(IllegalStateException.class, message, r);
    }

    static void checkIAE(ThrowingRunnable r) {
        checkWithThrowable(IllegalAccessException.class, null, r);
    }

    static void checkIAE(Object message, ThrowingRunnable r) {
        checkWithThrowable(IllegalAccessException.class, message, r);
    }

    static void checkWMTE(ThrowingRunnable r) {
        checkWithThrowable(WrongMethodTypeException.class, null, r);
    }

    static void checkWMTE(Object message, ThrowingRunnable r) {
        checkWithThrowable(WrongMethodTypeException.class, message, r);
    }

    static void checkCCE(ThrowingRunnable r) {
        checkWithThrowable(ClassCastException.class, null, r);
    }

    static void checkCCE(Object message, ThrowingRunnable r) {
        checkWithThrowable(ClassCastException.class, message, r);
    }

    static void checkNPE(ThrowingRunnable r) {
        checkWithThrowable(NullPointerException.class, null, r);
    }

    static void checkNPE(Object message, ThrowingRunnable r) {
        checkWithThrowable(NullPointerException.class, message, r);
    }

    static void checkWithThrowable(Class<? extends Throwable> re,
                                   Object message,
                                   ThrowingRunnable r) {
        Throwable _e = null;
        try {
            r.run();
        }
        catch (Throwable e) {
            _e = e;
        }
        message = message == null ? "" : message + ". ";
        assertNotNull(_e, String.format("%sNo throwable thrown. Expected %s", message, re));
        assertTrue(re.isInstance(_e), String.format("%sIncorrect throwable thrown, %s. Expected %s", message, _e, re));
    }


    enum TestAccessType {
        get,
        set,
        compareAndSet,
        compareAndExchange,
        getAndSet,
        getAndAdd;
    }

    enum TestAccessMode {
        get(TestAccessType.get),
        set(TestAccessType.set),
        getVolatile(TestAccessType.get),
        setVolatile(TestAccessType.set),
        getAcquire(TestAccessType.get),
        setRelease(TestAccessType.set),
        getOpaque(TestAccessType.get),
        setOpaque(TestAccessType.set),
        compareAndSet(TestAccessType.compareAndSet),
        compareAndExchangeVolatile(TestAccessType.compareAndExchange),
        compareAndExchangeAcquire(TestAccessType.compareAndExchange),
        compareAndExchangeRelease(TestAccessType.compareAndExchange),
        weakCompareAndSet(TestAccessType.compareAndSet),
        weakCompareAndSetAcquire(TestAccessType.compareAndSet),
        weakCompareAndSetRelease(TestAccessType.compareAndSet),
        getAndSet(TestAccessType.getAndSet),
        getAndAdd(TestAccessType.getAndAdd),
        addAndGet(TestAccessType.getAndAdd),;

        final TestAccessType at;
        final boolean isPolyMorphicInReturnType;
        final Class<?> returnType;

        TestAccessMode(TestAccessType at) {
            this.at = at;

            try {
                Method m = VarHandle.class.getMethod(name(), Object[].class);
                this.returnType = m.getReturnType();
                isPolyMorphicInReturnType = returnType != Object.class;
            }
            catch (Exception e) {
                throw new Error(e);
            }
        }

        boolean isOfType(TestAccessType at) {
            return this.at == at;
        }

        VarHandle.AccessMode toAccessMode() {
            return VarHandle.AccessMode.valueOf(name());
        }
    }

    static List<TestAccessMode> testAccessModes() {
        return Stream.of(TestAccessMode.values()).collect(toList());
    }

    static List<TestAccessMode> testAccessModesOfType(TestAccessType... ats) {
        Stream<TestAccessMode> s = Stream.of(TestAccessMode.values());
        for (TestAccessType at : ats) {
            s = s.filter(e -> e.isOfType(at));
        }
        return s.collect(toList());
    }

    static List<VarHandle.AccessMode> accessModes() {
        return Stream.of(VarHandle.AccessMode.values()).collect(toList());
    }

    static List<VarHandle.AccessMode> accessModesOfType(TestAccessType... ats) {
        Stream<TestAccessMode> s = Stream.of(TestAccessMode.values());
        for (TestAccessType at : ats) {
            s = s.filter(e -> e.isOfType(at));
        }
        return s.map(TestAccessMode::toAccessMode).collect(toList());
    }

    static MethodHandle toMethodHandle(VarHandle vh, TestAccessMode tam, MethodType mt) {
        return vh.toMethodHandle(tam.toAccessMode());
    }

    static MethodHandle findVirtual(VarHandle vh, TestAccessMode tam, MethodType mt) {
        mt = vh.accessModeType(tam.toAccessMode());
        MethodHandle mh;
        try {
            mh = MethodHandles.publicLookup().
                    findVirtual(VarHandle.class,
                                tam.name(),
                                mt);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return bind(vh, tam, mh, mt);
    }

    static MethodHandle varHandleInvokerWithAccessModeType(VarHandle vh, TestAccessMode tam, MethodType mt) {
        mt = vh.accessModeType(tam.toAccessMode());
        MethodHandle mh = MethodHandles.varHandleInvoker(
                tam.toAccessMode(),
                mt);

        return bind(vh, tam, mh, mt);
    }

    static MethodHandle varHandleInvokerWithSymbolicTypeDescriptor(VarHandle vh, TestAccessMode tam, MethodType mt) {
        MethodHandle mh = MethodHandles.varHandleInvoker(
                tam.toAccessMode(),
                mt);

        return bind(vh, tam, mh, mt);
    }

    static MethodHandle varHandleExactInvokerWithAccessModeType(VarHandle vh, TestAccessMode tam, MethodType mt) {
        mt = vh.accessModeType(tam.toAccessMode());
        MethodHandle mh = MethodHandles.varHandleExactInvoker(
                tam.toAccessMode(),
                mt);

        return bind(vh, tam, mh, mt);
    }

    private static MethodHandle bind(VarHandle vh, TestAccessMode testAccessMode, MethodHandle mh, MethodType emt) {
        assertEquals(mh.type(), emt.insertParameterTypes(0, VarHandle.class),
                     "MethodHandle type differs from access mode type");

        MethodHandleInfo info = MethodHandles.lookup().revealDirect(mh);
        assertEquals(info.getMethodType(), emt,
                     "MethodHandleInfo method type differs from access mode type");

        return mh.bindTo(vh);
    }

    private interface TriFunction<T, U, V, R> {
        R apply(T t, U u, V v);
    }

    enum VarHandleToMethodHandle {
        VAR_HANDLE_TO_METHOD_HANDLE(
                "VarHandle.toMethodHandle",
                VarHandleBaseTest::toMethodHandle),
        METHOD_HANDLES_LOOKUP_FIND_VIRTUAL(
                "Lookup.findVirtual",
                VarHandleBaseTest::findVirtual),
        METHOD_HANDLES_VAR_HANDLE_INVOKER_WITH_ACCESS_MODE_TYPE(
                "MethodHandles.varHandleInvoker(accessModeType)",
                VarHandleBaseTest::varHandleInvokerWithAccessModeType),
        METHOD_HANDLES_VAR_HANDLE_INVOKER_WITH_SYMBOLIC_TYPE_DESCRIPTOR(
                "MethodHandles.varHandleInvoker(symbolicTypeDescriptor)",
                VarHandleBaseTest::varHandleInvokerWithSymbolicTypeDescriptor),
        METHOD_HANDLES_VAR_HANDLE_EXACT_INVOKER_WITH_ACCESS_MODE_TYPE(
                "MethodHandles.varHandleExactInvoker(accessModeType)",
                VarHandleBaseTest::varHandleExactInvokerWithAccessModeType);

        final String desc;
        final TriFunction<VarHandle, TestAccessMode, MethodType, MethodHandle> f;
        final boolean exact;

        VarHandleToMethodHandle(String desc, TriFunction<VarHandle, TestAccessMode, MethodType, MethodHandle> f) {
            this(desc, f, false);
        }

        VarHandleToMethodHandle(String desc, TriFunction<VarHandle, TestAccessMode, MethodType, MethodHandle> f,
                                boolean exact) {
            this.desc = desc;
            this.f = f;
            this.exact = exact;
        }

        MethodHandle apply(VarHandle vh, TestAccessMode am, MethodType mt) {
            return f.apply(vh, am, mt);
        }

        @Override
        public String toString() {
            return desc;
        }
    }

    static class Handles {
        static class AccessModeAndType {
            final TestAccessMode tam;
            final MethodType t;

            public AccessModeAndType(TestAccessMode tam, MethodType t) {
                this.tam = tam;
                this.t = t;
            }

            @Override
            public boolean equals(Object o) {
                if (this == o) return true;
                if (o == null || getClass() != o.getClass()) return false;

                AccessModeAndType x = (AccessModeAndType) o;

                if (tam != x.tam) return false;
                if (t != null ? !t.equals(x.t) : x.t != null) return false;

                return true;
            }

            @Override
            public int hashCode() {
                int result = tam != null ? tam.hashCode() : 0;
                result = 31 * result + (t != null ? t.hashCode() : 0);
                return result;
            }
        }

        final VarHandle vh;
        final VarHandleToMethodHandle f;
        final EnumMap<TestAccessMode, MethodType> amToType;
        final Map<AccessModeAndType, MethodHandle> amToHandle;

        Handles(VarHandle vh, VarHandleToMethodHandle f) throws Exception {
            this.vh = vh;
            this.f = f;
            this.amToHandle = new HashMap<>();

            amToType = new EnumMap<>(TestAccessMode.class);
            for (TestAccessMode am : testAccessModes()) {
                amToType.put(am, vh.accessModeType(am.toAccessMode()));
            }
        }

        MethodHandle get(TestAccessMode am) {
            return get(am, amToType.get(am));
        }

        MethodHandle get(TestAccessMode am, MethodType mt) {
            AccessModeAndType amt = new AccessModeAndType(am, mt);
            return amToHandle.computeIfAbsent(
                    amt, k -> f.apply(vh, am, mt));
        }
    }

    interface AccessTestAction<T> {
        void action(T t) throws Throwable;
    }

    static abstract class AccessTestCase<T> {
        final String desc;
        final AccessTestAction<T> ata;
        final boolean loop;

        AccessTestCase(String desc, AccessTestAction<T> ata, boolean loop) {
            this.desc = desc;
            this.ata = ata;
            this.loop = loop;
        }

        boolean requiresLoop() {
            return loop;
        }

        abstract T get() throws Exception;

        void testAccess(T t) throws Throwable {
            ata.action(t);
        }

        @Override
        public String toString() {
            return desc;
        }
    }

    static class VarHandleAccessTestCase extends AccessTestCase<VarHandle> {
        final VarHandle vh;

        VarHandleAccessTestCase(String desc, VarHandle vh, AccessTestAction<VarHandle> ata) {
            this(desc, vh, ata, true);
        }

        VarHandleAccessTestCase(String desc, VarHandle vh, AccessTestAction<VarHandle> ata, boolean loop) {
            super("VarHandle -> " + desc, ata, loop);
            this.vh = vh;
        }

        @Override
        VarHandle get() {
            return vh;
        }
    }

    static class MethodHandleAccessTestCase extends AccessTestCase<Handles> {
        final VarHandle vh;
        final VarHandleToMethodHandle f;

        MethodHandleAccessTestCase(String desc, VarHandle vh, VarHandleToMethodHandle f, AccessTestAction<Handles> ata) {
            this(desc, vh, f, ata, true);
        }

        MethodHandleAccessTestCase(String desc, VarHandle vh, VarHandleToMethodHandle f, AccessTestAction<Handles> ata, boolean loop) {
            super("VarHandle -> " + f.toString() + " -> " + desc, ata, loop);
            this.vh = vh;
            this.f = f;
        }

        @Override
        Handles get() throws Exception {
            return new Handles(vh, f);
        }
    }

    static void testTypes(VarHandle vh) {
        List<Class<?>> pts = vh.coordinateTypes();

        for (TestAccessMode accessMode : testAccessModes()) {
            MethodType amt = vh.accessModeType(accessMode.toAccessMode());

            assertEquals(amt.parameterList().subList(0, pts.size()), pts);
        }

        for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.get)) {
            MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
            assertEquals(mt.returnType(), vh.varType());
            assertEquals(mt.parameterList(), pts);
        }

        for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.set)) {
            MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
            assertEquals(mt.returnType(), void.class);
            assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
        }

        for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.compareAndSet)) {
            MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
            assertEquals(mt.returnType(), boolean.class);
            assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
            assertEquals(mt.parameterType(mt.parameterCount() - 2), vh.varType());
        }

        for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.compareAndExchange)) {
            MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
            assertEquals(mt.returnType(), vh.varType());
            assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
            assertEquals(mt.parameterType(mt.parameterCount() - 2), vh.varType());
        }

        for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.getAndSet, TestAccessType.getAndAdd)) {
            MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
            assertEquals(mt.returnType(), vh.varType());
            assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
        }
    }
}