Refactoring

Motivation

Software development often deals with rethinking decisions.

Ultimately causes reworking portions of the code.

Code is not final, it needs to evolve

Why does it evolve?

  • Integrating new features
  • Optimizing current functionalities
  • Extending tests
  • Comment your code?!

Common scenario

Almost every time you change the source code you end up rewritting/reworking or restructuring it. Is that refactoring?

Definition

Disciplined technique for restructuring an existing body of code, altering its internal structure without changing its external behavior.” — Martin Fowler.

Source: Refactoring: Improving the Design of Existing Code. Addison-Wesley, Boston, MA, Second, 2019.

Why is it needed?

  • Improves software design
    • Easier to understand
    • More generalizable
    • Code maintenance
  • Preventing bugs

Easier to clean a room if only a few things are out of place

When should we refactor?

Rule of Three.

  • First time: Just make it work
  • Second: Modify it 🤨
  • Third: Refactor it 👷

How do we know its refactoring?

  • External behaviour cannot change!
  • (Unit-) Test have to be in place!
    • otherwise it is a gamble 😵
  • Happens naturally

What to refactor!

Code Smells

  • Poor naming
  • Duplicated code
  • Mixed responsabilities

Types of refactoring

  • Rename (function, variable,…)
  • Replace
  • Extract
  • Inline
  • Move

Rename

Variables and functions with a meaningful name make the code easiser to understand.

def hw():
    print("Hello world!")


t = 3
for i in range(t):
    hw()
def helloworld(): #renamed function
    print("Hello world!")


num_times = 3 #renamed variable
for i in range(num_times):
    helloworld()

Replace

Find an equivalent way to replace a block of code

def hello(something: str | None = None):
    if something:
        print(f"Hello {something}!")
    else:
        print("Hello!")


hello()
hello("world")
hello("everyone")
hello("class")
hello("humans")
hello(2025)
def hello(something: str | None = None):
    if something:
        print(f"Hello {something}!")
    else:
        print("Hello!")


for word in (None, "world", "everyone", "class", "humans", 2025):
    hello(word)

Extract Refactoring

Divide to conquer

prompt = """Welcome to the GSS course!
Today's lecture is about refactoring.
What is your name? """
reply = input(prompt)
print("Hello", reply.strip())
Welcome to the GSS course!
Today's lecture is about refactoring.
What is your name? Bob
Hello Bob!
  • Welcome
  • Question
  • Greeting
def welcome():
    message = """Welcome to the GSS course!
Today's lecture is about refactoring."""
    print(message)


def ask_name():
    name = input("What is your name? ")
    return name.strip()


def say_hello(name=""):
    print(f"Hello {name}!")


welcome()
name = ask_name()
say_hello(name=name)

Inline Refactoring

It’s the opposite of extract! If we only do it once, might as well not have it encapsulated in a function.

def welcome():
    message = """Welcome to the GSS course!
Today's lecture is about refactoring."""
    print(message)


def ask_name():
    name = input("What is your name? ")
    return name.strip()


def say_hello(name=""):
    print(f"Hello {name}!")


welcome()
name = ask_name()
say_hello(name=name)
def welcome_ask_name():
    message = """Welcome to the GSS course!
Today's lecture is about refactoring."""
    print(message)
    name = input("What is your name? ")
    return name.strip()


def say_hello(name=""):
    print(f"Hello {name}!")


name = welcome_ask_name()
say_hello(name=name)

Moving

Group functionality in a module.
E.g. move is_prime to another file

myscript.py
def is_prime(n: int):
    for i in range(2, n):
        if n % i == 0:
            return False
    return True

limit = int(
    input("""Printing first N primes.
Enter N:""")
)
count, number = (0, 0)
while count < limit:
    if is_prime(number):
        count += 1
        print(number)
    number += 1
primes.py
def is_prime(n: int):
    for i in range(2, n):
        if n % i == 0:
            return False
    return True
myscript.py
from primes import is_prime

limit = int(
    input("""Printing first N primes.
Enter N:""")
)
count, number = (0, 0)
while count < limit:
    if is_prime(number):
        count += 1
        print(number)
    number += 1

Refactoring examples

def fibonacci(n):
    if n <= 0:
        return 0
    elif n == 1:
        return 1
    else:
        previous = 0
        current = 1
        for i in range(n - 1):
            ith = previous + current
            previous = current
            current = ith
        return current

The Fibonacci sequence is defined by the recurrence relation:

\[ F(n) = F(n-1) + F(n-2) \]

with initial numbers:

\[ F(0) = 0,\quad F(1) = 1 \]

Resulting in: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, …

Refactoring examples

def fibonacci(n):
    if n <= 0:
        return 0
    elif n == 1:
        return 1
    else:
        previous = 0
        current = 1
        for i in range(n - 1):
            ith = previous + current
            previous = current
            current = ith
        return current
def fibonacci(n):
    if n <= 0:
        return 0
    previous = 0
    current = 1
    for i in range(n - 1):
        ith = previous + current
        previous = current
        current = ith
    return current

Refactoring examples

