Source code for mujoco_warp._src.island

# Copyright 2026 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from mujoco_warp._src import types
from mujoco_warp._src.types import ConstraintType
from mujoco_warp._src.types import EqType
from mujoco_warp._src.types import ObjType
from mujoco_warp._src.warp_util import event_scope
import warp as wp


@wp.kernel
def _tree_edges(
  # Model:
  nv: int,
  body_treeid: wp.array(dtype=int),
  jnt_dofadr: wp.array(dtype=int),
  dof_treeid: wp.array(dtype=int),
  geom_bodyid: wp.array(dtype=int),
  site_bodyid: wp.array(dtype=int),
  eq_type: wp.array(dtype=int),
  eq_obj1id: wp.array(dtype=int),
  eq_obj2id: wp.array(dtype=int),
  eq_objtype: wp.array(dtype=int),
  # Data in:
  nefc_in: wp.array(dtype=int),
  contact_geom_in: wp.array(dtype=wp.vec2i),
  efc_type_in: wp.array2d(dtype=int),
  efc_id_in: wp.array2d(dtype=int),
  efc_J_in: wp.array3d(dtype=float),
  njmax_in: int,
  # Out:
  tree_tree: wp.array3d(dtype=int),  # kernel_analyzer: off
):
  """Find tree edges."""
  worldid, efcid = wp.tid()

  # skip if beyond active constraints
  if efcid >= wp.min(njmax_in, nefc_in[worldid]):
    return

  efc_type = efc_type_in[worldid, efcid]
  efc_id = efc_id_in[worldid, efcid]

  tree0 = int(-1)
  tree1 = int(-1)
  use_generic = int(0)

  # equality (connect/weld)
  if efc_type == ConstraintType.EQUALITY:
    eq_t = eq_type[efc_id]

    if eq_t == EqType.CONNECT or eq_t == EqType.WELD:
      b1 = eq_obj1id[efc_id]
      b2 = eq_obj2id[efc_id]

      # site semantics
      if eq_objtype[efc_id] == ObjType.SITE:
        b1 = site_bodyid[b1]
        b2 = site_bodyid[b2]

      tree0 = body_treeid[b1]
      tree1 = body_treeid[b2]
    else:
      # JOINT, TENDON, FLEX
      use_generic = 1

  # joint friction
  elif efc_type == ConstraintType.FRICTION_DOF:
    tree0 = dof_treeid[efc_id]

  # joint limit
  elif efc_type == ConstraintType.LIMIT_JOINT:
    tree0 = dof_treeid[jnt_dofadr[efc_id]]

  # contact
  elif (
    efc_type == ConstraintType.CONTACT_FRICTIONLESS
    or efc_type == ConstraintType.CONTACT_PYRAMIDAL
    or efc_type == ConstraintType.CONTACT_ELLIPTIC
  ):
    geom_pair = contact_geom_in[efc_id]
    g1 = geom_pair[0]
    g2 = geom_pair[1]

    # flex contacts have negative geom ids
    if g1 >= 0 and g2 >= 0:
      tree0 = body_treeid[geom_bodyid[g1]]
      tree1 = body_treeid[geom_bodyid[g2]]
    else:
      use_generic = 1

  # generic
  else:
    use_generic = 1

  # handle static bodies
  if use_generic == 0:
    # swap so tree0 is non-negative if possible
    if tree0 < 0 and tree1 >= 0:
      tree0 = tree1
      tree1 = -1

    # mark the edge
    if tree0 >= 0:
      if tree1 < 0 or tree0 == tree1:
        # self-edge
        wp.atomic_max(tree_tree, worldid, tree0, tree0, 1)
      else:
        # cross-tree edge
        t1 = wp.min(tree0, tree1)
        t2 = wp.max(tree0, tree1)
        wp.atomic_max(tree_tree, worldid, t1, t2, 1)
        wp.atomic_max(tree_tree, worldid, t2, t1, 1)
    return

  # generic: scan Jacobian row
  first_tree = int(-1)
  has_cross_edge = int(0)

  for dof in range(nv):
    # TODO(team): sparse efc_J
    # TODO(team): tree dof skip
    J_val = efc_J_in[worldid, efcid, dof]
    if J_val != 0.0:
      tree = dof_treeid[dof]
      if tree < 0:
        continue
      if first_tree == -1:
        first_tree = tree
      elif tree != first_tree:
        t1 = wp.min(first_tree, tree)
        t2 = wp.max(first_tree, tree)
        wp.atomic_max(tree_tree, worldid, t1, t2, 1)
        has_cross_edge = 1

  if first_tree >= 0 and has_cross_edge == 0:
    wp.atomic_max(tree_tree, worldid, first_tree, first_tree, 1)


