Guidance
Please follow the guidance to get started.
You can download our framework in the
$\textbf{Download}$ section.
You need to edit
trainnet.py and
server.py. The structure of these core functions is shown below.
server.init_server_func
↓
server.server_aggregate_func
↑
(experiments.local_train_net)
↑
trainnet.train_net_func
general structure of our framework
1
trainnet.train_net_func:
for you to implement your local training algorithm.
We use
**kwargs so that you can add extra parameters in this function.
import arguments
args = arguments.get_args()
def train_net_func(net_id, net, train_dataloader, test_dataloader, epochs, lr, optimizer, device="cpu",**kwargs):
# add your train_net_func code here
extra_info = []
# return extra information if necessary; it's okay to return nothing
return extra_info
Here you only have to handle the case of every client, and
experiments.local_train_net will coordinate each client's situation in each round.
The return of
train_net_func will be gathered by
local_info and aggregated to
global_info as a python dictionary, combined with
net_id.
2
server.init_server_func:
for you to define extra variables if necessary, otherwise you can skip this part.
Beware that in precaution of memory overuse, you are not adviced to use deepcopy in the 'for' loop. You can only use the parameters listed below and are unable to add new parameters.
import arguments
import experiments
args = arguments.get_args()
# the first part is to initialize your server-side variables before local train. the variables therefore are kept within the function to avoid memory overuse
def init_server_func(args,nets, local_model_meta_data, layer_type,global_models, global_model_meta_data, global_layer_type,global_model,global_para):
print('use this function to initialize the server')
extra_info = []
return extra_info
If you have variable outputs, pack them in a single variable in the form of a list, etc.
3
server.server_aggregate_func:
Use this function for the server side to aggregate. You will receive the extra_info returned by
server.init_server_func. And
experiments.local_train_net will also be called here. To see how this function calls train_net_func, please see the examples in the
$\mathbf{Demo}$ section.
You cannot define extra parameters here, but you can pack them in extra_info, which will be returned by
server.init_server_func.
Please note that calling local_train_net requires a parameter 'alg', this is only for us to know your algorithm name, and make sure that it cannot be the same as the five examples in our demo.
def server_aggregate_func(nets, selected, args, net_dataidx_map, test_dl_global, device, global_model, global_para, extra_info=None):
print('use this function to perform server aggregation in every communication round')
4
Function 2 and function 3 will be directly called in the main function, and function 1 will be called in function 3. For your reference, here is how it works:
if __name__ == '__main__':
......
else: #if not demo alg
...... = init_nets(......)
......
# here implement init_server_func
extra_info = server.init_server_func(......)
......
for round in range(args.comm_round):
......
# here implement server_aggregate_func
server.server_aggregate_func(......)
5
run.sh and
accuracy.xlsx:
a script for running the code. To ensure uniformity, please write the script according to the parameters in
accuracy.xlsx.
An example shell script is like this:
python experiments.py --model=simple-cnn \
--dataset=cifar10 \
--alg=fedavg\
--lr=0.01 \
--batch-size=64 \
--epochs=10 \
--n_parties=10 \
--rho=0.9 \
--comm_round=200 \
--partition=noniid-labeldir \
--beta=0.5\
--device='cpu'\
--datadir='./data/' \
--logdir='./logs/' \
--noise=0\
--init_seed=0\
--sample=1
6
Submit the results in proper format. Please select the average of the results from
the last 10 rounds as the outcome for each experiment. You can use the example code below for quick extraction.
import numpy as np
f = open('log.txt', 'r')
pre = []
for st in f.readlines():
if "Global Model Test" in st:
acc = eval(st[-9:])
pre.append(acc)
f.close()
print(np.mean(pre[-10:]))
Please ensure that by editing
trainnet.py and
server.py only, your algorithm can successfully run.
To see how an algorithm can be implemented into our framework, see the examples in the
$\mathbf{Demo}$ section.
Demo
In the framework we implemented five algorithms:
FedAvg,
FedProx,
SCAFFOLD,
FedNova and
MOON*.
To run our demo, use
sh run.sh
Information of parameter
Parameter |
Description |
model |
The model architecture. Options: simple-cnn,
vgg, resnet, mlp. Default =mlp. |
dataset |
Dataset to use.Options: mnist, cifar10, fmnist, svhn, generated, femnist, a9a, rcv1, covtype. Default=mnist. |
alg |
The training algorithm. Options: fedavg, fedprox, scaffold, fednova, moon. Only for running our demo algorithms. Default = none. |
lr |
Learning rate for the local models, default = 0.01. |
batch-size |
Batch size, default = 64 |
epochs |
Number of local training epochs, default = 5. |
n_parties |
Number of parties, default = 2. |
mu |
The proximal term parameter for FedProx, default = 0.001. |
rho |
The parameter controlling the momentum SGD, default = 0. |
comm_round |
Number of communication rounds to use, default = 50. |
partition |
The partition way. Options: homo, noniid-labeldir, noniid-#label1(or 2, 3, ..., which means the fixed number of labels each party owns), real, iid-diff-quantity. Default = homo. |
beta |
The concentration parameter of the Dirichlet distribution for heterogeneous partition, defalut = 0.5. |
device |
Specify the device to run the program, default = cpu. |
datadir |
The path of the dataset, default = ./data/. |
logdir |
The path to store the logs, default = ./logs/. |
noise |
Ratio of parties that participate in each communication round, default = 1. |
sample |
Ratio of parties that participate in each communication round, default = 1. |
init_seed |
The initial seed, default =0. |
(
run.sh documents some example settings. For more information, please refer to our
github repository.)
You will see the output in
NIID-Bench/logs.
We here present the examples of fedavg(which is simple), scaffold and fednova(which are more complicated). Please compare them and see what's added to scaffold and fednova, and you will understand how we handle the cases.
1 fedavg:
train_net_fedavg ▼
def train_net_fedavg(net_id, net, train_dataloader, test_dataloader, epochs, lr, optimizer, device="cpu"):
logger.info('Training network %s' % str(net_id))
criterion = nn.CrossEntropyLoss().to(device)
cnt = 0
if type(train_dataloader) == type([1]):
pass
else:
train_dataloader = [train_dataloader]
for epoch in range(epochs):
epoch_loss_collector = []
for tmp in train_dataloader:
for batch_idx, (x, target) in enumerate(tmp):
x, target = x.to(device), target.to(device)
optimizer.zero_grad()
x.requires_grad = True
target.requires_grad = False
target = target.long()
out = net(x)
loss = criterion(out, target)
loss.backward()
optimizer.step()
cnt += 1
epoch_loss_collector.append(loss.item())
epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
net.to('cpu')
logger.info(' ** Training complete **')
server_aggregate_fedavg▼
def server_aggregate_fedavg(nets, selected, args, net_dataidx_map, test_dl_global, device, global_model, global_para, extra_info=None):
experiments.local_train_net('fedavg', nets, selected, args, net_dataidx_map, test_dl_global, device, 'extend')
# update global model
total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
for idx in range(len(selected)):
net_para = nets[selected[idx]].cpu().state_dict()
if idx == 0:
for key in net_para:
global_para[key] = net_para[key] * fed_avg_freqs[idx]
else:
for key in net_para:
global_para[key] += net_para[key] * fed_avg_freqs[idx]
global_model.load_state_dict(global_para)
2 scaffold:
train_net_scaffold ▼
def train_net_scaffold(net_id, net, train_dataloader, test_dataloader, epochs, lr, optimizer, device="cpu",global_model=None, c_local=None, c_global=None):
logger.info('Training network %s' % str(net_id))
c_local = c_local[net_id]#to solve the problem moving this from local_train_net_scaffold
criterion = nn.CrossEntropyLoss().to(device)
cnt = 0
if type(train_dataloader) == type([1]):
pass
else:
train_dataloader = [train_dataloader]
c_local.to(device)
c_global.to(device)
global_model.to(device)
c_global_para = c_global.state_dict()
c_local_para = c_local.state_dict()
for epoch in range(epochs):
epoch_loss_collector = []
for tmp in train_dataloader:
for batch_idx, (x, target) in enumerate(tmp):
x, target = x.to(device), target.to(device)
optimizer.zero_grad()
x.requires_grad = True
target.requires_grad = False
target = target.long()
out = net(x)
loss = criterion(out, target)
loss.backward()
optimizer.step()
net_para = net.state_dict()
for key in net_para:
net_para[key] = net_para[key] - args.lr * (c_global_para[key] - c_local_para[key])
net.load_state_dict(net_para)
cnt += 1
epoch_loss_collector.append(loss.item())
epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
c_new_para = c_local.state_dict()
c_delta_para = copy.deepcopy(c_local.state_dict())
global_model_para = global_model.state_dict()
net_para = net.state_dict()
for key in net_para:
c_new_para[key] = c_new_para[key] - c_global_para[key] + (global_model_para[key] - net_para[key]) / (cnt * args.lr)
c_delta_para[key] = c_new_para[key] - c_local_para[key]
c_local.load_state_dict(c_new_para)
net.to('cpu')
logger.info(' ** Training complete **')
return c_delta_para
init_server_scaffold ▼
def init_server_scaffold(args,nets, local_model_meta_data, layer_type,global_models, global_model_meta_data, global_layer_type,global_model,global_para):
c_nets,_,_ = experiments.init_nets(args.net_config, args.dropout_p, args.n_parties, args)
c_globals, _, _ = experiments.init_nets(args.net_config,0,1,args)
c_global = c_globals[0]
c_global_para = c_global.state_dict()
for net_id,net in c_nets.items():
net.load_state_dict(c_global_para)
return [c_nets,c_global]
server_aggregate_scaffold▼
def server_aggregate_scaffold(nets, selected, args, net_dataidx_map, test_dl_global, device, global_model, global_para, extra_info=None):
c_nets, c_global = extra_info[0], extra_info[1]
c_delta_para_all=experiments.local_train_net('scaffold', nets, selected, args, net_dataidx_map, test_dl_global, device, 'append', global_model=global_model, c_local=c_nets, c_global=c_global)
total_delta = copy.deepcopy(global_model.state_dict())
for key in total_delta:
total_delta[key] = 0.0
c_global.to(device)
global_model.to(device)
for c_delta_para in c_delta_para_all:
for key in total_delta:
total_delta[key] += c_delta_para[key]
for key in total_delta:
total_delta[key] /= args.n_parties
c_global_para = c_global.state_dict()
for key in c_global_para:
if c_global_para[key].type() == 'torch.LongTensor':
c_global_para[key] += total_delta[key].type(torch.LongTensor)
elif c_global_para[key].type() == 'torch.cuda.LongTensor':
c_global_para[key] += total_delta[key].type(torch.cuda.LongTensor)
else:
c_global_para[key] += total_delta[key]
c_global.load_state_dict(c_global_para)
# update global model
total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
for idx in range(len(selected)):
net_para = nets[selected[idx]].cpu().state_dict()
if idx == 0:
for key in net_para:
global_para[key] = net_para[key] * fed_avg_freqs[idx]
else:
for key in net_para:
global_para[key] += net_para[key] * fed_avg_freqs[idx]
global_model.load_state_dict(global_para)
3 fednova:
train_net_fednova ▼
def train_net_fednova(net_id, net, train_dataloader, test_dataloader, epochs, lr, optimizer, device="cpu", arguments=None, global_model=None,net_dataidx_map_in_train=None):
criterion = nn.CrossEntropyLoss().to(device)
if type(train_dataloader) == type([1]):
pass
else:
train_dataloader = [train_dataloader]
#writer = SummaryWriter()
tau = 0
for epoch in range(epochs):
epoch_loss_collector = []
for tmp in train_dataloader:
for batch_idx, (x, target) in enumerate(tmp):
x, target = x.to(device), target.to(device)
optimizer.zero_grad()
x.requires_grad = True
target.requires_grad = False
target = target.long()
out = net(x)
loss = criterion(out, target)
loss.backward()
optimizer.step()
tau = tau + 1
epoch_loss_collector.append(loss.item())
epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
global_model.to(device)
a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho)
global_model_para = global_model.state_dict()
net_para = net.state_dict()
norm_grad = copy.deepcopy(global_model.state_dict())
for key in norm_grad:
#norm_grad[key] = (global_model_para[key] - net_para[key]) / a_i
norm_grad[key] = torch.true_divide(global_model_para[key]-net_para[key], a_i)
net.to('cpu')
logger.info(' ** Training complete **')
#the part of getting len(train_dl_local.dataset()) is moved here. the "args" will be renamed arguments
dataidxs = net_dataidx_map_in_train[net_id]
if arguments.noise_type == 'space':
train_dl_local, test_dl_local, _, _ = get_dataloader(arguments.dataset, arguments.datadir, arguments.batch_size, 32, dataidxs, noise_level, net_id, arguments.n_parties-1)
else:
noise_level = arguments.noise / (arguments.n_parties - 1) * net_id
train_dl_local, test_dl_local, _, _ = get_dataloader(arguments.dataset, arguments.datadir, arguments.batch_size, 32, dataidxs, noise_level)
train_dataloader,test_dataloader, _, _ = get_dataloader(arguments.dataset, arguments.datadir, arguments.batch_size, 32)
epochs = arguments.epochs
return [a_i, norm_grad, len(train_dl_local.dataset)]
init_server_fednova ▼
def init_server_fednova(args,nets, local_model_meta_data, layer_type,global_models, global_model_meta_data, global_layer_type,global_model,global_para):
d_list = [copy.deepcopy(global_model.state_dict()) for i in range(args.n_parties)]
d_total_round = copy.deepcopy(global_model.state_dict())
for i in range(args.n_parties):
for key in d_list[i]:
d_list[i][key] = 0
for key in d_total_round:
d_total_round[key] = 0
return [d_list,d_total_round]
server_aggregate_fednova▼
def server_aggregate_fednova(nets,selected,args,net_dataidx_map,test_dl_global,device,global_model,global_para, extra_info=None):
all_lists = experiments.local_train_net('fednova',nets,selected,args,net_dataidx_map,test_dl_global,device,'extend',arguments=args,global_model = global_model,net_dataidx_map_in_train=net_dataidx_map)
#pick out the sublists from all_lists
a_list = [all_lists[3*i] for i in range(len(all_lists)//3)]
d_list = [all_lists[3*i+1] for i in range(len(all_lists)//3)]
n_list = [all_lists[3*i+2] for i in range(len(all_lists)//3)]
total_n = sum(n_list)
d_total_round = copy.deepcopy(global_model.state_dict())
for key in d_total_round:
d_total_round[key] = 0.0
for i in range(len(selected)):
d_para = d_list[i]
for key in d_para:
d_total_round[key] += d_para[key] * n_list[i] / total_n
# update global model
coeff = 0.0
for i in range(len(selected)):
coeff = coeff + a_list[i] * n_list[i]/total_n
updated_model = global_model.state_dict()
for key in updated_model:
#print(updated_model[key])
if updated_model[key].type() == 'torch.LongTensor':
updated_model[key] -= (coeff * d_total_round[key]).type(torch.LongTensor)
elif updated_model[key].type() == 'torch.cuda.LongTensor':
updated_model[key] -= (coeff * d_total_round[key]).type(torch.cuda.LongTensor)
else:
updated_model[key] -= coeff * d_total_round[key]
global_model.load_state_dict(updated_model)
Results of our demo
<< swipe to browse <<
cifar10/simple-cnn 10party/200round/1sample
|
Quantity-based label imbalance |
Distribution-based label imbalance |
IID setting |
Noise-based feature skew** |
Quantity Skew |
$\# C=1$ |
$\# C=2$ |
$\# C=3$ |
$p_k\sim \textbf{Dir}(0.5)$ |
$p_k\sim \textbf{Dir}(0.1)$ |
$\textbf{IID}$ |
$\hat{x} \sim \textbf{Gau}(0.1)$ |
$q \sim \textbf{Dir}(0.5)$ |
Fedavg |
10.0% | 49.6% | 60.6% |
67.4% | 63.2% | 72.8% |
71.5% | 75.0% |
Fedprox(mu=0.001) |
10.0% | 48.4% | 60.9% | 66.4% |
63.3% | 73.4% |
71.6% | 73.5% |
Scaffold |
10.0% | 48.6% | 59.0% |
71.5% | 67.6% | 74.6% |
72.4% | 10.0% |
Fednova |
10.0% | 51.5% | 60.0% |
67.3% | 65.4% | 74.1% |
71.7% | 30.9% |
Moon |
10.0% | 47.8% | 61.5% |
66.8% | 63.7% | 73.2% |
71.7% | 34.6% |
cifar100/resnet 10party/200round/1sample
|
Quantity-based label imbalance |
Distribution-based label imbalance |
IID setting |
Noise-based feature skew** |
Quantity Skew |
$\# C=1$ |
$\# C=2$ |
$\# C=3$ |
$p_k\sim \textbf{Dir}(0.5)$ |
$p_k\sim \textbf{Dir}(0.1)$ |
$\textbf{IID}$ |
$\hat{x} \sim \textbf{Gau}(0.1)$ |
$q \sim \textbf{Dir}(0.5)$ |
Fedavg |
N/A | N/A | N/A |
66.4% | 65.1% | 67.1% |
66.3% | 69.4% |
Fedprox(mu=0.001) |
N/A | N/A | N/A |
66.5% | 65.6% | 67.5% |
67.4% | 67.4% |
Scaffold |
N/A | N/A | N/A |
70.3% | 67.4% | 71.8% | 71.8% |
1.0% |
Fednova |
N/A | N/A | N/A |
65.8% | 65.2% | 67.9% |
67.4% | 1.0% |
Moon |
N/A | N/A | N/A |
66.9% | 65.3% | 67.9% |
67.6% | 69.9% |
cifar10/simple-cnn 200party/1000round/0.1sample
|
Quantity-based label imbalance |
Distribution-based label imbalance |
IID setting |
Noise-based feature skew** |
Quantity Skew |
$\# C=1$ |
$\# C=2$ |
$\# C=3$ |
$p_k\sim \textbf{Dir}(0.5)$ |
$p_k\sim \textbf{Dir}(0.1)$ |
$\textbf{IID}$ |
$\hat{x} \sim \textbf{Gau}(0.1)$ |
$q \sim \textbf{Dir}(0.5)$ |
Fedavg |
10.0% | 48.6% | 57.4% |
60.3% | 45.7% | 62.5% |
60.8% | 67.5% |
Fedprox(mu=0.001) |
10.0% | 52.0% | 57.2% |
61.3% | 46.8% | 62.3% |
60.7% | 67.7% |
Scaffold |
10.0% | 10.0% | 10.0% |
10.0% | 10.0% | 72.5% |
70.9% | 10.0% |
Fednova |
10.0% | 52.4% | 58.2% |
61.4% | 10.0% | 62.6% |
60.8% | 10.0% |
Moon |
10.0% | 51.3% | 58.5% |
60.0% | 46.9% | 62.1% |
60.8% | 67.9% |
cifar100/resnet 200party/1000round/0.1sample
|
Quantity-based label imbalance |
Distribution-based label imbalance |
IID setting |
Noise-based feature skew** |
Quantity Skew |
$\# C=1$ |
$\# C=2$ |
$\# C=3$ |
$p_k\sim \textbf{Dir}(0.5)$ |
$p_k\sim \textbf{Dir}(0.1)$ |
$\textbf{IID}$ |
$\hat{x} \sim \textbf{Gau}(0.1)$ |
$q \sim \textbf{Dir}(0.5)$ |
Fedavg |
1.0% | 10.3% | 22.3% |
51.2% | 50.1% | 48.0% |
49.6%*** | 57.3% |
Fedprox(mu=0.001) |
1.0% | 10.5% | 17.7% |
50.0% | 53.4% | 53.5% |
47.3%*** | 55.8% |
Scaffold |
1.0% | 9.8% | 12.6% |
34.6% | 20.3% | 62.9% |
53.2%*** | 57.5% |
Fednova |
1.0% | 10.4% | 22.8% |
52.8% | 53.5% | 47.9% |
47.3%*** | 1.0% |
Moon |
1.0% | 11.9% | 22.9% |
53.4% | 56.5% | 48.8% |
47.0%*** | 56.7% |
*: MOON added projection layers into the model and compared the other algorithms with the same expanded model in its experiments. In our experiments, we did not add additional layers and used the features before the last layer as the representations for MOON.
**: Due to result similarity with IID setting, and unbearable time cost, the experiments on feature skew is no longer required.
***: Result of 500 rounds due to unbearable time cost.