fix similarity, add checksum for swapper, fix minor bugs
parent
ee7f7d09d2
commit
b773bda19f
@ -0,0 +1,236 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from typing import *
|
||||||
|
from insightface.app.common import Face
|
||||||
|
from safetensors.torch import save_file, safe_open
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import modules.scripts as scripts
|
||||||
|
from modules import scripts
|
||||||
|
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||||
|
from scripts.faceswaplab_utils.typing import *
|
||||||
|
from scripts.faceswaplab_utils import imgutils
|
||||||
|
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||||
|
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||||
|
PostProcessingOptions,
|
||||||
|
)
|
||||||
|
from scripts.faceswaplab_utils.models_utils import get_models
|
||||||
|
from modules.shared import opts
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import dill as pickle # will be removed in future versions
|
||||||
|
from scripts.faceswaplab_swapping import swapper
|
||||||
|
from pprint import pformat
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize the input name by removing special characters and replacing spaces with underscores.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
name (str): The input name to be sanitized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized name with special characters removed and spaces replaced by underscores.
|
||||||
|
"""
|
||||||
|
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
|
||||||
|
name = name.replace(" ", "_")
|
||||||
|
return name[:255]
|
||||||
|
|
||||||
|
|
||||||
|
def build_face_checkpoint_and_save(
|
||||||
|
batch_files: List[str], name: str, overwrite: bool = False
|
||||||
|
) -> PILImage:
|
||||||
|
"""
|
||||||
|
Builds a face checkpoint using the provided image files, performs face swapping,
|
||||||
|
and saves the result to a file. If a blended face is successfully obtained and the face swapping
|
||||||
|
process succeeds, the resulting image is returned. Otherwise, None is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_files (list): List of image file paths used to create the face checkpoint.
|
||||||
|
name (str): The name assigned to the face checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL.PILImage or None: The resulting swapped face image if the process is successful; None otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
name = sanitize_name(name)
|
||||||
|
batch_files = batch_files or []
|
||||||
|
logger.info("Build %s %s", name, [x for x in batch_files])
|
||||||
|
faces = swapper.get_faces_from_img_files(batch_files)
|
||||||
|
blended_face = swapper.blend_faces(faces)
|
||||||
|
preview_path = os.path.join(
|
||||||
|
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_preview_img: PILImage = None
|
||||||
|
if blended_face:
|
||||||
|
if blended_face["gender"] == 0:
|
||||||
|
reference_preview_img = Image.open(
|
||||||
|
os.path.join(preview_path, "woman.png")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reference_preview_img = Image.open(
|
||||||
|
os.path.join(preview_path, "man.png")
|
||||||
|
)
|
||||||
|
|
||||||
|
if name == "":
|
||||||
|
name = "default_name"
|
||||||
|
logger.debug("Face %s", pformat(blended_face))
|
||||||
|
target_face = swapper.get_or_default(
|
||||||
|
swapper.get_faces(imgutils.pil_to_cv2(reference_preview_img)), 0, None
|
||||||
|
)
|
||||||
|
if target_face is None:
|
||||||
|
logger.error(
|
||||||
|
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = swapper.swap_face(
|
||||||
|
reference_face=blended_face,
|
||||||
|
target_faces=[target_face],
|
||||||
|
source_face=blended_face,
|
||||||
|
target_img=reference_preview_img,
|
||||||
|
model=get_models()[0],
|
||||||
|
upscaled_swapper=opts.data.get(
|
||||||
|
"faceswaplab_upscaled_swapper", False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
preview_image = enhance_image(
|
||||||
|
result.image,
|
||||||
|
PostProcessingOptions(
|
||||||
|
face_restorer_name="CodeFormer", restorer_visibility=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
|
||||||
|
if not overwrite:
|
||||||
|
file_number = 1
|
||||||
|
while os.path.exists(file_path):
|
||||||
|
file_path = os.path.join(
|
||||||
|
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
|
||||||
|
)
|
||||||
|
file_number += 1
|
||||||
|
save_face(filename=file_path, face=blended_face)
|
||||||
|
preview_image.save(file_path + ".png")
|
||||||
|
try:
|
||||||
|
data = load_face(file_path)
|
||||||
|
logger.debug(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading checkpoint, after creation %s", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return preview_image
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error("No face found")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to build checkpoint %s", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_face(face: Face, filename: str) -> None:
|
||||||
|
try:
|
||||||
|
tensors = {
|
||||||
|
"embedding": torch.tensor(face["embedding"]),
|
||||||
|
"gender": torch.tensor(face["gender"]),
|
||||||
|
"age": torch.tensor(face["age"]),
|
||||||
|
}
|
||||||
|
save_file(tensors, filename)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc
|
||||||
|
logger.error("Failed to save checkpoint %s", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def load_face(name: str) -> Face:
|
||||||
|
filename = matching_checkpoint(name)
|
||||||
|
if filename is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if filename.endswith(".pkl"):
|
||||||
|
logger.warning(
|
||||||
|
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
|
||||||
|
)
|
||||||
|
logger.warning("The file will be converted to .safetensors")
|
||||||
|
logger.warning(
|
||||||
|
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
|
||||||
|
)
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
logger.info("Load pkl")
|
||||||
|
face = Face(pickle.load(file))
|
||||||
|
logger.warning(
|
||||||
|
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
|
||||||
|
)
|
||||||
|
save_face(face, filename.replace(".pkl", ".safetensors"))
|
||||||
|
return face
|
||||||
|
|
||||||
|
elif filename.endswith(".safetensors"):
|
||||||
|
face = {}
|
||||||
|
with safe_open(filename, framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
logger.debug("load key %s", k)
|
||||||
|
face[k] = f.get_tensor(k).numpy()
|
||||||
|
return Face(face)
|
||||||
|
|
||||||
|
raise NotImplementedError("Unknown file type, face extraction not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_path() -> str:
|
||||||
|
checkpoint_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
|
||||||
|
os.makedirs(checkpoint_path, exist_ok=True)
|
||||||
|
return checkpoint_path
|
||||||
|
|
||||||
|
|
||||||
|
def matching_checkpoint(name: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve the full path of a checkpoint file matching the given name.
|
||||||
|
|
||||||
|
If the name already includes a path separator, it is returned as-is. Otherwise, the function looks for a matching
|
||||||
|
file with the extensions ".safetensors" or ".pkl" in the checkpoint directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name or path of the checkpoint file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The full path of the matching checkpoint file, or None if no match is found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If the name already includes a path separator, return it as is
|
||||||
|
if os.path.sep in name:
|
||||||
|
return name
|
||||||
|
|
||||||
|
# If the name doesn't end with the specified extensions, look for a matching file
|
||||||
|
if not (name.endswith(".safetensors") or name.endswith(".pkl")):
|
||||||
|
# Try appending each extension and check if the file exists in the checkpoint path
|
||||||
|
for ext in [".safetensors", ".pkl"]:
|
||||||
|
full_path = os.path.join(get_checkpoint_path(), name + ext)
|
||||||
|
if os.path.exists(full_path):
|
||||||
|
return full_path
|
||||||
|
# If no matching file is found, return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If the name already ends with the specified extensions, simply complete the path
|
||||||
|
return os.path.join(get_checkpoint_path(), name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_face_checkpoints() -> List[str]:
|
||||||
|
"""
|
||||||
|
Retrieve a list of face checkpoint paths.
|
||||||
|
|
||||||
|
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
|
||||||
|
containing the paths of those files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of face paths, including the string "None" as the first element.
|
||||||
|
"""
|
||||||
|
faces_path = os.path.join(get_checkpoint_path(), "*.safetensors")
|
||||||
|
faces = glob.glob(faces_path)
|
||||||
|
|
||||||
|
faces_path = os.path.join(get_checkpoint_path(), "*.pkl")
|
||||||
|
faces += glob.glob(faces_path)
|
||||||
|
|
||||||
|
return ["None"] + [os.path.basename(face) for face in sorted(faces)]
|
@ -1,72 +0,0 @@
|
|||||||
import glob
|
|
||||||
import os
|
|
||||||
from typing import List
|
|
||||||
from insightface.app.common import Face
|
|
||||||
from safetensors.torch import save_file, safe_open
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import modules.scripts as scripts
|
|
||||||
from modules import scripts
|
|
||||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
|
||||||
import dill as pickle # will be removed in future versions
|
|
||||||
|
|
||||||
|
|
||||||
def save_face(face: Face, filename: str) -> None:
|
|
||||||
tensors = {
|
|
||||||
"embedding": torch.tensor(face["embedding"]),
|
|
||||||
"gender": torch.tensor(face["gender"]),
|
|
||||||
"age": torch.tensor(face["age"]),
|
|
||||||
}
|
|
||||||
save_file(tensors, filename)
|
|
||||||
|
|
||||||
|
|
||||||
def load_face(filename: str) -> Face:
|
|
||||||
if filename.endswith(".pkl"):
|
|
||||||
logger.warning(
|
|
||||||
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
|
|
||||||
)
|
|
||||||
logger.warning("The file will be converted to .safetensors")
|
|
||||||
logger.warning(
|
|
||||||
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
|
|
||||||
)
|
|
||||||
with open(filename, "rb") as file:
|
|
||||||
logger.info("Load pkl")
|
|
||||||
face = Face(pickle.load(file))
|
|
||||||
logger.warning(
|
|
||||||
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
|
|
||||||
)
|
|
||||||
save_face(face, filename.replace(".pkl", ".safetensors"))
|
|
||||||
return face
|
|
||||||
|
|
||||||
elif filename.endswith(".safetensors"):
|
|
||||||
face = {}
|
|
||||||
with safe_open(filename, framework="pt", device="cpu") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
logger.debug("load key %s", k)
|
|
||||||
face[k] = f.get_tensor(k).numpy()
|
|
||||||
return Face(face)
|
|
||||||
|
|
||||||
raise NotImplementedError("Unknown file type, face extraction not implemented")
|
|
||||||
|
|
||||||
|
|
||||||
def get_face_checkpoints() -> List[str]:
|
|
||||||
"""
|
|
||||||
Retrieve a list of face checkpoint paths.
|
|
||||||
|
|
||||||
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
|
|
||||||
containing the paths of those files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: A list of face paths, including the string "None" as the first element.
|
|
||||||
"""
|
|
||||||
faces_path = os.path.join(
|
|
||||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
|
|
||||||
)
|
|
||||||
faces = glob.glob(faces_path)
|
|
||||||
|
|
||||||
faces_path = os.path.join(
|
|
||||||
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
|
|
||||||
)
|
|
||||||
faces += glob.glob(faces_path)
|
|
||||||
|
|
||||||
return ["None"] + sorted(faces)
|
|
Loading…
Reference in New Issue