Skip to content

Ví dụ: Chưng cất Tri thức (Knowledge Distillation) trong PyTorch trên MNIST

Spread the love

Ví dụ này minh họa về Chưng cất Tri thức (Knowledge Distillation), một kỹ thuật mà một mô hình “học trò” (student) nhỏ được huấn luyện để bắt chước một mô hình “thầy” (teacher) lớn hơn, đã được huấn luyện trước.

Trước tiên, hãy cùng tìm hiểu sơ lược về Chưng cất Tri thức.

🎓 Chưng cất Tri thức là gì?

Chưng cất Tri thức (Knowledge Distillation – KD) là một kỹ thuật học máy dùng để chuyển giao tri thức từ một mô hình lớn, hiệu suất cao (“thầy”) sang một mô hình nhỏ hơn, hiệu quả hơn (“học trò”).

Mục tiêu: Mục tiêu chính là tạo ra một mô hình nhỏ, nhanh (như cho điện thoại thông minh hoặc thiết bị biên) mà đạt được độ chính xác gần bằng với mô hình thầy lớn hơn, chậm hơn nhiều.

Cách thức hoạt động:

Thay vì chỉ huấn luyện mô hình trò trên các câu trả lời đúng (gọi là “nhãn cứng” – hard labels), mô hình trò được huấn luyện để khớp với toàn bộ quá trình suy nghĩ của mô hình thầy (gọi là “nhãn mềm” – soft labels, chính là lớp logits hay còn gọi là lớp softmax).

  • Nhãn cứng: Câu trả lời là “9”.
  • Nhãn mềm (từ Thầy): “Tôi chắc 90% đây là số 9, 8% chắc là số 7, và 2% chắc là số 4.”

Thông tin “mềm” này, thường được gọi là “tri thức tối” (dark knowledge), phong phú hơn nhiều. Nó dạy cho mô hình trò tại sao số 9 trông hơi giống số 7 hoặc số 4, dẫn đến một mô hình trò thông minh hơn nhiều so với việc chỉ huấn luyện từ đầu trên nhãn cứng.

Tóm lại: Bạn sử dụng một mô hình lớn để “dạy” một mô hình nhỏ, chuyển “trí thông minh” của nó vào một gói hiệu quả hơn nhiều.

Python Code

Trong đoạn mã sau, chúng ta sẽ

  • Định nghĩa một mô hình Thầy (Teacher) lớn hơn và một mô hình Trò (Student) nhỏ hơn.
  • Định nghĩa một hàm tiện ích để đếm tham số mô hình.
  • Huấn luyện mô hình Thầy trên MNIST và đánh giá độ chính xác cao của nó.
  • Định nghĩa hàm mất mát chưng cất (distillation loss function).
  • Huấn luyện mô hình Trò sử dụng chưng cất tri thức, học hỏi từ cả mô hình thầy và nhãn thật.
  • Huấn luyện một mô hình Trò y hệt từ đầu (không chưng cất) để so sánh.
  • In ra độ chính xác cuối cùng và số lượng tham số để cho thấy lợi ích.

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

Đây là các thư viện import tiêu chuẩn cho một dự án PyTorch:

  • torch: Thư viện PyTorch cốt lõi.
  • torch.nn: Chứa tất cả các khối xây dựng cho mạng nơ-ron, như các lớp ( nn.Module, nn.Conv2d, nn.Linear) và các hàm mất mát.
  • torch.nn.functional (as F): Cung cấp các hàm không có tham số học được (learnable parameters), chẳng hạn như các hàm kích hoạt (F.relu) và các hàm mất mát (F.cross_entropy).
  • torch.optim: Bao gồm các thuật toán tối ưu hóa như optim.SGD (Tối ưu hóa Gradient Descent Ngẫu nhiên).
  • torchvision: Một thư viện cho các tác vụ thị giác máy tính. Chúng ta sử dụng datasets để tải MNIST và transforms để tiền xử lý hình ảnh.
  • torch.utils.data.DataLoader: Một tiện ích giúp tải dữ liệu theo lô (batch), xáo trộn (shuffle) và tải song song.

Python

# --- 1. Cài đặt và Siêu tham số ---

# Cài đặt huấn luyện
BATCH_SIZE = 64
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9

