improve test, add extract to api

main
Tran Xen 2 years ago
parent b6add28267
commit be505f4086

@ -8,6 +8,7 @@ import base64, io
from io import BytesIO from io import BytesIO
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import numpy as np import numpy as np
import requests
class InpaintingWhen(Enum): class InpaintingWhen(Enum):
@ -151,7 +152,7 @@ class FaceSwapRequest(BaseModel):
class FaceSwapResponse(BaseModel): class FaceSwapResponse(BaseModel):
images: List[str] = Field(description="base64 swapped image", default=None) images: List[str] = Field(description="base64 swapped image", default=None)
infos: List[str] infos: Optional[List[str]] # not really used atm
@property @property
def pil_images(self) -> Image.Image: def pil_images(self) -> Image.Image:
@ -171,6 +172,23 @@ class FaceSwapCompareRequest(BaseModel):
) )
class FaceSwapExtractRequest(BaseModel):
images: List[str] = Field(
description="base64 reference image",
examples=["...."],
default=None,
)
postprocessing: Optional[PostProcessingOptions]
class FaceSwapExtractResponse(BaseModel):
images: List[str] = Field(description="base64 face images", default=None)
@property
def pil_images(self) -> Image.Image:
return [base64_to_pil(img) for img in self.images]
def pil_to_base64(img: Image.Image) -> np.array: # type:ignore def pil_to_base64(img: Image.Image) -> np.array: # type:ignore
if isinstance(img, str): if isinstance(img, str):
img = Image.open(img) img = Image.open(img)
@ -192,3 +210,20 @@ def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]:
# if no data URL scheme, just decode # if no data URL scheme, just decode
img_bytes = base64.b64decode(base64str) img_bytes = base64.b64decode(base64str)
return Image.open(io.BytesIO(img_bytes)) return Image.open(io.BytesIO(img_bytes))
def compare_faces(
image1: Image.Image, image2: Image.Image, base_url: str = "http://localhost:7860"
) -> float:
request = FaceSwapCompareRequest(
image1=pil_to_base64(image1),
image2=pil_to_base64(image2),
)
result = requests.post(
url=f"{base_url}/faceswaplab/compare",
data=request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
return float(result.text)

@ -7,10 +7,16 @@ from api_utils import (
pil_to_base64, pil_to_base64,
InpaintingWhen, InpaintingWhen,
FaceSwapCompareRequest, FaceSwapCompareRequest,
FaceSwapExtractRequest,
FaceSwapExtractResponse,
) )
address = "http://127.0.0.1:7860" address = "http://127.0.0.1:7860"
#############################
# FaceSwap
# First face unit : # First face unit :
unit1 = FaceSwapUnit( unit1 = FaceSwapUnit(
source_img=pil_to_base64("../references/man.png"), # The face you want to use source_img=pil_to_base64("../references/man.png"), # The face you want to use
@ -41,7 +47,7 @@ request = FaceSwapRequest(
image=pil_to_base64("test_image.png"), units=[unit1, unit2], postprocessing=pp image=pil_to_base64("test_image.png"), units=[unit1, unit2], postprocessing=pp
) )
# Face Swap
result = requests.post( result = requests.post(
url=f"{address}/faceswaplab/swap_face", url=f"{address}/faceswaplab/swap_face",
data=request.json(), data=request.json(),
@ -52,6 +58,8 @@ response = FaceSwapResponse.parse_obj(result.json())
for img in response.pil_images: for img in response.pil_images:
img.show() img.show()
#############################
# Comparison
request = FaceSwapCompareRequest( request = FaceSwapCompareRequest(
image1=pil_to_base64("../references/man.png"), image1=pil_to_base64("../references/man.png"),
@ -65,3 +73,21 @@ result = requests.post(
) )
print("similarity", result.text) print("similarity", result.text)
#############################
# Extraction
# Prepare the request
request = FaceSwapExtractRequest(
images=[pil_to_base64(response.pil_images[0])], postprocessing=pp
)
result = requests.post(
url=f"{address}/faceswaplab/extract",
data=request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
response = FaceSwapExtractResponse.parse_obj(result.json())
for img in response.pil_images:
img.show()

@ -110,7 +110,7 @@ class FaceSwapScript(scripts.Script):
components = [] components = []
for i in range(1, self.units_count + 1): for i in range(1, self.units_count + 1):
components += faceswaplab_unit_ui.faceswap_unit_ui(is_img2img, i) components += faceswaplab_unit_ui.faceswap_unit_ui(is_img2img, i)
upscaler = faceswaplab_tab.upscaler_ui() upscaler = faceswaplab_tab.postprocessing_ui()
# If the order is modified, the before_process should be changed accordingly. # If the order is modified, the before_process should be changed accordingly.
return components + upscaler return components + upscaler

@ -161,3 +161,22 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
return swapper.compare_faces( return swapper.compare_faces(
base64_to_pil(request.image1), base64_to_pil(request.image2) base64_to_pil(request.image1), base64_to_pil(request.image2)
) )
@app.post(
"/faceswaplab/extract",
tags=["faceswaplab"],
description="Extract faces of each images",
)
async def extract(
request: api_utils.FaceSwapExtractRequest,
) -> api_utils.FaceSwapExtractResponse:
pp_options = None
if request.postprocessing:
pp_options = get_postprocessing_options(request.postprocessing)
images = [base64_to_pil(img) for img in request.images]
faces = swapper.extract_faces(
images, extract_path=None, postprocess_options=pp_options
)
result_images = [encode_to_base64(img) for img in faces]
response = api_utils.FaceSwapExtractResponse(images=result_images)
return response

@ -41,13 +41,15 @@ def on_ui_settings() -> None:
"faceswaplab_detection_threshold", "faceswaplab_detection_threshold",
shared.OptionInfo( shared.OptionInfo(
0.5, 0.5,
"Detection threshold ", "Face Detection threshold",
gr.Slider, gr.Slider,
{"minimum": 0.1, "maximum": 0.99, "step": 0.001}, {"minimum": 0.1, "maximum": 0.99, "step": 0.001},
section=section, section=section,
), ),
) )
# DEFAULT UI SETTINGS
shared.opts.add_option( shared.opts.add_option(
"faceswaplab_pp_default_face_restorer", "faceswaplab_pp_default_face_restorer",
shared.OptionInfo( shared.OptionInfo(
@ -105,6 +107,30 @@ def on_ui_settings() -> None:
), ),
) )
shared.opts.add_option(
"faceswaplab_pp_default_inpainting_prompt",
shared.OptionInfo(
"Portrait of a [gender]",
"UI Default inpainting prompt [gender] is replaced by man or woman (requires restart)",
gr.Textbox,
{},
section=section,
),
)
shared.opts.add_option(
"faceswaplab_pp_default_inpainting_negative_prompt",
shared.OptionInfo(
"blurry",
"UI Default inpainting negative prompt [gender] (requires restart)",
gr.Textbox,
{},
section=section,
),
)
# UPSCALED SWAPPER
shared.opts.add_option( shared.opts.add_option(
"faceswaplab_upscaled_swapper", "faceswaplab_upscaled_swapper",
shared.OptionInfo( shared.OptionInfo(

@ -148,6 +148,71 @@ def batch_process(
return None return None
def extract_faces(
images: List[Image.Image],
extract_path: Optional[str],
postprocess_options: PostProcessingOptions,
) -> Optional[List[str]]:
"""
Extracts faces from a list of image files.
Given a list of image file paths, this function opens each image, extracts the faces,
and saves them in a specified directory. Post-processing is applied to each extracted face,
and the processed faces are saved as separate PNG files.
Parameters:
files (Optional[List[Image]]): List of file paths to the images to extract faces from.
extract_path (Optional[str]): Path where the extracted faces will be saved.
If no path is provided, a temporary directory will be created.
postprocess_options (PostProcessingOptions): Post-processing settings to be applied to the images.
Returns:
Optional[List[img]]: List of face images
"""
try:
if extract_path:
os.makedirs(extract_path, exist_ok=True)
if images:
result_images = []
for img in images:
faces = get_faces(pil_to_cv2(img))
if faces:
face_images = []
for face in faces:
bbox = face.bbox.astype(int)
x_min, y_min, x_max, y_max = bbox
face_image = img.crop((x_min, y_min, x_max, y_max))
if postprocess_options and (
postprocess_options.face_restorer_name
or postprocess_options.restorer_visibility
):
postprocess_options.scale = (
1 if face_image.width > 512 else 512 // face_image.width
)
face_image = enhance_image(face_image, postprocess_options)
if extract_path:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=extract_path
).name
face_image.save(path)
face_images.append(face_image)
result_images += face_images
return result_images
except Exception as e:
logger.info("Failed to extract : %s", e)
import traceback
traceback.print_exc()
return None
class FaceModelException(Exception): class FaceModelException(Exception):
"""Exception raised when an error is encountered in the face model.""" """Exception raised when an error is encountered in the face model."""

@ -6,7 +6,7 @@ from modules.shared import opts
from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen
def upscaler_ui() -> List[gr.components.Component]: def postprocessing_ui() -> List[gr.components.Component]:
with gr.Tab(f"Post-Processing"): with gr.Tab(f"Post-Processing"):
gr.Markdown( gr.Markdown(
"""Upscaling is performed on the whole image. Upscaling happens before face restoration.""" """Upscaling is performed on the whole image. Upscaling happens before face restoration."""
@ -87,12 +87,16 @@ def upscaler_ui() -> List[gr.components.Component]:
) )
inpainting_denoising_prompt = gr.Textbox( inpainting_denoising_prompt = gr.Textbox(
"Portrait of a [gender]", opts.data.get(
"faceswaplab_pp_default_inpainting_prompt", "Portrait of a [gender]"
),
elem_id="faceswaplab_pp_inpainting_denoising_prompt", elem_id="faceswaplab_pp_inpainting_denoising_prompt",
label="Inpainting prompt use [gender] instead of men or woman", label="Inpainting prompt use [gender] instead of men or woman",
) )
inpainting_denoising_negative_prompt = gr.Textbox( inpainting_denoising_negative_prompt = gr.Textbox(
"", opts.data.get(
"faceswaplab_pp_default_inpainting_negative_prompt", "blurry"
),
elem_id="faceswaplab_pp_inpainting_denoising_neg_prompt", elem_id="faceswaplab_pp_inpainting_denoising_neg_prompt",
label="Inpainting negative prompt use [gender] instead of men or woman", label="Inpainting negative prompt use [gender] instead of men or woman",
) )

@ -1,5 +1,4 @@
import os import os
import tempfile
from pprint import pformat, pprint from pprint import pformat, pprint
import dill as pickle import dill as pickle
@ -8,14 +7,13 @@ import modules.scripts as scripts
import onnx import onnx
import pandas as pd import pandas as pd
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
from scripts.faceswaplab_ui.faceswaplab_upscaler_ui import upscaler_ui from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
from insightface.app.common import Face from insightface.app.common import Face
from modules import scripts from modules import scripts
from PIL import Image from PIL import Image
from modules.shared import opts from modules.shared import opts
from scripts.faceswaplab_utils import imgutils from scripts.faceswaplab_utils import imgutils
from scripts.faceswaplab_utils.imgutils import pil_to_cv2
from scripts.faceswaplab_utils.models_utils import get_models from scripts.faceswaplab_utils.models_utils import get_models
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
import scripts.faceswaplab_swapping.swapper as swapper import scripts.faceswaplab_swapping.swapper as swapper
@ -54,7 +52,7 @@ def extract_faces(
files: List[gr.File], files: List[gr.File],
extract_path: Optional[str], extract_path: Optional[str],
*components: List[gr.components.Component], *components: List[gr.components.Component],
) -> Optional[List[str]]: ) -> Optional[List[Image.Image]]:
""" """
Extracts faces from a list of image files. Extracts faces from a list of image files.
@ -73,49 +71,13 @@ def extract_faces(
If no faces are found, None is returned. If no faces are found, None is returned.
""" """
try:
postprocess_options = PostProcessingOptions(*components) # type: ignore postprocess_options = PostProcessingOptions(*components) # type: ignore
images = [
if not extract_path: Image.open(file.name) for file in files
extract_path = tempfile.mkdtemp() ] # potentially greedy but Image.open is supposed to be lazy
return swapper.extract_faces(
if files: images, extract_path=extract_path, postprocess_options=postprocess_options
images = []
for file in files:
img = Image.open(file.name)
faces = swapper.get_faces(pil_to_cv2(img))
if faces:
face_images = []
for face in faces:
bbox = face.bbox.astype(int)
x_min, y_min, x_max, y_max = bbox
face_image = img.crop((x_min, y_min, x_max, y_max))
if (
postprocess_options.face_restorer_name
or postprocess_options.restorer_visibility
):
postprocess_options.scale = (
1 if face_image.width > 512 else 512 // face_image.width
) )
face_image = enhance_image(face_image, postprocess_options)
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=extract_path
).name
face_image.save(path)
face_images.append(path)
images += face_images
return images
except Exception as e:
logger.info("Failed to extract : %s", e)
import traceback
traceback.print_exc()
return None
def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> Optional[str]: def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> Optional[str]:
@ -459,7 +421,7 @@ def tools_ui() -> None:
for i in range(1, opts.data.get("faceswaplab_units_count", 3) + 1): for i in range(1, opts.data.get("faceswaplab_units_count", 3) + 1):
unit_components += faceswap_unit_ui(False, i, id_prefix="faceswaplab_tab") unit_components += faceswap_unit_ui(False, i, id_prefix="faceswaplab_tab")
upscale_options = upscaler_ui() upscale_options = postprocessing_ui()
explore_btn.click( explore_btn.click(
explore_onnx_faceswap_model, inputs=[model], outputs=[explore_result_text] explore_onnx_faceswap_model, inputs=[model], outputs=[explore_result_text]

@ -108,7 +108,7 @@ def faceswap_unit_ui(
elem_id=f"{id_prefix}_face{unit_num}_sort_by_size", elem_id=f"{id_prefix}_face{unit_num}_sort_by_size",
) )
target_faces_index = gr.Textbox( target_faces_index = gr.Textbox(
value="0", value=f"{unit_num-1}",
placeholder="Which face to swap (comma separated), start from 0 (by gender if same_gender is enabled)", placeholder="Which face to swap (comma separated), start from 0 (by gender if same_gender is enabled)",
label="Target face : Comma separated face number(s)", label="Target face : Comma separated face number(s)",
elem_id=f"{id_prefix}_face{unit_num}_target_faces_index", elem_id=f"{id_prefix}_face{unit_num}_target_faces_index",

@ -14,6 +14,9 @@ from client_api.api_utils import (
pil_to_base64, pil_to_base64,
InpaintingWhen, InpaintingWhen,
FaceSwapCompareRequest, FaceSwapCompareRequest,
FaceSwapExtractRequest,
FaceSwapExtractResponse,
compare_faces,
) )
from PIL import Image from PIL import Image
@ -79,6 +82,38 @@ def test_compare() -> None:
assert similarity > 0.90 assert similarity > 0.90
def test_extract() -> None:
pp = PostProcessingOptions(
face_restorer_name="CodeFormer",
codeformer_weight=0.5,
restorer_visibility=1,
upscaler_name="Lanczos",
)
request = FaceSwapExtractRequest(
images=[pil_to_base64("tests/test_image.png")], postprocessing=pp
)
response = requests.post(
url=f"{base_url}/faceswaplab/extract",
data=request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
res = FaceSwapExtractResponse.parse_obj(response.json())
assert len(res.pil_images) == 2
# First face is the man
assert (
compare_faces(
res.pil_images[0], Image.open("tests/test_image.png"), base_url=base_url
)
> 0.5
)
def test_faceswap(face_swap_request: FaceSwapRequest) -> None: def test_faceswap(face_swap_request: FaceSwapRequest) -> None:
response = requests.post( response = requests.post(
f"{base_url}/faceswaplab/swap_face", f"{base_url}/faceswaplab/swap_face",

Loading…
Cancel
Save