/*
 * Decompiled with CFR 0.152.
 */
package mb.statix.spec;

import io.usethesource.capsule.Set;
import io.usethesource.capsule.SetMultimap;
import jakarta.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import mb.nabl2.terms.ITerm;
import mb.nabl2.terms.ITermVar;
import mb.nabl2.terms.build.TermBuild;
import mb.nabl2.terms.matching.Pattern;
import mb.nabl2.terms.matching.TermPattern;
import mb.nabl2.terms.substitution.FreshVars;
import mb.nabl2.terms.substitution.IRenaming;
import mb.nabl2.terms.substitution.ISubstitution;
import mb.nabl2.terms.substitution.PersistentSubstitution;
import mb.nabl2.terms.unification.OccursException;
import mb.nabl2.terms.unification.u.IUnifier;
import mb.nabl2.terms.unification.ud.Diseq;
import mb.nabl2.terms.unification.ud.IUniDisunifier;
import mb.nabl2.terms.unification.ud.PersistentUniDisunifier;
import mb.statix.constraints.CConj;
import mb.statix.constraints.CEqual;
import mb.statix.constraints.CExists;
import mb.statix.constraints.CUser;
import mb.statix.constraints.Constraints;
import mb.statix.solver.IConstraint;
import mb.statix.solver.StateUtil;
import mb.statix.spec.ApplyMode;
import mb.statix.spec.ApplyResult;
import mb.statix.spec.PreSolvedConstraint;
import mb.statix.spec.Rule;
import mb.statix.spec.RuleName;
import mb.statix.spec.RuleSet;
import org.metaborg.util.collection.CapsuleUtil;
import org.metaborg.util.collection.ImList;
import org.metaborg.util.collection.MultiSet;
import org.metaborg.util.functions.Action1;
import org.metaborg.util.functions.PartialFunction1;
import org.metaborg.util.functions.Predicate1;
import org.metaborg.util.functions.Predicate2;
import org.metaborg.util.tuple.Tuple2;
import org.metaborg.util.tuple.Tuple3;

public final class RuleUtil {
    private RuleUtil() {
    }

    public static <E extends Throwable> Optional<Tuple3<Rule, ApplyResult, Boolean>> applyOrderedOne(IUniDisunifier.Immutable state, ImList.Immutable<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause, ApplyMode<E> mode, ApplyMode.Safety safety, boolean trackOrigins) throws E {
        List<Tuple2<Rule, ApplyResult>> results = RuleUtil.applyOrdered(state, rules, args, cause, mode, safety, true, trackOrigins);
        if (results.size() == 0) {
            return Optional.empty();
        }
        Tuple2<Rule, ApplyResult> result = results.get(0);
        return Optional.of(Tuple3.of(result._1(), result._2(), results.size() == 1));
    }

    public static <E extends Throwable> List<Tuple2<Rule, ApplyResult>> applyOrderedAll(IUniDisunifier.Immutable state, ImList.Immutable<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause, ApplyMode<E> mode, ApplyMode.Safety safety, boolean trackOrigins) throws E {
        return RuleUtil.applyOrdered(state, rules, args, cause, mode, safety, false, trackOrigins);
    }

    private static <E extends Throwable> List<Tuple2<Rule, ApplyResult>> applyOrdered(IUniDisunifier.Immutable unifier, ImList.Immutable<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause, ApplyMode<E> mode, ApplyMode.Safety safety, boolean onlyOne, boolean trackOrigins) throws E {
        ImList.Mutable<Tuple2> results = ImList.Mutable.of(new Tuple2[0]);
        AtomicBoolean foundOne = new AtomicBoolean(false);
        for (Rule rule : rules) {
            Tuple3 guard;
            ApplyResult applyResult = RuleUtil.apply(unifier, rule, args, cause, mode, safety, trackOrigins).orElse(null);
            if (applyResult == null) continue;
            results.add(Tuple2.of(rule, applyResult));
            if (onlyOne && foundOne.getAndSet(true) || (guard = (Tuple3)applyResult.guard().map(Diseq::toTuple).orElse(null)) == null) break;
            Optional<IUniDisunifier.Immutable> newUnifier = unifier.disunify((Iterable<ITermVar>)((Iterable)guard._1()), (ITerm)guard._2(), (ITerm)guard._3()).map(IUniDisunifier.Result::unifier);
            if (!newUnifier.isPresent()) {
                throw new IllegalStateException("Unexpected incompatible guard.");
            }
            unifier = newUnifier.get();
        }
        return results.freeze();
    }