# Cài đặt Chưng cất Tri thức (KD)
TEMPERATURE = 10  # Nhiệt độ để làm mềm xác suất
ALPHA = 0.1       # Trọng số cho mất mát "cứng" (nhãn thật)

# Đặt thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Phần này định nghĩa tất cả các biến cấu hình cho thử nghiệm.

  • Cài đặt Huấn luyện (Training Settings): Kiểm soát quá trình huấn luyện cơ bản. BATCH_SIZE là số lượng hình ảnh được xử lý trong một bước. EPOCHS_... định nghĩa số lần mô hình nhìn thấy toàn bộ tập dữ liệu. LEARNING_RATEMOMENTUM là các tham số cho thuật toán tối ưu hóa SGD.
  • Cài đặt KD (KD Settings): Các cài đặt này đặc thù cho Chưng cất Tri thức.
    • TEMPERATURE (T): (Nhiệt độ) Một nhiệt độ cao sẽ “làm mềm” (softens) các xác suất đầu ra của mô hình. Ví dụ, thay vì [0.1, 0.9, 0.0], nhiệt độ cao có thể tạo ra [0.25, 0.5, 0.25]. Điều này buộc mô hình trò phải học cách mô hình thầy “suy nghĩ” (ví dụ: “nó khá chắc đây là số 7, nhưng nó cũng trông hơi giống số 9”). Đây được gọi là “tri thức tối”.
    • ALPHA: Biến này cân bằng hai hàm mất mát cho mô hình trò. Mô hình trò được huấn luyện với một mất mát kết hợp: (ALPHA * hard_loss) + ((1 - ALPHA) * soft_loss). Với ALPHA = 0.1, mô hình trò chú ý 10% đến các nhãn thật (hard loss) và 90% chú ý đến việc khớp với các đầu ra đã được làm mềm của mô hình thầy (soft loss).
  • Thiết bị (Device): Mã này kiểm tra xem GPU (cuda) có khả dụng hay không. Nếu có, nó sẽ đặt thiết bị là “cuda” để tăng tốc huấn luyện; ngược lại, nó sử dụng “cpu”.

Python

# --- 2. Tải dữ liệu (MNIST) ---

# Các phép biến đổi MNIST tiêu chuẩn
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Trung bình và độ lệch chuẩn của MNIST
])

# Tải bộ dữ liệu
train_dataset = datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# Tạo các data loader
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False
)

Khối này xử lý việc tải và chuẩn bị bộ dữ liệu chữ số viết tay MNIST.

  • transform: Định nghĩa một quy trình tiền xử lý.
    • transforms.ToTensor(): Chuyển đổi hình ảnh đầu vào (là ảnh PIL) thành Tensor PyTorch và co giãn các giá trị pixel của chúng từ phạm vi [0, 255] về phạm vi [0.0, 1.0].
    • transforms.Normalize((0.1307,), (0.3081,)): Chuẩn hóa các giá trị của tensor. Nó trừ đi giá trị trung bình (0.1307) và chia cho độ lệch chuẩn (0.3081) của bộ dữ liệu MNIST. Điều này giúp mô hình huấn luyện ổn định và nhanh hơn.
  • datasets.MNIST: Tải về bộ huấn luyện (train=True) và bộ kiểm tra (train=False) vào thư mục ./data nếu chúng chưa tồn tại. Nó áp dụng transform cho mọi hình ảnh khi tải.
  • DataLoader: Bọc các bộ dữ liệu và biến chúng thành các iterator cung cấp dữ liệu theo lô (batch_size=BATCH_SIZE). shuffle=True cho bộ tải huấn luyện là rất quan trọng; nó ngẫu nhiên hóa thứ tự dữ liệu trong mỗi epoch để ngăn mô hình học theo thứ tự của dữ liệu.

Python

# --- 3. Định nghĩa Mô hình ---

class TeacherNet(nn.Module):
    """Một CNN lớn hơn cho MNIST ('Thầy')"""
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Tính kích thước sau khi làm phẳng (flatten) sau các lớp conv và pool
        # Đầu vào 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # -> Conv2 -> 14x14 -> Pool2 -> 7x7
        # Kích thước phẳng = 64 * 7 * 7
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7) # Làm phẳng
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Xuất ra logits thô
        return x

