Long Refactoring: Completing the Iterator

Now that the algorithm does not need a new test_board every time, we no longer need to treat the Node object as a Flyweight. We can move the board into the Node object and remove it from the parameter list of all the functions that operate on it.

Changing the Node object parameter list requires changing all the places it is called. The unit test helps here by identifying the places I break when I modify the function declaration.

______________________________________________________ ERROR collecting treesudoku/test_tree_soduku.py _______________________________________________________
treesudoku/test_tree_soduku.py:1: in <module>
    from treesudoku import tree_sudoku
treesudoku/tree_sudoku.py:245: in <module>
    solver = SudokuSolver(import_csv())
treesudoku/tree_sudoku.py:34: in __init__
    return_string = self.tree_to_solution_string(value)
treesudoku/tree_sudoku.py:38: in tree_to_solution_string
    head_node = Tree_Node(None, 0)
E   TypeError: __init__() missing 1 required positional argument: 'table'

Here are the places I had to change. Note that the new parameter is not yet used. We just know, logically, that the parameter passed in to the functions is always going to match the member.

diff --git a/treesudoku/test_tree_soduku.py b/treesudoku/test_tree_soduku.py
index be4b95b..1266e11 100644
--- a/treesudoku/test_tree_soduku.py
+++ b/treesudoku/test_tree_soduku.py
@@ -50,7 +50,7 @@ def test_sudoku_solver():
 
 def test_advance():
     test_board = tree_sudoku.build_board(puzzle0)
-    node = tree_sudoku.Tree_Node(None, 0)
+    node = tree_sudoku.Tree_Node(None, 0, test_board)
     node.write(test_board)
     assert(test_board[0][0] == '9')
     node = node.advance(test_board)
diff --git a/treesudoku/tree_sudoku.py b/treesudoku/tree_sudoku.py
index 89096d7..7d288ce 100644
--- a/treesudoku/tree_sudoku.py
+++ b/treesudoku/tree_sudoku.py
@@ -35,9 +35,9 @@ class SudokuSolver:
             self.solved_board_strings[key] = return_string
 
     def tree_to_solution_string(self, original_board):
-        head_node = Tree_Node(None, 0)
-        curr_node = head_node
         test_board = copy.deepcopy(original_board)
+        head_node = Tree_Node(None, 0, test_board)
+        curr_node = head_node
         curr_node.check_solved(test_board)
         while True:
             if not curr_node.solved:
@@ -187,12 +187,13 @@ board_index = BoardIndexTable()
 
 
 class Tree_Node:
-    def __init__(self, last_node, index):
+    def __init__(self, last_node, index, board):
         self.board_spot = board_index.table[index]
         self.last_node = last_node
         self.next_node = None
         self.solved = False
         self.index = index
+        self.board = board
         self.reset()
 
     def reset(self):
@@ -223,7 +224,7 @@ class Tree_Node:
     def advance(self, test_board):
         node = self
         if node.next_node is None:
-            new_node = Tree_Node(node, self.index + 1)
+            new_node = Tree_Node(node, self.index + 1, self.board)
             new_node.check_solved(test_board)
             node.next_node = new_node
         return node.next_node

Next step is to wire up the functions to use the member variable instead of the parameter. There are two ways to do this: change one function body, parameter list, and calling locations all at the same time, or change the function bodies first, and then each of the places that call it. I prefer the latter, as it lets me keep the code running successfully with fewer breaking changes between stable points.

This is a long change.

--- a/treesudoku/test_tree_soduku.py
+++ b/treesudoku/test_tree_soduku.py
@@ -50,19 +50,19 @@ def test_sudoku_solver():
 
 def test_advance():
     test_board = tree_sudoku.build_board(puzzle0)
-    node = tree_sudoku.Tree_Node(None, 0)
-    node.write(test_board)
+    node = tree_sudoku.Tree_Node(None, 0, test_board)
+    node.write()
     assert(test_board[0][0] == '9')
-    node = node.advance(test_board)
-    node = node.advance(test_board)
-    node.write(test_board)
+    node = node.advance()
+    node = node.advance()
+    node.write()
     assert(test_board[0][3] == '0')
-    node = node.advance(test_board)
-    node.write(test_board)
+    node = node.advance()
+    node.write()
     assert(test_board[0][3] == '9')
-    back_node = node.retreat(test_board)
+    back_node = node.retreat()
     assert(test_board[0][3] == '0')
     assert(node.value == "9")
-    back_node.write(test_board)
+    back_node.write()
     assert(test_board[0][2] == '3')
     assert(back_node.board_spot == '02')
diff --git a/treesudoku/tree_sudoku.py b/treesudoku/tree_sudoku.py
index 89096d7..00ee60c 100644
--- a/treesudoku/tree_sudoku.py
+++ b/treesudoku/tree_sudoku.py
@@ -35,24 +35,24 @@ class SudokuSolver:
             self.solved_board_strings[key] = return_string
 
     def tree_to_solution_string(self, original_board):
-        head_node = Tree_Node(None, 0)
-        curr_node = head_node
         test_board = copy.deepcopy(original_board)
-        curr_node.check_solved(test_board)
+        head_node = Tree_Node(None, 0, test_board)
+        curr_node = head_node
+        curr_node.check_solved()
         while True:
             if not curr_node.solved:
