A* pathfinding algorithm is being semi-greedy



So I have been trying to implement the A* pathfinding algorithm for a 2D tilemap in Java from this video: https://www.youtube.com/watch?v=-L-WgKMFuhE. I tried following the pseudocode and not enough details, I felt, were not explained greatly, especially the idea of the G cost. So, I decided to go into the actual code the video creator wrote and used a lot of his ideas and structure as I had spent a lot of time looking at multiple pseudocodes and not getting the desired result.

I understand the code perfectly and going step by step it makes sense, but for some reason the algorithm sometimes is not producing the shortest path. In fact, it is being greedy and trying to reach the end destination early on. Here are 2 examples (The green dot is the start and the red dot is the end. The gray tiles represent the path and the black tiles represent walls):

greedy path1

greedy path2

I have repeatedly looked at the creator’s code and my own and I just can’t find the problem.

Here is my pathfinding method:

Node start, end; //these are initialized through a GUI

List<Node> open = new ArrayList<Node>();
Set<Node> closed = new HashSet<Node>();
    
boolean pathExists = true;

public void findPath() {
        open.add(start);

        while(!open.isEmpty()) {
            Map<Node, Double> costs = new HashMap<Node, Double>();
            for(Node node: open) {
                costs.put(node, node.cost());
            }
            Node current = Collections.min(costs.entrySet(), Map.Entry.comparingByValue()).getKey();
            
            open.remove(current);
            closed.add(current);
            
            if(current.equals(end)) {
                end = current;
                return;
            }
            
            for(Node adjacent: current.getAdjacentNodes()) {
                if(contains(closed, adjacent) || !inBounds(adjacent.getPoint()))
                    continue;
                
                double newG = current.g + adjacent.getDistanceTo(current);
                if(newG < adjacent.g || !contains(open, adjacent)) {
                    adjacent.g = newG;
                    adjacent.h = adjacent.getDistanceTo(end);
                    adjacent.parent = current;
                    
                    if(!contains(open, adjacent))
                        open.add(adjacent);
                }
            }
        }
        pathExists = false;
    }

And here is my Node class:

private class Node {
        public Point point;
        public Node parent;
        
        public double g, h;

        public Node(Point point) {
            this.point = point;
        }
        
        public Node getParent() {
            return parent;
        }
        
        public double cost() {
            return g + h;
        }
        
        public double getDistanceTo(Node node) {
            int x = Math.abs((point.x - node.point.x)/cellWidth); //cell width and height are the dimensions of each tile
            int y = Math.abs((point.y - node.point.y)/cellHeight);
            
            if(x > y)
                return 14*y + 10*(x - y);
            return 14*x + 10*(y - x);
        }
        
        public List<Node> getAdjacentNodes() {
            return List.of(
                    new Node(new Point(point.x + cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x + cellWidth, point.y)),
                    new Node(new Point(point.x + cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y)),
                    new Node(new Point(point.x - cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x, point.y + cellHeight))
                    );
        }
        
        public Point getPoint() {
            return point;
        }
        
        public boolean equals(Node node) {
            return point.x == node.point.x && point.y == node.point.y;
        }
        
        public String toString() {
            return point.toString();
        }
    }

I also have a GUI set up using Java swing. The code is pretty messy, but what really matters is the pathfinding method and node class together. If you want to try and tweak these 2 parts and verify your results, I will leave the GUI code here. (Note: choose a start tile with the start radio button and clicking on a tile, same with the end tile. Right-click on a tile to make it a wall. Then click the “Find path” button. There is no user error handling.)

import javax.swing.*;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.MouseListener;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.awt.event.MouseEvent;
import java.awt.Point;

public class Test extends JPanel {
    
    private final int width = 600, height = 450;
    private final int cellWidth = 25, cellHeight = 25;
    
    Map<Point, Color> tiles = new HashMap<Point, Color>();
    
    Node start, end;
    