Đoạn mã này định nghĩa mô hình Thầy. Đây là một Mạng nơ-ron tích chập (CNN) tương đối lớn và phức tạp.

  • __init__: Hàm khởi tạo định nghĩa các lớp.
    • nn.Conv2d: Hai lớp tích chập. Lớp đầu tiên nhận 1 kênh (ảnh xám) và xuất ra 32 kênh. Lớp thứ hai nhận 32 và xuất ra 64. padding=1 giữ nguyên kích thước 28×28.
    • nn.MaxPool2d: Một lớp gộp (pooling) làm giảm một nửa kích thước hình ảnh (28×28 -> 14×14, sau đó 14×14 -> 7×7).
    • nn.Linear: Hai lớp kết nối đầy đủ (dense). Lớp đầu tiên nhận hình ảnh 7×7 đã được làm phẳng với 64 kênh (64 * 7 * 7) và ánh xạ nó tới 256 đặc trưng. Lớp thứ hai ánh xạ 256 đặc trưng tới 10 giá trị đầu ra (gọi là logits), một cho mỗi chữ số (0-9).
  • forward: Phương thức này định nghĩa cách dữ liệu chảy qua các lớp. Đó là một trình tự (Conv -> ReLU -> Pool) hai lần, sau đó là thao tác view để làm phẳng bản đồ đặc trưng 3D thành một vector 1D, tiếp theo là hai lớp tuyến tính.

Python

class StudentNet(nn.Module):
    """Một CNN nhỏ hơn nhiều cho MNIST ('Trò')"""
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Đầu vào 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # Kích thước phẳng = 16 * 14 * 14
        self.fc1 = nn.Linear(16 * 14 * 14, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 16 * 14 * 14) # Làm phẳng
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Xuất ra logits thô
        return x

