Lindenii Project Forge
Login

hare-ds

Data structures for Hare
Commit info
ID
aaaf85993ac7cdd6cd29778db21aefc71a16e26f
Author
Runxi Yu <me@runxiyu.org>
Author date
Tue, 23 Sep 2025 11:50:32 +0800
Committer
Runxi Yu <me@runxiyu.org>
Committer date
Tue, 23 Sep 2025 11:50:32 +0800
Actions
Fix memory leaks
// SPDX-License-Identifier: MPL-2.0

use bytes;
use sort;

// Deletes an item from a [[map]]. Returns the removed value or void.
export fn del(m: *map, key: []u8) (*opaque | void) = {
	const r = delete_rec(m, m.root, key);
	if (len(m.root.keys) == 0 && !m.root.leaf) {
		m.root = m.root.children[0];
		let old = m.root;
		m.root = old.children[0];
		node_shallow_finish(old);
	};
	return r;
};

// SPDX-License-Identifier: MPL-2.0

use bytes;
use sort;

fn keycmp(a: []u8, b: []u8) int = {
	let n = if (len(a) < len(b)) len(a) else len(b);
	for (let i = 0z; i < n; i += 1) {
		if (a[i] < b[i]) return -1;
		if (a[i] > b[i]) return 1;
	};
	if (len(a) < len(b)) return -1;
	if (len(a) > len(b)) return 1;
	return 0;
};

fn cmp_u8slice(a: const *opaque, b: const *opaque) int = {
	let sa = *(a: *[]u8);
	let sb = *(b: *[]u8);
	return keycmp(sa, sb);
};

fn node_new(t: size, leaf: bool) (*node | nomem) = {
	let capk = 2 * t - 1;
	let capc = if (leaf) 0z else 2z * t;

	let empty_keys: [][]u8 = [];
	let keys = alloc(empty_keys, capk)?;

	let empty_vals: []*opaque = [];
	let vals = alloc(empty_vals, capk)?;

	let children: []*node = if (leaf) {
		yield [];
	} else {
		let empty_children: []*node = [];
		yield alloc(empty_children, capc)?;
	};

	let nd = alloc(node {
		leaf = leaf,
		keys = keys,
		vals = vals,
		children = children,
	})?;
	return nd;
};

fn node_shallow_finish(n: *node) void = {
	if (!n.leaf) {
		free(n.children);
	};
	free(n.keys);
	free(n.vals);
	free(n);
};

fn split_child(m: *map, x: *node, i: size) (void | nomem) = {
	const t = m.t;
	let y = x.children[i];
	let z = node_new(t, y.leaf)?;

	let medk = y.keys[t - 1];
	let medv = y.vals[t - 1];

	append(z.keys, y.keys[t..]...)?;
	append(z.vals, y.vals[t..]...)?;
	if (!y.leaf) {
		append(z.children, y.children[t..]...)?;
	};

	y.keys = y.keys[..t - 1];
	y.vals = y.vals[..t - 1];
	if (!y.leaf) {
		y.children = y.children[..t];
	};

	insert(x.keys[i], medk)?;
	insert(x.vals[i], medv)?;
	insert(x.children[i + 1], z)?;
};

fn dup_u8(src: []u8) ([]u8 | nomem) = {
	return match (alloc(src, len(src))) {
	case let b: []u8 => yield b;
	case nomem => return nomem;
	};
};

fn insert_nonfull(m: *map, x: *node, key: []u8, val: *opaque) (void | nomem) = {
	let i = sort::lbisect((x.keys: []const opaque), size([]u8),
		(&key: const *opaque), &cmp_u8slice);

	if (i < len(x.keys) && bytes::equal(x.keys[i], key)) {
		x.vals[i] = val;
		return;
	};

	if (x.leaf) {
		let kcopy = match (dup_u8(key)) {
		case let b: []u8 => yield b;
		case nomem => return nomem;
		};
		insert(x.keys[i], kcopy)?;
		insert(x.vals[i], val)?;
		return;
	};

	if (len(x.children[i].keys) == 2 * m.t - 1) {
		split_child(m, x, i)?;
		let cmp = cmp_u8slice((&key: const *opaque),
			(&x.keys[i]: const *opaque));
		if (cmp == 0) {
			x.vals[i] = val;
			return;
		};
		if (cmp > 0) {
			i += 1;
		};
	};
	insert_nonfull(m, x.children[i], key, val)?;
};

