Function.java

package edu.udel.cis.vsl.tass.model.impl;

import java.io.PrintWriter;
import java.util.Collection;
import java.util.HashSet;
import java.util.Vector;

import edu.udel.cis.vsl.tass.model.IF.FunctionIF;
import edu.udel.cis.vsl.tass.model.IF.ModelIF;
import edu.udel.cis.vsl.tass.model.IF.ProcessIF;
import edu.udel.cis.vsl.tass.model.IF.SyntaxException;
import edu.udel.cis.vsl.tass.model.IF.location.LocationIF;
import edu.udel.cis.vsl.tass.model.IF.location.TerminalLocationIF;
import edu.udel.cis.vsl.tass.model.IF.scope.LocalScopeIF;
import edu.udel.cis.vsl.tass.model.IF.scope.ModelScopeIF;
import edu.udel.cis.vsl.tass.model.IF.scope.ProcessScopeIF;
import edu.udel.cis.vsl.tass.model.IF.scope.ScopeIF;
import edu.udel.cis.vsl.tass.model.IF.statement.StatementIF;
import edu.udel.cis.vsl.tass.model.IF.type.TypeIF;
import edu.udel.cis.vsl.tass.model.IF.variable.FormalVariableIF;
import edu.udel.cis.vsl.tass.model.IF.variable.LocalVariableIF;
import edu.udel.cis.vsl.tass.model.IF.variable.VariableIF;
import edu.udel.cis.vsl.tass.model.impl.location.Location;
import edu.udel.cis.vsl.tass.model.impl.scope.LocalScope;
import edu.udel.cis.vsl.tass.util.Source;

/**
 * Represents a function (or "procedure") in the MiniMP intermediate language.
 * Each process in a MiniMP model is composed of functions. A function comprises
 * a sequence of typed formal parameters, a return type, a set of local
 * variables, a set of locations, and a start location.
 * <p>
 * 
 * A function can reference its own formal variables, its own local variables,
 * and also also process global variables.
 */
public class Function implements FunctionIF {

	/**
	 * The process containing this function, or null if this function is not
	 * contained in a process
	 */
	protected ProcessIF process;

	/**
	 * The model containing this function, or null if this function is not
	 * contained in a model
	 */
	protected ModelIF model;

	/** The scope in which this function is defined. Must be non-null. */
	protected ScopeIF containingScope;

	/**
	 * The scope containing the formal parameters and outermost local variables
	 * for this function.
	 */
	protected LocalScope outermostScope;

	/**
	 * The set of all local scopes transitively contained in this function. This
	 * includes the outermost scope, its children which are local scopes, their
	 * children, etc.
	 */
	protected Vector<LocalScopeIF> localScopes = new Vector<LocalScopeIF>();

	/** Name of function */
	protected String name;

	/** Does this function take a variable number of arguments? */
	private boolean hasVariableArguments = false;

	/**
	 * Id number for this function. Functions within a given scope are numbered
	 * 0,1,etc. The id number is assigned when complete() is called.
	 */
	protected int idInScope = -1;

	/** The number of formal parameters for this function */
	protected int numFormals;

	/**
	 * Return type of function. May be "VoidType". In any case, will not be null
	 */
	private TypeIF returnType;

	/** The start location, where control beings when this function is invoked. */
	private LocationIF startLocation = null;

	/**
	 * A list of all locations in this function. These are the nodes in the
	 * function's location-transition graph.
	 */
	private Vector<LocationIF> locations = new Vector<LocationIF>();

	/**
	 * A list of all statements in this function. These are the edges in the
	 * function's location-transition graph.
	 */
	private Vector<StatementIF> statements = new Vector<StatementIF>();

	/**
	 * The source code for the definition of this function, but not including
	 * the function body, only the header/type signature.
	 */
	private Source sourceCode = null;

	/** Has complete() been called and nothing been changed since? */
	boolean complete = false;

