Add `PyAny::lookup_special`

`PyAny::lookup_special` is an approximate equivalent to the CPython
internal `_PyObject_LookupSpecial`, which is used to resolve lookups of
"magic" methods.  These are only looked up from the type, and skip the
instance dictionary during the lookup.  Despite this, they are still
required to resolve the descriptor protocol.

Many magic methods have slots on the `PyTypeObject` or respective
subobjects, but these are not necessarily available when targeting the
limited API or PyPy.  In these cases, the requisite logic can be worked
around using safe but likely slower APIs.

Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com>

Fix up lookup-special
This commit is contained in:
Jake Lishman 2023-05-31 16:51:40 +01:00
parent 451729aef0
commit 194d0c791f
No known key found for this signature in database
GPG Key ID: F111E77FA4F6AF0D
1 changed files with 119 additions and 1 deletions

View File

@ -124,6 +124,51 @@ impl PyAny {
}
}
/// Retrieve an attribute value, skipping the instance dictionary during the lookup but still
/// binding the object to the instance.
///
/// This is useful when trying to resolve Python's "magic" methods like `__getitem__`, which
/// are looked up starting from the type object. This returns an `Option` as it is not
/// typically a direct error for the special lookup to fail, as magic methods are optional in
/// many situations in which they might be called.
///
/// To avoid repeated temporary allocations of Python strings, the [`intern!`] macro can be used
/// to intern `attr_name`.
#[allow(dead_code)] // Currently only used with num-complex+abi3, so dead without that.
pub(crate) fn lookup_special<N>(&self, attr_name: N) -> PyResult<Option<&PyAny>>
where
N: IntoPy<Py<PyString>>,
{
let py = self.py();
let self_type = self.get_type();
let attr = if let Ok(attr) = self_type.getattr(attr_name) {
attr
} else {
return Ok(None);
};
// Manually resolve descriptor protocol.
if cfg!(Py_3_10)
|| unsafe { ffi::PyType_HasFeature(attr.get_type_ptr(), ffi::Py_TPFLAGS_HEAPTYPE) } != 0
{
// This is the preferred faster path, but does not work on static types (generally,
// types defined in extension modules) before Python 3.10.
unsafe {
let descr_get_ptr = ffi::PyType_GetSlot(attr.get_type_ptr(), ffi::Py_tp_descr_get);
if descr_get_ptr.is_null() {
return Ok(Some(attr));
}
let descr_get: ffi::descrgetfunc = std::mem::transmute(descr_get_ptr);
let ret = descr_get(attr.as_ptr(), self.as_ptr(), self_type.as_ptr());
py.from_owned_ptr_or_err(ret).map(Some)
}
} else if let Ok(descr_get) = attr.get_type().getattr(crate::intern!(py, "__get__")) {
descr_get.call1((attr, self, self_type)).map(Some)
} else {
Ok(Some(attr))
}
}
/// Sets an attribute value.
///
/// This is equivalent to the Python expression `self.attr_name = value`.
@ -974,9 +1019,82 @@ impl PyAny {
#[cfg(test)]
mod tests {
use crate::{
types::{IntoPyDict, PyBool, PyList, PyLong, PyModule},
types::{IntoPyDict, PyAny, PyBool, PyList, PyLong, PyModule},
Python, ToPyObject,
};
#[test]
fn test_lookup_special() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class CustomCallable:
def __call__(self):
return 1
class SimpleInt:
def __int__(self):
return 1
class InheritedInt(SimpleInt): pass
class NoInt: pass
class NoDescriptorInt:
__int__ = CustomCallable()
class InstanceOverrideInt:
def __int__(self):
return 1
instance_override = InstanceOverrideInt()
instance_override.__int__ = lambda self: 2
class ErrorInDescriptorInt:
@property
def __int__(self):
raise ValueError("uh-oh!")
class NonHeapNonDescriptorInt:
# A static-typed callable that doesn't implement `__get__`. These are pretty hard to come by.
__int__ = int
"#,
"test.py",
"test",
)
.unwrap();
let int = crate::intern!(py, "__int__");
let eval_int =
|obj: &PyAny| obj.lookup_special(int)?.unwrap().call0()?.extract::<u32>();
let simple = module.getattr("SimpleInt").unwrap().call0().unwrap();
assert_eq!(eval_int(simple).unwrap(), 1);
let inherited = module.getattr("InheritedInt").unwrap().call0().unwrap();
assert_eq!(eval_int(inherited).unwrap(), 1);
let no_descriptor = module.getattr("NoDescriptorInt").unwrap().call0().unwrap();
assert_eq!(eval_int(no_descriptor).unwrap(), 1);
let missing = module.getattr("NoInt").unwrap().call0().unwrap();
assert!(missing.lookup_special(int).unwrap().is_none());
// Note the instance override should _not_ call the instance method that returns 2,
// because that's not how special lookups are meant to work.
let instance_override = module.getattr("instance_override").unwrap();
assert_eq!(eval_int(instance_override).unwrap(), 1);
let descriptor_error = module
.getattr("ErrorInDescriptorInt")
.unwrap()
.call0()
.unwrap();
assert!(descriptor_error.lookup_special(int).is_err());
let nonheap_nondescriptor = module
.getattr("NonHeapNonDescriptorInt")
.unwrap()
.call0()
.unwrap();
assert_eq!(eval_int(nonheap_nondescriptor).unwrap(), 0);
})
}
#[test]
fn test_call_for_non_existing_method() {
Python::with_gil(|py| {