fn merge_children(m: *map, x: *node, i: size) void = {
	let left = x.children[i];
	let right = x.children[i + 1];

	insert(left.keys[len(left.keys)], x.keys[i])!;
	insert(left.vals[len(left.vals)], x.vals[i])!;

	append(left.keys, right.keys...)!;
	append(left.vals, right.vals...)!;
	if (!left.leaf) {
		append(left.children, right.children...)!;
	};

	delete(x.keys[i]);
	delete(x.vals[i]);
	delete(x.children[i + 1]);
	node_shallow_finish(right);
};

fn ensure_child_has_space(m: *map, x: *node, i: size) void = {
	const t = m.t;
	let c = x.children[i];

	if (len(c.keys) >= t) return;

	if (i > 0 && len(x.children[i - 1].keys) >= t) {
		let ls = x.children[i - 1];

		insert(c.keys[0], x.keys[i - 1])!;
		insert(c.vals[0], x.vals[i - 1])!;

		if (!c.leaf) {
			let moved = ls.children[len(ls.children) - 1];
			insert(c.children[0], moved)!;
			delete(ls.children[len(ls.children) - 1]);
		};

		x.keys[i - 1] = ls.keys[len(ls.keys) - 1];
		x.vals[i - 1] = ls.vals[len(ls.vals) - 1];
		delete(ls.keys[len(ls.keys) - 1]);
		delete(ls.vals[len(ls.vals) - 1]);
		return;
	};

	if (i + 1 < len(x.children) && len(x.children[i + 1].keys) >= t) {
		let rs = x.children[i + 1];

		insert(c.keys[len(c.keys)], x.keys[i])!;
		insert(c.vals[len(c.vals)], x.vals[i])!;

		if (!c.leaf) {
			let moved = rs.children[0];
			insert(c.children[len(c.children)], moved)!;
			delete(rs.children[0]);
		};

		x.keys[i] = rs.keys[0];
		x.vals[i] = rs.vals[0];
		delete(rs.keys[0]);
		delete(rs.vals[0]);
		return;
	};

	if (i + 1 < len(x.children)) {
		merge_children(m, x, i);
	} else {
		merge_children(m, x, i - 1);
	};
};

fn pop_max(m: *map, x: *node) ([]u8, *opaque) = {
	let cur = x;
	for (!cur.leaf) {
		let last_before = len(cur.children) - 1;
		ensure_child_has_space(m, cur, last_before);
		let last = len(cur.children) - 1;
		cur = cur.children[last];
	};
	let k = cur.keys[len(cur.keys) - 1];
	let v = cur.vals[len(cur.vals) - 1];
	delete(cur.keys[len(cur.keys) - 1]);
	delete(cur.vals[len(cur.vals) - 1]);
	return (k, v);
};

fn pop_min(m: *map, x: *node) ([]u8, *opaque) = {
	let cur = x;
	for (!cur.leaf) {
		ensure_child_has_space(m, cur, 0);
		cur = cur.children[0];
	};
	let k = cur.keys[0];
	let v = cur.vals[0];
	delete(cur.keys[0]);
	delete(cur.vals[0]);
	return (k, v);
};

fn delete_rec(m: *map, x: *node, key: []u8) (*opaque | void) = {
	let i = sort::lbisect((x.keys: []const opaque), size([]u8),
		(&key: const *opaque), &cmp_u8slice);

	if (i < len(x.keys) && bytes::equal(x.keys[i], key)) {
		if (x.leaf) {
			let ret = x.vals[i];
			free(x.keys[i]);
			delete(x.keys[i]);
			delete(x.vals[i]);
			return ret;
		};

		const t = m.t;
		let y = x.children[i];
		let z = x.children[i + 1];

		if (len(y.keys) >= t) {
			let (pk, pv) = pop_max(m, y);
			let ret = x.vals[i];
			let oldk = x.keys[i];
			x.keys[i] = pk;
			x.vals[i] = pv;
			free(oldk);
			return ret;
		} else if (len(z.keys) >= t) {
			let (sk, sv) = pop_min(m, z);
			let ret = x.vals[i];
			let oldk = x.keys[i];
			x.keys[i] = sk;
			x.vals[i] = sv;
			free(oldk);
			return ret;
		} else {
			merge_children(m, x, i);
			return delete_rec(m, y, key);
		};
	};

	if (x.leaf) {
		return;
	};

	ensure_child_has_space(m, x, i);
	if (i >= len(x.children)) {
		i = len(x.children) - 1;
	};
	return delete_rec(m, x.children[i], key);
};
// SPDX-License-Identifier: MPL-2.0

