Baladithya Balamurugan
Wave 20: close F4 s3:// gap — live-S3 DiLoCo allreduce smoke (AWS_SMOKE-gated)
bd37412
raw
history blame contribute delete
15.5 kB
"""Verifies the serverless DiLoCo allreduce wraps correctly across local
multiprocessing replicas using `file://` rendezvous.
This is the core multi-process test for the serverless layer. It exercises
the real allreduce barrier (with concurrent processes), not just the
single-process API.
"""
from __future__ import annotations
import os
import sys
import tempfile
import time
import pytest
import torch
from composer_replication.diloco.serverless import (
LocalProcessExecutor,
ObjectStoreAllReduce,
ReplicaHandle,
)
# ---------------------------------------------------------------------
# Single-process tests of ObjectStoreAllReduce primitives
# (don't need executor, just the file:// path + local manual orchestration)
# ---------------------------------------------------------------------
def test_object_store_allreduce_init_validates_rank():
with tempfile.TemporaryDirectory() as td:
with pytest.raises(ValueError, match="not in"):
ObjectStoreAllReduce(td, rank=5, world_size=2)
def test_object_store_allreduce_local_paths_create_dir():
"""Local backend should mkdir on init."""
with tempfile.TemporaryDirectory() as td:
new_path = os.path.join(td, "subdir", "subsubdir")
store = ObjectStoreAllReduce(new_path, rank=0, world_size=1)
assert os.path.isdir(new_path)
assert store.world_size == 1
def test_object_store_allreduce_world_size_1_passthrough():
"""With world_size=1 it just averages the tensor with itself."""
with tempfile.TemporaryDirectory() as td:
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
t = torch.tensor([1.0, 2.0, 3.0])
result = store.allreduce(t.clone())
torch.testing.assert_close(result, t, atol=1e-6, rtol=1e-6)
def test_object_store_allreduce_round_id_increments():
with tempfile.TemporaryDirectory() as td:
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
t = torch.zeros(3)
assert store.round_id == 0
store.allreduce(t.clone())
assert store.round_id == 1
store.allreduce(t.clone())
assert store.round_id == 2
# ---------------------------------------------------------------------
# Multi-process tests (the real verification — local executor + spawn)
# ---------------------------------------------------------------------
def _replica_compute_and_sync(
rendezvous_uri: str,
world_size: int,
rank_value: float,
) -> dict:
"""Top-level function — must be importable for multiprocessing 'spawn'.
Each replica creates a tensor whose value is `rank_value * (rank+1)` and
runs allreduce. The expected result is the mean of all replicas' tensors.
"""
rank = int(os.environ["REPLICA_RANK"])
store = ObjectStoreAllReduce(
rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
)
# tensor that depends on rank
t = torch.full((4,), float(rank_value * (rank + 1)))
pre = t.clone()
averaged = store.allreduce(t)
return {
"rank": rank,
"pre": pre.tolist(),
"post": averaged.tolist(),
"world_size": world_size,
}
@pytest.mark.parametrize("n_replicas", [2, 3])
def test_local_executor_runs_allreduce_across_replicas(n_replicas):
"""End-to-end: 2-3 replica processes each call allreduce; result is the mean."""
with tempfile.TemporaryDirectory() as td:
rendezvous = os.path.join(td, "run")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_compute_and_sync",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
"rank_value": 10.0,
"rank_env": "REPLICA_RANK",
},
timeout=180,
)
assert len(handles) == n_replicas
for i, h in enumerate(handles):
assert h.rank == i
assert h.backend_name == "local_process"
results = executor.collect(handles, timeout=180)
assert len(results) == n_replicas
# Verify all succeeded
for r in results:
assert r["status"] == "succeeded", \
f"rank {r['rank']} failed: {r.get('error')}"
# Each replica created tensor full(rank_value * (rank+1)).
# Expected mean = rank_value * (1+2+...+N) / N
N = n_replicas
expected_mean = 10.0 * (N * (N + 1) / 2) / N
for r in results:
post = r["result"]["post"]
for v in post:
assert abs(v - expected_mean) < 1e-4, \
f"rank {r['rank']}: expected mean {expected_mean}, got {v}"
def _replica_two_round_sync(
rendezvous_uri: str,
world_size: int,
) -> dict:
"""Each replica does TWO consecutive allreduce calls; checks round_id increments."""
rank = int(os.environ["REPLICA_RANK"])
store = ObjectStoreAllReduce(
rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
)
t1 = torch.full((2,), float(rank))
avg1 = store.allreduce(t1).clone()
t2 = torch.full((2,), float(rank * 100))
avg2 = store.allreduce(t2).clone()
return {
"rank": rank,
"round_after_2_calls": store.round_id,
"avg1": avg1.tolist(),
"avg2": avg2.tolist(),
}
def test_local_executor_handles_multiple_rounds():
"""Two consecutive rounds each give the right mean; round counter advances."""
n_replicas = 3
with tempfile.TemporaryDirectory() as td:
rendezvous = os.path.join(td, "run-2round")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_two_round_sync",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
},
timeout=180,
)
results = executor.collect(handles, timeout=180)
for r in results:
assert r["status"] == "succeeded", r.get("error")
assert r["result"]["round_after_2_calls"] == 2
# mean of 0,1,2 = 1.0
assert all(abs(v - 1.0) < 1e-4 for v in r["result"]["avg1"])
# mean of 0,100,200 = 100.0
assert all(abs(v - 100.0) < 1e-4 for v in r["result"]["avg2"])
# ---------------------------------------------------------------------
# Live-S3 smoke (F4 step 1): the file:// → s3:// transport gap.
#
# ObjectStoreAllReduce's S3 branches (_init_fsspec/_put/_exists/_get over
# s3fs) only have mock coverage; this exercises them against REAL S3 with
# concurrent OS processes, relying on S3's strong read-after-write
# consistency (the poll loop's _exists()→_get() assumption). Gated on
# AWS_SMOKE=1 so it never runs in ordinary CI / on machines without creds.
#
# Run it with:
# AWS_SMOKE=1 AWS_REGION=us-west-2 \
# DILOCO_S3_RENDEZVOUS=s3://<sagemaker-bucket>/diloco-rdv \
# pytest composer_replication/diloco/serverless/tests/test_serverless_local.py \
# -k s3_rendezvous -s
#
# Use a sagemaker-named bucket: stock AmazonSageMakerFullAccess only grants
# S3 on buckets whose name contains "sagemaker"/"aws-glue" — a custom-named
# bucket would 403 the first PUT and hang every peer until timeout_s (F4 §3).
# Verified PASS 2026-06-09 against
# s3://amazon-sagemaker-386931836011-us-west-2-7597bf4d9a3d/diloco-rdv/.
# ---------------------------------------------------------------------
def _s3_smoke_enabled() -> bool:
return os.environ.get("AWS_SMOKE") == "1"
@pytest.mark.skipif(
not _s3_smoke_enabled(),
reason="live-S3 smoke; set AWS_SMOKE=1 (+ AWS creds, DILOCO_S3_RENDEZVOUS) to run",
)
@pytest.mark.parametrize("n_replicas", [2])
def test_s3_rendezvous_allreduce_across_replicas(n_replicas):
"""Real-S3 analogue of test_local_executor_runs_allreduce_across_replicas.
Same property (N processes call allreduce, every rank ends with the
cross-rank mean) but over an ``s3://`` rendezvous instead of a tmp dir,
so it actually drives s3fs PUT/poll/GET and depends on S3 strong
read-after-write consistency. This is the cheapest (≈$0, no GPU) closure
of F4's documented "ObjectStoreAllReduce over s3:// never exercised
against real S3" gap.
"""
import uuid
pytest.importorskip("s3fs", reason="s3fs required for the live-S3 smoke")
import s3fs
base = os.environ.get(
"DILOCO_S3_RENDEZVOUS",
"s3://amazon-sagemaker-386931836011-us-west-2-7597bf4d9a3d/diloco-rdv",
).rstrip("/")
rendezvous = f"{base}/smoke-{uuid.uuid4().hex[:8]}/"
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_compute_and_sync",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
"rank_value": 10.0,
"rank_env": "REPLICA_RANK",
},
timeout=300,
)
try:
results = executor.collect(handles, timeout=300)
for r in results:
assert r["status"] == "succeeded", (
f"rank {r['rank']} failed (S3 rendezvous {rendezvous}): "
f"{r.get('error')}"
)
# Every rank must agree on the mean — only possible if each read the
# SAME peer objects through S3 (proves the cross-process exchange).
N = n_replicas
expected_mean = 10.0 * (N * (N + 1) / 2) / N
for r in results:
for v in r["result"]["post"]:
assert abs(v - expected_mean) < 1e-4, (
f"rank {r['rank']}: expected S3-averaged mean {expected_mean}, "
f"got {v}"
)
# Both ranks' pseudo-gradient objects must be present in S3.
fs = s3fs.S3FileSystem()
listing = fs.ls(rendezvous.replace("s3://", "") + "round_000000/")
got = {os.path.basename(p) for p in listing}
expected = {f"rank_{r:04d}.pt" for r in range(n_replicas)}
assert expected <= got, f"missing rank objects in S3: {expected - got}"
finally:
# Best-effort cleanup so repeated smokes don't accrete prefixes.
try:
s3fs.S3FileSystem().rm(rendezvous.replace("s3://", ""), recursive=True)
except Exception:
pass
def _replica_that_raises(rendezvous_uri: str, world_size: int) -> dict:
"""Simulates a replica that crashes mid-run."""
rank = int(os.environ["REPLICA_RANK"])
if rank == 1:
raise RuntimeError(f"Simulated crash on rank {rank}")
return {"rank": rank, "ok": True}
def test_local_executor_reports_failed_replicas():
"""When a replica crashes, collect() reports it as failed without hanging
(other ranks complete; the failed one should be reflected in the result).
Note (Wave 18): timeouts bumped from 30s → 90s because this test was
flaky in full-suite runs (passes individually but occasionally times
out when other parallel multiprocessing tests contend for CPU).
The 30s budget was tight for cold-start subprocess + import +
rendezvous-file IO under contention; 90s gives comfortable headroom
without changing the test's semantic intent (subprocess crashes
surface as `failed` status, not hangs).
"""
n_replicas = 2
with tempfile.TemporaryDirectory() as td:
rendezvous = os.path.join(td, "run-failure")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_that_raises",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
},
timeout=90,
)
results = executor.collect(handles, timeout=90)
statuses = {r["rank"]: r["status"] for r in results}
assert statuses[0] == "succeeded"
assert statuses[1] == "failed"
# Failure log should mention the simulated crash
failure_log = next(r for r in results if r["rank"] == 1).get("error") or ""
assert "Simulated crash" in failure_log
# ---------------------------------------------------------------------
# Sanity: MockManager is shape-compatible with torchft Manager surface
# ---------------------------------------------------------------------
def test_mock_manager_shape_compat():
from composer_replication.diloco.serverless import MockManager
with tempfile.TemporaryDirectory() as td:
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
mgr = MockManager(store)
# torchft.Manager surface (audited from torchft/local_sgd.py DiLoCo path)
assert hasattr(mgr, "allreduce")
assert hasattr(mgr, "should_commit")
assert hasattr(mgr, "start_quorum")
assert hasattr(mgr, "wait_quorum")
assert hasattr(mgr, "current_step")
assert hasattr(mgr, "disallow_state_dict_read")
assert hasattr(mgr, "allow_state_dict_read")
assert hasattr(mgr, "register_state_dict_fn")
assert hasattr(mgr, "_use_async_quorum")
assert mgr._use_async_quorum is False
assert mgr.num_participants == 1
assert mgr.rank == 0
assert mgr.should_commit() is True
# Single-replica allreduce: averaging is a passthrough, but the return
# must be a Work-shaped object (DiLoCo calls .wait() on it). The
# tensor itself is mutated in place by ObjectStoreAllReduce.
t = torch.tensor([1.0, 2.0])
buf = t.clone()
work = mgr.allreduce(buf)
assert hasattr(work, "wait") and callable(work.wait)
assert work.wait() is True
torch.testing.assert_close(buf, t, atol=1e-6, rtol=1e-6)
# ---------------------------------------------------------------------
# Public re-export surface (Wave 17a)
# ---------------------------------------------------------------------
def test_public_reexports_include_all_executors():
"""`from composer_replication.diloco.serverless import …` must
surface every executor adapter the module's docstring claims, not
just the LocalProcessExecutor.
Wave 16's user-journey reviewer caught that ModalExecutor /
HFJobsExecutor were defined in `modal.py` / `hf_jobs.py` but not
re-exported from the package's `__init__.py`. Users who copied the
docstring's `from composer_replication.diloco.serverless import
ModalExecutor` line got an ImportError. Wave 17a added the missing
re-exports; this test pins them.
"""
import composer_replication.diloco.serverless as ss
expected = {
"LocalProcessExecutor",
"ModalExecutor",
"HFJobsExecutor",
"MockManager",
"ObjectStoreAllReduce",
"ReplicaHandle",
"ServerlessExecutor",
}
actual = set(ss.__all__)
assert expected.issubset(actual), (
f"Missing re-exports: {expected - actual}. "
f"__all__ should include every executor adapter the package "
f"docstring documents."
)
# Also verify each name is actually importable, not just listed.
for name in expected:
assert hasattr(ss, name), (
f"{name} listed in __all__ but not present on package."
)