	/**
	 * Creates new function object, creating new outermost scope for the
	 * function.
	 */
	public Function(ScopeIF containingScope, String name, TypeIF returnType,
			int numFormals) {
		if (containingScope == null)
			throw new NullPointerException("null containingScope");
		if (name == null)
			throw new NullPointerException("null function name");
		if (returnType == null)
			throw new NullPointerException(
					"null return type: did you mean void type?");
		if (numFormals < 0)
			throw new IllegalArgumentException(
					"numFormals must be nonnegative, not " + numFormals);
		this.containingScope = containingScope;
		this.name = name;
		this.returnType = returnType;
		this.numFormals = numFormals;
		switch (containingScope.kind()) {
		case BOUND:
			process = null;
			model = null;
		case LOCAL:
			process = ((LocalScopeIF) containingScope).function().process();
			model = (process == null ? null : process.model());
			break;
		case PROCESS:
			process = ((ProcessScopeIF) containingScope).process();
			model = process.model();
			break;
		case MODEL:
			process = null;
			model = ((ModelScopeIF) containingScope).model();
			break;
		case SYSTEM:
		default:
			process = null;
			model = null;
		}
		outermostScope = new LocalScope(containingScope, this);
		localScopes.add(outermostScope);
	}

	/**
	 * Sets the id of this function. This is the id number which is unique among
	 * all functions defined within the containing scope.
	 */
	public void setIdInScope(int id) {
		this.idInScope = id;
	}

	/**
	 * Creates a new local scope within this function and adds it to this
	 * function's set of scopes.
	 */
	public LocalScope newScope(LocalScopeIF parentScope) {
		LocalScope newScope = new LocalScope(parentScope);

		localScopes.add(newScope);
		return newScope;
	}

	@Override
	public String name() {
		return name;
	}

	@Override
	public TypeIF returnType() {
		return returnType;
	}

	/**
	 * Sets the formal of index variable.idInScope() for this function to the
	 * given variable.
	 */
	public void setFormal(FormalVariableIF variable) throws SyntaxException {
		outermostScope.setFormal(variable);
		complete = false;
	}

	@Override
	public int numFormals() {
		return numFormals;
	}

	@Override
	public FormalVariableIF formal(int index) {
		assert index >= 0 && index < numFormals;
		return (FormalVariableIF) outermostScope.variableWithId(index);
	}

	public void addLocation(LocationIF location) {
		LocalScope scope = (LocalScope) location.scope();

		assert scope != null;
		assert scope.function() == this;
		complete = false;
		if (!this.equals(location.function()))
			throw new IllegalArgumentException(
					"attempt to add location from another function to this function: "
							+ location);
		if (location.id() != locations.size())
			throw new IllegalArgumentException("Location id mismatch: "
					+ location.id() + " != " + locations.size());
		locations.add(location);
		scope.addLocation(location);
	}

	@Override
	public ProcessIF process() {
		return process;
	}

	public LocationIF startLocation() {
		return startLocation;
	}

	public void setStartLocation(LocationIF startLocation) {
		complete = false;
		this.startLocation = startLocation;
	}

	public Source getSource() {
		return sourceCode;
	}

	public void setSource(Source sourceCode) {
		complete = false;
		this.sourceCode = sourceCode;
	}

	public Collection<LocationIF> locations() {
		return locations;
	}

	public void addStatement(StatementIF statement) {
		complete = false;
		statements.add(statement);
	}

	public Collection<StatementIF> statements() {
		return statements;
	}

	public Collection<LocationIF> locationsWithLabel(String label) {
		Collection<LocationIF> result = new Vector<LocationIF>();

		if (label == null)
			throw new NullPointerException("null label");
		for (LocationIF location : locations) {
			if (label.equals(location.label()))
				result.add(location);
		}
		return result;
	}

	private void dfs(LocationIF location,
			Collection<LocationIF> reachedLocations) throws SyntaxException {
		if (reachedLocations.contains(location))
			return;
		reachedLocations.add(location);
		for (StatementIF statement : location.statements()) {
			LocationIF nextLocation = statement.targetLocation();

			if (nextLocation == null)
				throw new SyntaxException(statement,
						"Statement does not have target location");
			dfs(nextLocation, reachedLocations);
		}
	}

	private void backwards_dfs(LocationIF location,
			Collection<LocationIF> reachedLocations) throws SyntaxException {
		if (reachedLocations.contains(location))
			return;
		reachedLocations.add(location);
		for (StatementIF statement : location.incomingStatements()) {
			LocationIF nextLocation = statement.sourceLocation();

			if (nextLocation == null)
				throw new SyntaxException(statement,
						"Statement does not have source location");
			backwards_dfs(nextLocation, reachedLocations);
		}
	}