    public static <E extends Throwable> Optional<ApplyResult> apply(IUniDisunifier.Immutable unifier, Rule rule, List<? extends ITerm> args, @Nullable IConstraint cause, ApplyMode<E> mode, ApplyMode.Safety safety) throws E {
        return RuleUtil.apply(unifier, rule, args, cause, mode, safety, false);
    }

    public static <E extends Throwable> Optional<ApplyResult> apply(IUniDisunifier.Immutable unifier, Rule rule, List<? extends ITerm> args, @Nullable IConstraint cause, ApplyMode<E> mode, ApplyMode.Safety safety, boolean trackOrigins) throws E {
        return mode.apply(unifier, rule, args, cause, safety, trackOrigins);
    }

    public static <E extends Throwable> List<Tuple2<Rule, ApplyResult>> applyAll(IUniDisunifier.Immutable state, Collection<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause, ApplyMode<E> mode, ApplyMode.Safety safety, boolean trackOrigins) throws E {
        ImList.Mutable<Tuple2> results = ImList.Mutable.of(new Tuple2[0]);
        for (Rule rule : rules) {
            ApplyResult result = RuleUtil.apply(state, rule, args, cause, mode, safety, trackOrigins).orElse(null);
            if (result == null) continue;
            results.add(Tuple2.of(rule, result));
        }
        return results.freeze();
    }

    public static Set.Immutable<Rule> computeOrderIndependentRules(ImList.Immutable<Rule> rules) {
        Set.Transient newRules = CapsuleUtil.transientSet();
        ArrayList<Tuple3<Set.Immutable, ITerm, IUniDisunifier.Immutable>> guards = new ArrayList<Tuple3<Set.Immutable, ITerm, IUniDisunifier.Immutable>>();
        block4: for (Rule rule : rules) {
            Set.Immutable<ITermVar> ruleParamVars = rule.paramVars();
            FreshVars fresh = new FreshVars(new Set[]{rule.freeVars(), ruleParamVars});
            ArrayList<ITerm> paramTerms = new ArrayList<ITerm>();
            IUniDisunifier.Transient _paramsUnifier = PersistentUniDisunifier.Immutable.of().melt();
            for (Pattern param : rule.params()) {
                Tuple2<ITerm, List<Tuple2<ITermVar, ITerm>>> paramTerm = param.asTerm(v -> v.orElseGet(() -> fresh.fresh("_")));
                paramTerms.add(paramTerm._1());
                try {
                    if (_paramsUnifier.unify((Iterable<? extends Map.Entry<? extends ITerm, ? extends ITerm>>)paramTerm._2()).isPresent()) continue;
                }
                catch (OccursException ex) {}
                continue block4;
            }
            Set.Immutable paramVars = fresh.fix().__insertAll(ruleParamVars);
            ITerm paramsTerm = TermBuild.B.newTuple(paramTerms);
            IUniDisunifier.Immutable paramsUnifier = _paramsUnifier.freeze();
            IUniDisunifier.Transient _unifier = paramsUnifier.melt();
            for (Tuple3 tuple3 : guards) {
                IRenaming guardRen = fresh.fresh((Set)tuple3._1());
                Set.Immutable<ITermVar> guardVars = fresh.reset();
                ITerm guardTerm = guardRen.apply((ITerm)tuple3._2());
                IUnifier.Immutable guardUnifier = ((IUnifier.Immutable)tuple3._3()).rename(guardRen);
                try {
                    guardUnifier = guardUnifier.unify(paramsTerm, guardTerm).map(r -> r.unifier()).orElse(null);
                    if (guardUnifier == null) {
                    }
                }
                catch (OccursException ex) {}
                continue;
                if (!_unifier.disunify((Iterable<ITermVar>)guardVars, guardUnifier).isPresent()) continue block4;
            }
            IUniDisunifier.Immutable immutable = _unifier.freeze();
            Tuple3<Set.Immutable, ITerm, IUniDisunifier.Immutable> guard = Tuple3.of(paramVars, paramsTerm, paramsUnifier);
            guards.add(guard);
            ImList.Immutable<Pattern> params = paramTerms.stream().map(TermPattern.P::fromTerm).collect(ImList.Immutable.toImmutableList());
            Set.Immutable newBodyVars = paramVars.__removeAll(paramsTerm.getVars());
            IConstraint body = Constraints.exists((Iterable<ITermVar>)newBodyVars, Constraints.conjoin(StateUtil.asConstraint(immutable), rule.body()));
            Rule newRule = Rule.builder().from(rule).params(params).body(body).bodyCriticalEdges(rule.bodyCriticalEdges()).build();
            newRules.__insert((Object)newRule);
        }
        return newRules.freeze();
    }

