
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.TreeSet;

/**
 *
 * @author davereed
 */
public class FiniteStateMachine<StateLabel, EdgeLabel> {
    private HashMap<StateLabel, HashMap<EdgeLabel, StateLabel>> fsm;

    /**
     * Creates an empty finite state machine (with no states or edges).
     */
    public FiniteStateMachine() {
        this.fsm = new HashMap<StateLabel, HashMap<EdgeLabel, StateLabel>>();
    }

    /**
     * Adds an edge to the finite state machine.
     *   @param start the label of the starting state
     *   @param edge the label of the edge
     *   @param end the label of the ending state
     */
    public void addEdge(StateLabel start, EdgeLabel edge, StateLabel end) {
        if (!this.fsm.containsKey(start)) {
            this.fsm.put(start, new HashMap<EdgeLabel, StateLabel>());
        }
        this.fsm.get(start).put(edge, end);
    }

    /**
     * Returns a String representation of the finite state machine.
     *   @return the String representation
     */
    public String toString() {
        return this.fsm.toString();
    }
    
    /**
     * Gets the adjacent state given the start state and edge labels. 
     *   @param startState the label of the start state
     *   @param edge the label of the edge
     *   @return the label of the ending state for that edge (or null if it doesn't exist)
     */
    public StateLabel getAdjacentState(StateLabel startState, EdgeLabel edge) {
        if (this.fsm.get(startState) == null || this.fsm.get(startState).get(edge) == null) {
            return null;
        }
        return this.fsm.get(startState).get(edge);
    }

    /**
     * Gets a list of all edge labels that connect two states.
     *   @param startState the label of the start state
     *   @param endState the label of the ending state
     *   @return a list of all edge labels that connect startState and endState
     */
    public List<EdgeLabel> getIncidentEdges(StateLabel startState, StateLabel endState) {
        List<EdgeLabel> edges = new ArrayList<EdgeLabel>();
        if (this.fsm.containsKey(startState)) {
            for (EdgeLabel edge : this.fsm.get(startState).keySet()) {
                if (this.fsm.get(startState).get(edge).equals(endState)) {
                    edges.add(edge);
                }
            }
        }
        return edges;
    }
  
    /**
     * Returns a Set of all states labels that are adjacent to a given state.
     * @param startState the label of the start state
     * @return a Set of all state labels adjacent to startState
     */
    public Set<StateLabel> getAllAdjacentStates(StateLabel startState) {
        Set<StateLabel> states = new TreeSet<StateLabel>();
        if (this.fsm.containsKey(startState)) {
            for (EdgeLabel edge: this.fsm.get(startState).keySet()) {
                states.add(this.fsm.get(startState).get(edge));    
            }
        }
        return states;
    }
        
    /**
     * Finds the ending state from a start state, given a sequence of edges.
     *   @param startState the label of the start state
     *   @param edgeSeq a list of edges to be followed from the start state
     *   @return the ending state after following the edge sequence
     */
    public StateLabel findEndState(StateLabel startState, List<EdgeLabel> edgeSeq) {
        StateLabel current = startState;
        for (int i = 0; i < edgeSeq.size(); i++) {
            current = this.getAdjacentState(current, edgeSeq.get(i));
            if (current == null) {
                break;
            }
        }
        return current;
    }


    ///////////////////////////////////////////////////////////////////////////


    
    public List<StateLabel> findPath(StateLabel startState, StateLabel endState) {
        List<StateLabel> startPath = new ArrayList<StateLabel>();
        startPath.add(startState);
        
        Queue<List<StateLabel>> paths = new LinkedList<List<StateLabel>>();
        paths.add(startPath);
        
        while (!paths.isEmpty()) {
            List<StateLabel> shortestPath = paths.remove();
            StateLabel current = shortestPath.get(shortestPath.size()-1);
            if (current.equals(endState)) {
                return shortestPath;
            }
            else {
                for (StateLabel s : this.getAllAdjacentStates(current)) {
                    if (!shortestPath.contains(s)) {
                        List<StateLabel> copy = new ArrayList<StateLabel>(shortestPath);
                        copy.add(s);
                        paths.add(copy);
                    }
                }
            }
        }
        return null;
    }
    
    public List<EdgeLabel> findPathEdges(List<StateLabel> stateSeq) {
        List<EdgeLabel> edges = new ArrayList<EdgeLabel>();
        for (int i = 1; i <stateSeq.size(); i++) {
            List<EdgeLabel> possEdges = this.getIncidentEdges(stateSeq.get(i-1), stateSeq.get(i));
            if (possEdges.isEmpty()) {
                return null;
            }
            else {
                edges.add(possEdges.get(0));
            }
        }
        return edges;
    }
}
