pkl to safetensors

main
Tran Xen 2 years ago
parent be02fdcd7d
commit 31635d369f

@ -1,3 +1,13 @@
# 1.1.2 :
+ BREAKING CHANGE : enforce face checkpoint format from pkl to safetensors
Using pkl files to store faces is dangerous from a security point of view. For the same reason that models are now stored in safetensors, We are switching to safetensors for the storage format.
A script with instructions for converting existing pkl files can be found here:
https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d
## 1.1.1 : ## 1.1.1 :
+ Add settings for default inpainting prompts + Add settings for default inpainting prompts

@ -28,7 +28,7 @@ class FaceSwapUnit(BaseModel):
# The checkpoint file # The checkpoint file
source_face: str = Field( source_face: str = Field(
description="face checkpoint (from models/faceswaplab/faces)", description="face checkpoint (from models/faceswaplab/faces)",
examples=["my_face.pkl"], examples=["my_face.safetensors"],
default=None, default=None,
) )
# base64 batch source images # base64 batch source images

@ -112,7 +112,7 @@ A face checkpoint is a saved embedding of a face, generated from multiple images
The primary advantage of face checkpoints is their size. An embedding is only around 2KB, meaning it's lightweight and can be reused later without requiring additional calculations. The primary advantage of face checkpoints is their size. An embedding is only around 2KB, meaning it's lightweight and can be reused later without requiring additional calculations.
Face checkpoints are saved as `.pkl` files. Please be aware that exchanging `.pkl` files carries potential security risks. These files, by default, are not secure and could potentially execute malicious code when opened. Therefore, extreme caution should be exercised when sharing or receiving this type of file. Face checkpoints are saved as `.safetensors` files. Please be aware that exchanging `.safetensors` files carries potential security risks. These files, by default, are not secure and could potentially execute malicious code when opened. Therefore, extreme caution should be exercised when sharing or receiving this type of file.
#### How is similarity determined? #### How is similarity determined?

@ -8,3 +8,14 @@ def preload(parser: ArgumentParser) -> None:
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
) )
print("FACESWAPLAB================================================================")
print("BREAKING CHANGE: enforce face checkpoint format from pkl to safetensors\n")
print("Using pkl files to store faces is dangerous from a security point of view.")
print("For the same reason that models are now stored in safetensors,")
print("We are switching to safetensors for the storage format.")
print(
"A script with instructions for converting existing pkl files can be found here:"
)
print("https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d")
print("==========================================================================")

@ -1,5 +1,4 @@
cython cython
dill==0.3.6
ifnude ifnude
insightface==0.7.3 insightface==0.7.3
onnx==1.14.0 onnx==1.14.0

@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
) )
VERSION_FLAG: str = "v1.1.1" VERSION_FLAG: str = "v1.1.2"
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab") EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW. # The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.

@ -1,14 +1,12 @@
import os import os
from pprint import pformat, pprint from pprint import pformat, pprint
from scripts.faceswaplab_utils import face_utils
import dill as pickle
import gradio as gr import gradio as gr
import modules.scripts as scripts import modules.scripts as scripts
import onnx import onnx
import pandas as pd import pandas as pd
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
from insightface.app.common import Face
from modules import scripts from modules import scripts
from PIL import Image from PIL import Image
from modules.shared import opts from modules.shared import opts
@ -128,10 +126,17 @@ def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> Optional[st
def sanitize_name(name: str) -> str: def sanitize_name(name: str) -> str:
logger.debug(f"Sanitize name {name}") """
Sanitize the input name by removing special characters and replacing spaces with underscores.
Parameters:
name (str): The input name to be sanitized.
Returns:
str: The sanitized name with special characters removed and spaces replaced by underscores.
"""
name = re.sub("[^A-Za-z0-9_. ]+", "", name) name = re.sub("[^A-Za-z0-9_. ]+", "", name)
name = name.replace(" ", "_") name = name.replace(" ", "_")
logger.debug(f"Sanitized name {name[:255]}")
return name[:255] return name[:255]
@ -185,25 +190,19 @@ def build_face_checkpoint_and_save(
), ),
) )
file_path = os.path.join(faces_path, f"{name}.pkl") file_path = os.path.join(faces_path, f"{name}.safetensors")
file_number = 1 file_number = 1
while os.path.exists(file_path): while os.path.exists(file_path):
file_path = os.path.join(faces_path, f"{name}_{file_number}.pkl") file_path = os.path.join(
faces_path, f"{name}_{file_number}.safetensors"
)
file_number += 1 file_number += 1
result_image.save(file_path + ".png") result_image.save(file_path + ".png")
with open(file_path, "wb") as file:
pickle.dump( face_utils.save_face(filename=file_path, face=blended_face)
{
"embedding": blended_face.embedding,
"gender": blended_face.gender,
"age": blended_face.age,
},
file,
)
try: try:
with open(file_path, "rb") as file: data = face_utils.load_face(filename=file_path)
data = Face(pickle.load(file)) print(data)
print(data)
except Exception as e: except Exception as e:
print(e) print(e)
return result_image return result_image

@ -4,12 +4,12 @@ import base64
import io import io
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, List, Optional, Set, Union from typing import Any, List, Optional, Set, Union
import dill as pickle
import gradio as gr import gradio as gr
from insightface.app.common import Face from insightface.app.common import Face
from PIL import Image from PIL import Image
from scripts.faceswaplab_utils.imgutils import pil_to_cv2 from scripts.faceswaplab_utils.imgutils import pil_to_cv2
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_utils import face_utils
@dataclass @dataclass
@ -94,8 +94,8 @@ class FaceSwapUnitSettings:
if self.source_face and self.source_face != "None": if self.source_face and self.source_face != "None":
with open(self.source_face, "rb") as file: with open(self.source_face, "rb") as file:
try: try:
logger.info(f"loading pickle {file.name}") logger.info(f"loading face {file.name}")
face = Face(pickle.load(file)) face = face_utils.load_face(file.name)
self._reference_face = face self._reference_face = face
except Exception as e: except Exception as e:
logger.error("Failed to load checkpoint : %s", e) logger.error("Failed to load checkpoint : %s", e)

@ -1,5 +1,5 @@
from typing import List from typing import List
from scripts.faceswaplab_utils.models_utils import get_face_checkpoints from scripts.faceswaplab_utils.face_utils import get_face_checkpoints
import gradio as gr import gradio as gr

@ -0,0 +1,48 @@
import glob
import os
from typing import List
from insightface.app.common import Face
from safetensors.torch import save_file, safe_open
import torch
import modules.scripts as scripts
from modules import scripts
from scripts.faceswaplab_utils.faceswaplab_logging import logger
def save_face(face: Face, filename: str) -> None:
tensors = {
"embedding": torch.tensor(face["embedding"]),
"gender": torch.tensor(face["gender"]),
"age": torch.tensor(face["age"]),
}
save_file(tensors, filename)
def load_face(filename: str) -> Face:
face = {}
logger.debug("Try to load face from %s", filename)
with safe_open(filename, framework="pt", device="cpu") as f:
logger.debug("File contains %s keys", f.keys())
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
logger.debug("face : %s", face)
return Face(face)
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
)
faces = glob.glob(faces_path)
return ["None"] + faces

@ -43,20 +43,3 @@ def get_current_model() -> str:
"No faceswap model found. Please add it to the faceswaplab directory." "No faceswap model found. Please add it to the faceswaplab directory."
) )
return model return model
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".pkl" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
)
faces = glob.glob(faces_path)
return ["None"] + faces

Loading…
Cancel
Save