    public static Optional<Rule> inline(Rule rule, int ith, Rule into) {
        FreshVars fresh = new FreshVars((Iterable<ITermVar>)RuleUtil.vars(into));
        AtomicInteger i = new AtomicInteger(0);
        IConstraint newBody = Constraints.map(c -> {
            if (!(c instanceof CUser)) {
                return c;
            }
            CUser constraint = (CUser)c;
            if (!constraint.name().equals(rule.name())) {
                return c;
            }
            if (i.getAndIncrement() != ith) {
                return c;
            }
            return RuleUtil.applyToConstraint(fresh, rule, constraint.args());
        }, false).apply(into.body());
        if (i.get() <= ith) {
            return Optional.empty();
        }
        return Optional.of(into.withBody(newBody));
    }

    private static IConstraint applyToConstraint(FreshVars fresh, Rule rule, List<? extends ITerm> args) {
        IRenaming swap = fresh.fresh((Set<ITermVar>)rule.paramVars());
        Pattern rulePatterns = TermPattern.P.newTuple(rule.params()).eliminateWld(() -> fresh.fresh("_"));
        Tuple2<ITerm, List<Tuple2<ITermVar, ITerm>>> p_eqs = rulePatterns.asTerm(v -> (ITermVar)v.get());
        ITerm p = swap.apply(p_eqs._1());
        ITerm t = TermBuild.B.newTuple(args);
        List c_eqs = p_eqs._2().stream().map(e -> new CEqual((ITerm)e._1(), swap.apply((ITerm)e._2()))).collect(Collectors.toList());
        IConstraint eq = Constraints.conjoin(c_eqs, new CEqual(t, p));
        Set.Immutable<ITermVar> newVars = fresh.reset();
        IConstraint newConstraint = Constraints.exists(newVars, new CConj(eq, rule.body().apply(swap)));
        return newConstraint;
    }

    public static Rule hoist(Rule rule) {
        PreSolvedConstraint preSolvedBody = PreSolvedConstraint.of(rule.body()).cleanup();
        return rule.withBody(preSolvedBody.toConstraint());
    }

    public static Rule instantiateHeadPatterns(Rule rule) {
        Set.Immutable<ITermVar> paramVars = rule.paramVars();
        FreshVars fresh = new FreshVars(new Set[]{rule.freeVars(), paramVars});
        ArrayList<ITerm> paramTerms = new ArrayList<ITerm>();
        IUniDisunifier.Transient _paramsUnifier = PersistentUniDisunifier.Immutable.of().melt();
        for (Pattern param : rule.params()) {
            Tuple2<ITerm, List<Tuple2<ITermVar, ITerm>>> paramTerm = param.asTerm(v -> v.orElseGet(() -> fresh.fresh("_")));
            paramTerms.add(paramTerm._1());
            try {
                if (_paramsUnifier.unify((Iterable<? extends Map.Entry<? extends ITerm, ? extends ITerm>>)paramTerm._2()).isPresent()) continue;
                return rule;
            }
            catch (OccursException ex) {
                return rule;
            }
        }
        fresh.fix();
        IUniDisunifier.Immutable paramsUnifier = _paramsUnifier.freeze();
        PreSolvedConstraint body = PreSolvedConstraint.of(rule.body());
        PreSolvedConstraint internedBody = body.intern(CapsuleUtil.immutableSet(), paramsUnifier);
        Tuple2<ISubstitution.Immutable, PreSolvedConstraint> externResult = internedBody.extern((Iterable<ITermVar>)paramVars);
        PreSolvedConstraint externedBody = externResult._2();
        PreSolvedConstraint finalBody = externedBody.cleanup();
        ArrayList<ITerm> newParamTerms = new ArrayList<ITerm>();
        MultiSet.Transient newParamVars = MultiSet.Transient.of();
        for (ITerm paramTerm : paramTerms) {
            ITerm newParamTerm = externResult._1().apply(paramTerm);
            newParamTerms.add(newParamTerm);
            newParamTerm.visitVars(newParamVars::add);
        }
        ImList.Immutable<Pattern> params = newParamTerms.stream().map(t -> TermPattern.P.fromTerm((ITerm)t, v -> !finalBody.freeVars().contains(v) && newParamVars.count(v) <= 1)).collect(ImList.Immutable.toImmutableList());
        return Rule.builder().from(rule).params(params).body(finalBody.toConstraint()).build();
    }

