Skip to main content
  1. Blogs/

Understanding MACE: Mass Concept Erasure in Diffusion Models, Part III

·1556 words·8 mins
Machine Unlearning Concept Erasure
MACE - This article is part of a series.
Part 3: This Article
This is a series of blogs that detailed recorded the exploration of the paper MACE: Mass Concept Erasure in Diffusion Models. The thumbnail image is generated by Stable Diffusion 3.5 Large using the prompt “Machine Unlearning”.

Here is the paper link and its github repository:

Paper: MACE: Mass Concept Erasure in Diffusion Models

Training Process
#

In this blog, we will first discuss the training process of the MACE implementation. It used the same paradigm as the diffusers LoRA fine-tuning process. We skip the standard initialisation of accelerator and start by loading the pipeline:

1# import tokenizer and text encoder
2tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", 
3                                            revision=args.revision, use_fast=False)
4text_encoder_cls = CLIPTextModel(args.pretrained_model_name_or_path, args.revision)
5
6# Load scheduler and models
7noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
8vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
9unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision)

Now that we have the models, we now set these model to freeze in case we tune them by accident:

1unet.to(accelerator.device, dtype=weight_dtype)
2vae.requires_grad_(False)
3unet.requires_grad_(False)
4if not args.train_text_encoder:
5    text_encoder.requires_grad_(False)

MACEDataset
#

In MACE framework, in order to gather all the information for the fine-tuning process, they constructed a dataset called MACEDataset that is used for pre-process the prompt and augmentation. First, let’s take a quick look of the configuration file with some necessary terms from a celebrity config file:

multi_concept:
    - [ 
        [melania-trump, object], 
      ]
use_pooler: true  
train_batch_size: 1
mapping_concept: 
- 'a woman'
augment: true
aug_length: 30
prompt_len: 30
use_gpt: false
prior_preservation_cache_path: ./cache/cache_coco.pt
domain_preservation_cache_path: ./cache/cache_cele.pt
input_data_dir: ./data/1cele

The definition of the dataset is located in dataset.py file and looks like this:

  1class MACEDataset(Dataset):
  2"""
  3A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
  4It pre-processes the images and the tokenizes prompts.
  5"""
  6
  7def __init__(
  8    self,
  9    tokenizer,
 10    size=512,
 11    center_crop=False,
 12    use_pooler=False,
 13    multi_concept=None,
 14    mapping=None,
 15    augment=True,
 16    batch_size=None,
 17    with_prior_preservation=False,
 18    preserve_info=None,
 19    num_class_images=None,
 20    train_seperate=False,
 21    aug_length=50,
 22    prompt_len=250,
 23    input_data_path=None,
 24    use_gpt=False,
 25):  
 26    """
 27    toeknizer: the pretrained tokenizer,
 28    size: the image size
 29    center_crop: whether using center_crop in transform
 30    use_pooler: the special token "<|endoftext|>"
 31    multi_concept: the corresponding unlearning prompts
 32    mapping: the corresponding anchoring prompts
 33    augment: whether augment the prompt using LLM
 34    batch_size: the training batchsize
 35    with_prior_preservation: whether to add preservation information
 36    preserve_info: 
 37    num_class_images: how many classes of images to preserve
 38    train_seperate:
 39    aug_length: the number of augmented prompts
 40    prompt_len: the length of prompt
 41    input_data_path: the input data directory
 42    use_gpt: whether to use LLM to generate synonymous prompt
 43    """
 44    
 45    # will skip the initialisation part
 46    
 47    for concept_idx, (data, mapping_concept) in enumerate(zip(multi_concept, mapping)):
 48        c, t = data # c = 'melania-trump', t = object, mapping_concept='a woman'
 49        
 50        # the image data generated with the unlearning prompt 'melania-trump' 
 51        if input_data_path is not None:
 52            # change 'melania-trump' to 'melania trump' and get the images
 53            p = Path(os.path.join(input_data_path, c.replace("-", " ")))
 54            if not p.exists():
 55                raise ValueError(f"Instance {p} images root doesn't exists.")
 56            
 57            # get the segmented mask of the image if it is belongs to `object` category
 58            if t == "object":
 59                p_mask = Path(os.path.join(input_data_path, c.replace("-", " ")).replace(f'{c.replace("-", " ")}', f'{c.replace("-", " ")} mask'))
 60                if not p_mask.exists():
 61                    raise ValueError(f"Instance {p_mask} images root doesn't exists.")
 62        else:
 63            raise ValueError(f"Input data path is not provided.")    
 64        
 65        image_paths = list(p.iterdir())
 66        single_concept_images_path = []
 67        # get all images path
 68        single_concept_images_path += image_paths
 69        self.all_concept_image_path.append(single_concept_images_path)
 70        
 71        if t == "object":
 72            mask_paths = list(p_mask.iterdir())
 73            single_concept_masks_path = []
 74            # get all mask path
 75            single_concept_masks_path += mask_paths
 76            self.all_concept_mask_path.append(single_concept_masks_path)
 77                 
 78        # the unlearning prompt
 79        erased_concept = c.replace("-", " ")
 80        
 81        # use LLM to generate prompts
 82        if use_gpt:
 83            class_prompt_collection, mapping_prompt_collection = text_augmentation(erased_concept, mapping_concept, t, num_text_augmentations=self.aug_length)
 84            self.instance_prompt.append(class_prompt_collection)
 85            self.target_prompt.append(mapping_prompt_collection)
 86        else: 
 87            # `self.aug_length` number of prompt from templates
 88            sampled_indices = random.sample(range(0, prompt_len), self.aug_length)
 89            self.instance_prompt.append(prompt_augmentation(erased_concept, augment=augment, sampled_indices=sampled_indices, concept_type=t))
 90            self.target_prompt.append(prompt_augmentation(mapping_concept, augment=augment, sampled_indices=sampled_indices, concept_type=t))
 91            
 92        self.num_instance_images += len(single_concept_images_path)
 93        
 94        # prepare the e^f and e^g entries
 95        entry = {"old": self.instance_prompt[concept_idx], "new": self.target_prompt[concept_idx]}
 96        self.dict_for_close_form.append(entry)
 97        
 98    # if have piror_preservation, need also prepare the prompt and image generation with the preservation prompt
 99    if with_prior_preservation:
