I had started implementation for my continual learning research since I thoroughly studied the codebase of HAT (Serra et al. 2018), a classic paper in the field. In this article, I will introduce the details of the code, aiming to help beginners understand how to write a continual learning or deep learning project and how to read others’ code. I will explain the code from outside to inside, starting from the overall logic of the project, then the main function, and finally the details of specific functions or classes.
Some parts of the code in this project are messy or not well-structured. Focus on learning the key ideas, and be aware of where the code is not well-designed.
To understand this article, you should have a good command of Python and PyTorch, as well as basic knowledge of Linux system usage, deep learning, and continual learning. Please refer to my related articles:
- Shawn’s Python Tutorial [Chinese]
- Shawn’s PyTorch Study Notes Series [Chinese]
- Linux Study Notes: From the Perspective of Research Usage [Chinese]
- A Beginner’s Guide to Continual Learning
Engineering Logic
Starting from the root directory, the src/ folder contains the actual code, which we’ll discuss later. The other files in the root directory are not directly related to the code:
LICENSE: A text file (notice it has no extension). Opening it reveals text declaring copyright, telling others how they can use this project and what’s prohibited (otherwise legal action may be taken). Most projects place such a LICENSE file in the root directory, and the declarative text doesn’t need to be written from scratch - you can look up various common options online (like CC 4.0, MIT License, etc.) and copy an appropriate one. When creating a project on GitHub, there’s an option to select a license, which will automatically create a LICENSE text file in the root directory - very convenient. This is the common practice in projects. For other contexts like articles and blogs, there are other ways to declare copyright with the same effect. For example, you can see at the end of every article on my blog there’s a statement: “©️ 2025 Pengxiang Wang. All rights reserved.”readme.md: As the name suggests, “please read me first” - it’s descriptive text written by the author about the project for users to read. You can write anything you want to tell users here, such as usage instructions and other information. GitHub project homepages also display this text by default (when creating a project on GitHub, you can choose whether to add a README file). As is well known, taking notes on computers using Markdown format is very convenient, and nowadays code projects also use it to write readmes (please learn the syntax yourself) rather than Word files, because technically it can be conveniently connected to web pages for display (for example, every article on my blog is written in Markdown format). Since this project is code for a research paper, the author mainly wrote paper information here and briefly introduced the program’s installation and running instructions.requirements.txt: A text file listing the environment dependencies required by the code, placed in the project root directory to tell others what third-party libraries need to be installed to run this project. Note this can be automatically generated by Pip or Conda, which follows certain syntax (though sometimes auto-generated files are too detailed and inappropriate, so only important packages need to be manually listed. This project doesn’t do this well - it includes packages from other unrelated environments). Here are some related commands:pip freeze >requirements.txtorconda list -e >requirements.txt: Generate environment list torequirements.txt;pip install -r requirements.txtorconda install --file requirements.txtorconda create --name XXX --file requirements.txt: Installrequirements.txtenvironment to current environment or create new environment.
.gitignore: A text file indicating which files should be ignored when uploading to GitHub (please learn the syntax yourself). Note it starts with., which represents hidden files in Linux or Mac systems, same below. These files are usually temporary, runtime non-code, non-essential files that don’t need to be uploaded for others to see. When creating a project on GitHub, you can choose whether to add.gitignore. This project ignores uploading the following files (some of which only appear after running):logs/folder: Stores saved experimental result files (see Section 6.2). It’s non-code data files that don’t need uploading;dat/folder: Stores deep learning datasets. It’s non-code data files and occupies large space, cannot be uploaded;res/folder: Stores other experimental results (see Section 6.2);old/,temp/folders: Unknown purpose, but should be temporary files based on their names;- Other unimportant files:
src/.idea/folder (PyCharm IDE configuration files);src/__pycache__/folder (Python cache files);.pycformat files (binary files compiled from.pyfiles);.DS_Storefile (Apple computer file system configuration files);.pngformat files; two script files (src/work.sh,src/immalpha.sh), possibly written by the author for testing.
Now let’s look at the code’s src/ folder. Four files serve as program entries:
run.py: The main program for running a single deep learning experiment;run_multi.py: Rewritten fromrun.pyfor running multiple deep learning experiments;work.py: Another program written by the author that can run multiple deep learning experiments;run_compression.sh: For running model compression experiments, see paper section 4.4. It’s a Linux shell script command (i.e., packaged commands combining individual commands). You can see it contains multiple commands for runningrun.py. When runningrun_compression.sh, it’s equivalent to running these commands.
The others are code for specific modules of the project:
approaches/folder: Defines various continual learning algorithms, each algorithm is a.pyfile;dataloaders/folder: Defines datasets and data preprocessing methods, each dataset is a.pyfile. Since continual learning datasets are generally constructed from existing datasets, this also defines how to construct datasets;networks/folder: Defines neural network structures, each structure is a.pyfile;plot_results.py: Tool for visualizing experimental results (see Section 6.2). This project first saves results, then visualizes them separately when needed. Visualization is separated from core program, this.pyfile is a standalone visualization program;utils.pyfile: Stores various utility functions to avoid making the main code too long, such as print functions, calculation functions, etc.
Main Program
run.py is the core of the entire project, completing the entire workflow of single deep learning experiment.
Parsing Command Line Arguments
Deep learning experiment needs to specify many things: dataset, network structure, learning algorithm, various hyperparameters, plus some detailed configurations like random seed, output format, etc. This information generally doesn’t appear in the code but serves as user-specified parameters when running the program, i.e., command-line arguments. For how to use Python command-line arguments, I have detailed discussion in this article.
The part defining command-line arguments in run.py is lines 9-20, parsing command-line arguments is lines 29-97. You can see it defines the following 7 command-line arguments:
--seed: See Section 6.3;--experiment: Manually select options through if statements during parsing. Based on command-line options, modules from thedataloaders/folder are uniformly parsed into a variable nameddataloaderto be called below;--approach: Manually select options through if statements during parsing. Based on command-line options, modules from theapproaches/folder are uniformly parsed into a variable namedapproachto be called below;--nepochs: Training epochs, an important hyperparameter that needs user to specify;--lr: Learning rate, an important hyperparameter that needs user to specify;--parameter: Reserved position for other hyperparameters (since each approach can have different hyperparameters), specifically how many and what hyperparameters depends on the specific approach definition;--output: Specifies output result file name path, see Section 6.2.
Deep Learning Pipeline
The following code corresponds to the deep learning pipeline:
- Read Dataset (lines 99-102): You can see all modules in
dataloaders/have oneget()function, called uniformly here to obtain the dataset (including training set, validation set, test set - the author’s approach is to first package into a data variable, then extract during training or testing, see lines 125, 154, etc.), as well as information like how many classes each task has, input dimensions, etc. (for defining the network); - Network Structure Instantiation (lines 104-107): You can see all modules in
networks/have oneNetclass, instantiated uniformly here as the network structure to be trained. Instantiation requires information like how many classes each task has, input dimensions, etc., from the return values of the datasetget()function above; - Define Learning Algorithm (lines 109-112): You can see all modules in
approaches/have oneApprclass, instantiated uniformly here as the learning algorithm. If you look at thisApprclass, it:- Not only defines continual learning mechanisms: so instantiation requires passing continual learning-related hyperparameters - the author’s approach is to pass the entire
args, for example, lines 27-31 in/approaches/hat.pyparseargs.parameteras \(\lambda\) and \(s_\text{max}\) hyperparameters, so users know--parameterrepresents these two hyperparameters when passing--parameter; - Also bundles optimizer and loss function together: so instantiation requires specifying optimizer hyperparameters, training epochs, etc., all in the command-line arguments
args; - Note that the
Apprclass also bundles the network as its property, so the program no longer shows the networknetvariable from here on;
- Not only defines continual learning mechanisms: so instantiation requires passing continual learning-related hyperparameters - the author’s approach is to pass the entire
- Training (lines 148-149): Uniformly calls the
Apprclass’strainmethod, which accepts the training and validation sets extracted above, plus information about which task number. Note no need to pass the network - it’s insideAppr, and this training function essentially modifies and updates it; - Testing (lines 152-159): Uniformly calls the
Apprclass’sevalmethod, which accepts the test set extracted above, plus information about which task number. Note there’s an outer u loop here to test all tasks. Still no need to pass the network.
Dataloaders
The project defines datasets, preprocessing methods, and code for constructing continual learning tasks in the dataloaders/ folder, with each dataset being a .py file. Each file defines only one get() function. We take the classic and simple pmnist.py (Permuted MNIST) as an example. The get() function returns the following content:
- Dataset variable
data: A nested dictionary, i.e., dictionary values are also dictionaries- First layer (line 11) for tasks, keys are task IDs;
- There’s an additional key ‘ncla’ storing the sum of classes across all tasks (line 80);
- Second layer (line 34) for task metadata, including:
- Task name ‘name’: Author names it ‘pmnist-task ID’ (line 35);
- Number of classes ‘ncla’: In Permuted MNIST, each task has a fixed 10 classes (line 36);
- Training, validation, test data ‘train’, ‘valid’, ‘test’;
- Third layer (line 39) inside datasets ‘train’, ‘valid’, ‘test’:
- Input ‘x’: A large Tensor - actually this project doesn’t use PyTorch’s Dataloader for model input data, the author manually divides batches, for example, see lines 81-82 of
approaches/sgd.py; - Labels ‘y’: A large Tensor;
- Input ‘x’: A large Tensor - actually this project doesn’t use PyTorch’s Dataloader for model input data, the author manually divides batches, for example, see lines 81-82 of
- First layer (line 11) for tasks, keys are task IDs;
- Number of classes per task
taskcla(line 78): A list, for Permuted MNIST it’s fixed as[10,...,10]; - Input dimensions
size(line 13): Directly defined as constant[1,28,28].
From the get() function parameters, you can see the author doesn’t provide much choice for users. A Permuted MNIST dataset is basically fixed, users can only set:
seedandfixed_ordercontrolling random seed (for the permutation operation): See Section 6.3;pc_valid: Validation set data proportion.
Now let’s see how the dataset in the third layer of the data variable is constructed step by step:
- First download the original MNIST dataset to the
datvariable throughtorchvision.datasets(lines 27-30), then parse step by step intodata; - Use the original dataset
datto construct a Dataloader with batch=1 (line 38, probably for convenient loop writing), perform permute operation on each image (lines 41-43), add todata. Note that the data parts ‘x’,‘y’ indataare now lists; - Save this list data (lines 20-21, 51-52) for direct reading later (lines 55-67). The reason is that the previous per-image processing operation was too slow, even saving and reading saves more time;
- Convert lists to Tensors that can be input to nn.Module (lines 48-52);
- Note the above only divided training and test sets, we still need to divide validation set ‘valid’ from training set ‘train’ (lines 70-73). Note that in
pmnist.pythe author directly copied the training set as validation set, meaning model selection is based on best performance on training set. I don’t know if the author is lazy or something, but this easily leads to overfitting, especially for smaller datasets.
Other datasets are similar, I’ll briefly introduce them, focusing on differences:
mnist2.py: 2-task Split MNIST dataset. Since it doesn’t involve per-image permutation operations, the author didn’t design saving/reading;cifar.py: 10-task Split CIFAR dataset, first 5 tasks use CIFAR-10 dataset with 2 classes each; last 5 tasks use CIFAR-100 dataset with 20 classes each. For this dataset the author finally randomly divided training set for validation (lines 79-90), proportionpc_validspecified by user inget()function parameters;mixture.py: Mixture of many datasets, 8 tasks total, each task is one dataset: CIFAR10, CIFAR100, MNIST, SVHN, FashionMNIST, TrafficSigns, Facescrub, notMNIST (not in order, but fixed random shuffle). Some datasets here aren’t intorchvision.datasets, so the author defined corresponding dataset classes below (equivalent to custom Dataset classes).
Let’s briefly look at how the author defines custom datasets in mixture.py:
- FashionMNIST: Actually
torchvision.datasetsdoes have this dataset, probably the author encountered bugs when using its API and rewrote this one (lines 249-257); - TrafficSigns, FaceScrub, NotMNIST: All inherit from
Datasetclass, following custom methods explained in this article. In the__init__()function, read entire dataset from local files intodataandlabelsvariables, then directly index in__getitem__(). Similar to MNIST,downloadandtrainparameters control downloading and selecting training vs test sets. Download operations require complex network communication and error correction mechanisms, also packaged in adownload()function. The if statement ontrain=Truedetermines reading training or test datasets.
Networks
The code in the networks/ folder defines network structures, with each model being a .py file. Each file defines one nn.Module class named Net. Writing these nn.Module classes is fundamental to deep learning, please refer to this article. Deep learning projects are similar, mostly using MLP, AlexNet, ResNet and other networks with similar nn.Module structures.
For this continual learning project, we need to look at how network structures adapt to continual learning scenarios or methods. There are two points:
- How to handle new classes from new tasks (output heads) in continual learning scenarios;
- For architecture-based continual learning approaches like HAT involving network structure modifications, how to modify.
We take the simpler MLP as an example to see how the author handles this, see mlp.py and mlp_hat.py.
For question 1, the author predefines all output heads for all tasks (lines 19-21) and outputs concatenated results of all output heads during forward pass (lines 30-32); rather than dynamically adding output heads for each new task, because this project is doing a fixed experiment and this method is more convenient. When and which task uses which output head is defined in the training and testing methods of the continual learning approach class Approach, see below. Additionally, even for datasets like Permuted MNIST where all tasks have the same classes, each task gets its own output head rather than sharing the same output head.
For question 2, each network structure needs a modified version inevitably, for example in this project mlp.py derives mlp_hat.py, alexnet.py derives alexnet_hat.py, alexnet_pathnet.py, alexnet_progressive.py, alexnet_lfl.py, etc. These are all architecture-based approaches involving network structure modifications. When using architecture-based approaches, we need to select the corresponding network structures.
Let’s look at mlp_hat.py, which defines the network structure for the HAT method, i.e., MLP with masks added. mlp.py provides a 3-hidden-layer MLP with the same number of neurons per layer; while mlp_hat.py provides MLPs with 1, 2, or 3 hidden layers (specified by the nlayers parameter in __init__()). Taking 3-layer MLP as example:
- In
__init__(), compared to regular MLP, there are three additional layersefc1,efc2,efc3(lines 20, 23, 26), i.e., task embeddings on each layer of neurons. They are implemented asnn.Embedding, a class used to represent a set of model parameters of the same length (called embeddings). The first argumentnum_embeddingsis the number of embeddings, the second argumentembedding_dimis the embedding length. This class is generally used for word vector representation (a vocabulary, each word represented by an embedding), but here the author uses it to represent task embedding vectors for different tasks. Note eachefc1,efc2,efc3are each predefined embeddings for all tasks’ layers, not individual tasks. - Through Eq. (1) in the paper, task embeddings are multiplied by scale parameter \(s\) then passed through gate function (Sigmoid) to get mask. This process is packaged into a mask function (lines 65-71);
- In the
forward()method, compared to regular MLP, there’s an additional mask applying step: each neuron activation is multiplied by its mask (lines 53, 56, 59). Note thatforward()not only accepts input \(x\), but also includes task \(t\) and \(s\) for calculating mask. That is, the model provides interface here for training and testing to determine which task is running. ThisNetclass is predefined for all tasks, then distinguishes tasks through this task argument.
There are other details. There are alexnet_hat.py and alexnet_hat_test.py being different in lines 43-50: whether to apply normalized initialization to task embeddings. This is related to model compression experiments. Additionally, some Net classes for HAT have get_view_for() functions, which uses PyTorch’s flatten operation (view) to flatten a mask Tensor. This is a utility function for the HAT algorithm, called in HAT’s training method, see below.
Approaches
The code in the approaches/ folder defines various continual learning algorithms, with each algorithm being a .py file. Each file defines one class named Appr, which all have training and testing functions train, eval as well as defined optimizer self.optimizer and loss function self.criterion. Let’s first look at the fine-tuning algorithm sgd.py without continual learning anti-forgetting mechanisms to understand the basic training and testing pipeline details:
- Loss function is defined in
self.criterion. For simple approaches likesgd.py, it is directly specified asnn.CrossEntropyLoss()in__init__(); for those likehat.py, customcriterion()method is defined in the class; - Optimizer is defined in
self.optimizer, constructed using optimizer hyperparameters likelrpassed to__init__()(there is also a method_get_optimizer(), probably because code for defining optimizer was too long); - The
trainmethod trains a task, with the core in line 40 calling thetrain_epochmethod defined in line 72, which is to train one epoch of task \(t\). Its pipeline is no different from ordinary deep learning training processes. The only thing to note is it trains task \(t\) (passed as method argument), reflected in line 88 result truncating output head \(t\); other parts of the code are all doing one thing: learning rate scheduling, which is a deep learning training technique. - The
evalmethod tests the current model’s (after training task \(t\)) accuracy on one task’s testset. Noterun.pyhas an outer loop testing all tasks. Its pipeline is no different from ordinary deep learning testing, returning test loss and accuracy in the end.
Next let’s look at the HAT algorithm hat.py, which is modified based on sgd.py and requires the passed model to be the ’_hat’ version:
- The
criterion()method (line 196) defines the mask sparsity regularization term (Eq. (5) in the HAT paper) and classification loss. This regularization term merges masks from previous tasks \(<t\), stored inself.mask_pre. That’s why the loss functioncriterion()not only has arguments model outputsoutputsand labelstargets, but alsomasks. Note the if statement distinguishes the first task case, which has no previous tasks, so no regularization term; - The core
train_epochmethod calls the HAT-versionforward()method with mask and compute the loss defined by the abovecriterion(), then backpropagates to compute gradients. Before updating model parameters:- First filter parameters masked by previous tasks to prevent their updates. This is implemented by setting gradients to 0. In line 135, the
mask_backthat gradients multiply by is the inverse of previous tasks’mask_prein the previous step (lines 97-102); - Then apply the gradient compensation mechanism in Sec. 2.5 in the HAT paper, by multiplying gradients by a compensation factor;
- After updating, clamp the trained task embedding to a smaller range in Sec. 2.5 in the HAT paper.
- First filter parameters masked by previous tasks to prevent their updates. This is implemented by setting gradients to 0. In line 135, the
- The
train()method for training one task includes the process of converting task embedding to mask (lines 87-90) and compute previous tasks’mask_pre(lines 91-95). Float type is used to store both task embedding and binary mask. eval()method adds output for regularization terms. No big difference fromsgd.py.
Besides the above algorithms, other architecture-based methods similar to HAT like PathNet, Progressive NN are also implemented for baseline comparison, as well as methods in other categories like EWC, IMM, LwF, etc. Their differences lie in training, testing, loss functions, optimizers, etc.
Note that among these methods, only HAT has a _test version, meaning it is paid more attention to. We can see the code inside is more complete - it defines logs, logpath in __init__() (in hat_test.py lines 29-59, parsed from command-line argument --parameter). At the end of run.py, more detailed test results from HAT is stored for processing (see Section 6.2).
Other Details
Print Debug Information
The author interspersed many logging messages throughout the code, most of which use print(). Many are packed in functions starting with print in utils.py to avoid making the main function too long:
- Lines 23-27: Print command-line arguments for user confirmation;
- Line 102: Print dataset information and continual learning task information;
- Line 107: Print model information;
- Lines 110-111: Print loss function and optimizer information;
- Lines 119, 176: Print training process and time information;
- Lines 157, 166-174: Print test results;
- Line 162: Print result saving information.
There are many other details in the code that need our attention. They are often very important things.
Processing Experimental Results
Throughout the code, the author saved the following experimental results to local files:
--outputargument: The path it defines stores the test accuracy triangular matrixacc(lines 161-163), where row \(t\) column \(u\) represents the model’s accuracy on task \(u\) after training the \(t\)-th task. This is the main metric for continual learning (see my article about continual learning metrics). If this is not specified, the author defines a default file path (lines 21-22), which you can see is named using meta-information like--experiment,--approach, etc., to distinguish different experiments;logpathargument ofApprclass: Only appears inhat_test.py, if included in--parameter,run.pysaves some result information using pickle after line 178 (pickle is a built-in Python library that can completely save and restore any Python variables), restored and called byplot_results.pywhen plotting is needed.
Random Seed
Global seed setting is in run.py (lines 31-33). There are also some local random variable seed settings, such as task order in line 18 of pmnist.py.
Using GPU
Deep learning experiments of this scale have to be run on GPU. In the code:
- Lines 34-35 check if GPU can be used, force exit if not;
- Line 106 puts the model on GPU;
- Lines 142-145, 154-155 put datasets on GPU.
Experimental Details
There are some details in the code that are often done in deep learning experiments:
- Data normalization: First manually calculate mean and variance, then apply
transforms.Normalize()transformation when constructing datasets, for example lines 24-30 inpmnist.py;
The following are some deep learning training techniques used in the code:
- Learning rate scheduling: If validation loss doesn’t decrease for consecutive
lr_patienceepochs, reduce learning rate a bit: divide by lr_factor. - Dropout layers to prevent overfitting: For example line 15 in
mlp.py.
Summary
This is the complete picture of a research-oriented deep learning project. After reading the code, we can also notice some shortcomings, such as:
- Calling formats not very consistent;
- Too few things users can specify;
- Some hyperparameters are buried too deep (for example
--parameterparsing rules, especiallyappr.logs), users must read the code very carefully to know how to use.
Of course, research-oriented code serves the research, this is okay as long as it’s convenient for the author himself to understand. This also determines that the author don’t have to write more robust and perfect code, because it’s not a product to be presented to users. The author just write what is needed, and for us we just need to understand its core logic and ideas, no need to pursue details.