演算子オーバーロードにおける参照型の考慮について。

基礎知識

演算子オーバーロード

演算子オーバーロードを使うと、型に対する演算子の挙動を定義できる。

例えば、組込の数値型 i32 では +, -, *, / などが定義されている。

そして、ユーザが作成した独自型についても、組込型と同様、これらを定義できる。

参照型への対応

組込の数値型の場合、その参照型についても演算子オーバーロードがされている。これにより例えば、組込の数値型どうしの加算だけでなく、その参照どうしの加算でも、同じ結果が得られるようになっている。

左辺と右辺

+ 演算子のような二項演算子では、左辺と右辺それぞれが値または参照になりうる。

そのため、組込型をまねるなら、演算子オーバーロードも四回必要になる。

サンプル

以下では、+ 演算子を型 P3 とその参照について、四通り定義している。


use std::ops::Add;

fn main() {
    let r1 = P3([1, 2, 3]) + P3([4, 5, 6]);
    let r2 = P3([1, 2, 3]) + &P3([4, 5, 6]);
    let r3 = &P3([1, 2, 3]) + P3([4, 5, 6]);
    let r4 = &P3([1, 2, 3]) + &P3([4, 5, 6]);
    assert!([r1, r2, r3, r4].iter().all(|x| *x == P3([5, 7, 9])));
}

#[derive(Eq, PartialEq)]
struct P3([i32; 3]);

impl Add for P3 {
    type Output = Self;
    fn add(self, rhs: Self) -> Self::Output {
        &self + &rhs
    }
}

impl Add<&Self> for P3 {
    type Output = Self;
    fn add(self, rhs: &Self) -> Self::Output {
        &self + rhs
    }
}

impl Add<P3> for &P3 {
    type Output = P3;
    fn add(self, rhs: P3) -> Self::Output {
        self + &rhs
    }
}

impl<'r> Add<&'r P3> for &P3 {
    type Output = P3;
    fn add(self, rhs: &'r P3) -> Self::Output {
        let arr = self.0.iter().enumerate().map(|(i, x)| x + rhs.0[i]);
        P3(arr.collect::<Vec<_>>().try_into().unwrap())
    }
}

トレイトの定義

トレイト境界に演算子用のトレイトを指定すれば、同種の演算子が使える型をまとめられる。これにより例えば、数値型と座標型の加算を同じソースコードで扱えるようになる。

関連バグ

、『E0275 - トレイト確認の無限再帰 - バグ』があるため、左辺が参照型の場合について理想的な表現ができない。この影響で値と参照の境界が個別に必要になる。

サンプル

以下では、前のサンプルで定義したのと同じ型 P3 を抽象化している。


use std::ops::Add;

fn main() {
    let def = P3([0, 0, 0]);
    let ret = add(Some(P3([1, 2, 3])), Some(P3([4, 5, 6])), &def);
    assert!(ret == P3([5, 7, 9]));
}

fn add<T>(x: Option<T>, y: Option<T>, def: &T) -> T
where
    T: Value<T>,
    for<'a> &'a T: Value<T>
{
    match (x, y) {
        (None, None) => def + def,
        (None, Some(y)) => def + y,
        (Some(x), None) => x + def,
        (Some(x), Some(y)) => x + y,
    }
}

trait Value<R>: Sized
where Self: Add<R, Output = R> + for<'a> Add<&'a R, Output = R> {}

impl<L, R> Value<R> for L
where L: Add<R, Output = R> + for<'a> Add<&'a R, Output = R> {}

// -- `P3` definition -- //

#[derive(Eq, PartialEq)]
struct P3([i32; 3]);

impl Add for P3 {
    type Output = Self;
    fn add(self, rhs: Self) -> Self::Output {
        &self + &rhs
    }
}

impl Add<&Self> for P3 {
    type Output = Self;
    fn add(self, rhs: &Self) -> Self::Output {
        &self + rhs
    }
}

impl Add<P3> for &P3 {
    type Output = P3;
    fn add(self, rhs: P3) -> Self::Output {
        self + &rhs
    }
}

impl<'r> Add<&'r P3> for &P3 {
    type Output = P3;
    fn add(self, rhs: &'r P3) -> Self::Output {
        let arr = self.0.iter().enumerate().map(|(i, x)| x + rhs.0[i]);
        P3(arr.collect::<Vec<_>>().try_into().unwrap())
    }
}