def proxy_one_batch(config, input_wrong, cam):
grads = cam(input_tensor=input_wrong.to(config["device"]), targets=None)
grads = torch.Tensor(grads).to(config["device"]).unsqueeze(1).expand(-1, 3, -1, -1)
normalized_inps = inv_normalize(input_wrong)
if config["pixel_replacement_method"] != "blended":
output = torch.where(
grads > config["proxy_threshold"],
dict_decide_change[config["pixel_replacement_method"](grads),
normalized_inps,
)
else:
output= torch.where(
grads > config["proxy_threshold"],
(1 - config["proxy_image_weight"] * grads) * normalized_inps,
normalized_inps,
)
del grads
return output
def proxy_callback(config, input_wrong_full, label_wrong_full, cam):
# TODO Save Classwise fraction
chosen_inds = int(np.ceil(config["change_subset_attention"] * len(label_wrong_full)))
# TODO some sort of decay?
# TODO Remove min and batchify
input_wrong_full = input_wrong_full[:chosen_inds]
label_wrong_full = label_wrong_full[:chosen_inds]
processed_labels = []
processed_thresholds = []
for i in tqdm(range(0, len(input_wrong_full), config["batch_size"]), desc="Running proxy"):
try:
input_wrong = input_wrong_full[i:i+config["batch_size"]
label_wrong = label_wrong_full[i:i+config["batch_size"]
try:
input_wrong = torch.squeeze(torch.stack(input_wrong, dim=1))
label_wrong = torch.squeeze(torch.stack(label_wrong, dim=1))
except:
input_wrong = torch.squeeze(input_wrong)
label_wrong = torch.squeeze(label_wrong)
thresholded_ims = proxy_one_batch(config, input_wrong.to(config["device"]), cam)
processed_thresholds.extend(thresholded_ims.detach().cpu())
processed_labels.extend(label_wrong)
processed_thresholds = torch.stack(processed_thresholds, dim = 0).detach()
batch_size = processed_thresholds.size(0)
for ind in tqdm(range(batch_size), total=batch_size, desc="Saving images"):
label = config["label_map"][processed_labels[ind].item()]
save_name = (
config["ds_path"] / label / f"proxy-{ind}-{config['global_run_count']}.jpeg"
)
tfm(processed_thresholds[ind, :, :, :]).save(save_name)