-                curr_node.write(test_board)
+                curr_node.write()
             if self.box_index.is_value_valid(test_board, curr_node):
                 if curr_node.index + 1 &gt;= MAX:
                     break
-                curr_node = curr_node.advance(test_board)
-                curr_node.check_solved(test_board)
+                curr_node = curr_node.advance()
+                curr_node.check_solved()
             else:
                 # backtrack
                 while len(curr_node.possible_values) == 0:
-                    curr_node = curr_node.retreat(test_board)
+                    curr_node = curr_node.retreat()
                 curr_node.next()
-                curr_node.write(test_board)
+                curr_node.write()
         return self.build_solution_string(head_node)
 
     def build_solution_string(self, head_node):
@@ -187,12 +187,13 @@ board_index = BoardIndexTable()
 
 
 class Tree_Node:
-    def __init__(self, last_node, index):
+    def __init__(self, last_node, index, board):
         self.board_spot = board_index.table[index]
         self.last_node = last_node
         self.next_node = None
         self.solved = False
         self.index = index
+        self.board = board
         self.reset()
 
     def reset(self):
@@ -205,36 +206,36 @@ class Tree_Node:
     def __str__(self):
         return self.value
 
-    def write(self, board):
+    def write(self):
         row = int(self.board_spot[0])
         col = int(self.board_spot[1])
-        board[row][col] = self.value
+        self.board[row][col] = self.value
 
-    def check_solved(self, board):
+    def check_solved(self):
         if self.solved:
             return
         row = int(self.board_spot[0])
         col = int(self.board_spot[1])
-        if board[row][col] != '0':
-            self.value = board[row][col]
+        if self.board[row][col] != '0':
+            self.value = self.board[row][col]
             self.possible_values = []
             self.solved = True
 
-    def advance(self, test_board):
+    def advance(self):
         node = self
         if node.next_node is None:
-            new_node = Tree_Node(node, self.index + 1)
-            new_node.check_solved(test_board)
+            new_node = Tree_Node(node, self.index + 1, self.board)
+            new_node.check_solved()
             node.next_node = new_node
         return node.next_node
 
-    def retreat(self, test_board):
+    def retreat(self):
         node = self
         node.reset()
         if not self.solved:
             row = int(self.board_spot[0])
             col = int(self.board_spot[1])
-            test_board[row][col] = '0'
+            self.board[row][col] = '0'
         node = self.last_node
         node.next_node = None
         return node

Here is our current algorithmic function:

    def tree_to_solution_string(self, original_board):
        test_board = copy.deepcopy(original_board)
        head_node = Tree_Node(None, 0, test_board)
        curr_node = head_node
        curr_node.check_solved()
        while True:
            if not curr_node.solved:
                curr_node.write()
            if self.box_index.is_value_valid(test_board, curr_node):
                if curr_node.index + 1 &gt;= MAX:
                    break
                curr_node = curr_node.advance()
                curr_node.check_solved()
            else:
                # backtrack
                while len(curr_node.possible_values) == 0:
                    curr_node = curr_node.retreat()
                curr_node.next()
                curr_node.write()
        return self.build_solution_string(head_node)

One thing that occurred to me when I reviewed this is that we create a new node, we immediately check if it is solved. Node creation is done at the top of the function and in the advance function. If we move the is_solved check into the constructor, we treat it as an invariant.

-- a/treesudoku/tree_sudoku.py
+++ b/treesudoku/tree_sudoku.py
@@ -38,7 +38,6 @@ class SudokuSolver:
         test_board = copy.deepcopy(original_board)
         head_node = Tree_Node(None, 0, test_board)
         curr_node = head_node
-        curr_node.check_solved()
         while True:
             if not curr_node.solved:
                 curr_node.write()
@@ -46,7 +45,6 @@ class SudokuSolver:
                 if curr_node.index + 1 &gt;= MAX:
                     break
                 curr_node = curr_node.advance()
-                curr_node.check_solved()
             else:
                 # backtrack
                 while len(curr_node.possible_values) == 0:
@@ -195,6 +193,7 @@ class Tree_Node:
         self.index = index
         self.board = board
         self.reset()
+        self.check_solved()
 
     def reset(self):
         self.value = '9'
@@ -225,7 +224,6 @@ class Tree_Node:
         node = self
         if node.next_node is None:
             new_node = Tree_Node(node, self.index + 1, self.board)
-            new_node.check_solved()
             node.next_node = new_node
         return node.next_node

Something else that is now apparent: we are writing the value both at the bottom and the top of the loop, and that is redundant. We should be able to remove the write at the bottom.

diff --git a/treesudoku/tree_sudoku.py b/treesudoku/tree_sudoku.py
index eadd374..0c7b16c 100644
--- a/treesudoku/tree_sudoku.py
+++ b/treesudoku/tree_sudoku.py
@@ -50,7 +50,6 @@ class SudokuSolver:
                 while len(curr_node.possible_values) == 0:
                     curr_node = curr_node.retreat()
                 curr_node.next()
-                curr_node.write()
         return self.build_solution_string(head_node)
 
     def build_solution_string(self, head_node):

This is the common pattern of a long refactoring. You tease apart, and then you simplify. Extract method in one place, Inline that method somewhere else if applicable. The algorithm is now fairly straightforward and understandable, and I would feel that the effort put in thus far would be justified in a production project. But more can certainly be done.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.