Add a new classifier
Code for this section:
core/model/abstract_model.py
core/model/meta/*
core/model/metric/*
core/model/pretrain/*
We need to select one representative method from matric based
methods, meta learning
methods and fine-tuning
methods, respectively, and describe how to add new methods of the three categories.
Before this,we need to introduce a parent class of all methods: abstract_model
.
class AbstractModel(nn.Module):
def __init__(self,...)
# base info
@abstractmethod
def set_forward(self,):
# inference phase
pass
@abstractmethod
def set_forward_loss(self,):
# training phase
pass
def forward(self, x):
out = self.emb_func(x)
return out
def train(self,):
# override super's function
def eval(self,):
# override super's function
def _init_network(self,):
# init all layers
def _generate_local_targets(self,):
# formate the few shot labels
def split_by_episode(self,):
# split batch by way, shot and query
def reset_base_info(self,):
# change way, shot and query
__init__
:init func,used to initialize the few shot learning settings like way, shot, query and other train parameters.set_forward
:used to be called in inference phase, return classifier’s output and accuracy.set_forward_loss
:used to be called in training phase, return classifier’s output, accuracy and loss.forward
:override the forward functionforward
ofModule
inpytorch
, return the ouput ofbackbone
.train
:override the forward functiontrain
ofModule
inpytorch
, used to unfix theBatchNorm
layer parameter.eval
:override the forward functiontest
ofModule
inpytorch
, used to fix theBatchNorm
layer parameter._init_network
:used to initialize all network parameters._generate_local_targets
:used to generatetarget
for few shot learning.split_by_episode
:used to split batch in shape:[episode_size, way, shot+query, …]. It has several split modes.reset_base_info
:used to change the few shot learning settings.
New methods must override the set_forward
and set_forward_loss
functions, and all other functions can be called according to the needs of the implemented methods.
Note that in order for the newly added method to be called through reflection, add a line to the __init__.py
file in the directory of the corresponding method type:
from NewMethodFileName import *
metric based
Using DN4
as an example, we will describe how to add a new metric based classifier
to LibFewShot
.
metric based
methods have a common parent class MetricModel
, which is inherited from AbstractModel
.
class MetricModel(AbstractModel):
def __init__(self,):
super(MetricModel, self).__init__()
@abstractmethod
def set_forward(self, *args, **kwargs):
pass
@abstractmethod
def set_forward_loss(self, *args, **kwargs):
pass
def forward(self, x):
out = self.emb_func(x)
return out
Since the pipeline
of metric based
methods are mostly simple, MetricModel
just inherites AbstractModel
and no other changes are made.
build model
First, create DN4
model class, add file dn4.py
under core/model/metric/
: (this code have some differences with source code)
class DN4(MetricModel):
def __init__(self, n_k=3, **kwargs):
# base info
super(DN4Layer, self).__init__(**kwargs)
self.n_k = n_k
self.loss_func = nn.CrossEntropyLoss()
def set_forward(self, batch):
# inference phase
"""
:param batch: (images, labels)
:param batch.images: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query),C,H,W]
:param batch.labels: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query), ]
:return: net output and accuracy
"""
image, global_target = batch
image = image.to(self.device)
episode_size = image.size(0) // (
self.way_num * (self.shot_num + self.query_num)
)
feat = self.emb_func(image)
support_feat, query_feat, support_target, query_target = self.split_by_episode(
feat, mode=2
)
t, wq, c, h, w = query_feat.size()
_, ws, _, _, _ = support_feat.size()
# t, wq, c, hw -> t, wq, hw, c -> t, wq, 1, hw, c
query_feat = query_feat.view(
t, self.way_num * self.query_num, c, h * w
).permute(0, 1, 3, 2)
query_feat = F.normalize(query_feat, p=2, dim=2).unsqueeze(2)
# t, ws, c, h, w -> t, w, s, c, hw -> t, 1, w, c, shw
support_feat = (
support_feat.view(t, self.way_num, self.shot_num, c, h * w)
.permute(0, 1, 3, 2, 4)
.contiguous()
.view(t, self.way_num, c, self.shot_num * h * w)
)
support_feat = F.normalize(support_feat, p=2, dim=2).unsqueeze(1)
# t, wq, w, hw, shw -> t, wq, w, hw, n_k -> t, wq, w
relation = torch.matmul(query_feat, support_feat)
topk_value, _ = torch.topk(relation, self.n_k, dim=-1)
score = torch.sum(topk_value, dim=[3, 4])
output = score.view(episode_size * self.way_num * self.query_num, self.way_num)
acc = accuracy(output, query_target)
return output, acc
def set_forward_loss(self, batch):
# training phase
"""
:param batch: (images, labels)
:param batch.images: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query),C,H,W]
:param batch.labels: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query), ]
:return: net output, accuracy and train loss
"""
image, global_target = batch
image = image.to(self.device)
episode_size = image.size(0) // (
self.way_num * (self.shot_num + self.query_num)
)
emb = self.emb_func(image)
support_feat, query_feat, support_target, query_target = self.split_by_episode(
emb, mode=2
)
t, wq, c, h, w = query_feat.size()
_, ws, _, _, _ = support_feat.size()
# t, wq, c, hw -> t, wq, hw, c -> t, wq, 1, hw, c
query_feat = query_feat.view(
t, self.way_num * self.query_num, c, h * w
).permute(0, 1, 3, 2)
query_feat = F.normalize(query_feat, p=2, dim=2).unsqueeze(2)
# t, ws, c, h, w -> t, w, s, c, hw -> t, 1, w, c, shw
support_feat = (
support_feat.view(t, self.way_num, self.shot_num, c, h * w)
.permute(0, 1, 3, 2, 4)
.contiguous()
.view(t, self.way_num, c, self.shot_num * h * w)
)
support_feat = F.normalize(support_feat, p=2, dim=2).unsqueeze(1)
# t, wq, w, hw, shw -> t, wq, w, hw, n_k -> t, wq, w
relation = torch.matmul(query_feat, support_feat)
topk_value, _ = torch.topk(relation, self.n_k, dim=-1)
score = torch.sum(topk_value, dim=[3, 4])
output = score.view(episode_size * self.way_num * self.query_num, self.way_num)
loss = self.loss_func(output, query_target)
acc = accuracy(output, query_target)
return output, acc, loss
__init__
function call super.__init__()
to initialize few shot learning settings, and initialize DN4
method’s super parameter n_k
.
Please notice line 19-27,65-73
, these lines aim to split batch feature vectors into correct shape that fit few shot learning setting. In deatils, in order to maximize the useage of computing resources, we first get all images’ feature vectors, and then divide the feature vectors into support set
, suery set
. 29-50
lines are used to calculate DN4 method’s output. Finally, the ouput shape of set_forward
is $output.shape:[episode_sizewayquery,way],acc:float$, the output shape of set_forward_loss
is $output.shape:[episode_sizewayquery,way], acc:float, loss:tensor$. Where output
needs to be cabculated according to the method, acc
can call the accuracy
function provided by LibFewShot
and input output, target
to get the classification accuracy.While loss
can use the loss function that the user initializes at the start of the method, used in set_forward_loss
to get the classification loss.
The metric based method simply needs to process the input images into the corresponding form according to the method, and then begin the training.
meta learning
Using MAML
as an example, we will describe how to add a new meta learning classifier
to LibFewShot
.
meta learning
methods have a common parent class MetaModel
, which is inherited from AbstractModel
.
class MetaModel(AbstractModel):
def __init__(self,):
super(MetaModel, self).__init__(init_type, ModelType.META, **kwargs)
@abstractmethod
def set_forward(self, *args, **kwargs):
pass
@abstractmethod
def set_forward_loss(self, *args, **kwargs):
pass
def forward(self, x):
out = self.emb_func(x)
return out
@abstractmethod
def set_forward_adaptation(self, *args, **kwargs):
pass
def sub_optimizer(self, parameters, config):
kwargs = dict()
if config["kwargs"] is not None:
kwargs.update(config["kwargs"])
return getattr(torch.optim, config["name"])(parameters, **kwargs)
The meta-learning method adds two new functions, set_forward_adaptation
and sub_optimizer
. set_forward_adaptation
is the logic that deals with the need to fine-tune the network during the classification process, and sub_optimizer
is to provide a new sub-optimizer for the fine-tuning.
build model
First, create MAML
model class, add file maml.py
under core/model/meta/
: (this code have some differences with source code)
from ..backbone.utils import convert_maml_module
class MAML(MetaModel):
def __init__(self, inner_param, feat_dim, **kwargs):
super(MAML, self).__init__(**kwargs)
self.loss_func = nn.CrossEntropyLoss()
self.classifier = nn.Sequential(nn.Linear(feat_dim, self.way_num))
self.inner_param = inner_param
convert_maml_module(self)
def forward_output(self, x):
"""
:param x: feature vectors, shape: [batch, C]
:return: probability of classification
"""
out1 = self.emb_func(x)
out2 = self.classifier(out1)
return out2
def set_forward(self, batch):
"""
:param batch: (images, labels)
:param batch.images: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query),C,H,W]
:param batch.labels: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query), ]
:return: net output, accuracy and train loss
"""
image, global_target = batch # unused global_target
image = image.to(self.device)
support_image, query_image, support_target, query_target = self.split_by_episode(
image, mode=2
)
episode_size, _, c, h, w = support_image.size()
output_list = []
for i in range(episode_size):
episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w)
episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w)
episode_support_target = support_target[i].reshape(-1)
self.set_forward_adaptation(episode_support_image, episode_support_target)
output = self.forward_output(episode_query_image)
output_list.append(output)
output = torch.cat(output_list, dim=0)
acc = accuracy(output, query_target.contiguous().view(-1))
return output, acc
def set_forward_loss(self, batch):
"""
:param batch: (images, labels)
:param batch.images: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query),C,H,W]
:param batch.labels: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query), ]
:return: net output, accuracy and train loss
"""
image, global_target = batch # unused global_target
image = image.to(self.device)
support_image, query_image, support_target, query_target = self.split_by_episode(
image, mode=2
)
episode_size, _, c, h, w = support_image.size()
output_list = []
for i in range(episode_size):
episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w)
episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w)
episode_support_target = support_target[i].reshape(-1)
self.set_forward_adaptation(episode_support_image, episode_support_target)
output = self.forward_output(episode_query_image)
output_list.append(output)
output = torch.cat(output_list, dim=0)
loss = self.loss_func(output, query_target.contiguous().view(-1))
acc = accuracy(output, query_target.contiguous().view(-1))
return output, acc, loss
def set_forward_adaptation(self, support_set, support_target):
lr = self.inner_param["lr"]
fast_parameters = list(self.parameters())
for parameter in self.parameters():
parameter.fast = None
self.emb_func.train()
self.classifier.train()
for i in range(self.inner_param["iter"]):
output = self.forward_output(support_set)
loss = self.loss_func(output, support_target)
grad = torch.autograd.grad(loss, fast_parameters, create_graph=True)
fast_parameters = []
for k, weight in enumerate(self.parameters()):
if weight.fast is None:
weight.fast = weight - lr * grad[k]
else:
weight.fast = weight.fast - lr * grad[k]
fast_parameters.append(weight.fast)
The most important parts of MAML
are the two parts. The first part is the convert_maml_module
function on line 10
, which changes all the layers in the network to MAML format layers for easy parameter updating. The other part is the set_forward_adaptation
function, which updates the fast parameters of the network. MAML
is a common meta learning method, so we will use MAML as an example to show how to add meta learning
method to LibFewShot
.
fine-tuning
Using Baseline
as an example, we will describe how to add a new fine-tuning classifier
to LibFewShot
.
fine-tuning
methods have a common parent class FinetuningModel
, which is inherited from AbstractModel
.
class FinetuningModel(AbstractModel):
def __init__(self,):
super(FinetuningModel, self).__init__()
# ...
@abstractmethod
def set_forward(self, *args, **kwargs):
pass
@abstractmethod
def set_forward_loss(self, *args, **kwargs):
pass
def forward(self, x):
out = self.emb_func(x)
return out
@abstractmethod
def set_forward_adaptation(self, *args, **kwargs):
pass
def sub_optimizer(self, model, config):
kwargs = dict()
if config["kwargs"] is not None:
kwargs.update(config["kwargs"])
return getattr(torch.optim, config["name"])(model.parameters(), **kwargs)
The main aim of finetuning method train phase is to train a good feature extractor, while using the few shot learning setting in the test phase to finetune the model by the support set. Another method is to use the training setting of few shot learning to fine-tune the whole model after the feature extractor is trained. In line with the meta learning
method, a set_forward_adaptation
abstract function is added to handle the forward process during test phase. In addition, since there are some fine-tuning
methods in which the classifier needs to be trained, a sub_optimizer
method is added, passing in the parameters to be optimized and the optimized configuration parameters, and returning the optimizer for easy call.
build model
First, create Baseline
model class, add file baseline.py
under core/model/finetuning/
: (this code have some differences with source code)
class Baseline(FinetuningModel):
def __init__(self, feat_dim, num_class, inner_param, **kwargs):
super(Baseline, self).__init__(**kwargs)
self.feat_dim = feat_dim
self.num_class = num_class
self.inner_param = inner_param
self.classifier = nn.Linear(self.feat_dim, self.num_class)
self.loss_func = nn.CrossEntropyLoss()
def set_forward(self, batch):
"""
:param batch: (images, labels)
:param batch.images: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query),C,H,W]
:param batch.labels: shape: [episodeSize*way*(shot*augment_times+query*augment_times_query), ]
:return: net output, accuracy and train loss
"""
image, global_target = batch
image = image.to(self.device)
feat = self.emb_func(image)
support_feat, query_feat, support_target, query_target = self.split_by_episode(feat, mode=1)
episode_size = support_feat.size(0)
support_target = support_target.reshape(episode_size, self.way_num, self.shot_num)
query_target = query_target.reshape(episode_size, self.way_num, self.query_num)
output_list = []
for i in range(episode_size):
output = self.set_forward_adaptation(support_feat, support_target, query_feat)
output_list.append(output)
output = torch.stack(output_list, dim=0)
acc = accuracy(output, query_target)
return output, acc
def set_forward_loss(self, batch):
"""
:param batch: (images, labels)
:param batch.images: shape: [batch_size,C,H,W]
:param batch.labels: shape: [batch_size, ]
:return: net output, accuracy and train loss
"""
image, target = batch
image = image.to(self.device)
target = target.to(self.device)
feat = self.emb_func(image)
output = self.classifier(feat)
loss = self.loss_func(output, target)
acc = accuracy(output, target)
return output, acc, loss
def set_forward_adaptation(self, support_feat, support_target, query_feat):
"""
support_feat: shape: [way_num, shot_num, C]
support_target: shape: [way_num*shot_num, ]
query_feat: shape: [way_num, shot_num, C]
"""
classifier = nn.Linear(self.feat_dim, self.way_num)
optimizer = self.sub_optimizer(classifier, self.inner_param["inner_optim"])
classifier = classifier.to(self.device)
classifier.train()
support_size = support_feat.size(0)
for epoch in range(self.inner_param["inner_train_iter"]):
rand_id = torch.randperm(support_size)
for i in range(0, support_size, self.inner_param["inner_batch_size"]):
select_id = rand_id[i : min(i + self.inner_param["inner_batch_size"], support_size)]
batch = support_feat[select_id]
target = support_target[select_id]
output = classifier(batch)
loss = self.loss_func(output, target)
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
output = classifier(query_feat)
return output
The set_forward_loss
is the same as the classical supervised classification method, while the set_forward
is the same as the meta learning
method. The contents of the set_forward_adaptation
function is the main part of the test phase. The feature vectors of support set
and query set
extracted by backbone is used to train a classifier, and the feature vectors of query set
is used to classify by the classifier.