    public Test() {
        JFrame frame = new JFrame();
        frame.setSize(800, 450);
        frame.setLocationRelativeTo(null);
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.setResizable(false);
        
        ButtonGroup group = new ButtonGroup();
        
        JRadioButton startSelect = new JRadioButton("Start");
        JRadioButton endSelect = new JRadioButton("End");
        JRadioButton coord = new JRadioButton("Coord");
        
        startSelect.setSelected(true);
        
        group.add(startSelect);
        group.add(endSelect);
        group.add(coord);
        
        JButton find = new JButton("Find Path");
        find.addActionListener(event -> {
            findPath();
            drawPath();
            repaint();
        });
        
        SpringLayout layout = new SpringLayout();
        this.setLayout(layout);
        layout.putConstraint(SpringLayout.EAST, startSelect, -130, SpringLayout.EAST, this);
        layout.putConstraint(SpringLayout.NORTH, startSelect, 30, SpringLayout.NORTH, this);
        this.add(startSelect);
        layout.putConstraint(SpringLayout.WEST, endSelect, 10, SpringLayout.EAST, startSelect);
        layout.putConstraint(SpringLayout.NORTH, endSelect, 0, SpringLayout.NORTH, startSelect);
        this.add(endSelect);
        layout.putConstraint(SpringLayout.WEST, coord, 10, SpringLayout.EAST, endSelect);
        layout.putConstraint(SpringLayout.NORTH, coord, 0, SpringLayout.NORTH, endSelect);
        this.add(coord);
        layout.putConstraint(SpringLayout.EAST, find, -60, SpringLayout.EAST, this);
        layout.putConstraint(SpringLayout.NORTH, find, 120, SpringLayout.NORTH, this);
        this.add(find);
        
        this.addMouseListener(new MouseListener() {
            @Override
            public void mousePressed(MouseEvent e) {
                double x = e.getX();
                double y = e.getY();
                
                Node tile = new Node(new Point(cellWidth*((int) x/cellWidth), cellHeight*((int) y/cellHeight)));
                
                switch(e.getButton()) {
                case MouseEvent.BUTTON1:
                    if(startSelect.isSelected()) {
                        start = tile;
                        tiles.put(tile.getPoint(), Color.GRAY);
                    }
                    else if(endSelect.isSelected()) {
                        end = tile;
                        tiles.put(tile.getPoint(), Color.GRAY);
                    }
                    else if(coord.isSelected()) {
                        System.out.println(tile.getPoint());
                    }
                    break;
                case MouseEvent.BUTTON3:
                    tiles.put(tile.getPoint(), Color.BLACK);
                    closed.add(tile);
                    break;
                }
                repaint();
            }
            
            @Override
            public void mouseClicked(MouseEvent e) {

            }
            
            @Override
            public void mouseReleased(MouseEvent e) {

            }

            @Override
            public void mouseEntered(MouseEvent e) {
                
            }

            @Override
            public void mouseExited(MouseEvent e) {
                
            }
        });
        
        this.setPreferredSize(frame.getSize());
        frame.add(this);
        frame.pack();
        
        for(int i=0; i<width; i+=cellWidth) {
            for(int j=0; j<height; j+=cellHeight) {
                tiles.put(new Point(i, j), Color.WHITE);
            }
        }
        
        frame.setVisible(true);
    }
    
    public boolean inBounds(Point point) {
        return (point.x >= 0 && point.y >= 0) && (point.x <= width && point.y <= height);
    }
    
    public boolean contains(List<Node> list, Node node) {
        for(Node n: list) {
            if(n.equals(node))
                return true;
        }
        return false;
    }
    
    public boolean contains(Set<Node> list, Node node) {
        for(Node n: list) {
            if(n.equals(node))
                return true;
        }
        return false;
    }
    
    Set<Node> open = new HashSet<Node>();
    Set<Node> closed = new HashSet<Node>();
    
    boolean pathExists = true;
    
    /*
     * 
     * 
     * 
     * PATHFINDING METHOD*/
    public void findPath() {
        open.add(start);

        while(!open.isEmpty()) {
            Map<Node, Double> costs = new HashMap<Node, Double>();
            for(Node node: open) {
                costs.put(node, node.cost());
            }
            Node current = Collections.min(costs.entrySet(), Map.Entry.comparingByValue()).getKey();
            
            open.remove(current);
            closed.add(current);
            
            if(current.equals(end)) {
                end = current;
                return;
            }
            
            for(Node adjacent: current.getAdjacentNodes()) {
                if(contains(closed, adjacent) || !inBounds(adjacent.getPoint()))
                    continue;
                
                double newG = current.g + adjacent.getDistanceTo(current);
                if(newG < adjacent.g || !contains(open, adjacent)) {
                    adjacent.g = newG;
                    adjacent.h = adjacent.getDistanceTo(end);
                    adjacent.parent = current;
                    
                    if(!contains(open, adjacent))
                        open.add(adjacent);
                }
            }
        }
        pathExists = false;
    }
    
    public void drawPath() {
        if(pathExists) {
            Node current = end;
            
            while(!current.equals(start)) {
                tiles.put(current.getParent().getPoint(), Color.GRAY);
                current = current.getParent();
            }
        }
    }
    
