TaylorSubstituter.java

package edu.udel.cis.vsl.sarl.reason.common;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import edu.udel.cis.vsl.sarl.IF.Reasoner;
import edu.udel.cis.vsl.sarl.IF.UnaryOperator;
import edu.udel.cis.vsl.sarl.IF.expr.BooleanExpression;
import edu.udel.cis.vsl.sarl.IF.expr.NumericExpression;
import edu.udel.cis.vsl.sarl.IF.expr.NumericSymbolicConstant;
import edu.udel.cis.vsl.sarl.IF.expr.SymbolicConstant;
import edu.udel.cis.vsl.sarl.IF.expr.SymbolicExpression;
import edu.udel.cis.vsl.sarl.IF.expr.SymbolicExpression.SymbolicOperator;
import edu.udel.cis.vsl.sarl.IF.object.IntObject;
import edu.udel.cis.vsl.sarl.IF.object.SymbolicObject;
import edu.udel.cis.vsl.sarl.IF.type.SymbolicFunctionType;
import edu.udel.cis.vsl.sarl.IF.type.SymbolicType;
import edu.udel.cis.vsl.sarl.IF.type.SymbolicType.SymbolicTypeKind;
import edu.udel.cis.vsl.sarl.IF.type.SymbolicTypeSequence;
import edu.udel.cis.vsl.sarl.object.IF.ObjectFactory;
import edu.udel.cis.vsl.sarl.preuniverse.IF.PreUniverse;
import edu.udel.cis.vsl.sarl.preuniverse.common.ExpressionSubstituter;
import edu.udel.cis.vsl.sarl.type.IF.SymbolicTypeFactory;

/**
 * A substituter that replaces certain function calls with their truncated
 * Taylor polynomial expansions.
 * 
 * @author siegel
 */
public class TaylorSubstituter extends ExpressionSubstituter {

	/**
	 * Reasoner containing expanded context, which includes assumptions on index
	 * variables.
	 */
	private Reasoner reasoner;

	/**
	 * The variables (h0, h1, ...) tending to 0. These are essentially the bound
	 * variables of a big-O expression. The length of this array is n, where n
	 * is the dimension of the domain.
	 */
	private NumericSymbolicConstant[] limitVars;

	/**
	 * The orders of the corresponding bound variables. Also an array of length
	 * n, where n is the dimension of the domain and the length of
	 * {@link #limitVars}. These are nonnegative integers. If orders[i] is k,
	 * that means we are trying to prove that something is O(hi^k).
	 */
	private int[] orders;

	/**
	 * Not sure if this is really needed, but the super-class requires some
	 * state.
	 */
	private SubstituterState trivialState = new SubstituterState() {
		@Override
		public boolean isInitial() {
			return true;
		}
	};

	/**
	 * Creates new {@link TaylorSubstituter}.
	 * 
	 * @param universe
	 *            the symbolic universe to be used to create new
	 *            {@link SymbolicExpression}s
	 * @param objectFactory
	 *            the {@link SymbolicObjectFactory} used to create new
	 *            {@link SymbolicObject}s
	 * @param typeFactory
	 *            the {@link SymbolicTypeFactory} used to create new
	 *            {@link SymbolicType}s
	 * @param reasoner
	 *            the {@link Reasoner} with the expanded context that includes
	 *            assumptions on the integer bound variables that index the grid
	 *            points
	 * @param limitVars
	 *            the real variables that are tending to 0: h0, h1, ...; an
	 *            array of length n, where n is the dimension of the domain
	 * @param orders
	 *            the "orders" of the corresponding <code>limitVars</code>;
	 *            i.e., the nonnegative integers n0, n1, ..., where we are
	 *            trying to prove O(h0^n0)+O(h1^n1)+...
	 * @param lowerBounds
	 *            lower bounds of rectangular domain (length n)
	 * @param upperBounds
	 *            upper bounds of rectangular domain (length n)
	 */
	public TaylorSubstituter(PreUniverse universe, ObjectFactory objectFactory,
			SymbolicTypeFactory typeFactory, Reasoner reasoner,
			NumericSymbolicConstant[] limitVars, int[] orders) {
		super(universe, objectFactory, typeFactory);
		this.reasoner = reasoner;
		this.limitVars = limitVars;
		this.orders = orders;
	}

