import copy
import functools
import json
import typing as t
from pydantic import BaseModel, Field, field_validator, ConfigDict, computed_field, model_validator
import numpy as np
import warnings
from collections import defaultdict
from aiida import orm
from aiida.common.constants import elements
from aiida.orm.nodes.data import Data
from aiida_atomistic.data.structure.site import Site, FrozenList, freeze_nested, FrozenSite
from aiida_atomistic.data.structure.kind import Kind
from aiida_quantumespresso.common.hubbard import Hubbard
from aiida_atomistic.data.structure import (
_atomic_masses,
_DEFAULT_CELL,
_DEFAULT_PBC,
_DEFAULT_VALUES,
)
[docs]
class StructureBaseModel(BaseModel):
"""
A base model representing a structure in atomistic simulations.
Attributes:
pbc (Optional[List[bool]]): Periodic boundary conditions in the x, y, and z directions.
cell (Optional[List[List[float]]]): The cell vectors defining the unit cell of the structure.
"""
[docs]
_mutable: t.ClassVar[bool] = True # class variable to control mutability
[docs]
pbc: list[bool] = Field(
default=_DEFAULT_PBC,
description="Periodic boundary conditions",
min_items=3,
max_items=3,
)
[docs]
cell: t.Union[np.ndarray[float]] = Field(
default=_DEFAULT_CELL,
description="Lattice vectors",
units="Angstrom",
)
[docs]
sites: list[Site] = Field(
default=[],
description="List of sites in the structure",
)
# global and more specific properties
[docs]
tot_magnetization: t.Optional[float] = Field(default=None)
[docs]
tot_charge: t.Optional[float] = Field(default=None)
[docs]
hubbard: t.Optional[Hubbard] = Field(default=Hubbard(parameters=[])) # to have access to the methods.
[docs]
custom: t.Optional[dict] = Field(default=None)
#validate_assignment = True
@field_validator('cell', mode='before')
@classmethod
[docs]
def validate_cell_shape(cls, v):
"""Ensure cell is always a 3x3 array."""
v = np.asarray(v)
v.flags.writeable = cls._mutable
if v.shape != (3, 3):
raise ValueError("The cell must be a 3x3 array.")
return v
@model_validator(mode='before')
[docs]
def check_minimal_requirements(cls, data):
"""
Validate the minimal requirements of the structure.
Args:
data (dict): The input data for the structure. This is automatically passed by pydantic.
Returns:
dict: The validated input data.
Raises:
ValueError: If the structure does not meet the minimal requirements.
"""
from aiida_atomistic.data.structure.utils import _check_valid_sites
if not data.get("sites", None):
# if no symbols, no positions, we just return the pbc and cell
return {
"pbc": data.get("pbc", cls.model_fields["pbc"].default),
"cell": data.get("cell", cls.model_fields["cell"].default),
"sites": []
}
# explicitly set default values for pbc and cell if not provided, so in the self.get_defined_properties() they are always there
for global_property in ['pbc', 'cell']:
if global_property not in data:
data[global_property] = cls.model_fields[global_property].default
return data
@field_validator('sites', mode='before')
[docs]
def validate_sites(cls, v):
"""Validate the list of sites."""
from aiida_atomistic.data.structure.utils import _check_valid_sites
if v is None:
return v
else:
# test if they can be converted to Site
sites = [Site.model_validate(site) if not isinstance(site, Site) else site for site in v]
_check_valid_sites(v)
return v
@field_validator('sites', mode='after')
[docs]
def freeze_sites(cls, v):
"""Freeze the list of sites if the structure is immutable."""
if not cls._mutable and v is not None:
return freeze_nested(v)
return v
# computed properties
@computed_field
[docs]
def cell_volume(self) -> float:
"""
Compute the volume of the unit cell.
Returns:
float: The volume of the unit cell in cubic Angstroms.
"""
from aiida_atomistic.data.structure.utils import calc_cell_volume
return calc_cell_volume(self.cell)
@computed_field
[docs]
def dimensionality(self) -> dict:
"""
Determine the dimensionality of the structure.
Returns:
dict: A dictionary indicating the dimensionality of the structure.
"""
from aiida_atomistic.data.structure.utils import get_dimensionality
return get_dimensionality(self.pbc, self.cell)
@computed_field
@computed_field
[docs]
def is_alloy(self) -> dict:
"""
Computed field to determine if the structure is an alloy.
"""
return any(_.is_alloy for _ in self.sites)
@computed_field
[docs]
def has_vacancies(self) -> bool:
"""
Computed field to determine if the structure has vacancies.
"""
return any(_.has_vacancies for _ in self.sites)
# HERE I AM DEFINING EXPLICITLY THE COMPUTED FIELDS LIKE POSITIONS AND KINDS, but maybe we can do it with some metaclass.
@computed_field
[docs]
def positions(self) -> np.ndarray:
"""
Return the positions of all sites in the structure as a numpy array.
Returns:
np.ndarray: An array of shape (N, 3) where N is the number of sites.
"""
if all(site.position is None for site in self.sites):
return None
return np.array([site.position for site in self.sites])
@computed_field
[docs]
def kind_names(self) -> t.List[str]:
"""
Return the list of kind names for all sites in the structure.
Returns:
List[str]: A list of kind names corresponding to each site.
"""
if all(site.kind_name is None for site in self.sites):
return None
return FrozenList([site.kind_name if site.kind_name is not None else site.symbol for site in self.sites])
@computed_field
[docs]
def symbols(self) -> t.List[str]:
"""
Return the list of chemical symbols for all sites in the structure.
Returns:
List[str]: A list of chemical symbols corresponding to each site.
"""
if all(site.symbol is None for site in self.sites):
return None
return FrozenList([site.symbol for site in self.sites])
@computed_field
[docs]
def masses(self) -> np.ndarray:
"""
Return the masses of all sites in the structure as a numpy array.
Returns:
np.ndarray: An array of masses corresponding to each site.
"""
if all(site.mass is None for site in self.sites):
return None
return np.array([site.mass for site in self.sites])
@computed_field
[docs]
def charges(self) -> np.ndarray:
"""
Return the charges of all sites in the structure as a numpy array.
Returns:
np.ndarray: An array of charges corresponding to each site.
"""
if all(site.charge is None for site in self.sites):
return None
return np.array([site.charge if site.charge else _DEFAULT_VALUES['charge'] for site in self.sites])
@computed_field
[docs]
def magmoms(self) -> np.ndarray:
"""
Return the magnetic moments of all sites in the structure as a numpy array.
Returns:
np.ndarray: An array of magnetic moments corresponding to each site.
"""
# if all none, return None, otherwise return array with default values if None
if all(site.magmom is None for site in self.sites):
return None
return np.array([site.magmom if site.magmom is not None else _DEFAULT_VALUES['magmom'] for site in self.sites])
@computed_field
[docs]
def magnetizations(self) -> np.ndarray:
"""
Return the magnetizations of all sites in the structure as a numpy array.
Returns:
np.ndarray: An array of magnetizations corresponding to each site.
"""
if all(site.magnetization is None for site in self.sites):
return None
return np.array([site.magnetization if site.magnetization is not None else _DEFAULT_VALUES['magnetization'] for site in self.sites])
@computed_field
[docs]
def weights(self) -> t.List[t.Tuple[float, ...]]:
"""
Return the weights of all sites in the structure as a list of tuples.
Returns:
List[Tuple[float, ...]]: A list of weight tuples corresponding to each site.
"""
if all(site.weight is None for site in self.sites):
return None
return FrozenList([site.weight if site.weight is not None else _DEFAULT_VALUES['weight'] for site in self.sites])
@computed_field
[docs]
def kinds(self) -> FrozenList[Kind]:
"""
Return the reduced set of kinds, grouping sites that share all properties except positions and site_indices.
"""
# Group sites by their kind_name. Here there is no kinds validation, just grouping.
# the validation can be done with the dedicated method validate_kinds
if not self.kind_names:
#raise ValueError("Kind names must be defined to access kinds.")
return None
kinds_list = []
kind_name_set = set(self.kind_names)
for idx, site in enumerate(self.sites):
kind_name = site.kind_name if site.kind_name else site.symbol
if kind_name in kind_name_set:
site_indices = [i for i, name in enumerate(self.kind_names) if name == kind_name]
positions=np.array([self.positions[i] for i in site_indices])
kind = Kind(
**site.model_dump(exclude={'position'}),
site_indices=site_indices,
positions=positions,
)
kinds_list.append(kind)
kind_name_set.remove(kind_name) # Ensure we don't add the same kind multiple
return FrozenList(kinds_list)
[docs]
class MutableStructureModel(StructureBaseModel):
"""
A mutable structure model that extends the StructureBaseModel class.
Attributes:
_mutable (bool): Flag indicating whether the structure is mutable or not.
sites (List[Site]): List of immutable sites in the structure.
"""
[docs]
class ImmutableStructureModel(StructureBaseModel):
"""
A class representing an immutable structure model.
This class inherits from `StructureBaseModel` and provides additional functionality for handling immutable structures.
Attributes:
_mutable (bool): Flag indicating whether the structure is mutable or not.
sites (List[Site]): List of immutable sites in the structure.
Config:
from_attributes (bool): Flag indicating whether to load attributes from the input data.
frozen (bool): Flag indicating whether the model is frozen or not.
arbitrary_types_allowed (bool): Flag indicating whether arbitrary types are allowed or not.
"""
[docs]
sites: t.Optional[list[FrozenSite]] = Field(
default=None,
description="List of sites in the structure",
)
[docs]
def __setattr__(self, key, value):
# Customizing the exception message when trying to mutate attributes
if key in self.model_fields:
raise ValueError("The AiiDA `StructureData` is immutable. You can create a mutable copy of it using its `get_value` method.")
super().__setattr__(key, value)