    /*
     * 
     * 
     * 
     * 
     * NODE CLASS*/
    private class Node {
        public Point point;
        public Node parent;
        
        public double g, h;

        public Node(Point point) {
            this.point = point;
        }
        
        public Node getParent() {
            return parent;
        }
        
        public double cost() {
            return g + h;
        }
        
        /*public double getH() {
            return 10*Math.sqrt(Math.pow((end.getPoint().x - point.x)/cellWidth, 2) + Math.pow((end.getPoint().y - point.y)/cellHeight, 2));
        }*/
        
        public double getDistanceTo(Node node) {
            int x = Math.abs((point.x - node.point.x)/cellWidth);
            int y = Math.abs((point.y - node.point.y)/cellHeight);
            
            if(x > y)
                return 14*y + 10*(x - y);
            return 14*x + 10*(y - x);       
        }
        
        /*public double getPositionTo(Node node) {
            if((point.y == node.point.y && (point.x > node.point.x || point.x < node.point.x)) ||
                    (point.x == node.point.x && (point.y < node.point.y || point.y > node.point.y)))
                return 10;
            else
                return 14;
        }*/
        
        public List<Node> getAdjacentNodes() {
            return List.of(
                    new Node(new Point(point.x + cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x + cellWidth, point.y)),
                    new Node(new Point(point.x + cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y)),
                    new Node(new Point(point.x - cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x, point.y + cellHeight))
                    );
        }
        
        public Point getPoint() {
            return point;
        }
        
        public boolean equals(Node node) {
            return point.x == node.point.x && point.y == node.point.y;
        }
        
        public String toString() {
            return point.toString();
        }
    }
    
    public void paintComponent(Graphics tool) {
        super.paintComponent(tool);
        
        for(Map.Entry<Point, Color> tile: tiles.entrySet()) {
            tool.setColor(tile.getValue());
            tool.fillRect(tile.getKey().x, tile.getKey().y, cellWidth, cellHeight);
        }
        tool.setColor(Color.BLACK);
        for(int i=0; i<width; i+=cellWidth) {
            tool.drawLine(i, 0, i, height);
        }
        for(int i=0; i<height; i+=cellHeight) {
            tool.drawLine(0, i, width, i);
        }
        
        if(start != null) {
            tool.setColor(Color.GREEN);
            tool.fillOval(start.point.x + 8, start.point.y + 8, 10, 10);
        }
        if(end != null) {
            tool.setColor(Color.RED);
            tool.fillOval(end.point.x + 8, end.point.y + 8, 10, 10);
        }
    }
    
    public static void main(String[] args) {
        new Test();
    }
}

EDIT I wrote my own contains method for the open and closed sets, just to make sure I wasn’t messing up something there:

public boolean contains(Set<Node> list, Node node) {
        for(Node n: list) {
            if(n.equals(node))
                return true;
        }
        return false;
    }

Answer

Thanks for @thatotherguy for the solution

If it helps anyone in the future, the problem was that whenever I was getting the adjacent nodes for the current node, I was creating brand new nodes all with g values of 0.

To correct this, I first checked if the adjacent nodes existed in the open list. If so, get that node instead of creating a new node. If the node does not exist in the open set, create a new node but with a correct g value.

public List<Node> getAdjacentNodes() {
            List<Node> nodes = new LinkedList<Node>();
            nodes.add(new Node(new Point(point.x + cellWidth, point.y + cellHeight)));
            nodes.add(new Node(new Point(point.x + cellWidth, point.y)));
            nodes.add(new Node(new Point(point.x + cellWidth, point.y - cellHeight)));
            nodes.add(new Node(new Point(point.x, point.y - cellHeight)));
            nodes.add(new Node(new Point(point.x - cellWidth, point.y - cellHeight)));
            nodes.add(new Node(new Point(point.x - cellWidth, point.y)));
            nodes.add(new Node(new Point(point.x - cellWidth, point.y + cellHeight)));
            nodes.add(new Node(new Point(point.x, point.y + cellHeight)));
            
            List<Node> correctNodes = new ArrayList<Node>();
            List<Node> remove = new ArrayList<Node>();
            
            for(Node node: nodes) {
                for(Node openNode: open) {
                    if(node.equals(openNode)) {
                        correctNodes.add(openNode);
                        remove.add(node);
                        break;
                    }
                }
            }
            
            nodes.removeAll(remove);
            for(Node node: nodes) {
                node.g = this.g + getDistanceTo(node);
                correctNodes.add(node);
            }
            return correctNodes;
        }


Source: stackoverflow