use ds::map;

export type iterator = struct {
	vt: map::iterator,
	nodes: []*node,
	idxs: []size,
	visit_key: []bool,
	finished: bool,
};

const _itvt: map::vtable_iterator = map::vtable_iterator {
	nexter = &vt_next,
};

fn vt_next(it: *map::iterator) (([]u8, *opaque) | done) = next(it: *iterator);

export fn iter(m: *map) (*map::iterator | nomem) = {
	let it = alloc(iterator {
		vt = &_itvt,
		nodes = [],
		idxs = [],
		visit_key = [],
		finished = false,
	})?;

	match (append(it.nodes, m.root)) {
	case void => void;
	case nomem => return nomem;
	};
	match (append(it.idxs, 0z)) {
	case void => void;
	case nomem => return nomem;
	};
	match (append(it.visit_key, false)) {
	case void => void;
	case nomem => return nomem;
	};

	return (it: *map::iterator);
};

export fn next(it: *iterator) (([]u8, *opaque) | done) = {
	if (it.finished) {
		return done;
	};
	for (len(it.nodes) != 0) {
		let top = len(it.nodes) - 1;
		let x = it.nodes[top];
		let i = it.idxs[top];
		let vk = it.visit_key[top];

		if (x.leaf) {
			if (i < len(x.keys)) {
				it.idxs[top] = i + 1;
				return (x.keys[i], x.vals[i]);
			};
			delete(it.nodes[top]);
			delete(it.idxs[top]);
			delete(it.visit_key[top]);
			continue;
		};

		if (vk) {
			if (i >= len(x.keys)) {
				it.visit_key[top] = false;
				continue;
			};
			it.visit_key[top] = false;
			it.idxs[top] = i + 1;
			return (x.keys[i], x.vals[i]);
		};

		if (i < len(x.keys)) {
			match (append(it.nodes, x.children[i])) {
			case void => void;
			case nomem => abort("btree::iter: nomem");
			};
			match (append(it.idxs, 0z)) {
			case void => void;
			case nomem => abort("btree::iter: nomem");
			};
			match (append(it.visit_key, false)) {
			case void => void;
			case nomem => abort("btree::iter: nomem");
			};
			it.visit_key[top] = true;
			continue;
		};

		if (i == len(x.keys)) {
			match (append(it.nodes, x.children[i])) {
			case void => void;
			case nomem => abort("btree::iter: nomem");
			};
			match (append(it.idxs, 0z)) {
			case void => void;
			case nomem => abort("btree::iter: nomem");
			};
			match (append(it.visit_key, false)) {
			case void => void;
			case nomem => abort("btree::iter: nomem");
			};
			it.idxs[top] = i + 1;
			continue;
		};

		delete(it.nodes[top]);
		delete(it.idxs[top]);
		delete(it.visit_key[top]);
	};

	free(it.nodes);
	free(it.idxs);
	free(it.visit_key);
	it.nodes = [];
	it.idxs = [];
	it.visit_key = [];
	it.finished = true;
	return done;
};
// SPDX-License-Identifier: MPL-2.0
// SPDX-FileCopyrightText: 2024 Drew DeVault <drew@ddevault.org>
// SPDX-FileCopyrightText: 2025 Runxi Yu <me@runxiyu.org>

use ds::map;

export type iterator = struct {
	vt: map::iterator,
	m: *map,
	bi: size,
	cur: nullable *map::iterator,
	finished: bool,
};

const _itvt: map::vtable_iterator = map::vtable_iterator {
	nexter = &vt_next,
};

fn vt_next(it: *map::iterator) (([]u8, *opaque) | done) = next(it: *iterator);

export fn iter(m: *map) (*map::iterator | nomem) = {
	let it = alloc(iterator {
		vt = &_itvt,
		m = m,
		bi = 0,
		cur = null,
		finished = false,
	})?;
	for (it.bi < m.n) {
		let b = m.buckets[it.bi];
		it.bi += 1;
		match (map::iter(b)) {
		case let i: *map::iterator =>
			it.cur = i;
			break;
		case nomem =>
			abort("hashmap::iter: nomem");
		};
	};
	return (it: *map::iterator);
};

