Here is the paper link and its github repository:
Paper: MACE: Mass Concept Erasure in Diffusion Models
Shilin-LU/MACE[CVPR 2024] “MACE: Mass Concept Erasure in Diffusion Models” (Official Implementation)
Python36624
Training Process#
In this blog, we will first discuss the training process of the MACE implementation. It used the same paradigm as the diffusers
LoRA fine-tuning process. We skip the standard initialisation of accelerator
and start by loading the pipeline:
1# import tokenizer and text encoder
2tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
3 revision=args.revision, use_fast=False)
4text_encoder_cls = CLIPTextModel(args.pretrained_model_name_or_path, args.revision)
5
6# Load scheduler and models
7noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
8vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
9unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision)
Now that we have the models, we now set these model to freeze in case we tune them by accident:
1unet.to(accelerator.device, dtype=weight_dtype)
2vae.requires_grad_(False)
3unet.requires_grad_(False)
4if not args.train_text_encoder:
5 text_encoder.requires_grad_(False)
MACEDataset
#
In MACE framework, in order to gather all the information for the fine-tuning process, they constructed a dataset called MACEDataset
that is used for pre-process the prompt and augmentation. First, let’s take a quick look of the configuration file with some necessary terms from a celebrity config file:
multi_concept:
- [
[melania-trump, object],
]
use_pooler: true
train_batch_size: 1
mapping_concept:
- 'a woman'
augment: true
aug_length: 30
prompt_len: 30
use_gpt: false
prior_preservation_cache_path: ./cache/cache_coco.pt
domain_preservation_cache_path: ./cache/cache_cele.pt
input_data_dir: ./data/1cele
The definition of the dataset is located in dataset.py
file and looks like this:
1class MACEDataset(Dataset):
2"""
3A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
4It pre-processes the images and the tokenizes prompts.
5"""
6
7def __init__(
8 self,
9 tokenizer,
10 size=512,
11 center_crop=False,
12 use_pooler=False,
13 multi_concept=None,
14 mapping=None,
15 augment=True,
16 batch_size=None,
17 with_prior_preservation=False,
18 preserve_info=None,
19 num_class_images=None,
20 train_seperate=False,
21 aug_length=50,
22 prompt_len=250,
23 input_data_path=None,
24 use_gpt=False,
25):
26 """
27 toeknizer: the pretrained tokenizer,
28 size: the image size
29 center_crop: whether using center_crop in transform
30 use_pooler: the special token "<|endoftext|>"
31 multi_concept: the corresponding unlearning prompts
32 mapping: the corresponding anchoring prompts
33 augment: whether augment the prompt using LLM
34 batch_size: the training batchsize
35 with_prior_preservation: whether to add preservation information
36 preserve_info:
37 num_class_images: how many classes of images to preserve
38 train_seperate:
39 aug_length: the number of augmented prompts
40 prompt_len: the length of prompt
41 input_data_path: the input data directory
42 use_gpt: whether to use LLM to generate synonymous prompt
43 """
44
45 # will skip the initialisation part
46
47 for concept_idx, (data, mapping_concept) in enumerate(zip(multi_concept, mapping)):
48 c, t = data # c = 'melania-trump', t = object, mapping_concept='a woman'
49
50 # the image data generated with the unlearning prompt 'melania-trump'
51 if input_data_path is not None:
52 # change 'melania-trump' to 'melania trump' and get the images
53 p = Path(os.path.join(input_data_path, c.replace("-", " ")))
54 if not p.exists():
55 raise ValueError(f"Instance {p} images root doesn't exists.")
56
57 # get the segmented mask of the image if it is belongs to `object` category
58 if t == "object":
59 p_mask = Path(os.path.join(input_data_path, c.replace("-", " ")).replace(f'{c.replace("-", " ")}', f'{c.replace("-", " ")} mask'))
60 if not p_mask.exists():
61 raise ValueError(f"Instance {p_mask} images root doesn't exists.")
62 else:
63 raise ValueError(f"Input data path is not provided.")
64
65 image_paths = list(p.iterdir())
66 single_concept_images_path = []
67 # get all images path
68 single_concept_images_path += image_paths
69 self.all_concept_image_path.append(single_concept_images_path)
70
71 if t == "object":
72 mask_paths = list(p_mask.iterdir())
73 single_concept_masks_path = []
74 # get all mask path
75 single_concept_masks_path += mask_paths
76 self.all_concept_mask_path.append(single_concept_masks_path)
77
78 # the unlearning prompt
79 erased_concept = c.replace("-", " ")
80
81 # use LLM to generate prompts
82 if use_gpt:
83 class_prompt_collection, mapping_prompt_collection = text_augmentation(erased_concept, mapping_concept, t, num_text_augmentations=self.aug_length)
84 self.instance_prompt.append(class_prompt_collection)
85 self.target_prompt.append(mapping_prompt_collection)
86 else:
87 # `self.aug_length` number of prompt from templates
88 sampled_indices = random.sample(range(0, prompt_len), self.aug_length)
89 self.instance_prompt.append(prompt_augmentation(erased_concept, augment=augment, sampled_indices=sampled_indices, concept_type=t))
90 self.target_prompt.append(prompt_augmentation(mapping_concept, augment=augment, sampled_indices=sampled_indices, concept_type=t))
91
92 self.num_instance_images += len(single_concept_images_path)
93
94 # prepare the e^f and e^g entries
95 entry = {"old": self.instance_prompt[concept_idx], "new": self.target_prompt[concept_idx]}
96 self.dict_for_close_form.append(entry)
97
98 # if have piror_preservation, need also prepare the prompt and image generation with the preservation prompt
99 if with_prior_preservation:
100 class_data_root = Path(preserve_info['preserve_data_dir'])
101 if os.path.isdir(class_data_root):
102 class_images_path = list(class_data_root.iterdir())
103 class_prompt = [preserve_info["preserve_prompt"] for _ in range(len(class_images_path))]
104 else:
105 with open(class_data_root, "r") as f:
106 class_images_path = f.read().splitlines()
107 with open(preserve_info["preserve_prompt"], "r") as f:
108 class_prompt = f.read().splitlines()
109
110 class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
111 self.class_images_path.extend(class_img_path[:num_class_images])
112
113 # transforms for the images
114 self.image_transforms = transforms.Compose(
115 [
116 # transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
117 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
118 transforms.ToTensor(),
119 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
120 ]
121 )
122
123 self._concept_num = len(self.instance_prompt)
124 self.num_class_images = len(self.class_images_path)
125 self._length = max(self.num_instance_images // self._concept_num, self.num_class_images)
126
127 # skip __len__
128
129 def __getitem__(self, index):
130 example = {}
131
132 if not self.train_seperate:
133 if self.batch_counter % self.batch_size == 0:
134 self.concept_number = random.randint(0, self._concept_num - 1)
135 self.batch_counter += 1
136
137 # get the image
138 instance_image = Image.open(self.all_concept_image_path[self.concept_number][index % self._length])
139
140 if len(self.all_concept_mask_path) == 0:
141 # artistic style erasure
142 binary_tensor = None
143 else:
144 # object/celebrity erasure, get the corresponding mask image
145 instance_mask = Image.open(self.all_concept_mask_path[self.concept_number][index % self._length])
146 instance_mask = instance_mask.convert('L')
147 trans = transforms.ToTensor()
148 binary_tensor = trans(instance_mask)
149
150 # get a random prompt from the list
151 prompt_number = random.randint(0, len(self.instance_prompt[self.concept_number]) - 1)
152 instance_prompt, target_tokens = self.instance_prompt[self.concept_number][prompt_number]
153
154 if not instance_image.mode == "RGB":
155 instance_image = instance_image.convert("RGB")
156 example["instance_prompt"] = instance_prompt
157 example["instance_images"] = self.image_transforms(instance_image)
158 example["instance_masks"] = binary_tensor
159
160 # tokenize prompts
161 example["instance_prompt_ids"] = self.tokenizer(
162 instance_prompt,
163 truncation=True,
164 padding="max_length",
165 max_length=self.tokenizer.model_max_length,
166 return_tensors="pt",
167 ).input_ids
168 prompt_ids = self.tokenizer(
169 instance_prompt,
170 truncation=True,
171 padding="max_length",
172 max_length=self.tokenizer.model_max_length
173 ).input_ids
174 # get anchor token
175 concept_ids = self.tokenizer(
176 target_tokens,
177 add_special_tokens=False
178 ).input_ids
179
180 pooler_token_id = self.tokenizer(
181 "<|endoftext|>",
182 add_special_tokens=False
183 ).input_ids[0]
184
185 concept_positions = [0] * self.tokenizer.model_max_length
186 # Loop to Find the Positions of Concept Tokens and Pooler Token
187 for i, tok_id in enumerate(prompt_ids):
188 if tok_id == concept_ids[0] and prompt_ids[i:i + len(concept_ids)] == concept_ids:
189 concept_positions[i:i + len(concept_ids)] = [1]*len(concept_ids)
190 if self.use_pooler and tok_id == pooler_token_id:
191 concept_positions[i] = 1
192 # Storing concept_positions
193 example["concept_positions"] = torch.tensor(concept_positions)[None]
194
195 if self.with_prior_preservation:
196 class_image, class_prompt = self.class_images_path[index % self.num_class_images]
197 class_image = Image.open(class_image)
198 if not class_image.mode == "RGB":
199 class_image = class_image.convert("RGB")
200 example["preserve_images"] = self.image_transforms(class_image)
201 example["preserve_prompt_ids"] = self.tokenizer(
202 class_prompt,
203 padding="max_length",
204 truncation=True,
205 max_length=self.tokenizer.model_max_length,
206 return_tensors="pt",
207 ).input_ids
208
209 return example
The instance example
of the MACEDataset
contains several useful elements:
1"""
2example["instance_prompt"]: the unlearning instance prompt
3example["instance_images"]: the corresponding unlearning instance image
4example["instance_masks"]: the corresponding unlearning instance image mask
5example["instance_prompt_ids"]: the unlearning tokenized id
6example["concept_positions"]: where the core concept located in the prompt
7example["preserve_images"]: the preservation image
8example["preserve_prompt_ids"]: the preservation tokenized id
9"""
Training Process (continued)#
Let’s get back to the training process, after instantiating the train_dataset
and train_loader
, and recalculating the number of train epochs, we come to the first stage, the introduced closed-form refinement
.
1# stage 1: closed-form refinement
2# get the correspoding layers and modules
3projection_matrices, ca_layers, og_matrices = get_ca_layers(unet, with_to_k=True)
4
5# to save memory
6CFR_dict = {}
7max_concept_num = args.max_memory # the maximum number of concept that can be processed at once
8if len(train_dataset.dict_for_close_form) > max_concept_num:
9
10 for layer_num in tqdm(range(len(projection_matrices))):
11 CFR_dict[f'{layer_num}_for_mat1'] = None
12 CFR_dict[f'{layer_num}_for_mat2'] = None
13
14 for i in tqdm(range(0, len(train_dataset.dict_for_close_form), max_concept_num)):
15 contexts_sub, valuess_sub = prepare_k_v(text_encoder, projection_matrices, ca_layers, og_matrices,
16 train_dataset.dict_for_close_form[i:i+5], tokenizer, all_words=args.all_words)
17 closed_form_refinement(projection_matrices, contexts_sub, valuess_sub, cache_dict=CFR_dict, cache_mode=True)
18
19 del contexts_sub, valuess_sub
20 gc.collect()
21 torch.cuda.empty_cache()
22
23else:
24 for layer_num in tqdm(range(len(projection_matrices))):
25 CFR_dict[f'{layer_num}_for_mat1'] = .0
26 CFR_dict[f'{layer_num}_for_mat2'] = .0
27 # prepare for the closed-form refinement
28 contexts, valuess = prepare_k_v(text_encoder, projection_matrices, ca_layers, og_matrices,
29 train_dataset.dict_for_close_form, tokenizer, all_words=args.all_words)
30
31del ca_layers, og_matrices
32
33# Load cached prior knowledge for preserving, e^p
34if args.prior_preservation_cache_path:
35 prior_preservation_cache_dict = torch.load(args.prior_preservation_cache_path, map_location=projection_matrices[0].weight.device)
36else:
37 prior_preservation_cache_dict = {}
38 for layer_num in tqdm(range(len(projection_matrices))):
39 prior_preservation_cache_dict[f'{layer_num}_for_mat1'] = .0
40 prior_preservation_cache_dict[f'{layer_num}_for_mat2'] = .0
41
42# Load cached domain knowledge for preserving, lambda_3, e^p
43if args.domain_preservation_cache_path:
44 domain_preservation_cache_dict = torch.load(args.domain_preservation_cache_path, map_location=projection_matrices[0].weight.device)
45else:
46 domain_preservation_cache_dict = {}
47 for layer_num in tqdm(range(len(projection_matrices))):
48 domain_preservation_cache_dict[f'{layer_num}_for_mat1'] = .0
49 domain_preservation_cache_dict[f'{layer_num}_for_mat2'] = .0
50
51# integrate the prior knowledge, domain knowledge and closed-form refinement
52cache_dict = {}
53for key in CFR_dict:
54 cache_dict[key] = args.train_preserve_scale * (prior_preservation_cache_dict[key] \
55 + args.preserve_weight * domain_preservation_cache_dict[key]) \
56 + CFR_dict[key]
57
58# closed-form refinement
59projection_matrices, _, _ = get_ca_layers(unet, with_to_k=True)
60
61if len(train_dataset.dict_for_close_form) > max_concept_num:
62 closed_form_refinement(projection_matrices, lamb=args.lamb, preserve_scale=1, cache_dict=cache_dict)
63else:
64 # ge3t the final weights
65 closed_form_refinement(projection_matrices, contexts, valuess, lamb=args.lamb,
66 preserve_scale=args.train_preserve_scale, cache_dict=cache_dict)
67
68del contexts, valuess, cache_dict
69gc.collect()
70torch.cuda.empty_cache()
Mutli-LoRA Training#
The closed-form refinement focuses on the removal of co-existing words, either generated from LLMs or from the template. The next step is to erase the target concept itself. The author proposed that the attention maps corresponding to the tokens of the concept should display high activation values in certain regions.