feat: support meta SSL watermarking
Browse files- SSL_watermark.py +87 -0
- app.py +23 -6
- dino_r50.pth +3 -0
- image_utils.py +80 -0
- out2048.pth +3 -0
- requirements.txt +4 -0
- torch_utils.py +84 -0
SSL_watermark.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
import torch_utils
|
| 8 |
+
import image_utils
|
| 9 |
+
|
| 10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
|
| 12 |
+
torch.manual_seed(0)
|
| 13 |
+
np.random.seed(0)
|
| 14 |
+
|
| 15 |
+
print('Building backbone and normalization layer...')
|
| 16 |
+
backbone = torch_utils.build_backbone(path='dino_r50.pth')
|
| 17 |
+
normlayer = torch_utils.load_normalization_layer(path='out2048.pth')
|
| 18 |
+
model = torch_utils.NormLayerWrapper(backbone, normlayer)
|
| 19 |
+
|
| 20 |
+
print('Building the hypercone...')
|
| 21 |
+
FPR = 1e-6
|
| 22 |
+
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
|
| 23 |
+
rho = 1 + np.tan(angle)**2
|
| 24 |
+
carrier = torch.randn(1, 2048)
|
| 25 |
+
carrier /= torch.norm(carrier, dim=1, keepdim=True)
|
| 26 |
+
|
| 27 |
+
default_transform = transforms.Compose([
|
| 28 |
+
transforms.ToTensor(),
|
| 29 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
|
| 33 |
+
img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
|
| 34 |
+
img = img_orig.clone().to(device, non_blocking=True)
|
| 35 |
+
img.requires_grad = True
|
| 36 |
+
optimizer = torch.optim.Adam([img], lr=1e-2)
|
| 37 |
+
|
| 38 |
+
for iteration in range(epochs):
|
| 39 |
+
print(f'iteration: {iteration}')
|
| 40 |
+
x = image_utils.ssim_attenuation(img, img_orig)
|
| 41 |
+
x = image_utils.psnr_clip(x, img_orig, psnr)
|
| 42 |
+
|
| 43 |
+
ft = model(x) # BxCxWxH -> BxD
|
| 44 |
+
|
| 45 |
+
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
|
| 46 |
+
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
|
| 47 |
+
cosines = torch.abs(dot_product/norm)
|
| 48 |
+
log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
|
| 49 |
+
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
|
| 50 |
+
|
| 51 |
+
loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
|
| 52 |
+
loss = lambda_w*loss_R + lambda_i*loss_l2_img
|
| 53 |
+
|
| 54 |
+
optimizer.zero_grad()
|
| 55 |
+
loss.backward()
|
| 56 |
+
optimizer.step()
|
| 57 |
+
|
| 58 |
+
logs = {
|
| 59 |
+
"keyword": "img_optim",
|
| 60 |
+
"iteration": iteration,
|
| 61 |
+
"loss": loss.item(),
|
| 62 |
+
"loss_R": loss_R.item(),
|
| 63 |
+
"loss_l2_img": loss_l2_img.item(),
|
| 64 |
+
"log10_pvalue": log10_pvalue.item(),
|
| 65 |
+
}
|
| 66 |
+
print("__log__:%s" % json.dumps(logs))
|
| 67 |
+
|
| 68 |
+
img = image_utils.ssim_attenuation(img, img_orig)
|
| 69 |
+
img = image_utils.psnr_clip(img, img_orig, psnr)
|
| 70 |
+
img = image_utils.round_pixel(img)
|
| 71 |
+
img = img.squeeze(0).detach().cpu()
|
| 72 |
+
img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0))
|
| 73 |
+
|
| 74 |
+
return img
|
| 75 |
+
|
| 76 |
+
def decode(image):
|
| 77 |
+
img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
|
| 78 |
+
ft = model(img) # BxCxWxH -> BxD
|
| 79 |
+
|
| 80 |
+
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
|
| 81 |
+
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
|
| 82 |
+
cosines = torch.abs(dot_product/norm)
|
| 83 |
+
log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
|
| 84 |
+
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
|
| 85 |
+
|
| 86 |
+
text_marked = "marked" if loss_R < 0 else "unmarked"
|
| 87 |
+
return f'Image is {text_marked}, with p-value={10**log10_pvalue}'
|
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from steganography import Steganography
|
| 3 |
from utils import draw_multiple_line_text, generate_qr_code
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
|
|
@@ -8,20 +9,27 @@ TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
|
|
| 8 |
|
| 9 |
def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
|
| 10 |
input_image = input_image.convert('RGB')
|
| 11 |
-
|
| 12 |
if radio_button == "Image":
|
| 13 |
watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
|
| 14 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
| 15 |
elif radio_button == "Text":
|
| 16 |
watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
|
| 17 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
| 18 |
-
|
| 19 |
size = min(input_image.width, input_image.height)
|
| 20 |
watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
|
| 21 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
def extract_watermark(input_image_to_extract):
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
with gr.Blocks() as demo:
|
|
@@ -34,7 +42,7 @@ with gr.Blocks() as demo:
|
|
| 34 |
with gr.Blocks():
|
| 35 |
gr.Markdown("### Which type of watermark you want to apply?")
|
| 36 |
radio_button = gr.Radio(
|
| 37 |
-
choices=["QRCode", "Text", "Image"],
|
| 38 |
label="Watermark type",
|
| 39 |
value="QRCode",
|
| 40 |
# info="Which type of watermark you want to apply?"
|
|
@@ -82,6 +90,11 @@ with gr.Blocks() as demo:
|
|
| 82 |
with gr.Column():
|
| 83 |
gr.Markdown("### Image to extract watermark")
|
| 84 |
input_image_to_extract = gr.Image(type='pil')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
with gr.Column():
|
| 86 |
gr.Markdown("### Extracted watermark")
|
| 87 |
extracted_watermark = gr.Image(type='pil')
|
|
@@ -97,6 +110,10 @@ with gr.Blocks() as demo:
|
|
| 97 |
inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url],
|
| 98 |
outputs=[output_image]
|
| 99 |
)
|
| 100 |
-
extract_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from steganography import Steganography
|
| 3 |
from utils import draw_multiple_line_text, generate_qr_code
|
| 4 |
+
from SSL_watermark import encode, decode
|
| 5 |
|
| 6 |
|
| 7 |
TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
|
|
|
|
| 9 |
|
| 10 |
def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
|
| 11 |
input_image = input_image.convert('RGB')
|
| 12 |
+
print(f'radio_button: {radio_button}')
|
| 13 |
if radio_button == "Image":
|
| 14 |
watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
|
| 15 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
| 16 |
elif radio_button == "Text":
|
| 17 |
watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
|
| 18 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
| 19 |
+
elif radio_button == "QRCode":
|
| 20 |
size = min(input_image.width, input_image.height)
|
| 21 |
watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
|
| 22 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
| 23 |
+
else:
|
| 24 |
+
print('start encoding ssl watermark...')
|
| 25 |
+
return encode(input_image, epochs=5)
|
| 26 |
|
| 27 |
+
def extract_watermark(extract_radio_button, input_image_to_extract):
|
| 28 |
+
if extract_radio_button == 'Steganography':
|
| 29 |
+
return Steganography().unmerge(input_image_to_extract.convert('RGB'), digit=7).convert('RGBA')
|
| 30 |
+
else:
|
| 31 |
+
decoded_info = decode(image=input_image_to_extract)
|
| 32 |
+
return draw_multiple_line_text(input_image_size=input_image_to_extract.size, text=decoded_info)
|
| 33 |
|
| 34 |
|
| 35 |
with gr.Blocks() as demo:
|
|
|
|
| 42 |
with gr.Blocks():
|
| 43 |
gr.Markdown("### Which type of watermark you want to apply?")
|
| 44 |
radio_button = gr.Radio(
|
| 45 |
+
choices=["QRCode", "Text", "Image", "SSL Watermark"],
|
| 46 |
label="Watermark type",
|
| 47 |
value="QRCode",
|
| 48 |
# info="Which type of watermark you want to apply?"
|
|
|
|
| 90 |
with gr.Column():
|
| 91 |
gr.Markdown("### Image to extract watermark")
|
| 92 |
input_image_to_extract = gr.Image(type='pil')
|
| 93 |
+
extract_radio_button = gr.Radio(
|
| 94 |
+
choices=["Steganography", "SSL Watermark"],
|
| 95 |
+
label="Extract methods",
|
| 96 |
+
value="Steganography"
|
| 97 |
+
)
|
| 98 |
with gr.Column():
|
| 99 |
gr.Markdown("### Extracted watermark")
|
| 100 |
extracted_watermark = gr.Image(type='pil')
|
|
|
|
| 110 |
inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url],
|
| 111 |
outputs=[output_image]
|
| 112 |
)
|
| 113 |
+
extract_button.click(
|
| 114 |
+
fn=extract_watermark,
|
| 115 |
+
inputs=[extract_radio_button, input_image_to_extract],
|
| 116 |
+
outputs=[extracted_watermark]
|
| 117 |
+
)
|
| 118 |
|
| 119 |
demo.launch()
|
dino_r50.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab26d85d00cb1be8e757cf8820cf0fd8aa729ea7e21b1cf6c44875952ba8eb0f
|
| 3 |
+
size 788803344
|
image_utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from torch.autograd.variable import Variable
|
| 9 |
+
|
| 10 |
+
NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 11 |
+
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
|
| 14 |
+
image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
|
| 15 |
+
|
| 16 |
+
def normalize_img(x):
|
| 17 |
+
return (x.to(device) - image_mean) / image_std
|
| 18 |
+
|
| 19 |
+
def unnormalize_img(x):
|
| 20 |
+
return (x.to(device) * image_std) + image_mean
|
| 21 |
+
|
| 22 |
+
def round_pixel(x):
|
| 23 |
+
x_pixel = 255 * unnormalize_img(x)
|
| 24 |
+
y = torch.round(x_pixel).clamp(0, 255)
|
| 25 |
+
y = normalize_img(y/255.0)
|
| 26 |
+
return y
|
| 27 |
+
|
| 28 |
+
def project_linf(x, y, radius):
|
| 29 |
+
""" Clamp x-y so that Linf(x,y)<=radius """
|
| 30 |
+
delta = x - y
|
| 31 |
+
delta = 255 * (delta * image_std)
|
| 32 |
+
delta = torch.clamp(delta, -radius, radius)
|
| 33 |
+
delta = (delta / 255.0) / image_std
|
| 34 |
+
return y + delta
|
| 35 |
+
|
| 36 |
+
def psnr_clip(x, y, target_psnr):
|
| 37 |
+
""" Clip x-y so that PSNR(x,y)=target_psnr """
|
| 38 |
+
delta = x - y
|
| 39 |
+
delta = 255 * (delta * image_std)
|
| 40 |
+
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
|
| 41 |
+
if psnr<target_psnr:
|
| 42 |
+
delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta
|
| 43 |
+
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
|
| 44 |
+
delta = (delta / 255.0) / image_std
|
| 45 |
+
return y + delta
|
| 46 |
+
|
| 47 |
+
def ssim_heatmap(img1, img2, window_size):
|
| 48 |
+
""" Compute the SSIM heatmap between 2 images """
|
| 49 |
+
_1D_window = torch.Tensor(
|
| 50 |
+
[np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
|
| 51 |
+
).to(device, non_blocking=True)
|
| 52 |
+
_1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
|
| 53 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 54 |
+
window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
|
| 55 |
+
|
| 56 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
|
| 57 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
|
| 58 |
+
|
| 59 |
+
mu1_sq = mu1.pow(2)
|
| 60 |
+
mu2_sq = mu2.pow(2)
|
| 61 |
+
mu1_mu2 = mu1*mu2
|
| 62 |
+
|
| 63 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
|
| 64 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
|
| 65 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
|
| 66 |
+
|
| 67 |
+
C1 = 0.01**2
|
| 68 |
+
C2 = 0.03**2
|
| 69 |
+
|
| 70 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
| 71 |
+
return ssim_map
|
| 72 |
+
|
| 73 |
+
def ssim_attenuation(x, y):
|
| 74 |
+
""" attenuate x-y using SSIM heatmap """
|
| 75 |
+
delta = x - y
|
| 76 |
+
ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
|
| 77 |
+
ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
|
| 78 |
+
ssim_map = torch.clamp_min(ssim_map,0)
|
| 79 |
+
delta = delta*ssim_map
|
| 80 |
+
return y + delta
|
out2048.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b256188454d8f7cf440de048df398e2a3209136a52cd7cdac834f5792f526a3
|
| 3 |
+
size 16786561
|
requirements.txt
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
|
|
|
|
|
| 1 |
Pillow
|
| 2 |
click
|
| 3 |
gradio
|
| 4 |
qrcode
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.10.1
|
| 2 |
+
torchvision==0.11.2
|
| 3 |
Pillow
|
| 4 |
click
|
| 5 |
gradio
|
| 6 |
qrcode
|
| 7 |
+
scipy
|
| 8 |
+
json
|
torch_utils.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torchvision import models
|
| 6 |
+
|
| 7 |
+
from scipy.optimize import root_scalar
|
| 8 |
+
from scipy.special import betainc
|
| 9 |
+
|
| 10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
|
| 12 |
+
def build_backbone(path, name='resnet50'):
|
| 13 |
+
""" Builds a pretrained ResNet-50 backbone. """
|
| 14 |
+
model = getattr(models, name)(pretrained=False)
|
| 15 |
+
model.head = nn.Identity()
|
| 16 |
+
model.fc = nn.Identity()
|
| 17 |
+
checkpoint = torch.load(path, map_location=device)
|
| 18 |
+
state_dict = checkpoint
|
| 19 |
+
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
|
| 20 |
+
if ckpt_key in checkpoint:
|
| 21 |
+
state_dict = checkpoint[ckpt_key]
|
| 22 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 23 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 24 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
def get_linear_layer(weight, bias):
|
| 28 |
+
""" Creates a layer that performs feature whitening or centering """
|
| 29 |
+
dim_out, dim_in = weight.shape
|
| 30 |
+
layer = nn.Linear(dim_in, dim_out)
|
| 31 |
+
layer.weight = nn.Parameter(weight)
|
| 32 |
+
layer.bias = nn.Parameter(bias)
|
| 33 |
+
return layer
|
| 34 |
+
|
| 35 |
+
def load_normalization_layer(path):
|
| 36 |
+
"""
|
| 37 |
+
Loads the normalization layer from a checkpoint and returns the layer.
|
| 38 |
+
"""
|
| 39 |
+
checkpoint = torch.load(path, map_location=device)
|
| 40 |
+
if 'whitening' in path or 'out' in path:
|
| 41 |
+
D = checkpoint['weight'].shape[1]
|
| 42 |
+
weight = torch.nn.Parameter(D*checkpoint['weight'])
|
| 43 |
+
bias = torch.nn.Parameter(D*checkpoint['bias'])
|
| 44 |
+
else:
|
| 45 |
+
weight = checkpoint['weight']
|
| 46 |
+
bias = checkpoint['bias']
|
| 47 |
+
return get_linear_layer(weight, bias).to(device, non_blocking=True)
|
| 48 |
+
|
| 49 |
+
class NormLayerWrapper(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
Wraps backbone model and normalization layer
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, backbone, head):
|
| 54 |
+
super(NormLayerWrapper, self).__init__()
|
| 55 |
+
backbone.eval(), head.eval()
|
| 56 |
+
self.backbone = backbone
|
| 57 |
+
self.head = head
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
output = self.backbone(x)
|
| 61 |
+
return self.head(output)
|
| 62 |
+
|
| 63 |
+
def cosine_pvalue(c, d, k=1):
|
| 64 |
+
"""
|
| 65 |
+
Returns the probability that the absolute value of the projection
|
| 66 |
+
between random unit vectors is higher than c
|
| 67 |
+
Args:
|
| 68 |
+
c: cosine value
|
| 69 |
+
d: dimension of the features
|
| 70 |
+
k: number of dimensions of the projection
|
| 71 |
+
"""
|
| 72 |
+
assert k>0
|
| 73 |
+
a = (d - k) / 2.0
|
| 74 |
+
b = k / 2.0
|
| 75 |
+
if c < 0:
|
| 76 |
+
return 1.0
|
| 77 |
+
return betainc(a, b, 1 - c ** 2)
|
| 78 |
+
|
| 79 |
+
def pvalue_angle(dim, k=1, angle=None, proba=None):
|
| 80 |
+
def f(a):
|
| 81 |
+
return cosine_pvalue(np.cos(a), dim, k) - proba
|
| 82 |
+
a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
|
| 83 |
+
# a = fsolve(f, x0=0.49*np.pi)[0]
|
| 84 |
+
return a.root
|