def count_parameters(model):
    """Hàm tiện ích để đếm tham số mô hình"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

Phần này định nghĩa mô hình Trò, là mô hình chúng ta muốn sử dụng cuối cùng (ví dụ: trên thiết bị di động) vì nó nhỏ và nhanh.

  • StudentNet: Mô hình này có cấu trúc tương tự như mô hình thầy nhưng nhỏ hơn và đơn giản hơn nhiều. Nó chỉ có một lớp tích chập (với 16 kênh, so với 32 và 64 của thầy) và các lớp tuyến tính nhỏ hơn (ánh xạ tới 32 đặc trưng, so với 256). Điều này có nghĩa là nó có ít tham số hơn nhiều và sẽ nhanh hơn nhiều.
  • count_parameters: Đây là một hàm trợ giúp (helper function) chỉ đơn giản là đếm tổng số tham số có thể học được (trọng số và bias) trong một mô hình. Hàm này được sử dụng ở cuối để chứng minh rằng mô hình trò thực sự nhỏ hơn nhiều so với mô hình thầy.

Python

# --- 4. Các hàm Huấn luyện và Đánh giá Tiêu chuẩn ---

def train_standard(model, train_loader, optimizer, epoch):
    """Vòng lặp huấn luyện tiêu chuẩn cho một bộ phân loại."""
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        
        # Mất mát Cross-Entropy Tiêu chuẩn
        loss = F.cross_entropy(output, target)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tLoss: {loss.item():.6f}")

Đây là một hàm huấn luyện tiêu chuẩn. Nó sẽ được sử dụng để huấn luyện mô hình Thầy và mô hình Trò “cơ sở” (baseline) (mô hình được huấn luyện mà không có chưng cất).

  • model.train(): Đặt mô hình ở chế độ “training” (kích hoạt các lớp như Dropout, nếu có).
  • data, target = data.to(device), target.to(device): Di chuyển lô hình ảnh và nhãn sang GPU/CPU.
  • optimizer.zero_grad(): Xóa mọi gradient cũ từ bước trước đó.
  • output = model(data): Thực hiện một lượt truyền xuôi (forward pass) để lấy các dự đoán (logits) của mô hình.
  • loss = F.cross_entropy(output, target): Tính toán mất mát cross-entropy. Mất mát này so sánh logits của mô hình với các nhãn thật (ví dụ: “ảnh này là số 7”).
  • loss.backward(): Thực hiện lan truyền ngược (backpropagation), tính toán gradient của mất mát đối với mọi tham số mô hình.
  • optimizer.step(): Cập nhật các tham số (trọng số) của mô hình bằng cách sử dụng gradient và thuật toán của trình tối ưu hóa (SGD).

Python

def evaluate(model, test_loader):
    """Vòng lặp đánh giá tiêu chuẩn."""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
    return accuracy

Hàm này đánh giá hiệu suất của mô hình trên tập dữ liệu kiểm tra.

  • model.eval(): Đặt mô hình ở chế độ “đánh giá” (tắt Dropout, v.v.).
  • with torch.no_grad(): Đây là một tối ưu hóa quan trọng. Nó báo cho PyTorch không theo dõi gradient, giúp tiết kiệm bộ nhớ và tăng tốc độ tính toán vì chúng ta chỉ đang kiểm tra, không huấn luyện.
  • Bên trong vòng lặp, nó tính tổng mất mát và đếm số lượng dự đoán đúng.
  • pred = output.argmax(dim=1, ...): Lệnh này tìm chỉ số (index) của giá trị logit cao nhất cho mỗi hình ảnh. Chỉ số này là dự đoán cuối cùng của mô hình (ví dụ: 0, 1, 2… 9).
  • correct += ...: Nó so sánh các dự đoán của mô hình (pred) với các nhãn thật (target) và cộng dồn các dự đoán đúng.
  • Cuối cùng, nó in ra mất mát trung bình và tỷ lệ phần trăm chính xác tổng thể.

Python

# --- 5. Mất mát Chưng cất Tri thức và Huấn luyện ---

def distillation_loss(student_logits, labels, teacher_logits, T, alpha):
    """
    Tính toán mất mát chưng cất tri thức.
    ...
    """
    
    # 1. Mất mát Chưng cất (Soft Loss)
    # Sử dụng KLDivLoss, yêu cầu log-probabilities (log_softmax) làm đầu vào
    # và probabilities (softmax) làm mục tiêu.
    
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1)
    ) * (T * T) # Nhân mất mát với T^2 như đề xuất trong bài báo gốc

    # 2. Mất mát của Học trò (Hard Loss)
    # Mất mát cross-entropy tiêu chuẩn giữa logits của trò và nhãn thật
    hard_loss = F.cross_entropy(student_logits, labels)

    # 3. Kết hợp các mất mát
    combined_loss = alpha * hard_loss + (1 - alpha) * soft_loss
    return combined_loss

Đây là phần cốt lõi của logic Chưng cất Tri thức. Hàm mất mát tùy chỉnh này kết hợp hai mất mát riêng biệt:

  1. Soft Loss (Mất mát Mềm): Đây là phần “chưng cất”.
    • student_logits / Tteacher_logits / T: Đầu ra của cả hai mô hình được chia cho TEMPERATURE (T=10) để “làm mềm” chúng.
    • F.log_softmaxF.softmax: Chúng ta tạo ra các phân phối xác suất từ các logits đã được làm mềm.
    • nn.KLDivLoss: Mất mát Phân kỳ Kullback-Leibler (Kullback-Leibler Divergence Loss). Nó đo lường mức độ khác biệt giữa phân phối xác suất của mô hình trò và của mô hình thầy. Mục tiêu là giảm thiểu sự khác biệt này, buộc mô hình trò phải bắt chước “quá trình suy nghĩ” của mô hình thầy.
    • * (T * T): Hệ số nhân này là một phần của bài báo chưng cất gốc. Nó cần thiết để điều chỉnh tỷ lệ gradient một cách chính xác, vốn bị giảm đi do quá trình làm mềm.
  2. Hard Loss (Mất mát Cứng): Đây là mất mát F.cross_entropy chúng ta đã sử dụng trước đó. Nó so sánh logits của mô hình trò với các nhãn thật, thực tế.
  3. Combined Loss (Mất mát Kết hợp): Mất mát cuối cùng là tổng có trọng số, được kiểm soát bởi ALPHA. Vì ALPHA = 0.1, mất mát là 10% từ nhãn thật và 90% từ việc bắt chước mô hình thầy.

Python

def train_distillation(student, teacher, train_loader, optimizer, epoch, T, alpha):
    """Vòng lặp huấn luyện cho chưng cất tri thức."""
    student.train()
    teacher.eval() # Thầy ở chế độ eval và trọng số của nó bị đóng băng
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Lấy đầu ra của trò
        student_logits = student(data)
        
        # Lấy đầu ra của thầy (với no_grad để đóng băng nó)
        with torch.no_grad():
            teacher_logits = teacher(data)
        
        # Tính toán mất mát chưng cất
        loss = distillation_loss(student_logits, target, teacher_logits, T, alpha)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tKD Loss: {loss.item():.6f}")

Đây là vòng lặp huấn luyện đặc biệt cho mô hình trò được chưng cất.

  • student.train(): Mô hình trò ở chế độ huấn luyện.
  • teacher.eval(): Mô hình thầy ở chế độ đánh giá. Điều này rất quan trọng; chúng ta không muốn huấn luyện mô hình thầy nữa.
  • with torch.no_grad(): Bối cảnh này chỉ bao bọc quanh lượt truyền xuôi (forward pass) của mô hình thầy (teacher(data)). Điều này báo cho PyTorch không tính toán gradient cho mô hình thầy, tiết kiệm tính toán và đảm bảo trọng số của nó được giữ nguyên (đóng băng).
  • loss = distillation_loss(...): Gọi hàm mất mát tùy chỉnh đã định nghĩa ở trên, sử dụng đầu ra của cả hai mô hình.
  • loss.backward(): Lệnh này tính toán gradient, nhưng chỉ cho mô hình trò (vì mô hình thầy nằm trong chế độ no_grad).
  • optimizer.step(): Lệnh này chỉ cập nhật trọng số của mô hình trò.

Python

# --- 6. Thực thi chính ---

# --- Bước A: Huấn luyện mô hình Thầy ---
print("--- 1. Training Teacher Model ---")
teacher_model = TeacherNet().to(device)
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_TEACHER + 1):
    train_standard(teacher_model, train_loader, optimizer_teacher, epoch)

print("\n--- Evaluating Teacher Model ---")
teacher_acc = evaluate(teacher_model, test_loader)

Đây là Bước A của kịch bản chính.

  1. Một thể hiện của TeacherNet (lớn) được tạo ra và chuyển đến thiết bị.
  2. Một trình tối ưu hóa được tạo ra cho các tham số của mô hình thầy.
  3. Nó huấn luyện mô hình thầy trong EPOCHS_TEACHER (5 epochs) bằng cách sử dụng hàm train_standard (tức là chỉ sử dụng nhãn thật và mất mát cross-entropy tiêu chuẩn).
  4. Nó đánh giá mô hình thầy đã được huấn luyện đầy đủ và lưu độ chính xác cuối cùng của nó vào teacher_acc.

Python

# --- Bước B: Huấn luyện mô hình Trò với Chưng cất Tri thức ---
print("\n--- 2. Training Student with Distillation ---")
student_model_kd = StudentNet().to(device)
optimizer_student_kd = optim.SGD(student_model_kd.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_distillation(
        student=student_model_kd,
        teacher=teacher_model,
        train_loader=train_loader,
        optimizer=optimizer_student_kd,
        epoch=epoch,
        T=TEMPERATURE,
        alpha=ALPHA
    )

print("\n--- Evaluating Distilled Student Model ---")
student_kd_acc = evaluate(student_model_kd, test_loader)

Đây là Bước B, phần cốt lõi của thử nghiệm.

  1. Một thể hiện mới của StudentNet (nhỏ) được tạo ra (student_model_kd).
  2. Một trình tối ưu hóa được tạo ra cho các tham số của mô hình trò này.
  3. Nó huấn luyện mô hình trò trong EPOCHS_STUDENT (10 epochs) bằng cách sử dụng hàm đặc biệt train_distillation.
  4. Hàm này truyền vào mô hình thầy đã được huấn luyện, bị đóng băng để hoạt động như người hướng dẫn.
  5. Nó đánh giá mô hình trò “được chưng cất” và lưu độ chính xác của nó vào student_kd_acc.

Python

# --- Bước C: Huấn luyện mô hình Trò từ đầu (để so sánh) ---
print("\n--- 3. Training Student from Scratch (Baseline) ---")
student_model_scratch = StudentNet().to(device)
optimizer_student_scratch = optim.SGD(student_model_scratch.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_standard(student_model_scratch, train_loader, optimizer_student_scratch, epoch)

print("\n--- Evaluating 'Scratch' Student Model ---")
student_scratch_acc = evaluate(student_model_scratch, test_loader)

Đây là Bước C, “kiểm soát” hoặc “cơ sở” (baseline) để so sánh.

  1. Một mô hình thứ ba, một thể hiện mới khác của StudentNet, được tạo ra (student_model_scratch).
  2. Một trình tối ưu hóa được tạo ra cho các tham số của nó.
  3. Nó huấn luyện mô hình trò này trong cùng số epochs (EPOCHS_STUDENT) như mô hình trò được chưng cất.
  4. Tuy nhiên, nó sử dụng hàm train_standard, có nghĩa là nó chỉ học từ các nhãn thật, không có sự trợ giúp từ mô hình thầy.
  5. Nó đánh giá mô hình trò “từ đầu” này và lưu độ chính xác của nó vào student_scratch_acc.

Python

# --- 4. So sánh cuối cùng ---
print("\n" + "="*30)
print("--- Final Results ---")
print(f"Teacher Model:\t\tParams: {count_parameters(teacher_model):,}\tAccuracy: {teacher_acc:.2f}%")
print(f"Student (Distilled):\tParams: {count_parameters(student_model_kd):,}\tAccuracy: {student_kd_acc:.2f}%")
print(f"Student (Scratch):\tParams: {count_parameters(student_model_scratch):,}\tAccuracy: {student_scratch_acc:.2f}%")
print("="*30)

Khối cuối cùng này in ra kết quả của toàn bộ thử nghiệm.

  • Nó sử dụng count_parameters để chỉ ra rằng mô hình Thầy lớn hơn nhiều (nhiều tham số hơn) so với hai mô hình Trò (có kích thước giống hệt nhau).
  • Nó in độ chính xác cuối cùng cho cả ba mô hình.
  • Kết quả mong đợi là student_kd_acc (Được chưng cất) sẽ cao hơn student_scratch_acc (Huấn luyện từ đầu). Điều này chứng tỏ rằng mô hình trò đã học hiệu quả hơn bằng cách bắt chước các xác suất “mềm” của mô hình thầy so với việc chỉ học từ các nhãn “cứng” thật, chuyển giao thành công tri thức từ mô hình lớn sang mô hình nhỏ.

Mã nguồn đầy đủ:

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- 1. Cài đặt và Siêu tham số ---

# Cài đặt huấn luyện
BATCH_SIZE = 64
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9

# Cài đặt Chưng cất Tri thức (KD)
TEMPERATURE = 10  # Nhiệt độ để làm mềm xác suất
ALPHA = 0.1       # Trọng số cho mất mát "cứng" (nhãn thật)

# Đặt thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 2. Tải dữ liệu (MNIST) ---

# Các phép biến đổi MNIST tiêu chuẩn
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Trung bình và độ lệch chuẩn của MNIST
])

# Tải bộ dữ liệu
train_dataset = datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# Tạo các data loader
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False
)

# --- 3. Định nghĩa Mô hình ---

class TeacherNet(nn.Module):
    """Một CNN lớn hơn cho MNIST ('Thầy')"""
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Tính kích thước sau khi làm phẳng (flatten) sau các lớp conv và pool
        # Đầu vào 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # -> Conv2 -> 14x14 -> Pool2 -> 7x7
        # Kích thước phẳng = 64 * 7 * 7
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7) # Làm phẳng
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Xuất ra logits thô
        return x

class StudentNet(nn.Module):
    """Một CNN nhỏ hơn nhiều cho MNIST ('Trò')"""
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Đầu vào 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # Kích thước phẳng = 16 * 14 * 14
        self.fc1 = nn.Linear(16 * 14 * 14, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 16 * 14 * 14) # Làm phẳng
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Xuất ra logits thô
        return x

def count_parameters(model):
    """Hàm tiện ích để đếm tham số mô hình"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- 4. Các hàm Huấn luyện và Đánh giá Tiêu chuẩn ---

