Model Inversion Attack is the method to create a model which is about the same functions of the target model that attackers does not know the architecture (so-called black-box model) by the outputs o
import numpy as npfrom collections import namedtupleimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderimport torchvision.transforms as transformsfrom torchvision.datasets import EMNIST, MNISTfrom tqdm.notebook import tqdm, trangeimport matplotlib.pyplot as plt
2. Set Hyperparameters of Each Model
Next, we prepare the hyperparemeters for each model. These values will be used for training, splitting dataset, etc.
hyperparams =namedtuple("hyperparams", "batch_size,epochs,learning_rate,n_data")# Hyperparameters for victim modelvictim_hyperparams =hyperparams( batch_size=256, epochs=10, learning_rate=1e-4, n_data=20_000, # no required all dataset)# Hyperparamerters for evil model used to attackevil_hyperparams =hyperparams( batch_size=32, epochs=10, learning_rate=1e-4, n_data=500,)
3. Load/Preprocess Dataset and Create DataLoader
We use MNIST dataset for this explanation purpose.
Since this article is for educational purpose, we need to create target model to be inversed at first. In practice, we don’t have the architecture of target model.
Here we create the neural network named VictimNet as an example.
The layers are separated the two stages. We will intercept the stage1 in the later process.
classVictimNet(nn.Module):def__init__(self,first_network,second_network) ->None:super().__init__() self.stage1 = first_network self.stage2 = second_networkdefmobile_stage(self,x):return self.stage1(x)defforward(self,x): out = self.mobile_stage(x) out = out.view(out.size(0), -1)return self.stage2(out)
In addition, we need to prepare dataset and data loader for this evil model.
evil_dataset =EMNIST("emnist", "letters", download=True, train=False, transform=preprocess)# Use the last n_data images in the test set to train the evil modelevil_dataset.data = evil_dataset.data[:evil_hyperparams.n_data]evil_dataset.targets = evil_dataset.targets[:evil_hyperparams.n_data]# Dataloaderevil_loader =DataLoader(evil_dataset, batch_size=evil_hyperparams.batch_size)
To train, execute the following script.
# Optimizerevil_optim = torch.optim.Adam(evil_model.parameters(), lr=evil_hyperparams.learning_rate)# Train by each epochfor epoch intrange(evil_hyperparams.epochs):for data, targets in evil_loader: data.float() targets.float()# Intercept the output of the mobile device's model.# This is the input of the evil model.with torch.no_grad(): evil_input = victim_model.mobile_stage(data) output =evil_model(evil_input)# Calculate the mean squared loss between the predicted output and the original input data loss = ((output - data)**2).mean() loss.backward() evil_optim.step()
6. Attack
Since we have all equipment, start inversing the target model and generate images which are about the same as the output of the target model.
At first, we create a function to plot the generated images.