当前位置: 首页 > news >正文

深入解析:李宏毅2025春季机器学习作业ML2025_Spring_HW4在kaggle上的实操笔记

Training Transformer

TA’s Slide

Slide

Description

In this assignment, we are tasked with utilizing a transformer decoder-only architecture for pretraining, with a focus on next-token prediction, applied to Pokémon images.

Please feel free to mail us if you have any questions.

ntu-ml-2025-spring-ta@googlegroups.com

Utilities

Download packages

!pip install datasets==3.3.2
Collecting datasets==3.3.2Using cached datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: filelock in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.17.0)
Requirement already satisfied: numpy>=1.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.0.1)
Requirement already satisfied: pyarrow>=15.0.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (21.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.3.2)Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.3.1)
Requirement already satisfied: requests>=2.32.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (4.67.1)
Requirement already satisfied: xxhash in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.6.0)
Requirement already satisfied: multiprocess<0.70.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.70.16)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)Using cached fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.13.0)
Requirement already satisfied: huggingface-hub>=0.24.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.35.3)
Requirement already satisfied: packaging in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (25.0)
Requirement already satisfied: pyyaml>=5.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.22.0)
Requirement already satisfied: typing-extensions>=4.1.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from multidict<7.0,>=4.5->aiohttp->datasets==3.3.2) (4.15.0)
Requirement already satisfied: idna>=2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.3.2) (3.7)
Requirement already satisfied: charset_normalizer<4,>=2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2025.10.5)
Requirement already satisfied: colorama in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from tqdm>=4.66.3->datasets==3.3.2) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: six>=1.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets==3.3.2) (1.17.0)
Using cached datasets-3.3.2-py3-none-any.whl (485 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Using cached fsspec-2024.12.0-py3-none-any.whl (183 kB)
Installing collected packages: fsspec, dill, datasetsAttempting uninstall: fsspecFound existing installation: fsspec 2025.9.0Uninstalling fsspec-2025.9.0:Successfully uninstalled fsspec-2025.9.0---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]Attempting uninstall: dill---------------------------------------- 0/3 [fsspec]Found existing installation: dill 0.4.0---------------------------------------- 0/3 [fsspec]Uninstalling dill-0.4.0:---------------------------------------- 0/3 [fsspec]Successfully uninstalled dill-0.4.0---------------------------------------- 0/3 [fsspec]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]Attempting uninstall: datasets------------- -------------------------- 1/3 [dill]Found existing installation: datasets 4.1.1------------- -------------------------- 1/3 [dill]Uninstalling datasets-4.1.1:------------- -------------------------- 1/3 [dill]Successfully uninstalled datasets-4.1.1------------- -------------------------- 1/3 [dill]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]---------------------------------------- 3/3 [datasets]
Successfully installed datasets-3.3.2 dill-0.3.8 fsspec-2024.12.0

Import Packages

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional

Check Devices

!nvidia-smi
Wed Oct  8 18:50:06 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.97                 Driver Version: 580.97         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090 Ti   WDDM  |   00000000:07:00.0  On |                  Off |
| 47%   42C    P8             25W /  450W |   12684MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2292    C+G   C:\Windows\System32\dwm.exe           N/A      |
|    0   N/A  N/A            5552    C+G   ...8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A            9928    C+G   C:\Windows\explorer.exe               N/A      |
|    0   N/A  N/A           10036    C+G   ..._cw5n1h2txyewy\SearchHost.exe      N/A      |
|    0   N/A  N/A           10264    C+G   ...y\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A           10632    C+G   ...ogram Files\ToDesk\ToDesk.exe      N/A      |
|    0   N/A  N/A           14304    C+G   ...xyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A           15600    C+G   ...5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A           15812    C+G   ...ouryDevice\asus_framework.exe      N/A      |
|    0   N/A  N/A           18660    C+G   ...crosoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A           18668    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           21724    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           22748    C+G   ...s\TencentDocs\TencentDocs.exe      N/A      |
|    0   N/A  N/A           25412    C+G   ...ram Files\Tencent\QQNT\QQ.exe      N/A      |
|    0   N/A  N/A           25872    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           26600    C+G   ...ocal\Programs\Quark\quark.exe      N/A      |
|    0   N/A  N/A           28688    C+G   ...ntrolPanel\SystemSettings.exe      N/A      |
|    0   N/A  N/A           30104    C+G   ...de\Microsoft VS Code\Code.exe      N/A      |
|    0   N/A  N/A           31500    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           39276    C+G   ...t\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A           41696    C+G   ...PotPlayer\PotPlayerMini64.exe      N/A      |
|    0   N/A  N/A           44176    C+G   ...ffice6\promecefpluginhost.exe      N/A      |
|    0   N/A  N/A           72652      C   ...2025-Spring-Hw1\python.exe.c~      N/A      |
|    0   N/A  N/A          115660    C+G   ...ef.win7x64\steamwebhelper.exe      N/A      |
|    0   N/A  N/A          124396    C+G   ...yb3d8bbwe\WindowsTerminal.exe      N/A      |
+-----------------------------------------------------------------------------------------+

