CommonHerbrandFactory.java

/*******************************************************************************
 * Copyright (c) 2013 Stephen F. Siegel, University of Delaware.
 * 
 * This file is part of SARL.
 * 
 * SARL is free software: you can redistribute it and/or modify it under the
 * terms of the GNU Lesser General Public License as published by the Free
 * Software Foundation, either version 3 of the License, or (at your option) any
 * later version.
 * 
 * SARL is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
 * details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with SARL. If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.udel.cis.vsl.sarl.herbrand.common;

import java.util.Comparator;

import edu.udel.cis.vsl.sarl.IF.expr.BooleanExpression;
import edu.udel.cis.vsl.sarl.IF.expr.NumericExpression;
import edu.udel.cis.vsl.sarl.IF.expr.NumericSymbolicConstant;
import edu.udel.cis.vsl.sarl.IF.expr.SymbolicConstant;
import edu.udel.cis.vsl.sarl.IF.expr.SymbolicExpression.SymbolicOperator;
import edu.udel.cis.vsl.sarl.IF.number.Number;
import edu.udel.cis.vsl.sarl.IF.number.NumberFactory;
import edu.udel.cis.vsl.sarl.IF.object.NumberObject;
import edu.udel.cis.vsl.sarl.IF.object.StringObject;
import edu.udel.cis.vsl.sarl.IF.object.SymbolicObject;
import edu.udel.cis.vsl.sarl.IF.object.SymbolicSequence;
import edu.udel.cis.vsl.sarl.IF.type.SymbolicFunctionType;
import edu.udel.cis.vsl.sarl.IF.type.SymbolicType;
import edu.udel.cis.vsl.sarl.expr.IF.BooleanExpressionFactory;
import edu.udel.cis.vsl.sarl.expr.IF.NumericExpressionFactory;
import edu.udel.cis.vsl.sarl.expr.common.CommonSymbolicConstant;
import edu.udel.cis.vsl.sarl.object.IF.ObjectFactory;
import edu.udel.cis.vsl.sarl.type.IF.SymbolicTypeFactory;

public class CommonHerbrandFactory implements NumericExpressionFactory {

	private NumberFactory numberFactory;

	private BooleanExpressionFactory booleanFactory;

	private ObjectFactory objectFactory;

	private SymbolicTypeFactory typeFactory;

	private Comparator<NumericExpression> comparator;

	private SymbolicType herbrandIntegerType, herbrandRealType, booleanType;

	private HerbrandExpression zeroInt, zeroReal, oneInt, oneReal;

	private SymbolicFunctionType realBinaryOp, realUnaryOp, integerBinaryOp,
			integerUnaryOp, realBinaryPred, integerBinaryPred;

	private SymbolicConstant plusReal, plusInteger, minusReal, minusInteger,
			timesReal, timesInteger, divideReal, divideInteger, negativeReal,
			negativeInteger, modulo, powerInteger, powerReal, lessThanInteger,
			lessThanReal, lteInteger, lteReal;

	public CommonHerbrandFactory(NumberFactory numberFactory,
			ObjectFactory objectFactory, SymbolicTypeFactory typeFactory,
			BooleanExpressionFactory booleanFactory) {
		this.numberFactory = numberFactory;
		this.objectFactory = objectFactory;
		this.typeFactory = typeFactory;
		this.booleanFactory = booleanFactory;
		this.comparator = new HerbrandComparator(objectFactory.comparator(),
				typeFactory.typeComparator());
	}

	@Override
	public void init() {
		this.herbrandIntegerType = typeFactory.herbrandIntegerType();
		this.herbrandRealType = typeFactory.herbrandRealType();
		this.booleanType = typeFactory.booleanType();
		this.oneInt = number(
				objectFactory.numberObject(numberFactory.oneInteger()));
		this.oneReal = number(
				objectFactory.numberObject(numberFactory.oneRational()));
		this.zeroInt = number(
				objectFactory.numberObject(numberFactory.zeroInteger()));
		this.zeroReal = number(
				objectFactory.numberObject(numberFactory.zeroRational()));
	}

	private SymbolicConstant commonSymbolicConstant(StringObject name,
			SymbolicType type) {
		return objectFactory.canonic(new CommonSymbolicConstant(name, type));
	}

	private SymbolicSequence<NumericExpression> sequence(NumericExpression e0,
			NumericExpression e1) {
		return objectFactory.sequence(new NumericExpression[] { e0, e1 });
	}

	private SymbolicFunctionType realBinaryOp() {
		if (realBinaryOp == null)
			realBinaryOp = typeFactory
					.functionType(
							typeFactory.sequence(new SymbolicType[] {
									herbrandRealType, herbrandRealType }),
							herbrandRealType);
		return realBinaryOp;
	}

	private SymbolicFunctionType integerBinaryOp() {
		if (integerBinaryOp == null)
			integerBinaryOp = typeFactory.functionType(
					typeFactory.sequence(new SymbolicType[] {
							herbrandIntegerType, herbrandIntegerType }),
					herbrandIntegerType);
		return integerBinaryOp;
	}

	private SymbolicFunctionType realUnaryOp() {
		if (realUnaryOp == null)
			realUnaryOp = typeFactory.functionType(
					typeFactory
							.sequence(new SymbolicType[] { herbrandRealType }),
					herbrandRealType);
		return realUnaryOp;
	}

	private SymbolicFunctionType integerUnaryOp() {
		if (integerUnaryOp == null)
			integerUnaryOp = typeFactory.functionType(
					typeFactory.sequence(
							new SymbolicType[] { herbrandIntegerType }),
					herbrandIntegerType);
		return integerUnaryOp;
	}

	private SymbolicFunctionType realBinaryPred() {
		if (realBinaryPred == null)
			realBinaryPred = typeFactory
					.functionType(
							typeFactory.sequence(new SymbolicType[] {
									herbrandRealType, herbrandRealType }),
							booleanType);
		return realBinaryPred;
	}

	private SymbolicFunctionType integerBinaryPred() {
		if (integerBinaryPred == null)
			integerBinaryPred = typeFactory.functionType(
					typeFactory.sequence(new SymbolicType[] {
							herbrandIntegerType, herbrandIntegerType }),
					booleanType);
		return integerBinaryPred;
	}

	private SymbolicConstant plusReal() {
		if (plusReal == null)
			plusReal = commonSymbolicConstant(
					objectFactory.stringObject("PLUS_REAL"), realBinaryOp());
		return plusReal;
	}

	private SymbolicConstant plusInteger() {
		if (plusInteger == null)
			plusInteger = commonSymbolicConstant(
					objectFactory.stringObject("PLUS_INT"), integerBinaryOp());
		return plusInteger;
	}

	private SymbolicConstant plusOperator(SymbolicType type) {
		return type.isInteger() ? plusInteger() : plusReal();
	}

	private SymbolicConstant minusReal() {
		if (minusReal == null)
			minusReal = commonSymbolicConstant(
					objectFactory.stringObject("MINUS_REAL"), realBinaryOp());
		return minusReal;
	}

	private SymbolicConstant minusInteger() {
		if (minusInteger == null)
			minusInteger = commonSymbolicConstant(
					objectFactory.stringObject("MINUS_INT"), integerBinaryOp());
		return minusInteger;
	}

	private SymbolicConstant minusOperator(SymbolicType type) {
		return type.isInteger() ? minusInteger() : minusReal();
	}

	private SymbolicConstant timesReal() {
		if (timesReal == null)
			timesReal = commonSymbolicConstant(
					objectFactory.stringObject("TIMES_REAL"), realBinaryOp());
		return timesReal;
	}

	private SymbolicConstant timesInteger() {
		if (timesInteger == null)
			timesInteger = commonSymbolicConstant(
					objectFactory.stringObject("TIMES_INT"), integerBinaryOp());
		return timesInteger;
	}

	private SymbolicConstant timesOperator(SymbolicType type) {
		return type.isInteger() ? timesInteger() : timesReal();
	}

	private SymbolicConstant divideReal() {
		if (divideReal == null)
			divideReal = commonSymbolicConstant(
					objectFactory.stringObject("DIVIDE_REAL"), realBinaryOp());
		return divideReal;
	}

	private SymbolicConstant divideInteger() {
		if (divideInteger == null)
			divideInteger = commonSymbolicConstant(
					objectFactory.stringObject("DIVIDE_INT"),
					integerBinaryOp());
		return divideInteger;
	}

	private SymbolicConstant divideOperator(SymbolicType type) {
		return type.isInteger() ? divideInteger() : divideReal();
	}

	private SymbolicConstant moduloOperator() {
		if (modulo == null)
			modulo = commonSymbolicConstant(
					objectFactory.stringObject("MODULO"), integerBinaryOp());
		return modulo;
	}

	private SymbolicConstant negativeInteger() {
		if (negativeInteger == null)
			negativeInteger = commonSymbolicConstant(
					objectFactory.stringObject("NEGATIVE_INT"),
					integerUnaryOp());
		return negativeInteger;
	}

	private SymbolicConstant negativeReal() {
		if (negativeReal == null)
			negativeReal = commonSymbolicConstant(
					objectFactory.stringObject("NEGATIVE_REAL"), realUnaryOp());
		return negativeReal;
	}

	private SymbolicConstant negativeOperator(SymbolicType type) {
		return type.isInteger() ? negativeInteger() : negativeReal();
	}

	private SymbolicConstant powerReal() {
		if (powerReal == null)
			powerReal = commonSymbolicConstant(
					objectFactory.stringObject("POWER_REAL"), realBinaryOp());
		return powerReal;
	}

	private SymbolicConstant powerInteger() {
		if (powerInteger == null)
			powerInteger = commonSymbolicConstant(
					objectFactory.stringObject("POWER_INT"), integerBinaryOp());
		return powerInteger;
	}

	private SymbolicConstant lessThanInteger() {
		if (lessThanInteger == null)
			lessThanInteger = commonSymbolicConstant(
					objectFactory.stringObject("LT_INT"), integerBinaryPred());
		return lessThanInteger;
	}

	private SymbolicConstant lessThanReal() {
		if (lessThanReal == null)
			lessThanReal = commonSymbolicConstant(
					objectFactory.stringObject("LT_REAL"), realBinaryPred());
		return lessThanReal;
	}

	private SymbolicConstant lessThanOperator(SymbolicType type) {
		return type.isInteger() ? lessThanInteger() : lessThanReal();
	}

	private SymbolicConstant lteInteger() {
		if (lteInteger == null)
			lteInteger = commonSymbolicConstant(
					objectFactory.stringObject("LTE_INT"), integerBinaryPred());
		return lteInteger;
	}

	private SymbolicConstant lteReal() {
		if (lteReal == null)
			lteReal = commonSymbolicConstant(
					objectFactory.stringObject("LTE_REAL"), realBinaryPred());
		return lteReal;
	}

	private SymbolicConstant lteOperator(SymbolicType type) {
		return type.isInteger() ? lteInteger() : lteReal();
	}

	@Override
	public BooleanExpressionFactory booleanFactory() {
		return booleanFactory;
	}

	@Override
	public NumberFactory numberFactory() {
		return numberFactory;
	}

	@Override
	public ObjectFactory objectFactory() {
		return objectFactory;
	}

	@Override
	public SymbolicTypeFactory typeFactory() {
		return typeFactory;
	}

	@Override
	public Comparator<NumericExpression> comparator() {
		return comparator;
	}

	@Override
	public HerbrandExpression number(NumberObject numberObject) {
		return expression(SymbolicOperator.CONCRETE,
				numberObject.isReal() ? herbrandRealType : herbrandIntegerType,
				new SymbolicObject[] { numberObject });
	}

	@Override
	public NumericSymbolicConstant symbolicConstant(StringObject name,
			SymbolicType type) {
		assert type.isNumeric();
		return objectFactory.canonic(new HerbrandSymbolicConstant(name, type));
	}

	@Override
	public HerbrandExpression expression(SymbolicOperator operator,
			SymbolicType numericType, SymbolicObject... arguments) {
		return objectFactory.canonic(
				new HerbrandExpression(operator, numericType, arguments));
	}

	@Override
	public HerbrandExpression zeroInt() {
		return zeroInt;
	}

	@Override
	public HerbrandExpression zeroReal() {
		return zeroReal;
	}

	@Override
	public HerbrandExpression oneInt() {
		return oneInt;
	}

	@Override
	public HerbrandExpression oneReal() {
		return oneReal;
	}

	@Override
	public NumericExpression add(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return expression(SymbolicOperator.APPLY, t0, plusOperator(t0),
				sequence(arg0, arg1));
	}

	@Override
	public NumericExpression subtract(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return expression(SymbolicOperator.APPLY, t0, minusOperator(t0),
				sequence(arg0, arg1));
	}

	@Override
	public NumericExpression multiply(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return expression(SymbolicOperator.APPLY, t0, timesOperator(t0),
				sequence(arg0, arg1));
	}

	@Override
	public NumericExpression divide(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return expression(SymbolicOperator.APPLY, t0, divideOperator(t0),
				sequence(arg0, arg1));
	}

	@Override
	public NumericExpression modulo(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return expression(SymbolicOperator.APPLY, t0, moduloOperator(),
				sequence(arg0, arg1));
	}

	@Override
	public NumericExpression minus(NumericExpression arg) {
		SymbolicType type = arg.type();

		return expression(SymbolicOperator.APPLY, type, negativeOperator(type),
				objectFactory.singletonSequence(arg));
	}

	@Override
	public NumericExpression power(NumericExpression base,
			NumberObject exponent) {
		SymbolicType type = base.type();

		if (type.isInteger())
			return expression(SymbolicOperator.APPLY, type, powerInteger(),
					sequence(base, number(exponent)));
		else
			return expression(SymbolicOperator.APPLY, type, powerReal(),
					sequence(base, number(objectFactory.numberObject(
							numberFactory.rational(exponent.getNumber())))));
	}

	@Override
	public NumericExpression power(NumericExpression base,
			NumericExpression exponent) {
		SymbolicType t1 = base.type(), t2 = exponent.type();

		if (t1.isInteger() && t2.isInteger()) {
			return expression(SymbolicOperator.APPLY, herbrandIntegerType,
					powerInteger(), sequence(base, exponent));
		} else {
			if (t1.isInteger())
				base = cast(base, herbrandRealType);
			if (t2.isInteger())
				exponent = cast(exponent, herbrandRealType);
			return expression(SymbolicOperator.APPLY, herbrandRealType,
					powerReal(), sequence(base, exponent));
		}
	}

	@Override
	public BooleanExpression lessThan(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return booleanFactory.booleanExpression(SymbolicOperator.APPLY,
				lessThanOperator(t0), sequence(arg0, arg1));
	}

	@Override
	public BooleanExpression lessThanEquals(NumericExpression arg0,
			NumericExpression arg1) {
		SymbolicType t0 = arg0.type();

		return booleanFactory.booleanExpression(SymbolicOperator.APPLY,
				lteOperator(t0), sequence(arg0, arg1));
	}

	@Override
	public BooleanExpression notLessThan(NumericExpression arg0,
			NumericExpression arg1) {
		return booleanFactory.booleanExpression(SymbolicOperator.NOT,
				lessThan(arg0, arg1));
	}

	@Override
	public BooleanExpression notLessThanEquals(NumericExpression arg0,
			NumericExpression arg1) {
		return booleanFactory.booleanExpression(SymbolicOperator.NOT,
				lessThanEquals(arg0, arg1));
	}

	@Override
	public NumericExpression cast(NumericExpression numericExpression,
			SymbolicType newType) {
		SymbolicType oldType = numericExpression.type();

		if (newType.equals(oldType))
			return numericExpression;
		if (newType.isHerbrand()
				&& numericExpression.operator() == SymbolicOperator.CAST
				&& oldType.isIdeal()) {
			// if numericExpression is a cast from herbrand to ideal,
			// and the new type is herbrand, the two casts cancel...
			NumericExpression originalExpression = (NumericExpression) numericExpression
					.argument(0);
			SymbolicType originalType = originalExpression.type();

			if (originalType.equals(newType)
					&& oldType.isInteger() == newType.isInteger())
				return originalExpression;
		}
		return expression(SymbolicOperator.CAST, newType, numericExpression);
	}

	@Override
	public Number extractNumber(NumericExpression expression) {
		if (expression.operator() == SymbolicOperator.CONCRETE) {
			SymbolicObject arg = expression.argument(0);

			if (arg instanceof NumberObject) {
				return ((NumberObject) arg).getNumber();
			}
		}
		return null;
	}

	@Override
	public BooleanExpression equals(NumericExpression arg0,
			NumericExpression arg1) {
		return booleanFactory.booleanExpression(SymbolicOperator.EQUALS, arg0,
				arg1);
	}

	@Override
	public BooleanExpression neq(NumericExpression arg0,
			NumericExpression arg1) {
		return booleanFactory.booleanExpression(SymbolicOperator.NEQ, arg0,
				arg1);
	}

	@Override
	public NumericExpression[] expand(NumericExpression expr) {
		return new NumericExpression[] { expr };
	}

	@Override
	public NumericExpression floor(NumericExpression expr) {
		// TODO Auto-generated method stub
		// create uninterpreted functions for floor, ceil, roundToZero?
		return null;
	}

	@Override
	public NumericExpression ceil(NumericExpression expr) {
		// TODO Auto-generated method stub
		return null;
	}

	@Override
	public NumericExpression roundToZero(NumericExpression expr) {
		// TODO Auto-generated method stub
		return null;
	}
}