def train_standard(model, train_loader, optimizer, epoch):
    """Vòng lặp huấn luyện tiêu chuẩn cho một bộ phân loại."""
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        
        # Mất mát Cross-Entropy Tiêu chuẩn
        loss = F.cross_entropy(output, target)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tLoss: {loss.item():.6f}")

def evaluate(model, test_loader):
    """Vòng lặp đánh giá tiêu chuẩn."""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
    return accuracy

# --- 5. Mất mát Chưng cất Tri thức và Huấn luyện ---

def distillation_loss(student_logits, labels, teacher_logits, T, alpha):
    """
    Tính toán mất mát chưng cất tri thức.
    
    :param student_logits: Logits thô từ mô hình trò
    :param labels: Nhãn thật (cho hard loss)
    :param teacher_logits: Logits thô từ mô hình thầy
    :param T: Nhiệt độ
    :param alpha: Hệ số trọng số
    :return: Mất mát chưng cất kết hợp
    """
    
    # 1. Mất mát Chưng cất (Soft Loss)
    # Sử dụng KLDivLoss, yêu cầu log-probabilities (log_softmax) làm đầu vào
    # và probabilities (softmax) làm mục tiêu.
    soft_loss = nn.KLDivLoss(reduction='batchmean', log_target=True)(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1)
    ) * (T * T) # Nhân mất mát với T^2 như đề xuất trong bài báo gốc

    # 2. Mất mát của Học trò (Hard Loss)
    # Mất mát cross-entropy tiêu chuẩn giữa logits của trò và nhãn thật
    hard_loss = F.cross_entropy(student_logits, labels)

    # 3. Kết hợp các mất mát
    combined_loss = alpha * hard_loss + (1 - alpha) * soft_loss
    return combined_loss