	@Override
	protected SubstituterState newState() {
		return trivialState;
	}

	/**
	 * Is the expression one of : h, C*h, or h*C where C is a concrete real?
	 * 
	 * @param expr
	 *            the expression you want to check to determine if it is a
	 *            constant multiple of <code>h</code>
	 * @param h
	 *            a symbolic constant of real type
	 * @return <code>true</code> iff <code>expr</code> is h, C*h, or h*C.
	 */
	private boolean isConstantMultiple(NumericExpression expr,
			NumericSymbolicConstant h) {
		if (expr.equals(h))
			return true;
		if (expr.operator() == SymbolicOperator.MULTIPLY) {
			int numArgs = expr.numArguments();

			if (numArgs == 2) {
				NumericExpression arg0 = (NumericExpression) expr.argument(0),
						arg1 = (NumericExpression) expr.argument(1);

				if (arg0.operator() == SymbolicOperator.CONCRETE
						&& arg1.equals(h))
					return true;
				if (arg1.operator() == SymbolicOperator.CONCRETE
						&& arg0.equals(h))
					return true;
			} else { // numArgs == 1
				@SuppressWarnings("unchecked")
				Iterable<? extends SymbolicExpression> args = (Iterable<? extends SymbolicExpression>) expr
						.argument(0);
				boolean foundH = false, foundC = false;

				for (SymbolicExpression arg : args) {
					if (arg.operator() == SymbolicOperator.CONCRETE && !foundC)
						foundC = true;
					else if (arg.equals(h) && !foundH)
						foundH = true;
					else
						return false;
				}
				if (foundH)
					return true;
			}
		}
		return false;
	}

	class ExpansionSpec {
		/**
		 * Index of the component in <code>point</code> (the function call
		 * argument list). Indexed from 0
		 */
		int argumentIndex;

		/**
		 * Index in {@link #limitVars}, specifying the limiting variable that
		 * occurs in the argument.
		 */
		int limitVarIndex;

		/**
		 * The actual summand in the argument which is a multiple of the
		 * limiting variable h.
		 */
		NumericExpression hTerm;

		/**
		 * The argument minus <code>hTerm</code>. Hence the argument is the sum
		 * of <code>remains</code> and <code>hTerm</code>.
		 */
		NumericExpression remains;
	}

	private ExpansionSpec findExpansionPoint(NumericExpression[] point,
			int maxDegree) {
		int n = point.length;
		ExpansionSpec result;

		for (int i = 0; i < n; i++) {// look for an index i that can be expanded
			NumericExpression arg = point[i];
			NumericExpression[] terms = universe.expand(arg);

			for (NumericExpression term : terms) {
				for (int j = 0; j < limitVars.length; j++) {
					if (orders[j] <= maxDegree
							&& isConstantMultiple(term, limitVars[j])) {
						result = new ExpansionSpec();
						result.argumentIndex = i;
						result.limitVarIndex = j;
						result.hTerm = term;
						result.remains = universe.subtract(arg, term);
						return result;
					}
				}
			}
			// } else if (op == SymbolicOperator.SUBTRACT) {
			// NumericExpression arg0 = (NumericExpression) arg.argument(0),
			// arg1 = (NumericExpression) arg.argument(1);
			//
			// for (int j = 0; j < limitVars.length; j++) {
			// if (orders[j] > maxDegree)
			// continue;
			//
			// NumericSymbolicConstant h = limitVars[j];
			//
			// if (isConstantMultiple(arg0, h)) {
			// result = new ExpansionSpec();
			// result.argumentIndex = i;
			// result.limitVarIndex = j;
			// result.hTerm = arg0;
			// result.remains = universe.minus(arg1);
			// return result;
			// } else if (isConstantMultiple(arg1, h)) {
			// result = new ExpansionSpec();
			// result.argumentIndex = i;
			// result.limitVarIndex = j;
			// result.hTerm = universe.minus(arg1);
			// result.remains = arg0;
			// return result;
			// }
			// }
			// }
		} // end loop over arguments
		return null;
	}

