Source code for academic_metrics.dataclass_models.abstract_base_dataclass

from abc import ABC
from dataclasses import asdict, dataclass, fields
from typing import Any, Dict, List, Set


[docs] @dataclass class AbstractBaseDataClass(ABC): """ Abstract base class for all data model classes providing common functionality. Methods: to_dict: Converts the dataclass to a dictionary, handling Set conversion for JSON serialization. set_params: Sets the parameters from a dictionary, handling type conversions. """
[docs] def to_dict(self, exclude_keys: List[str] | None = None) -> dict: """Convert the dataclass to a dictionary, handling Set conversion for JSON serialization. Returns: dict: A dictionary representation of the dataclass. """ data_dict = asdict(self) # Remove excluded keys if any if exclude_keys: for key in exclude_keys: data_dict.pop(key, None) def convert_sets(obj): if isinstance(obj, dict): return {k: convert_sets(v) for k, v in obj.items()} elif isinstance(obj, Set): return list(obj) return obj # Convert sets to lists, including those in nested dicts return convert_sets(data_dict)
[docs] def set_params(self, params: Dict[str, Any], debug: bool = False) -> None: """ Updates the dataclass fields, merging sets and handling nested updates. It handles: 1. Converting lists to sets for fields annotated as Set 2. Merging sets instead of overwriting 3. Ignoring keys that don't match attributes 4. Handling nested dataclass updates Args: params (Dict[str, Any]): A dictionary of parameters to update the dataclass fields. Examples: >>> class MyClass(AbstractBaseDataClass): ... items: Set[str] = field(default_factory=set) >>> obj = MyClass() >>> obj.set_params({"items": ["a", "b"]}) >>> obj.set_params({"items": ["c", "d"]}) >>> sorted(list(obj.items)) # Contains all items ['a', 'b', 'c', 'd'] """ # Get fields from the concrete class, not the base class if debug: print( f"AbstractBaseDataClass.set_params called on {self.__class__.__name__}" ) input() print(f"Fields: {fields(self.__class__)}") input() field_types = {field.name: field.type for field in fields(self.__class__)} if debug: print(f"Field types: {field_types}") input() for key, value in params.items(): if debug: print(f"Processing {key} = {value}") input() if hasattr(self, key) and value is not None: current_value = getattr(self, key) if debug: print(f"Current value of {key}: {current_value}") input() # Handle Set fields if key in field_types and field_types[key] == Set[str]: if debug: print(f"{key} is a Set[str] field") input() # Convert input to set if needed if isinstance(value, (List, Set)): new_value = set(value) else: new_value = {str(value)} # Merge with existing set if isinstance(current_value, Set): current_value.update(new_value) else: setattr(self, key, new_value) # Handle other fields normally else: if debug: print(f"{key} is not a Set[str] field") input() setattr(self, key, value)