You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
import glob
|
|
import os
|
|
from typing import List
|
|
from insightface.app.common import Face
|
|
from safetensors.torch import save_file, safe_open
|
|
import torch
|
|
|
|
import modules.scripts as scripts
|
|
from modules import scripts
|
|
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
|
import dill as pickle # will be removed in future versions
|
|
|
|
|
|
def save_face(face: Face, filename: str) -> None:
|
|
tensors = {
|
|
"embedding": torch.tensor(face["embedding"]),
|
|
"gender": torch.tensor(face["gender"]),
|
|
"age": torch.tensor(face["age"]),
|
|
}
|
|
save_file(tensors, filename)
|
|
|
|
|
|
def load_face(filename: str) -> Face:
|
|
if filename.endswith(".pkl"):
|
|
logger.warning(
|
|
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
|
|
)
|
|
logger.warning("The file will be converted to .safetensors")
|
|
logger.warning(
|
|
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
|
|
)
|
|
with open(filename, "rb") as file:
|
|
logger.info("Load pkl")
|
|
face = Face(pickle.load(file))
|
|
logger.warning(
|
|
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
|
|
)
|
|
save_face(face, filename.replace(".pkl", ".safetensors"))
|
|
return face
|
|
|
|
elif filename.endswith(".safetensors"):
|
|
face = {}
|
|
with safe_open(filename, framework="pt", device="cpu") as f:
|
|
for k in f.keys():
|
|
logger.debug("load key %s", k)
|
|
face[k] = f.get_tensor(k).numpy()
|
|
return Face(face)
|
|
|
|
raise NotImplementedError("Unknown file type, face extraction not implemented")
|
|
|
|
|
|
def get_face_checkpoints() -> List[str]:
|
|
"""
|
|
Retrieve a list of face checkpoint paths.
|
|
|
|
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
|
|
containing the paths of those files.
|
|
|
|
Returns:
|
|
list: A list of face paths, including the string "None" as the first element.
|
|
"""
|
|
faces_path = os.path.join(
|
|
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
|
|
)
|
|
faces = glob.glob(faces_path)
|
|
|
|
faces_path = os.path.join(
|
|
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
|
|
)
|
|
faces += glob.glob(faces_path)
|
|
|
|
return ["None"] + sorted(faces)
|