You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
9.0 KiB
Python

import glob
import os
from typing import *
from insightface.app.common import Face
from safetensors.torch import save_file, safe_open
import torch
import modules.scripts as scripts
from modules import scripts
from scripts.faceswaplab_swapping.upcaled_inswapper_options import InswappperOptions
from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_utils.typing import *
from scripts.faceswaplab_utils import imgutils
from scripts.faceswaplab_utils.models_utils import get_swap_models
import traceback
from scripts.faceswaplab_swapping import swapper
from pprint import pformat
import re
from client_api import api_utils
import tempfile
def sanitize_name(name: str) -> str:
"""
Sanitize the input name by removing special characters and replacing spaces with underscores.
Parameters:
name (str): The input name to be sanitized.
Returns:
str: The sanitized name with special characters removed and spaces replaced by underscores.
"""
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
name = name.replace(" ", "_")
return name[:255]
def build_face_checkpoint_and_save(
images: List[PILImage],
name: str,
overwrite: bool = False,
path: Optional[str] = None,
) -> Optional[PILImage]:
"""
Builds a face checkpoint using the provided image files, performs face swapping,
and saves the result to a file. If a blended face is successfully obtained and the face swapping
process succeeds, the resulting image is returned. Otherwise, None is returned.
Args:
batch_files (list): List of image file paths used to create the face checkpoint.
name (str): The name assigned to the face checkpoint.
Returns:
PIL.PILImage or None: The resulting swapped face image if the process is successful; None otherwise.
"""
try:
name = sanitize_name(name)
images = images or []
logger.info("Build %s with %s images", name, len(images))
faces: List[Face] = swapper.get_faces_from_img_files(images=images)
if faces is None or len(faces) == 0:
logger.error("No source faces found")
return None
blended_face: Optional[Face] = swapper.blend_faces(faces)
preview_path = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
reference_preview_img: PILImage
if blended_face:
if blended_face["gender"] == 0:
reference_preview_img = Image.open(
os.path.join(preview_path, "woman.png")
)
else:
reference_preview_img = Image.open(
os.path.join(preview_path, "man.png")
)
if name == "":
name = "default_name"
logger.debug("Face %s", pformat(blended_face))
target_face = swapper.get_or_default(
swapper.get_faces(imgutils.pil_to_cv2(reference_preview_img)), 0, None
)
if target_face is None:
logger.error(
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
)
else:
result: swapper.ImageResult = swapper.swap_face(
target_faces=[target_face],
source_face=blended_face,
target_img=reference_preview_img,
model=get_swap_models()[0],
swapping_options=InswappperOptions(
face_restorer_name="CodeFormer",
restorer_visibility=1,
upscaler_name="Lanczos",
codeformer_weight=1,
improved_mask=True,
color_corrections=False,
sharpen=True,
),
)
preview_image = result.image
if path:
file_path = path
else:
file_path = os.path.join(
get_checkpoint_path(), f"{name}.safetensors"
)
if not overwrite:
file_number = 1
while os.path.exists(file_path):
file_path = os.path.join(
get_checkpoint_path(),
f"{name}_{file_number}.safetensors",
)
file_number += 1
save_face(filename=file_path, face=blended_face)
preview_image.save(file_path + ".png")
try:
data = load_face(file_path)
logger.debug(data)
except Exception as e:
logger.error("Error loading checkpoint, after creation %s", e)
traceback.print_exc()
return preview_image
else:
logger.error("No face found")
return None # type: ignore
except Exception as e:
logger.error("Failed to build checkpoint %s", e)
traceback.print_exc()
return None
def save_face(face: Face, filename: str) -> None:
try:
tensors = {
"embedding": torch.tensor(face["embedding"]),
"gender": torch.tensor(face["gender"]),
"age": torch.tensor(face["age"]),
}
save_file(tensors, filename)
except Exception as e:
traceback.print_exc
logger.error("Failed to save checkpoint %s", e)
raise e
def load_face(name: str) -> Optional[Face]:
if name.startswith("data:application/face;base64,"):
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
api_utils.base64_to_safetensors(name, temp_file.name)
face = {}
with safe_open(temp_file.name, framework="pt", device="cpu") as f:
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
return Face(face)
filename = matching_checkpoint(name)
if filename is None:
return None
if filename.endswith(".pkl"):
logger.warning(
"Pkl files for faces are deprecated to enhance safety, you need to convert them"
)
logger.warning("The file will be converted to .safetensors")
logger.warning(
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
)
return None
elif filename.endswith(".safetensors"):
face = {}
with safe_open(filename, framework="pt", device="cpu") as f:
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
return Face(face)
raise NotImplementedError("Unknown file type, face extraction not implemented")
def get_checkpoint_path() -> str:
checkpoint_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
os.makedirs(checkpoint_path, exist_ok=True)
return checkpoint_path
def matching_checkpoint(name: str) -> Optional[str]:
"""
Retrieve the full path of a checkpoint file matching the given name.
If the name already includes a path separator, it is returned as-is. Otherwise, the function looks for a matching
file with the extensions ".safetensors" or ".pkl" in the checkpoint directory.
Args:
name (str): The name or path of the checkpoint file.
Returns:
Optional[str]: The full path of the matching checkpoint file, or None if no match is found.
"""
# If the name already includes a path separator, return it as is
if os.path.sep in name:
return name
# If the name doesn't end with the specified extensions, look for a matching file
if not (name.endswith(".safetensors") or name.endswith(".pkl")):
# Try appending each extension and check if the file exists in the checkpoint path
for ext in [".safetensors", ".pkl"]:
full_path = os.path.join(get_checkpoint_path(), name + ext)
if os.path.exists(full_path):
return full_path
# If no matching file is found, return None
return None
# If the name already ends with the specified extensions, simply complete the path
return os.path.join(get_checkpoint_path(), name)
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(get_checkpoint_path(), "*.safetensors")
faces = glob.glob(faces_path)
faces_path = os.path.join(get_checkpoint_path(), "*.pkl")
faces += glob.glob(faces_path)
return ["None"] + [os.path.basename(face) for face in sorted(faces)]