	/**
	 * Attempts to perform a Taylor expansion of the given function evaluated at
	 * a point in R^n. This method looks for an appropriate component to expand.
	 * If i-th argument is a sum in which a term is a constant multiple of one
	 * of the accuracy variables (limitVars), it satisfies the heuristic. The
	 * least such is chosen. The expansion is truncated according to the degree
	 * of that limit variable. {@link #orders}. If no such i is found, returns
	 * <code>null</code>.
	 * 
	 * @param function
	 *            a function which accepts n real inputs (for some positive
	 *            integer n) and returns real
	 * @param maxDegree
	 *            the maximum number of derivatives that can be taken of
	 *            <code>function</code>
	 * @param point
	 *            an array of length n consisting of the arguments to
	 *            <code>function</code>; this is the "point" in R^n at which
	 *            <code>function</code> is evaluated
	 * @return a truncated Taylor expansion or <code>null</code>
	 */
	private NumericExpression taylorExpansion(SymbolicExpression function,
			int maxDegree, NumericExpression[] point) {
		ExpansionSpec spec = findExpansionPoint(point, maxDegree);

		if (spec == null)
			return null;

		int order = orders[spec.limitVarIndex];
		NumericExpression result = universe.zeroReal();
		int n = point.length;

		if (order >= 1) {
			IntObject indexObj = universe.intObject(spec.argumentIndex);
			int j = 0;
			NumericExpression hPower = universe.oneReal(); // hTerm^j
			int jFactorial = 1; // j!
			List<NumericExpression> newArgs = new LinkedList<>();

			for (int i = 0; i < n; i++)
				newArgs.add(i == spec.argumentIndex ? spec.remains : point[i]);
			while (true) {
				SymbolicExpression deriv = j == 0 ? function
						: universe.derivative(function, indexObj,
								universe.intObject(j));
				NumericExpression derivApplication = (NumericExpression) universe
						.apply(deriv, newArgs);

				result = universe.add(result,
						universe.divide(
								universe.multiply(derivApplication, hPower),
								universe.rational(jFactorial)));
				j++;
				if (j == order)
					break;
				hPower = universe.multiply(hPower, spec.hTerm);
				jFactorial *= j;
			}
		}
		return result;
	}

	/**
	 * Attempts to find, in the context of the {@link #reasoner}, a clause which
	 * states the differentiability of the given <code>function</code>. This is
	 * a clause with operator {@link SymbolicOperator#DIFFERENTIABLE} and with
	 * the function argument (argument 0) equal to <code>function</code>.
	 * 
	 * @param function
	 *            the function for which a differentiability claim is sought
	 * @return a clause in the context dealing with the differentiability of
	 *         <code>function</code>, or <code>null</code> if no such clause is
	 *         found.
	 */
	private BooleanExpression findDifferentiableClaim(
			SymbolicExpression function) {
		for (BooleanExpression clause : reasoner.getReducedContext()
				.getClauses()) {
			if (clause.operator() != SymbolicOperator.DIFFERENTIABLE)
				continue;

			if (clause.argument(0).equals(function))
				return clause;
		}
		return null;
	}

	private NumericExpression[] toArray(SymbolicObject sequence, int length) {
		int count = 0;
		@SuppressWarnings("unchecked")
		Iterable<? extends NumericExpression> iterable = (Iterable<? extends NumericExpression>) sequence;
		NumericExpression[] result = new NumericExpression[length];

		for (NumericExpression x : iterable) {
			result[count] = x;
			count++;
		}
		return result;
	}