def fibonacci(n):
    if n <= 0:
        return 0
    previous = 0
    current = 1
    for i in range(n - 1):
        ith = previous + current
        previous = current
        current = ith
    return current
def fibonacci(n):
    if n <= 0:
        return 0
    previous, current = 0, 1
    for i in range(n - 1):
        previous, current = current, previous + current
    return current

Refactoring examples

def fibonacci(n):
    """
    Computes the Nth number of the Fibonacci Sequence
    """
    if n <= 0:
        return 0
    previous, current = 0, 1
    for i in range(n - 1):
        previous, current = current, previous + current
    return current
help(fibonacci)
Help on function fibonacci in module fib:

fibonacci(n)
    Computes the Nth number of the Fibonacci Sequence

Refactoring examples

def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n - 2) + fibonacci(n - 1)

What was missing?

Development approach

  • Write software that solves current needs
    • Only add functionality when you need it
    • yagni - you aren’t going to need it
  • New code? New tests!!!
    • Test-Driven Development highly incentivized

Refactoring on a team

  • Each member should be able to refactor independently
  • Continuous Integration (CI) shows results by:
    • automatically test code (Unit Tests)
    • pointing where the problems are

Hands-on

  • Find code smells in report.py.
  • Refactor the function grades_report.
  • Run test_grades_report.py (with pytest) to ensure behaviour remains.

Hands-on

report.py
def get_grades():
    students = list()
    grades = list()
    while True:
        name = input("Insert student name (quit to finish):")
        if "quit" in name:
            return students, grades
        grade = int(input(f"Insert {name}'s grade:"))
        students.append(name)
        grades.append(grade)


def grades_report(students, grades):
    average = 0
    maxgrade = 0
    mingrade = 100
    counter = 0  # do we need this variable?
    for grade in grades:  # can this be improved?
        counter += 1  # len(grades) will give the number of students

    # do we need this? we can just return avg, max and min
    report = {"students": students, "grades": grades, "size": counter}
    for i in range(counter):
        grade = grades[i]

        average += grade

        if maxgrade < grade:
            maxgrade = grade
        if mingrade > grade:
            mingrade = grade

    if counter > 0:
        average = average / counter
    else:
        average = None
        maxgrade = None
        mingrade = None

    return average, maxgrade, mingrade


def format_report(students, grades, average, maxgrade, mingrade):
    report = "Report of " + str(len(students)) + " students:\n"
    for i in range(len(students)):
        name = students[i]
        grade = grades[i]
        report += str(name) + " -> " + str(grade) + "\n"
    report += "  Average: " + str(average) + "\n"
    report += "  Maximum: " + str(maxgrade) + "\n"
    report += "  Minimum: " + str(mingrade) + "\n"
    return report


if __name__ == "__main__":
    students, grades = get_grades()
    average, maxgrade, mingrade = grades_report(students, grades)
    report = format_report(students, grades, average, maxgrade, mingrade)
    print(report)
test_grades_report.py
from report import grades_report, format_report


def test_empty_report():
    students = list()
    grades = list()
    assert (None, None, None) == grades_report(students, grades)


def test_grades_report():
    students = ["Alice", "Bob"]
    grades = [1, 2]
    assert len(students) == len(grades)
    size = len(students)
    assert (
        1.5,
        2,
        1,
    ) == grades_report(students, grades)


def test_format_report():
    students = ["Alice", "Bob"]
    grades = [1, 2]

    expected = (
        "Report of 2 students:\n"
        "Alice -> 1\n"
        "Bob -> 2\n"
        "  Average: 1.5\n"
        "  Maximum: 2\n"
        "  Minimum: 1\n"
    )

    report = format_report(students, grades, 1.5, 2, 1)
    assert report == expected

Hands-on

$ python3 report.py
Insert student name (quit to finish):Alice
Insert Alice's grade:2
Insert student name (quit to finish):Bob
Insert Bob's grade:1
Insert student name (quit to finish):quit
Report of 2 students:
Alice -> 2
Bob -> 1
  Average: 1.5
  Maximum: 2
  Minimum: 1
pytest test_grades_report.py
================== test session starts ==================
platform linux -- Python 3.10.12, pytest-8.3.5
plugins: anyio-4.4.0
collected 3 items
test_grades_report.py::test_empty_report PASSED    [ 33%]
test_grades_report.py::test_grades_report PASSED   [ 66%]
test_grades_report.py::test_format_report PASSED   [100%]

=================== 3 passed in 0.01s ===================

Legacy Code

  • Dramatic: code becomes legacy code as soon as it’s written
  • General: code inherited from someone else
  • Complement: code without tests (specially dangerous!!!)

How to deal with legacy code

  • Read it
  • Create tests (extensively)
    • Characterization tests are crucial!
  • Refactor gradually
  • Rewrite new interface (if immutable)
  • Deprecate or remove dead code

Take home messages

  • You can only safely refactor tested codebases
    • Write tests before touching legacy ones
  • Test-driven development & yagni are refactor-friendly
  • Refactor small bits, often

Further Reading