    public static Rule closeInUnifier(Rule rule, IUnifier.Immutable unifier, ApplyMode.Safety safety) {
        ISubstitution.Immutable subst = PersistentSubstitution.Immutable.of();
        for (ITermVar var : rule.freeVars()) {
            subst = subst.put(var, unifier.findRecursive(var));
        }
        Rule newRule = safety.equals((Object)ApplyMode.Safety.UNSAFE) ? rule.unsafeApply(subst) : rule.apply(subst);
        return RuleUtil.hoist(newRule);
    }

    public static SetMultimap.Immutable<String, Rule> makeFragments(RuleSet rules, Predicate1<String> includePredicate, Predicate2<String, RuleName> includeRule, int generations) {
        SetMultimap.Transient fragments = SetMultimap.Transient.of();
        SetMultimap.Transient newRules = SetMultimap.Transient.of();
        for (String ruleName : rules.getRuleNames()) {
            if (!includePredicate.test(ruleName)) continue;
            for (Object r : rules.getOrderIndependentRules(ruleName)) {
                if (!includeRule.test(((Rule)r).name(), ((Rule)r).label())) continue;
                newRules.__insert((Object)ruleName, r);
            }
        }
        PartialFunction1 expandable = c -> c instanceof CUser && newRules.containsKey((Object)((CUser)c).name()) ? Optional.of((CUser)c) : Optional.empty();
        for (Map.Entry e2 : newRules.entrySet()) {
            if (!Constraints.collectBase(expandable, false).apply(((Rule)e2.getValue()).body()).isEmpty()) continue;
            fragments.__insert((Object)((String)e2.getKey()), (Object)((Rule)e2.getValue()));
        }
        fragments.entrySet().forEach(e -> {
            boolean bl = newRules.__remove((Object)((String)e.getKey()), (Object)((Rule)e.getValue()));
        });
        int g = 0;
        while (g < generations) {
            SetMultimap.Transient generation = SetMultimap.Transient.of();
            for (Map.Entry e3 : newRules.entrySet()) {
                String name = (String)e3.getKey();
                Rule r = (Rule)e3.getValue();
                FreshVars fresh = new FreshVars((Iterable<ITermVar>)RuleUtil.vars(r));
                List cs = Constraints.flatMap(c -> {
                    Optional u = (Optional)expandable.apply(c);
                    if (u.isPresent()) {
                        return fragments.get((Object)((CUser)u.get()).name()).stream().map(f -> RuleUtil.applyToConstraint(fresh, f, ((CUser)u.get()).args()));
                    }
                    return Stream.of(c);
                }, false).apply(r.body()).collect(Collectors.toList());
                for (IConstraint c2 : cs) {
                    Rule f = r.withLabel(RuleName.empty()).withBody(new CExists((Iterable<ITermVar>)CapsuleUtil.immutableSet(), c2));
                    generation.__insert((Object)name, (Object)RuleUtil.hoist(f));
                }
            }
            CapsuleUtil.putAll(fragments, generation);
            ++g;
        }
        return fragments.freeze();
    }

    public static Set.Immutable<ITermVar> vars(Rule rule) {
        Set.Transient vars = CapsuleUtil.transientSet();
        RuleUtil.vars(rule, arg_0 -> vars.__insert(arg_0));
        return vars.freeze();
    }

    public static void vars(Rule rule, Action1<ITermVar> onVar) {
        rule.paramVars().forEach(onVar::apply);
        Constraints.vars(rule.body(), onVar);
    }
}