100        class_data_root = Path(preserve_info['preserve_data_dir'])
101        if os.path.isdir(class_data_root):
102            class_images_path = list(class_data_root.iterdir())
103            class_prompt = [preserve_info["preserve_prompt"] for _ in range(len(class_images_path))]
104        else:
105            with open(class_data_root, "r") as f:
106                class_images_path = f.read().splitlines()
107            with open(preserve_info["preserve_prompt"], "r") as f:
108                class_prompt = f.read().splitlines()
109
110        class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
111        self.class_images_path.extend(class_img_path[:num_class_images])
112    
113    # transforms for the images
114    self.image_transforms = transforms.Compose(
115        [
116            # transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
117            transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
118            transforms.ToTensor(),
119            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
120        ]
121    )
122    
123    self._concept_num = len(self.instance_prompt)
124    self.num_class_images = len(self.class_images_path)
125    self._length = max(self.num_instance_images // self._concept_num, self.num_class_images)
126    
127    # skip __len__
128    
129    def __getitem__(self, index):
130        example = {}
131        
132        if not self.train_seperate:
133            if self.batch_counter % self.batch_size == 0:
134                self.concept_number = random.randint(0, self._concept_num - 1)
135            self.batch_counter += 1
136        
137        # get the image
138        instance_image = Image.open(self.all_concept_image_path[self.concept_number][index % self._length])
139        
140        if len(self.all_concept_mask_path) == 0:
141            # artistic style erasure
142            binary_tensor = None
143        else:
144            # object/celebrity erasure, get the corresponding mask image
145            instance_mask = Image.open(self.all_concept_mask_path[self.concept_number][index % self._length])
146            instance_mask = instance_mask.convert('L')
147            trans = transforms.ToTensor()
148            binary_tensor = trans(instance_mask)
149        
150        # get a random prompt from the list
151        prompt_number = random.randint(0, len(self.instance_prompt[self.concept_number]) - 1)
152        instance_prompt, target_tokens = self.instance_prompt[self.concept_number][prompt_number]
153        
154        if not instance_image.mode == "RGB":
155            instance_image = instance_image.convert("RGB")
156        example["instance_prompt"] = instance_prompt
157        example["instance_images"] = self.image_transforms(instance_image)
158        example["instance_masks"] = binary_tensor
159        
160        # tokenize prompts
161        example["instance_prompt_ids"] = self.tokenizer(
162            instance_prompt,
163            truncation=True,
164            padding="max_length",
165            max_length=self.tokenizer.model_max_length,
166            return_tensors="pt",
167        ).input_ids
168        prompt_ids = self.tokenizer(
169            instance_prompt,
170            truncation=True,
171            padding="max_length",
172            max_length=self.tokenizer.model_max_length
173        ).input_ids
174        # get anchor token
175        concept_ids = self.tokenizer(
176            target_tokens,
177            add_special_tokens=False
178        ).input_ids             
179
180        pooler_token_id = self.tokenizer(
181            "<|endoftext|>",
182            add_special_tokens=False
183        ).input_ids[0]
184
185        concept_positions = [0] * self.tokenizer.model_max_length
186        # Loop to Find the Positions of Concept Tokens and Pooler Token
187        for i, tok_id in enumerate(prompt_ids):
188            if tok_id == concept_ids[0] and prompt_ids[i:i + len(concept_ids)] == concept_ids:
189                concept_positions[i:i + len(concept_ids)] = [1]*len(concept_ids)
190            if self.use_pooler and tok_id == pooler_token_id:
191                concept_positions[i] = 1
192        # Storing concept_positions
193        example["concept_positions"] = torch.tensor(concept_positions)[None]               
194
195        if self.with_prior_preservation:
196            class_image, class_prompt = self.class_images_path[index % self.num_class_images]
197            class_image = Image.open(class_image)
198            if not class_image.mode == "RGB":
199                class_image = class_image.convert("RGB")
200            example["preserve_images"] = self.image_transforms(class_image)
201            example["preserve_prompt_ids"] = self.tokenizer(
202                class_prompt,
203                padding="max_length",
204                truncation=True,
205                max_length=self.tokenizer.model_max_length,
206                return_tensors="pt",
207            ).input_ids
208            
209        return example

The instance example of the MACEDataset contains several useful elements:

1"""
2example["instance_prompt"]: the unlearning instance prompt
3example["instance_images"]: the corresponding unlearning instance image
4example["instance_masks"]: the corresponding unlearning instance image mask
5example["instance_prompt_ids"]: the unlearning tokenized id
6example["concept_positions"]: where the core concept located in the prompt
7example["preserve_images"]: the preservation image
8example["preserve_prompt_ids"]: the preservation tokenized id
9"""

Training Process (continued)
#

Let’s get back to the training process, after instantiating the train_dataset and train_loader, and recalculating the number of train epochs, we come to the first stage, the introduced closed-form refinement.

 1# stage 1: closed-form refinement
 2# get the correspoding layers and modules
 3projection_matrices, ca_layers, og_matrices = get_ca_layers(unet, with_to_k=True)
 4
 5# to save memory
 6CFR_dict = {}
 7max_concept_num = args.max_memory # the maximum number of concept that can be processed at once
 8if len(train_dataset.dict_for_close_form) > max_concept_num:
 9    
10    for layer_num in tqdm(range(len(projection_matrices))):
11        CFR_dict[f'{layer_num}_for_mat1'] = None
12        CFR_dict[f'{layer_num}_for_mat2'] = None
13        
14    for i in tqdm(range(0, len(train_dataset.dict_for_close_form), max_concept_num)):
15        contexts_sub, valuess_sub = prepare_k_v(text_encoder, projection_matrices, ca_layers, og_matrices, 
16                                                train_dataset.dict_for_close_form[i:i+5], tokenizer, all_words=args.all_words)
17        closed_form_refinement(projection_matrices, contexts_sub, valuess_sub, cache_dict=CFR_dict, cache_mode=True)
18        
19        del contexts_sub, valuess_sub
20        gc.collect()
21        torch.cuda.empty_cache()
22        
23else:
24    for layer_num in tqdm(range(len(projection_matrices))):
25        CFR_dict[f'{layer_num}_for_mat1'] = .0
26        CFR_dict[f'{layer_num}_for_mat2'] = .0
27    # prepare for the closed-form refinement
28    contexts, valuess = prepare_k_v(text_encoder, projection_matrices, ca_layers, og_matrices, 
29                                    train_dataset.dict_for_close_form, tokenizer, all_words=args.all_words)
30
31del ca_layers, og_matrices
32
33# Load cached prior knowledge for preserving, e^p
34if args.prior_preservation_cache_path:
35    prior_preservation_cache_dict = torch.load(args.prior_preservation_cache_path, map_location=projection_matrices[0].weight.device)
36else:
37    prior_preservation_cache_dict = {}
38    for layer_num in tqdm(range(len(projection_matrices))):
39        prior_preservation_cache_dict[f'{layer_num}_for_mat1'] = .0
40        prior_preservation_cache_dict[f'{layer_num}_for_mat2'] = .0
41        
42# Load cached domain knowledge for preserving, lambda_3, e^p
43if args.domain_preservation_cache_path:
44    domain_preservation_cache_dict = torch.load(args.domain_preservation_cache_path, map_location=projection_matrices[0].weight.device)
45else:
46    domain_preservation_cache_dict = {}
47    for layer_num in tqdm(range(len(projection_matrices))):
48        domain_preservation_cache_dict[f'{layer_num}_for_mat1'] = .0
49        domain_preservation_cache_dict[f'{layer_num}_for_mat2'] = .0
50
51# integrate the prior knowledge, domain knowledge and closed-form refinement
52cache_dict = {}
53for key in CFR_dict:
54    cache_dict[key] = args.train_preserve_scale * (prior_preservation_cache_dict[key] \
55                    + args.preserve_weight * domain_preservation_cache_dict[key]) \
56                    + CFR_dict[key]
57
58# closed-form refinement
59projection_matrices, _, _ = get_ca_layers(unet, with_to_k=True)
60
61if len(train_dataset.dict_for_close_form) > max_concept_num:
62    closed_form_refinement(projection_matrices, lamb=args.lamb, preserve_scale=1, cache_dict=cache_dict)
63else:
64    # ge3t the final weights
65    closed_form_refinement(projection_matrices, contexts, valuess, lamb=args.lamb, 
66                           preserve_scale=args.train_preserve_scale, cache_dict=cache_dict)
67
68del contexts, valuess, cache_dict
69gc.collect()
70torch.cuda.empty_cache()

Mutli-LoRA Training
#

The closed-form refinement focuses on the removal of co-existing words, either generated from LLMs or from the template. The next step is to erase the target concept itself. The author proposed that the attention maps corresponding to the tokens of the concept should display high activation values in certain regions.

MACE - This article is part of a series.
Part 3: This Article

comments powered by Disqus