You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
176 lines
6.2 KiB
Python
176 lines
6.2 KiB
Python
from scripts.faceswaplab_swapping import swapper
|
|
import numpy as np
|
|
import base64
|
|
import io
|
|
from dataclasses import dataclass, fields
|
|
from typing import List, Union
|
|
import dill as pickle
|
|
import gradio as gr
|
|
from insightface.app.common import Face
|
|
from PIL import Image
|
|
from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw
|
|
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
|
|
|
|
|
@dataclass
|
|
class FaceSwapUnitSettings:
|
|
# ORDER of parameters is IMPORTANT. It should match the result of faceswap_unit_ui
|
|
|
|
# The image given in reference
|
|
source_img: Union[Image.Image, str]
|
|
# The checkpoint file
|
|
source_face: str
|
|
# The batch source images
|
|
_batch_files: Union[gr.components.File, List[Image.Image]]
|
|
# Will blend faces if True
|
|
blend_faces: bool
|
|
# Enable this unit
|
|
enable: bool
|
|
# Use same gender filtering
|
|
same_gender: bool
|
|
# Sort faces by their size (from larger to smaller)
|
|
sort_by_size: bool
|
|
# If True, discard images with low similarity
|
|
check_similarity: bool
|
|
# if True will compute similarity and add it to the image info
|
|
_compute_similarity: bool
|
|
|
|
# Minimum similarity against the used face (reference, batch or checkpoint)
|
|
min_sim: float
|
|
# Minimum similarity against the reference (reference or checkpoint if checkpoint is given)
|
|
min_ref_sim: float
|
|
# The face index to use for swapping
|
|
_faces_index: str
|
|
# The face index to get image from source
|
|
reference_face_index: int
|
|
|
|
# Swap in the source image in img2img (before processing)
|
|
swap_in_source: bool
|
|
# Swap in the generated image in img2img (always on for txt2img)
|
|
swap_in_generated: bool
|
|
|
|
@staticmethod
|
|
def get_unit_configuration(unit: int, components):
|
|
fields_count = len(fields(FaceSwapUnitSettings))
|
|
return FaceSwapUnitSettings(
|
|
*components[unit * fields_count : unit * fields_count + fields_count]
|
|
)
|
|
|
|
@property
|
|
def faces_index(self):
|
|
"""
|
|
Convert _faces_index from str to int
|
|
"""
|
|
faces_index = {
|
|
int(x) for x in self._faces_index.strip(",").split(",") if x.isnumeric()
|
|
}
|
|
if len(faces_index) == 0:
|
|
return {0}
|
|
|
|
logger.debug("FACES INDEX : %s", faces_index)
|
|
|
|
return faces_index
|
|
|
|
@property
|
|
def compute_similarity(self):
|
|
return self._compute_similarity or self.check_similarity
|
|
|
|
@property
|
|
def batch_files(self):
|
|
"""
|
|
Return empty array instead of None for batch files
|
|
"""
|
|
return self._batch_files or []
|
|
|
|
@property
|
|
def reference_face(self):
|
|
"""
|
|
Extract reference face (only once and store it for the rest of processing).
|
|
Reference face is the checkpoint or the source image or the first image in the batch in that order.
|
|
"""
|
|
if not hasattr(self, "_reference_face"):
|
|
if self.source_face and self.source_face != "None":
|
|
with open(self.source_face, "rb") as file:
|
|
try:
|
|
logger.info(f"loading pickle {file.name}")
|
|
face = Face(pickle.load(file))
|
|
self._reference_face = face
|
|
except Exception as e:
|
|
logger.error("Failed to load checkpoint : %s", e)
|
|
elif self.source_img is not None:
|
|
if isinstance(self.source_img, str): # source_img is a base64 string
|
|
if (
|
|
"base64," in self.source_img
|
|
): # check if the base64 string has a data URL scheme
|
|
base64_data = self.source_img.split("base64,")[-1]
|
|
img_bytes = base64.b64decode(base64_data)
|
|
else:
|
|
# if no data URL scheme, just decode
|
|
img_bytes = base64.b64decode(self.source_img)
|
|
self.source_img = Image.open(io.BytesIO(img_bytes))
|
|
source_img = pil_to_cv2(self.source_img)
|
|
self._reference_face = swapper.get_or_default(
|
|
swapper.get_faces(source_img), self.reference_face_index, None
|
|
)
|
|
if self._reference_face is None:
|
|
logger.error("Face not found in reference image")
|
|
else:
|
|
self._reference_face = None
|
|
|
|
if self._reference_face is None:
|
|
logger.error("You need at least one reference face")
|
|
|
|
return self._reference_face
|
|
|
|
@property
|
|
def faces(self):
|
|
"""_summary_
|
|
Extract all faces (including reference face) to provide an array of faces
|
|
Only processed once.
|
|
"""
|
|
if self.batch_files is not None and not hasattr(self, "_faces"):
|
|
self._faces = (
|
|
[self.reference_face] if self.reference_face is not None else []
|
|
)
|
|
for file in self.batch_files:
|
|
if isinstance(file, Image.Image):
|
|
img = file
|
|
else:
|
|
img = Image.open(file.name)
|
|
|
|
face = swapper.get_or_default(
|
|
swapper.get_faces(pil_to_cv2(img)), 0, None
|
|
)
|
|
if face is not None:
|
|
self._faces.append(face)
|
|
return self._faces
|
|
|
|
@property
|
|
def blended_faces(self):
|
|
"""
|
|
Blend the faces using the mean of all embeddings
|
|
"""
|
|
if not hasattr(self, "_blended_faces"):
|
|
self._blended_faces = swapper.blend_faces(self.faces)
|
|
assert (
|
|
all(
|
|
[
|
|
not np.array_equal(
|
|
self._blended_faces.embedding, face.embedding
|
|
)
|
|
for face in self.faces
|
|
]
|
|
)
|
|
if len(self.faces) > 1
|
|
else True
|
|
), "Blended faces cannot be the same as one of the face if len(face)>0"
|
|
assert (
|
|
not np.array_equal(
|
|
self._blended_faces.embedding, self.reference_face.embedding
|
|
)
|
|
if len(self.faces) > 1
|
|
else True
|
|
), "Blended faces cannot be the same as reference face if len(face)>0"
|
|
|
|
return self._blended_faces
|