You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
from dataclasses import fields, is_dataclass
|
|
from typing import *
|
|
|
|
|
|
def dataclass_from_flat_list(cls: type, values: Tuple[Any, ...]) -> Any:
|
|
if not is_dataclass(cls):
|
|
raise TypeError(f"{cls} is not a dataclass")
|
|
|
|
idx = 0
|
|
init_values = {}
|
|
for field in fields(cls):
|
|
if is_dataclass(field.type):
|
|
inner_values = [values[idx + i] for i in range(len(fields(field.type)))]
|
|
init_values[field.name] = field.type(*inner_values)
|
|
idx += len(inner_values)
|
|
else:
|
|
if idx >= len(values):
|
|
raise IndexError(
|
|
f"Expected more values for dataclass {cls}. Current index: {idx}, values length: {len(values)}"
|
|
)
|
|
value = values[idx]
|
|
init_values[field.name] = value
|
|
idx += 1
|
|
return cls(**init_values)
|
|
|
|
|
|
def dataclasses_from_flat_list(
|
|
classes_mapping: List[type], values: Tuple[Any, ...]
|
|
) -> List[Any]:
|
|
instances = []
|
|
idx = 0
|
|
for cls in classes_mapping:
|
|
num_fields = sum(
|
|
len(fields(field.type)) if is_dataclass(field.type) else 1
|
|
for field in fields(cls)
|
|
)
|
|
instance = dataclass_from_flat_list(cls, values[idx : idx + num_fields])
|
|
instances.append(instance)
|
|
idx += num_fields
|
|
assert [
|
|
isinstance(i, t) for i, t in zip(instances, classes_mapping)
|
|
], "Instances should match types"
|
|
return instances
|