export fn next(it: *iterator) (([]u8, *opaque) | done) = {
	if (it.finished) {
		return done;
	};
	for (true) {
		match (it.cur) {
		case null =>
			if (it.bi >= it.m.n) return done;
			if (it.bi >= it.m.n) {
				it.finished = true;
				return done;
			};
			let b = it.m.buckets[it.bi];
			it.bi += 1;
			match (map::iter(b)) {
			case let i: *map::iterator =>
				it.cur = i;
			case nomem =>
				abort("hashmap::iter: nomem");
			};
		case let curi: *map::iterator =>
			match (map::next(curi)) {
			case let kv: ([]u8, *opaque) =>
				return kv;
			case done =>
				free(curi);
				it.cur = null;
			};
		};
	};
};
// SPDX-License-Identifier: MPL-2.0

def KEY_LEN: size = 16z;

fn put_le64(dst: []u8, x: u64) void = {
	for (let i = 0z; i < 8z; i += 1) {
		dst[i] = ((x >> (8u64 * (i: u64))) & 0xFFu64): u8;
	};
};

fn get_le64(src: []u8) u64 = {
	let x: u64 = 0u64;
	for (let i = 0z; i < 8z; i += 1) {
		x |= (src[i]: u64) << (8u64 * (i: u64));
	};
	return x;
};

type oracle = []nullable *opaque;

fn must_set(m: *map, key: []u8, v: *opaque) void = {
	match (set(m, key, v)) {
	case void => void;
	case nomem => abort("set: out of memory");
	};
};
fn must_get(m: *map, key: []u8) (*opaque | void) = get(m, key);
fn must_del(m: *map, key: []u8) (*opaque | void) = del(m, key);

fn verify_iter_matches_oracle(m: *map, exp: oracle) void = {
	let seen: []bool = alloc([false...], len(exp))!;
	defer free(seen);

	let it: *iterator = match (iter(m)) {
	case let p: *iterator => yield p;
	case nomem => abort("iter: out of memory");
	};
	defer free(it);

	for (true) {
		match (next(it)) {
		case let kv: ([]u8, *opaque) =>
			let k = kv.0;
			let v = kv.1;

			assert(len(k) >= 8z, "iter: key too short");
			let idx = (get_le64(k[..8]): size);
			assert(idx < len(exp), "iter: decoded index out of range");

			match (exp[idx]) {
			case null =>
				abort("iter: found element which should not exist");
			case let p: *opaque =>
				assert(v == p, "iter: value pointer mismatch");
			};

			assert(!seen[idx], "iter: duplicate key encountered");
			seen[idx] = true;
		case done =>
			break;
		};
	};

	for (let i = 0z; i < len(exp); i += 1) {
		match (exp[i]) {
		case null => void;
		case *opaque =>
			assert(seen[i], "iter: missing expected key");
		};
	};
};

