Source code for aiida_atomistic.data.structure.models

import copy
import functools
import json
import typing as t
from pydantic import BaseModel, Field, field_validator, ConfigDict, computed_field, model_validator
import numpy as np
import warnings

from collections import defaultdict

from aiida import orm
from aiida.common.constants import elements
from aiida.orm.nodes.data import Data

from aiida_atomistic.data.structure.site import Site, FrozenList, freeze_nested, FrozenSite
from aiida_atomistic.data.structure.kind import Kind

from aiida_quantumespresso.common.hubbard import Hubbard

from aiida_atomistic.data.structure import (
    _atomic_masses,
    _DEFAULT_CELL,
    _DEFAULT_PBC,
    _DEFAULT_VALUES,
)

[docs] class StructureBaseModel(BaseModel): """ A base model representing a structure in atomistic simulations. Attributes: pbc (Optional[List[bool]]): Periodic boundary conditions in the x, y, and z directions. cell (Optional[List[List[float]]]): The cell vectors defining the unit cell of the structure. """
[docs] _mutable: t.ClassVar[bool] = True # class variable to control mutability
[docs] pbc: list[bool] = Field( default=_DEFAULT_PBC, description="Periodic boundary conditions", min_items=3, max_items=3, )
[docs] cell: t.Union[np.ndarray[float]] = Field( default=_DEFAULT_CELL, description="Lattice vectors", units="Angstrom", )
[docs] sites: list[Site] = Field( default=[], description="List of sites in the structure", )
# global and more specific properties
[docs] tot_magnetization: t.Optional[float] = Field(default=None)
[docs] tot_charge: t.Optional[float] = Field(default=None)
[docs] hubbard: t.Optional[Hubbard] = Field(default=Hubbard(parameters=[])) # to have access to the methods.
[docs] custom: t.Optional[dict] = Field(default=None)
[docs] class Config:
[docs] from_attributes = True
[docs] frozen = False
[docs] arbitrary_types_allowed = True
#validate_assignment = True @field_validator('cell', mode='before') @classmethod
[docs] def validate_cell_shape(cls, v): """Ensure cell is always a 3x3 array.""" v = np.asarray(v) v.flags.writeable = cls._mutable if v.shape != (3, 3): raise ValueError("The cell must be a 3x3 array.") return v
@model_validator(mode='before')
[docs] def check_minimal_requirements(cls, data): """ Validate the minimal requirements of the structure. Args: data (dict): The input data for the structure. This is automatically passed by pydantic. Returns: dict: The validated input data. Raises: ValueError: If the structure does not meet the minimal requirements. """ from aiida_atomistic.data.structure.utils import _check_valid_sites if not data.get("sites", None): # if no symbols, no positions, we just return the pbc and cell return { "pbc": data.get("pbc", cls.model_fields["pbc"].default), "cell": data.get("cell", cls.model_fields["cell"].default), "sites": [] } # explicitly set default values for pbc and cell if not provided, so in the self.get_defined_properties() they are always there for global_property in ['pbc', 'cell']: if global_property not in data: data[global_property] = cls.model_fields[global_property].default return data
@field_validator('sites', mode='before')
[docs] def validate_sites(cls, v): """Validate the list of sites.""" from aiida_atomistic.data.structure.utils import _check_valid_sites if v is None: return v else: # test if they can be converted to Site sites = [Site.model_validate(site) if not isinstance(site, Site) else site for site in v] _check_valid_sites(v) return v
@field_validator('sites', mode='after')
[docs] def freeze_sites(cls, v): """Freeze the list of sites if the structure is immutable.""" if not cls._mutable and v is not None: return freeze_nested(v) return v
# computed properties @computed_field
[docs] def cell_volume(self) -> float: """ Compute the volume of the unit cell. Returns: float: The volume of the unit cell in cubic Angstroms. """ from aiida_atomistic.data.structure.utils import calc_cell_volume return calc_cell_volume(self.cell)
@computed_field
[docs] def dimensionality(self) -> dict: """ Determine the dimensionality of the structure. Returns: dict: A dictionary indicating the dimensionality of the structure. """ from aiida_atomistic.data.structure.utils import get_dimensionality return get_dimensionality(self.pbc, self.cell)
@computed_field
[docs] def formula(self) -> str: """ Get the chemical formula of the structure. Returns: str: The chemical formula of the structure. """ from aiida_atomistic.data.structure.utils import get_formula return get_formula(self.sites)
@computed_field
[docs] def is_alloy(self) -> dict: """ Computed field to determine if the structure is an alloy. """ return any(_.is_alloy for _ in self.sites)
@computed_field
[docs] def has_vacancies(self) -> bool: """ Computed field to determine if the structure has vacancies. """ return any(_.has_vacancies for _ in self.sites)
# HERE I AM DEFINING EXPLICITLY THE COMPUTED FIELDS LIKE POSITIONS AND KINDS, but maybe we can do it with some metaclass. @computed_field
[docs] def positions(self) -> np.ndarray: """ Return the positions of all sites in the structure as a numpy array. Returns: np.ndarray: An array of shape (N, 3) where N is the number of sites. """ if all(site.position is None for site in self.sites): return None return np.array([site.position for site in self.sites])
@computed_field
[docs] def kind_names(self) -> t.List[str]: """ Return the list of kind names for all sites in the structure. Returns: List[str]: A list of kind names corresponding to each site. """ if all(site.kind_name is None for site in self.sites): return None return FrozenList([site.kind_name if site.kind_name is not None else site.symbol for site in self.sites])
@computed_field
[docs] def symbols(self) -> t.List[str]: """ Return the list of chemical symbols for all sites in the structure. Returns: List[str]: A list of chemical symbols corresponding to each site. """ if all(site.symbol is None for site in self.sites): return None return FrozenList([site.symbol for site in self.sites])
@computed_field
[docs] def masses(self) -> np.ndarray: """ Return the masses of all sites in the structure as a numpy array. Returns: np.ndarray: An array of masses corresponding to each site. """ if all(site.mass is None for site in self.sites): return None return np.array([site.mass for site in self.sites])
@computed_field
[docs] def charges(self) -> np.ndarray: """ Return the charges of all sites in the structure as a numpy array. Returns: np.ndarray: An array of charges corresponding to each site. """ if all(site.charge is None for site in self.sites): return None return np.array([site.charge if site.charge else _DEFAULT_VALUES['charge'] for site in self.sites])
@computed_field
[docs] def magmoms(self) -> np.ndarray: """ Return the magnetic moments of all sites in the structure as a numpy array. Returns: np.ndarray: An array of magnetic moments corresponding to each site. """ # if all none, return None, otherwise return array with default values if None if all(site.magmom is None for site in self.sites): return None return np.array([site.magmom if site.magmom is not None else _DEFAULT_VALUES['magmom'] for site in self.sites])
@computed_field
[docs] def magnetizations(self) -> np.ndarray: """ Return the magnetizations of all sites in the structure as a numpy array. Returns: np.ndarray: An array of magnetizations corresponding to each site. """ if all(site.magnetization is None for site in self.sites): return None return np.array([site.magnetization if site.magnetization is not None else _DEFAULT_VALUES['magnetization'] for site in self.sites])
@computed_field
[docs] def weights(self) -> t.List[t.Tuple[float, ...]]: """ Return the weights of all sites in the structure as a list of tuples. Returns: List[Tuple[float, ...]]: A list of weight tuples corresponding to each site. """ if all(site.weight is None for site in self.sites): return None return FrozenList([site.weight if site.weight is not None else _DEFAULT_VALUES['weight'] for site in self.sites])
@computed_field
[docs] def kinds(self) -> FrozenList[Kind]: """ Return the reduced set of kinds, grouping sites that share all properties except positions and site_indices. """ # Group sites by their kind_name. Here there is no kinds validation, just grouping. # the validation can be done with the dedicated method validate_kinds if not self.kind_names: #raise ValueError("Kind names must be defined to access kinds.") return None kinds_list = [] kind_name_set = set(self.kind_names) for idx, site in enumerate(self.sites): kind_name = site.kind_name if site.kind_name else site.symbol if kind_name in kind_name_set: site_indices = [i for i, name in enumerate(self.kind_names) if name == kind_name] positions=np.array([self.positions[i] for i in site_indices]) kind = Kind( **site.model_dump(exclude={'position'}), site_indices=site_indices, positions=positions, ) kinds_list.append(kind) kind_name_set.remove(kind_name) # Ensure we don't add the same kind multiple return FrozenList(kinds_list)
[docs] class MutableStructureModel(StructureBaseModel): """ A mutable structure model that extends the StructureBaseModel class. Attributes: _mutable (bool): Flag indicating whether the structure is mutable or not. sites (List[Site]): List of immutable sites in the structure. """
[docs] _mutable = True
[docs] class ImmutableStructureModel(StructureBaseModel): """ A class representing an immutable structure model. This class inherits from `StructureBaseModel` and provides additional functionality for handling immutable structures. Attributes: _mutable (bool): Flag indicating whether the structure is mutable or not. sites (List[Site]): List of immutable sites in the structure. Config: from_attributes (bool): Flag indicating whether to load attributes from the input data. frozen (bool): Flag indicating whether the model is frozen or not. arbitrary_types_allowed (bool): Flag indicating whether arbitrary types are allowed or not. """
[docs] _mutable = False
[docs] sites: t.Optional[list[FrozenSite]] = Field( default=None, description="List of sites in the structure", )
[docs] class Config:
[docs] from_attributes = True
[docs] frozen = True
[docs] arbitrary_types_allowed = True
[docs] def __setattr__(self, key, value): # Customizing the exception message when trying to mutate attributes if key in self.model_fields: raise ValueError("The AiiDA `StructureData` is immutable. You can create a mutable copy of it using its `get_value` method.") super().__setattr__(key, value)