	/**
	 * Determines whether this is a function from R^n (for some n) to R.
	 * 
	 * @param function
	 *            a symbolic expression, non-<code>null</code>
	 * @return the number of inputs n, if <code>function</code> is a function
	 *         from R^n to R, else -1
	 */
	private int getNumRealFunctionInputs(SymbolicExpression function) {
		if (function.type().typeKind() != SymbolicTypeKind.FUNCTION)
			return -1;

		SymbolicFunctionType functionType = (SymbolicFunctionType) function
				.type();
		SymbolicTypeSequence inputTypes = functionType.inputTypes();
		int n = inputTypes.numTypes();

		if (!functionType.outputType().isReal())
			return -1;
		for (int i = 0; i < n; i++) {
			if (!inputTypes.getType(i).isReal())
				return -1;
		}
		return n;
	}

	/**
	 * Checks that the point is in the domain defined by the given bounds.
	 * 
	 * @param point
	 *            a point (x_i) in R^n
	 * @param lowerBounds
	 *            a point (a_i) in R^n
	 * @param upperBounds
	 *            a point (b_i) in R^n
	 * @return true if it can be proved a_i<x_i<b_i for all i.
	 */
	private boolean checkDomain(NumericExpression[] point,
			NumericExpression[] lowerBounds, NumericExpression[] upperBounds) {
		// need to substitute 0 for all limitVars in the argArray
		// before checking the arguments are in range. Otherwise, without
		// any restriction on limitVars, who knows.
		// Actually you should be taking limit as h->0, but for now...
		int n = point.length;
		Map<SymbolicExpression, SymbolicExpression> zeroMap = new HashMap<>();
		NumericExpression zero = universe.zeroReal();

		for (NumericSymbolicConstant limitVar : limitVars) {
			zeroMap.put(limitVar, zero);
		}

		UnaryOperator<SymbolicExpression> zeroSubber = universe
				.mapSubstituter(zeroMap);

		for (int i = 0; i < n; i++) {
			NumericExpression arg = (NumericExpression) zeroSubber
					.apply(point[i]);
			BooleanExpression inDomain = universe.and(
					universe.lessThan(lowerBounds[i], arg),
					universe.lessThan(arg, upperBounds[i]));

			if (!reasoner.isValid(inDomain))
				return false;
		}
		return true;
	}

	/**
	 * Given a function from R^n to R for some n, and a nonnegative integer d,
	 * and point x in R^n, determines if the application of a d-th derivative of
	 * f to x is bounded.
	 * 
	 * @param function
	 * @param degree
	 * @param point
	 * @return
	 */
	private boolean isBoundedApplicationOfDeriv(SymbolicExpression function,
			int degree, NumericExpression[] point) {
		BooleanExpression claim = this.findDifferentiableClaim(function);

		if (claim == null)
			return false;

		int degree1 = ((IntObject) claim.argument(1)).getInt();

		if (degree > degree1)
			return false;

		int n = point.length;
		NumericExpression[] lowerBounds = toArray(claim.argument(2), n);
		NumericExpression[] upperBounds = toArray(claim.argument(3), n);

		return checkDomain(point, lowerBounds, upperBounds);
	}

	/**
	 * Given a function from R^n to R, and a point in R^n, determines if this is
	 * a bounded application of the function.
	 * 
	 * @param function
	 *            function from R^n to R
	 * @param point
	 *            in R^n
	 * @return
	 */
	private boolean isBoundedApplication(SymbolicExpression function,
			NumericExpression[] point) {
		SymbolicOperator op = function.operator();

		if (op == SymbolicOperator.DERIV) {
			SymbolicExpression f0 = (SymbolicExpression) function.argument(0);
			int degree0 = ((IntObject) function.argument(2)).getInt();

			return isBoundedApplicationOfDeriv(f0, degree0, point);
		} else if (op == SymbolicOperator.SYMBOLIC_CONSTANT) {
			return isBoundedApplicationOfDeriv(function, 0, point);
		} else {
			return false;
		}
	}