export fn stress_test(m: *map, key_space: size) void = {
	let empty: [KEY_LEN]u8 = [0...];
	let keybufs: [][KEY_LEN]u8 = alloc([empty...], key_space)!;
	let keys: [][]u8 = alloc([[0...]...], key_space)!;
	defer free(keys);
	defer free(keybufs);

	for (let i = 0z; i < key_space; i += 1) {
		for (let j = 8z; j < KEY_LEN; j += 1) keybufs[i][j] = 0xABu8;
		put_le64((keybufs[i][..8]), (i: u64));
		keys[i] = keybufs[i][..];
	};

	let vals: []int = alloc([0...], key_space)!;
	defer free(vals);
	for (let i = 0z; i < key_space; i += 1) vals[i] = (i: int);

	let exp: oracle = alloc([null...], key_space)!;
	defer free(exp);

	// Sequential inserts with immediate verification
	for (let i = 0z; i < key_space; i += 1) {
		let vp: *opaque = (&vals[i]: *opaque);
		must_set(m, keys[i], vp);
		exp[i] = vp;

		match (must_get(m, keys[i])) {
		case let got: *opaque => assert(got == vp, "phase1: get != set");
		case void => abort("phase1: get void after set");
		};
	};

	// Verify contents via iterator after initial inserts.
	verify_iter_matches_oracle(m, exp);

	// Forward read sweep (all should be present)
	for (let i = 0z; i < key_space; i += 1) {
		match (must_get(m, keys[i])) {
		case let got: *opaque =>
			let want = match (exp[i]) {
			case null => abort("phase2: expect null but found value");
			case let p: *opaque => yield p;
			};
			assert(got == want, "phase2: value mismatch");
		case void =>
			abort("phase2: get void but expect value");
		};
	};

	// Strided overwrites with immediate verification
	for (let step = 0z; step < key_space; step += 1) {
		let i = (step * 7z) % key_space;
		let new_ix = (i + 12345z) % key_space;
		let vp: *opaque = (&vals[new_ix]: *opaque);
		must_set(m, keys[i], vp);
		exp[i] = vp;

		match (must_get(m, keys[i])) {
		case let got: *opaque => assert(got == vp, "phase3: replace mismatch");
		case void => abort("phase3: get void after replace");
		};
	};

	// Sparse deletes, every 3rd key
	for (let i = 0z; i < key_space; i += 1) {
		if (i % 3z != 0z) continue;

		let want = exp[i];
		match (must_del(m, keys[i])) {
		case let ret: *opaque =>
			match (want) {
			case null => abort("phase4: del returned value but expect null");
			case let p: *opaque => assert(ret == p, "phase4: del value mismatch");
			};
			exp[i] = null;

			match (must_del(m, keys[i])) {
			case void => void;
			case *opaque => abort("phase4: second delete returned value");
			};
		case void =>
			match (want) {
			case null => void;
			case *opaque => abort("phase4: del void but expect value");
			};
		};
	};

	// Insert again every 6th key
	for (let i = 0z; i < key_space; i += 1) {
		if (i % 6z != 0z) continue;
		let vp: *opaque = (&vals[(i * 5z + 1z) % key_space]: *opaque);
		must_set(m, keys[i], vp);
		exp[i] = vp;
	};

	// Even indices read, odd indices write
	for (let i = 0z; i < key_space; i += 1) {
		if ((i & 1z) == 0z) {
			// Read check
			match (must_get(m, keys[i])) {
			case let got: *opaque =>
				match (exp[i]) {
				case null => abort("phase6: even get found value; expect null");
				case let p: *opaque => assert(got == p, "phase6: even mismatch");
				};
			case void =>
				match (exp[i]) {
				case null => void;
				case *opaque => abort("phase6: even get void; expect value");
				};
			};
		} else {
			// Write (overwrite or create)
			let vp: *opaque = (&vals[(i * 3z + 7z) % key_space]: *opaque);
			must_set(m, keys[i], vp);
			exp[i] = vp;
		};
	};

	// Verify contents via iterator after mixed updates.
	verify_iter_matches_oracle(m, exp);

	// Reading in reverse order
	for (let r = key_space; r > 0z; r -= 1) {
		let i = r - 1z;
		match (must_get(m, keys[i])) {
		case let got: *opaque =>
			match (exp[i]) {
			case null => abort("phase7: get found value; expect null");
			case let p: *opaque => assert(got == p, "phase7: reverse mismatch");
			};
		case void =>
			match (exp[i]) {
			case null => void;
			case *opaque => abort("phase7: reverse get void; expect value");
			};
		};
	};

	// Clear in reverse order
	for (let r = key_space; r > 0z; r -= 1) {
		let i = r - 1z;
		match (must_del(m, keys[i])) {
		case let ret: *opaque =>
			match (exp[i]) {
			case null => abort("phase8: del returned value; expect null");
			case let p: *opaque => assert(ret == p, "phase8: final del mismatch");
			};
			exp[i] = null;
		case void =>
			match (exp[i]) {
			case null => void;
			case *opaque => abort("phase8: del void; expect value");
			};
		};
	};

	// Read sweep ensure empty
	for (let i = 0z; i < key_space; i += 1) {
		match (must_get(m, keys[i])) {
		case void => void;
		case *opaque => abort("final: get returned value after full clear");
		};
		match (must_del(m, keys[i])) {
		case void => void;
		case *opaque => abort("final: del returned value after full clear");
		};
	};

	// Iterator should also report empty
	let it_empty: *iterator = match (iter(m)) {
	case let p: *iterator => yield p;
	case nomem => abort("iter(empty): out of memory");
	};
	defer free(it_empty);
	match (next(it_empty)) {
	case done => void;
	case ([]u8, *opaque) => abort("final: iterator produced elements after clear");
	};
};