	protected void completeLocations() throws SyntaxException {
		Collection<LocationIF> reachedLocations;

		if (startLocation == null)
			throw new SyntaxException(this, "start location has not been set");
		/* check every location is reachable from the start location */
		reachedLocations = new HashSet<LocationIF>();
		dfs(startLocation, reachedLocations);
		for (LocationIF location : locations) {
			if (!reachedLocations.contains(location))
				throw new SyntaxException(location, "Unreachable location");
		}
		/*
		 * now do backwards reachability from all terminal locations to make
		 * sure every location can get to a return statement
		 */
		reachedLocations = new HashSet<LocationIF>();
		for (LocationIF location : locations)
			if (location instanceof TerminalLocationIF)
				backwards_dfs(location, reachedLocations);
		for (LocationIF location : locations) {
			if (!reachedLocations.contains(location))
				throw new SyntaxException(location,
						"Cannot reach return statement");
		}
		/* Now complete all locations */
		for (LocationIF location : locations) {
			((Location) location).complete();
		}
	}

	public void complete() throws SyntaxException {
		outermostScope.complete();
		completeLocations();
		complete = true;
	}

	public ModelIF model() {
		return model;
	}

	private String variableDeclaration(VariableIF variable) {
		String result;

		if (variable instanceof FormalVariableIF) {
			result = "formal " + variable.idInScope() + ": ";
		} else {
			result = "";
		}
		result += variable + " : " + variable.decl();
		if (variable.initializationExpression() != null) {
			result += " = " + variable.initializationExpression();
		}
		result += ";";
		return result;
	}

	public void print(String prefix, PrintWriter out) {
		print(prefix, out, false);
	}

	/**
	 * A hook for subclasses to add additional information to the print method
	 * specific to the subclass. They do so by overriding this method
	 */
	protected void printAdditionalData(String prefix, PrintWriter out,
			boolean withSource) {

	}

	public void print(String prefix, PrintWriter out, boolean withSource) {
		int numScopes = numScopes();

		out.println(prefix + "begin function " + name);
		if (withSource) {
			out.println(prefix + "| " + this.sourceCode);
		}
		if (numFormals() > 0) {
			out.println(prefix + "| begin formal parameters");
			for (int i = 0; i < numFormals(); i++) {
				out.println(prefix + "| | " + formal(i) + " : "
						+ formal(i).type() + ";");
				if (withSource) {
					out.println(prefix + "| | | " + formal(i).getSource());
				}
			}
			out.println(prefix + "| end formal parameters;");
		}
		printAdditionalData(prefix + "  ", out, withSource);
		if (startLocation != null)
			out.println(prefix + "| start location : " + startLocation().id()
					+ ";");
		for (int i = 0; i < numScopes; i++) {
			LocalScopeIF scope = scope(i);

			out.println(prefix + "| begin scope " + i);
			if (i != 0) {
				ScopeIF parent = scope.parent();
				int parentId = ((LocalScopeIF) parent).localId();

				out.println(prefix + "| | parent scope: " + parentId);
			}
			out.println(prefix + "| | begin local variables");
			for (VariableIF variable : scope.variables()) {
				out.println(prefix + "| | | " + variableDeclaration(variable));
				if (withSource) {
					out.println(prefix + "| | | | " + variable.getSource());
				}
			}
			out.println(prefix + "| | end local variables");
			out.println(prefix + "| | begin locations");
			for (LocationIF location : scope.locations()) {
				((Location) location).print(prefix + "| | | ", out, withSource);
			}
			out.println(prefix + "| | end locations");
			out.println(prefix + "| end scope " + i);
		}
		out.println(prefix + "end function " + name + ";");
		out.flush();
	}

	public String toString() {
		return name;
	}

	@Override
	public int idInScope() {
		return idInScope;
	}

	@Override
	public FormalVariableIF getFormal(String name) {
		LocalVariableIF result = outermostScope.variableWithName(name);

		if (result instanceof FormalVariableIF)
			return (FormalVariableIF) result;
		return null;
	}

	@Override
	public int numScopes() {
		return localScopes.size();
	}

	@Override
	public LocalScopeIF outermostScope() {
		return outermostScope;
	}

	@Override
	public LocalScopeIF scope(int localScopeId) {
		return localScopes.get(localScopeId);
	}

	@Override
	public boolean hasVariableArguments() {
		return hasVariableArguments;
	}

	@Override
	public void setVariableArguments(boolean value) {
		this.hasVariableArguments = value;
	}

}