Set Random Seed

set_seed(0)

Prepare Data

Define Dataset

from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
class PixelSequenceDataset(Dataset):
def __init__(self, data: List[List[int]], mode: str = "train"):
"""
A dataset class for handling pixel sequences.
Args:
data (List[List[int]]): A list of sequences, where each sequence is a list of integers.
mode (str): The mode of operation, either "train", "dev", or "test".
- "train": Returns (input_ids, labels) where input_ids are sequence[:-1] and labels are sequence[1:].
- "dev": Returns (input_ids, labels) where input_ids are sequence[:-160] and labels are sequence[-160:].
- "test": Returns only input_ids, as labels are not available.
"""
self.data = data
self.mode = mode
def __len__(self) -> int:
"""Returns the total number of sequences in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Fetches a sequence from the dataset and processes it based on the mode.
Args:
idx (int): The index of the sequence.
Returns:
- If mode == "train": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "dev": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "test": torch.Tensor -> input_ids
"""
sequence = self.data[idx]
if self.mode == "train":
input_ids = torch.tensor(sequence[:-1], dtype=torch.long)
labels = torch.tensor(sequence[1:], dtype=torch.long)
return input_ids, labels
elif self.mode == "dev":
input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
labels = torch.tensor(sequence[-160:], dtype=torch.long)
return input_ids, labels
elif self.mode == "test":
input_ids = torch.tensor(sequence, dtype=torch.long)
return input_ids
raise ValueError(f"Invalid mode: {
self.mode}. Choose from 'train', 'dev', or 'test'.")

Download Dataset & Prepare Dataloader

# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# Define number of classes
num_classes = len(colormap)
# Define batch size
batch_size = 16
# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
dev_dataset, batch_size=batch_size, shuffle=False
)
# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["test"][
http://www.jsqmd.com/news/35784/

相关文章:

  • 完整教程:PostgreSQL + Redis + Elasticsearch 实时同步方案实践:从触发器到高性能搜索
  • 基于最小二乘法的五颗可见卫星伪距定位
  • new day
  • 2025 年 11 月冰水机厂家推荐排行榜,工业冰水机,冷却冰水机,制冷冰水机,低温冰水机公司精选
  • 2025 年 11 月工业冰水机厂家权威推荐榜:专业制冷与高效节能口碑之选,工业冰水机,工业冷水机,工业冷冻机公司推荐
  • 词根学习笔记 | Alter系列 - 详解
  • 图片加字,用我最爽
  • new day
  • How to do PhD work
  • 关于计算机语言的学习
  • 计算机视觉(opencv)——基于MediaPipe与机器学习的手势识别高效的系统
  • 2025年合肥品牌设计团队专业排行
  • 2025年国内品牌设计公司top5推荐:专业团队口碑榜单
  • 英语_中考作文_An Act of Kindness_待读
  • [题解]【MX-S10】梦熊 NOIP 2025 模拟赛 2 FeOI Round 4 T1~T2
  • 小聊一下 带圈的数字,以及罕用字的显示、字体文件的分割
  • CSP挂分记
  • 实用指南:Agent 的感知-决策-行动循环实现
  • Ubuntu 22.04 的镜像源列表
  • 关于梅特勒-托利多 称重传感器检查
  • Window 11 安装wsl
  • 深入解析:达梦数据库TDE透明加密解决方案:构建高安全数据存储体系
  • 现代Web API应用与优化建议
  • Linux 云计算核心技术:原理、组件与 K8s 实战部署 - 详解
  • 局域网---传输文件资料信息
  • ICPC2023南京个人题解
  • 从C++到wasm,并在JavaScript中调用
  • 图书馆管理系统初步设计
  • Delphi 修改单元名称后,编译报错找不到修改前的单元
  • 详细介绍:计算某字符出现次数