def tree_edges(m: types.Model, d: types.Data, tree_tree: wp.array3d(dtype=int)):
  """Compute tree-tree adjacency matrix."""
  tree_tree.zero_()
  wp.launch(
    kernel=_tree_edges,
    dim=(d.nworld, d.njmax),
    inputs=[
      m.nv,
      m.body_treeid,
      m.jnt_dofadr,
      m.dof_treeid,
      m.geom_bodyid,
      m.site_bodyid,
      m.eq_type,
      m.eq_obj1id,
      m.eq_obj2id,
      m.eq_objtype,
      d.nefc,
      d.contact.geom,
      d.efc.type,
      d.efc.id,
      d.efc.J,
      d.njmax,
    ],
    outputs=[tree_tree],
  )


@wp.kernel
def _flood_fill(
    # Model:
    ntree: int,
    # In:
    tree_tree_in: wp.array3d(dtype=int),
    labels_in: wp.array2d(dtype=int),
    stack_in: wp.array2d(dtype=int),
    # Data out:
    nisland_out: wp.array(dtype=int),
    tree_island_out: wp.array2d(dtype=int),
    # Out:
    stack_out: wp.array2d(dtype=int),
):
  """DFS flood fill to discover islands using tree_tree matrix."""
  worldid = wp.tid()
  nisland = int(0)

  # iterate over trees
  for i in range(ntree):
    # already assigned
    if labels_in[worldid, i] != -1:
      continue

    # check if tree has any edges
    has_edge = int(0)
    for j in range(ntree):
      if tree_tree_in[worldid, i, j] != 0:
        has_edge = 1
        break
    if has_edge == 0:
      continue

    # DFS: push i onto stack
    nstack = int(0)
    stack_out[worldid, nstack] = i
    nstack = nstack + 1

    while nstack > 0:
      # pop v from stack
      nstack = nstack - 1
      v = stack_in[worldid, nstack]

      # already assigned
      if labels_in[worldid, v] != -1:
        continue

      # assign to current island
      tree_island_out[worldid, v] = nisland

      # push neighbors
      for neighbor in range(ntree):
        if tree_tree_in[worldid, v, neighbor] != 0:
          if labels_in[worldid, neighbor] == -1:
            stack_out[worldid, nstack] = neighbor
            nstack = nstack + 1

    # island filled
    nisland = nisland + 1

  nisland_out[worldid] = nisland


[docs] @event_scope def island(m: types.Model, d: types.Data): """Discover constraint islands.""" if m.ntree == 0: d.nisland.zero_() return # Step 1: Find tree edges tree_tree = wp.zeros((d.nworld, m.ntree, m.ntree), dtype=int) tree_edges(m, d, tree_tree) # Step 2: DFS flood fill d.tree_island.fill_(-1) stack_scratch = wp.empty((d.nworld, m.ntree * m.ntree), dtype=int) wp.launch( _flood_fill, dim=d.nworld, inputs=[m.ntree, tree_tree, d.tree_island, stack_scratch], outputs=[d.nisland, d.tree_island, stack_scratch], )