add api for face building, add tests

main
Tran Xen 2 years ago
parent 4533750c49
commit 02d88bac91

@ -187,7 +187,7 @@ class FaceSwapRequest(BaseModel):
default=None, default=None,
) )
units: List[FaceSwapUnit] units: List[FaceSwapUnit]
postprocessing: Optional[PostProcessingOptions] postprocessing: Optional[PostProcessingOptions] = None
class FaceSwapResponse(BaseModel): class FaceSwapResponse(BaseModel):

@ -1,7 +1,9 @@
from typing import List
import requests import requests
from api_utils import ( from api_utils import (
FaceSwapUnit, FaceSwapUnit,
InswappperOptions, InswappperOptions,
base64_to_safetensors,
pil_to_base64, pil_to_base64,
PostProcessingOptions, PostProcessingOptions,
InpaintingWhen, InpaintingWhen,
@ -98,12 +100,30 @@ for img in response.pil_images:
img.show() img.show()
#############################
# Build checkpoint
source_images: List[str] = [
pil_to_base64("../references/man.png"),
pil_to_base64("../references/woman.png"),
]
result = requests.post(
url=f"{address}/faceswaplab/build",
json=source_images,
headers={"Content-Type": "application/json; charset=utf-8"},
)
base64_to_safetensors(result.json(), output_path="test.safetensors")
############################# #############################
# FaceSwap with local safetensors # FaceSwap with local safetensors
# First face unit : # First face unit :
unit1 = FaceSwapUnit( unit1 = FaceSwapUnit(
source_face=safetensors_to_base64("test.safetensors"), source_face=safetensors_to_base64(
"test.safetensors"
), # convert the checkpoint to base64
faces_index=(0,), # Replace first face faces_index=(0,), # Replace first face
swapping_options=InswappperOptions( swapping_options=InswappperOptions(
face_restorer_name="CodeFormer", face_restorer_name="CodeFormer",

Binary file not shown.

@ -1,3 +1,4 @@
import tempfile
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from fastapi import FastAPI from fastapi import FastAPI
@ -17,6 +18,9 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,
) )
from client_api import api_utils from client_api import api_utils
from scripts.faceswaplab_utils.face_checkpoints_utils import (
build_face_checkpoint_and_save,
)
def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: # type: ignore def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: # type: ignore
@ -135,3 +139,23 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
result_images = [encode_to_base64(img) for img in faces] result_images = [encode_to_base64(img) for img in faces]
response = api_utils.FaceSwapExtractResponse(images=result_images) response = api_utils.FaceSwapExtractResponse(images=result_images)
return response return response
@app.post(
"/faceswaplab/build",
tags=["faceswaplab"],
description="Build a face checkpoint using base64 images, return base64 satetensors",
)
async def build(base64_images: List[str]) -> Optional[str]:
if len(base64_images) > 0:
pil_images = [base64_to_pil(img) for img in base64_images]
with tempfile.NamedTemporaryFile(
delete=True, suffix=".safetensors"
) as temp_file:
build_face_checkpoint_and_save(
images=pil_images,
name="api_ckpt",
overwrite=True,
path=temp_file.name,
)
return api_utils.safetensors_to_base64(temp_file.name)
return None

@ -468,12 +468,12 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any:
return l[index] if index < len(l) else default return l[index] if index < len(l) else default
def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]: def get_faces_from_img_files(images: List[PILImage]) -> List[Optional[CV2ImgU8]]:
""" """
Extracts faces from a list of image files. Extracts faces from a list of image files.
Args: Args:
files (list): A list of file objects representing image files. images (list): A list of PILImage objects representing image files.
Returns: Returns:
list: A list of detected faces. list: A list of detected faces.
@ -482,9 +482,8 @@ def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]:
faces = [] faces = []
if len(files) > 0: if len(images) > 0:
for file in files: for img in images:
img = Image.open(file) # Open the image file
face = get_or_default( face = get_or_default(
get_faces(pil_to_cv2(img)), 0, None get_faces(pil_to_cv2(img)), 0, None
) # Extract faces from the image ) # Extract faces from the image

@ -153,9 +153,9 @@ def build_face_checkpoint_and_save(
if not batch_files: if not batch_files:
logger.error("No face found") logger.error("No face found")
return None return None
filenames = [x.name for x in batch_files] images = [Image.open(file.name) for file in batch_files]
preview_image = face_checkpoints_utils.build_face_checkpoint_and_save( preview_image = face_checkpoints_utils.build_face_checkpoint_and_save(
filenames, name, overwrite=overwrite images, name, overwrite=overwrite
) )
except Exception as e: except Exception as e:
logger.error("Failed to build checkpoint %s", e) logger.error("Failed to build checkpoint %s", e)

@ -38,7 +38,7 @@ def sanitize_name(name: str) -> str:
def build_face_checkpoint_and_save( def build_face_checkpoint_and_save(
batch_files: List[str], name: str, overwrite: bool = False images: List[PILImage], name: str, overwrite: bool = False, path: str = None
) -> PILImage: ) -> PILImage:
""" """
Builds a face checkpoint using the provided image files, performs face swapping, Builds a face checkpoint using the provided image files, performs face swapping,
@ -55,9 +55,9 @@ def build_face_checkpoint_and_save(
try: try:
name = sanitize_name(name) name = sanitize_name(name)
batch_files = batch_files or [] images = images or []
logger.info("Build %s %s", name, [x for x in batch_files]) logger.info("Build %s with %s images", name, len(images))
faces = swapper.get_faces_from_img_files(batch_files) faces = swapper.get_faces_from_img_files(images)
blended_face = swapper.blend_faces(faces) blended_face = swapper.blend_faces(faces)
preview_path = os.path.join( preview_path = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
@ -95,14 +95,17 @@ def build_face_checkpoint_and_save(
) )
preview_image = result.image preview_image = result.image
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors") if path:
if not overwrite: file_path = path
file_number = 1 else:
while os.path.exists(file_path): file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
file_path = os.path.join( if not overwrite:
get_checkpoint_path(), f"{name}_{file_number}.safetensors" file_number = 1
) while os.path.exists(file_path):
file_number += 1 file_path = os.path.join(
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
)
file_number += 1
save_face(filename=file_path, face=blended_face) save_face(filename=file_path, face=blended_face)
preview_image.save(file_path + ".png") preview_image.save(file_path + ".png")
try: try:

@ -2,22 +2,28 @@ from typing import List
import pytest import pytest
import requests import requests
import sys import sys
import tempfile
import safetensors
sys.path.append(".") sys.path.append(".")
import requests
from client_api.api_utils import ( from client_api.api_utils import (
FaceSwapUnit, FaceSwapUnit,
FaceSwapResponse, InswappperOptions,
PostProcessingOptions,
FaceSwapRequest,
base64_to_pil,
pil_to_base64, pil_to_base64,
PostProcessingOptions,
InpaintingWhen, InpaintingWhen,
FaceSwapCompareRequest, InpaintingOptions,
FaceSwapRequest,
FaceSwapResponse,
FaceSwapExtractRequest, FaceSwapExtractRequest,
FaceSwapCompareRequest,
FaceSwapExtractResponse, FaceSwapExtractResponse,
compare_faces, compare_faces,
InpaintingOptions, base64_to_pil,
base64_to_safetensors,
safetensors_to_base64,
) )
from PIL import Image from PIL import Image
@ -37,6 +43,13 @@ def face_swap_request() -> FaceSwapRequest:
source_img=pil_to_base64("references/woman.png"), # The face you want to use source_img=pil_to_base64("references/woman.png"), # The face you want to use
same_gender=True, same_gender=True,
faces_index=(0,), # Replace first woman since same gender is on faces_index=(0,), # Replace first woman since same gender is on
swapping_options=InswappperOptions(
face_restorer_name="CodeFormer",
upscaler_name="LDSR",
improved_mask=True,
sharpen=True,
color_corrections=True,
),
) )
# Post-processing config # Post-processing config
@ -179,3 +192,86 @@ def test_faceswap_inpainting(face_swap_request: FaceSwapRequest) -> None:
data = response.json() data = response.json()
assert "images" in data assert "images" in data
assert "infos" in data assert "infos" in data
def test_faceswap_checkpoint_building() -> None:
source_images: List[str] = [
pil_to_base64("references/man.png"),
pil_to_base64("references/woman.png"),
]
response = requests.post(
url=f"{base_url}/faceswaplab/build",
json=source_images,
headers={"Content-Type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
base64_to_safetensors(response.json(), output_path=temp_file.name)
with safetensors.safe_open(temp_file.name, framework="pt") as f:
assert "age" in f.keys()
assert "gender" in f.keys()
assert "embedding" in f.keys()
def test_faceswap_checkpoint_building_and_using() -> None:
source_images: List[str] = [
pil_to_base64("references/man.png"),
]
response = requests.post(
url=f"{base_url}/faceswaplab/build",
json=source_images,
headers={"Content-Type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
base64_to_safetensors(response.json(), output_path=temp_file.name)
with safetensors.safe_open(temp_file.name, framework="pt") as f:
assert "age" in f.keys()
assert "gender" in f.keys()
assert "embedding" in f.keys()
# First face unit :
unit1 = FaceSwapUnit(
source_face=safetensors_to_base64(
temp_file.name
), # convert the checkpoint to base64
faces_index=(0,), # Replace first face
swapping_options=InswappperOptions(
face_restorer_name="CodeFormer",
upscaler_name="LDSR",
improved_mask=True,
sharpen=True,
color_corrections=True,
),
)
# Prepare the request
request = FaceSwapRequest(
image=pil_to_base64("tests/test_image.png"), units=[unit1]
)
# Face Swap
response = requests.post(
url=f"{base_url}/faceswaplab/swap_face",
data=request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
fsr = FaceSwapResponse.parse_obj(response.json())
data = response.json()
assert "images" in data
assert "infos" in data
# First face is the man
assert (
compare_faces(
fsr.pil_images[0], Image.open("references/man.png"), base_url=base_url
)
> 0.5
)

Loading…
Cancel
Save