	private boolean isBounded(NumericExpression expr) {
		switch (expr.operator()) {
		case ADD:
		case SUBTRACT:
		case MULTIPLY:
		case NEGATIVE: {
			int n = expr.numArguments();

			for (int i = 0; i < n; i++) {
				if (!isBounded((NumericExpression) expr.argument(i)))
					return false;
			}
			return true;
		}
		case DIVIDE: {
			if (!isBounded((NumericExpression) expr.argument(0)))
				return false;
			return ((NumericExpression) expr.argument(1))
					.operator() == SymbolicOperator.CONCRETE;
		}
		case CONCRETE:
			return true;
		case SYMBOLIC_CONSTANT: {
			for (SymbolicConstant h : limitVars)
				if (h.equals(expr))
					return true;
		}
		case APPLY: {
			// is this a function from R^n to R?
			SymbolicExpression f = (SymbolicExpression) expr.argument(0);
			int n = getNumRealFunctionInputs(f);

			if (n < 0)
				return false;
			return isBoundedApplication(f, toArray(expr.argument(1), n));
		}
		default:
			return false;
		}
	}

	public NumericExpression reduceModLimits(NumericExpression expr) {
		NumericExpression[] terms = universe.expand(expr);
		int numVars = limitVars.length;
		boolean change = false;

		for (NumericExpression term : terms) {
			for (int i = 0; i < numVars; i++) {
				NumericSymbolicConstant h = limitVars[i];
				int order = orders[i];
				NumericExpression hton = universe.power(h, order);
				NumericExpression q = universe.divide(term, hton);

				if (isBounded(q)) {
					change = true;
					terms[i] = null;
					break;
				}
			}
		}
		if (change) {
			NumericExpression result = universe.zeroReal();

			for (NumericExpression term : terms) {
				if (term != null)
					result = universe.add(result, term);
			}
			return result;
		} else {
			return expr;
		}
	}

	private SymbolicExpression tryToExpand(SymbolicExpression expression,
			SubstituterState state) {
		if (expression.operator() != SymbolicOperator.APPLY)
			return null;

		SymbolicExpression function = (SymbolicExpression) expression
				.argument(0);

		if (function.operator() != SymbolicOperator.SYMBOLIC_CONSTANT)
			return null;

		SymbolicFunctionType functionType = (SymbolicFunctionType) function
				.type();
		SymbolicTypeSequence inputTypes = functionType.inputTypes();
		int n = inputTypes.numTypes();

		if (!functionType.outputType().isReal())
			return null;
		for (int i = 0; i < n; i++) {
			if (!inputTypes.getType(i).isReal())
				return null;
		}

		BooleanExpression diffClaim = findDifferentiableClaim(function);

		// Arg0 is a function from R^n to R for some positive
		// integer n. Arg1 is the degree, a nonnegative integer
		// {@link IntObject} which tells the number of partial derivatives (of
		// any combination) that exist and are continuous. Arg2 is a sequence of
		// real-valued expressions which are the lower bounds of the intervals
		// in the domain; the length is n. Arg3 is a similar sequence of upper
		// bounds.

		if (diffClaim == null)
			return null;

		NumericExpression[] argArray = toArray(expression.argument(1), n);
		NumericExpression[] lowerBounds = toArray(diffClaim.argument(2), n);
		NumericExpression[] upperBounds = toArray(diffClaim.argument(3), n);

		if (!checkDomain(argArray, lowerBounds, upperBounds))
			return null;

		int maxDegree = ((IntObject) diffClaim.argument(1)).getInt();
		SymbolicExpression result = taylorExpansion(function, maxDegree,
				argArray);

		// TODO: debugging:
		System.out.println("Taylor: expression   : " + expression);
		System.out.println("Taylor: result       : " + result);
		System.out.println();
		System.out.flush();

		return result;
	}

	@Override
	protected SymbolicExpression substituteExpression(
			SymbolicExpression expression, SubstituterState state) {
		SymbolicExpression result = tryToExpand(expression, state);

		if (result == null)
			result = super.substituteExpression(expression, state);
		return result;
	}

}