def train_distillation(student, teacher, train_loader, optimizer, epoch, T, alpha):
    """Vòng lặp huấn luyện cho chưng cất tri thức."""
    student.train()
    teacher.eval() # Thầy ở chế độ eval và trọng số của nó bị đóng băng
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Lấy đầu ra của trò
        student_logits = student(data)
        
        # Lấy đầu ra của thầy (với no_grad để đóng băng nó)
        with torch.no_grad():
            teacher_logits = teacher(data)
        
        # Tính toán mất mát chưng cất
        loss = distillation_loss(student_logits, target, teacher_logits, T, alpha)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tKD Loss: {loss.item():.6f}")

# --- 6. Thực thi chính ---

# --- Bước A: Huấn luyện mô hình Thầy ---
print("--- 1. Training Teacher Model ---")
teacher_model = TeacherNet().to(device)
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_TEACHER + 1):
    train_standard(teacher_model, train_loader, optimizer_teacher, epoch)

print("\n--- Evaluating Teacher Model ---")
teacher_acc = evaluate(teacher_model, test_loader)

# --- Bước B: Huấn luyện mô hình Trò với Chưng cất Tri thức ---
print("\n--- 2. Training Student with Distillation ---")
student_model_kd = StudentNet().to(device)
optimizer_student_kd = optim.SGD(student_model_kd.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_distillation(
        student=student_model_kd,
        teacher=teacher_model,
        train_loader=train_loader,
        optimizer=optimizer_student_kd,
        epoch=epoch,
        T=TEMPERATURE,
        alpha=ALPHA
    )

print("\n--- Evaluating Distilled Student Model ---")
student_kd_acc = evaluate(student_model_kd, test_loader)

# --- Bước C: Huấn luyện mô hình Trò từ đầu (để so sánh) ---
print("\n--- 3. Training Student from Scratch (Baseline) ---")
student_model_scratch = StudentNet().to(device)
optimizer_student_scratch = optim.SGD(student_model_scratch.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_standard(student_model_scratch, train_loader, optimizer_student_scratch, epoch)

print("\n--- Evaluating 'Scratch' Student Model ---")
student_scratch_acc = evaluate(student_model_scratch, test_loader)

# --- 4. So sánh cuối cùng ---
print("\n" + "="*30)
print("--- Final Results ---")
print(f"Teacher Model:\t\tParams: {count_parameters(teacher_model):,}\tAccuracy: {teacher_acc:.2f}%")
print(f"Student (Distilled):\tParams: {count_parameters(student_model_kd):,}\tAccuracy: {student_kd_acc:.2f}%")
print(f"Student (Scratch):\tParams: {count_parameters(student_model_scratch):,}\tAccuracy: {student_scratch_acc:.2f}%")
print("="*30)

Chạy trong Colab

Theo


Discover more from Cùng Học Cùng Mơ

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

error: Content is protected !!