SSRNode.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import { NearestFilter, RenderTarget, Vector2, RendererUtils, QuadMesh, TempNode, NodeMaterial, NodeUpdateType } from 'three/webgpu';
  2. import { reference, viewZToPerspectiveDepth, logarithmicDepthToViewZ, getScreenPosition, getViewPosition, sqrt, mul, div, cross, float, Continue, Break, Loop, int, max, abs, sub, If, dot, reflect, normalize, screenCoordinate, nodeObject, Fn, passTexture, uv, uniform, perspectiveDepthToViewZ, orthographicDepthToViewZ, vec2, vec3, vec4 } from 'three/tsl';
  3. const _quadMesh = /*@__PURE__*/ new QuadMesh();
  4. const _size = /*@__PURE__*/ new Vector2();
  5. let _rendererState;
  6. /**
  7. * Post processing node for computing screen space reflections (SSR).
  8. *
  9. * Reference: {@link https://lettier.github.io/3d-game-shaders-for-beginners/screen-space-reflection.html}
  10. *
  11. * @augments TempNode
  12. * @three_import import { ssr } from 'three/addons/tsl/display/SSRNode.js';
  13. */
  14. class SSRNode extends TempNode {
  15. static get type() {
  16. return 'SSRNode';
  17. }
  18. /**
  19. * Constructs a new SSR node.
  20. *
  21. * @param {Node<vec4>} colorNode - The node that represents the beauty pass.
  22. * @param {Node<float>} depthNode - A node that represents the beauty pass's depth.
  23. * @param {Node<vec3>} normalNode - A node that represents the beauty pass's normals.
  24. * @param {Node<float>} metalnessNode - A node that represents the beauty pass's metalness.
  25. * @param {Camera} camera - The camera the scene is rendered with.
  26. */
  27. constructor( colorNode, depthNode, normalNode, metalnessNode, camera ) {
  28. super( 'vec4' );
  29. /**
  30. * The node that represents the beauty pass.
  31. *
  32. * @type {Node<vec4>}
  33. */
  34. this.colorNode = colorNode;
  35. /**
  36. * A node that represents the beauty pass's depth.
  37. *
  38. * @type {Node<float>}
  39. */
  40. this.depthNode = depthNode;
  41. /**
  42. * A node that represents the beauty pass's normals.
  43. *
  44. * @type {Node<vec3>}
  45. */
  46. this.normalNode = normalNode;
  47. /**
  48. * A node that represents the beauty pass's metalness.
  49. *
  50. * @type {Node<float>}
  51. */
  52. this.metalnessNode = metalnessNode;
  53. /**
  54. * The camera the scene is rendered with.
  55. *
  56. * @type {Camera}
  57. */
  58. this.camera = camera;
  59. /**
  60. * The resolution scale. By default SSR reflections
  61. * are computed in half resolutions. Setting the value
  62. * to `1` improves quality but also results in more
  63. * computational overhead.
  64. *
  65. * @type {number}
  66. * @default 0.5
  67. */
  68. this.resolutionScale = 0.5;
  69. /**
  70. * The `updateBeforeType` is set to `NodeUpdateType.FRAME` since the node renders
  71. * its effect once per frame in `updateBefore()`.
  72. *
  73. * @type {string}
  74. * @default 'frame'
  75. */
  76. this.updateBeforeType = NodeUpdateType.FRAME;
  77. /**
  78. * The render target the SSR is rendered into.
  79. *
  80. * @private
  81. * @type {RenderTarget}
  82. */
  83. this._ssrRenderTarget = new RenderTarget( 1, 1, { depthBuffer: false, minFilter: NearestFilter, magFilter: NearestFilter } );
  84. this._ssrRenderTarget.texture.name = 'SSRNode.SSR';
  85. /**
  86. * Controls how far a fragment can reflect.
  87. *
  88. *
  89. * @type {UniformNode<float>}
  90. */
  91. this.maxDistance = uniform( 1 );
  92. /**
  93. * Controls the cutoff between what counts as a possible reflection hit and what does not.
  94. *
  95. * @type {UniformNode<float>}
  96. */
  97. this.thickness = uniform( 0.1 );
  98. /**
  99. * Controls the transparency of the reflected colors.
  100. *
  101. * @type {UniformNode<float>}
  102. */
  103. this.opacity = uniform( 1 );
  104. /**
  105. * Represents the projection matrix of the scene's camera.
  106. *
  107. * @private
  108. * @type {UniformNode<mat4>}
  109. */
  110. this._cameraProjectionMatrix = uniform( camera.projectionMatrix );
  111. /**
  112. * Represents the inverse projection matrix of the scene's camera.
  113. *
  114. * @private
  115. * @type {UniformNode<mat4>}
  116. */
  117. this._cameraProjectionMatrixInverse = uniform( camera.projectionMatrixInverse );
  118. /**
  119. * Represents the near value of the scene's camera.
  120. *
  121. * @private
  122. * @type {ReferenceNode<float>}
  123. */
  124. this._cameraNear = reference( 'near', 'float', camera );
  125. /**
  126. * Represents the far value of the scene's camera.
  127. *
  128. * @private
  129. * @type {ReferenceNode<float>}
  130. */
  131. this._cameraFar = reference( 'far', 'float', camera );
  132. /**
  133. * Whether the scene's camera is perspective or orthographic.
  134. *
  135. * @private
  136. * @type {UniformNode<bool>}
  137. */
  138. this._isPerspectiveCamera = uniform( camera.isPerspectiveCamera ? 1 : 0 );
  139. /**
  140. * The resolution of the pass.
  141. *
  142. * @private
  143. * @type {UniformNode<vec2>}
  144. */
  145. this._resolution = uniform( new Vector2() );
  146. /**
  147. * This value is derived from the resolution and restricts
  148. * the maximum raymarching steps in the fragment shader.
  149. *
  150. * @private
  151. * @type {UniformNode<float>}
  152. */
  153. this._maxStep = uniform( 0 );
  154. /**
  155. * The material that is used to render the effect.
  156. *
  157. * @private
  158. * @type {NodeMaterial}
  159. */
  160. this._material = new NodeMaterial();
  161. this._material.name = 'SSRNode.SSR';
  162. /**
  163. * The result of the effect is represented as a separate texture node.
  164. *
  165. * @private
  166. * @type {PassTextureNode}
  167. */
  168. this._textureNode = passTexture( this, this._ssrRenderTarget.texture );
  169. }
  170. /**
  171. * Returns the result of the effect as a texture node.
  172. *
  173. * @return {PassTextureNode} A texture node that represents the result of the effect.
  174. */
  175. getTextureNode() {
  176. return this._textureNode;
  177. }
  178. /**
  179. * Sets the size of the effect.
  180. *
  181. * @param {number} width - The width of the effect.
  182. * @param {number} height - The height of the effect.
  183. */
  184. setSize( width, height ) {
  185. width = Math.round( this.resolutionScale * width );
  186. height = Math.round( this.resolutionScale * height );
  187. this._resolution.value.set( width, height );
  188. this._maxStep.value = Math.round( Math.sqrt( width * width + height * height ) );
  189. this._ssrRenderTarget.setSize( width, height );
  190. }
  191. /**
  192. * This method is used to render the effect once per frame.
  193. *
  194. * @param {NodeFrame} frame - The current node frame.
  195. */
  196. updateBefore( frame ) {
  197. const { renderer } = frame;
  198. _rendererState = RendererUtils.resetRendererState( renderer, _rendererState );
  199. const size = renderer.getDrawingBufferSize( _size );
  200. _quadMesh.material = this._material;
  201. this.setSize( size.width, size.height );
  202. // clear
  203. renderer.setMRT( null );
  204. renderer.setClearColor( 0x000000, 0 );
  205. // ssr
  206. renderer.setRenderTarget( this._ssrRenderTarget );
  207. _quadMesh.render( renderer );
  208. // restore
  209. RendererUtils.restoreRendererState( renderer, _rendererState );
  210. }
  211. /**
  212. * This method is used to setup the effect's TSL code.
  213. *
  214. * @param {NodeBuilder} builder - The current node builder.
  215. * @return {PassTextureNode}
  216. */
  217. setup( builder ) {
  218. const uvNode = uv();
  219. const pointToLineDistance = Fn( ( [ point, linePointA, linePointB ] )=> {
  220. // https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
  221. return cross( point.sub( linePointA ), point.sub( linePointB ) ).length().div( linePointB.sub( linePointA ).length() );
  222. } );
  223. const pointPlaneDistance = Fn( ( [ point, planePoint, planeNormal ] )=> {
  224. // https://mathworld.wolfram.com/Point-PlaneDistance.html
  225. // https://en.wikipedia.org/wiki/Plane_(geometry)
  226. // http://paulbourke.net/geometry/pointlineplane/
  227. const d = mul( planeNormal.x, planePoint.x ).add( mul( planeNormal.y, planePoint.y ) ).add( mul( planeNormal.z, planePoint.z ) ).negate().toVar();
  228. const denominator = sqrt( mul( planeNormal.x, planeNormal.x, ).add( mul( planeNormal.y, planeNormal.y ) ).add( mul( planeNormal.z, planeNormal.z ) ) ).toVar();
  229. const distance = div( mul( planeNormal.x, point.x ).add( mul( planeNormal.y, point.y ) ).add( mul( planeNormal.z, point.z ) ).add( d ), denominator );
  230. return distance;
  231. } );
  232. const getViewZ = Fn( ( [ depth ] ) => {
  233. let viewZNode;
  234. if ( this.camera.isPerspectiveCamera ) {
  235. viewZNode = perspectiveDepthToViewZ( depth, this._cameraNear, this._cameraFar );
  236. } else {
  237. viewZNode = orthographicDepthToViewZ( depth, this._cameraNear, this._cameraFar );
  238. }
  239. return viewZNode;
  240. } );
  241. const sampleDepth = ( uv ) => {
  242. const depth = this.depthNode.sample( uv ).r;
  243. if ( builder.renderer.logarithmicDepthBuffer === true ) {
  244. const viewZ = logarithmicDepthToViewZ( depth, this._cameraNear, this._cameraFar );
  245. return viewZToPerspectiveDepth( viewZ, this._cameraNear, this._cameraFar );
  246. }
  247. return depth;
  248. };
  249. const ssr = Fn( () => {
  250. const metalness = this.metalnessNode.sample( uvNode ).r;
  251. // fragments with no metalness do not reflect their environment
  252. metalness.equal( 0.0 ).discard();
  253. // compute some standard FX entities
  254. const depth = sampleDepth( uvNode ).toVar();
  255. const viewPosition = getViewPosition( uvNode, depth, this._cameraProjectionMatrixInverse ).toVar();
  256. const viewNormal = this.normalNode.rgb.normalize().toVar();
  257. // compute the direction from the position in view space to the camera
  258. const viewIncidentDir = ( ( this.camera.isPerspectiveCamera ) ? normalize( viewPosition ) : vec3( 0, 0, - 1 ) ).toVar();
  259. // compute the direction in which the light is reflected on the surface
  260. const viewReflectDir = reflect( viewIncidentDir, viewNormal ).toVar();
  261. // adapt maximum distance to the local geometry (see https://www.mathsisfun.com/algebra/vectors-dot-product.html)
  262. const maxReflectRayLen = this.maxDistance.div( dot( viewIncidentDir.negate(), viewNormal ) ).toVar();
  263. // compute the maximum point of the reflection ray in view space
  264. const d1viewPosition = viewPosition.add( viewReflectDir.mul( maxReflectRayLen ) ).toVar();
  265. // check if d1viewPosition lies behind the camera near plane
  266. If( this._isPerspectiveCamera.equal( float( 1 ) ).and( d1viewPosition.z.greaterThan( this._cameraNear.negate() ) ), () => {
  267. // if so, ensure d1viewPosition is clamped on the near plane.
  268. // this prevents artifacts during the ray marching process
  269. const t = sub( this._cameraNear.negate(), viewPosition.z ).div( viewReflectDir.z );
  270. d1viewPosition.assign( viewPosition.add( viewReflectDir.mul( t ) ) );
  271. } );
  272. // d0 and d1 are the start and maximum points of the reflection ray in screen space
  273. const d0 = screenCoordinate.xy.toVar();
  274. const d1 = getScreenPosition( d1viewPosition, this._cameraProjectionMatrix ).mul( this._resolution ).toVar();
  275. // below variables are used to control the raymarching process
  276. // total length of the ray
  277. const totalLen = d1.sub( d0 ).length().toVar();
  278. // offset in x and y direction
  279. const xLen = d1.x.sub( d0.x ).toVar();
  280. const yLen = d1.y.sub( d0.y ).toVar();
  281. // determine the larger delta
  282. // The larger difference will help to determine how much to travel in the X and Y direction each iteration and
  283. // how many iterations are needed to travel the entire ray
  284. const totalStep = max( abs( xLen ), abs( yLen ) ).toVar();
  285. // step sizes in the x and y directions
  286. const xSpan = xLen.div( totalStep ).toVar();
  287. const ySpan = yLen.div( totalStep ).toVar();
  288. const output = vec4( 0 ).toVar();
  289. // the actual ray marching loop
  290. // starting from d0, the code gradually travels along the ray and looks for an intersection with the geometry.
  291. // it does not exceed d1 (the maximum ray extend)
  292. Loop( { start: int( 0 ), end: int( this._maxStep ), type: 'int', condition: '<' }, ( { i } ) => {
  293. // TODO: Remove this when Chrome is fixed, see https://issues.chromium.org/issues/372714384#comment14
  294. If( metalness.equal( 0 ), () => {
  295. Break();
  296. } );
  297. // stop if the maximum number of steps is reached for this specific ray
  298. If( float( i ).greaterThanEqual( totalStep ), () => {
  299. Break();
  300. } );
  301. // advance on the ray by computing a new position in screen space
  302. const xy = vec2( d0.x.add( xSpan.mul( float( i ) ) ), d0.y.add( ySpan.mul( float( i ) ) ) ).toVar();
  303. // stop processing if the new position lies outside of the screen
  304. If( xy.x.lessThan( 0 ).or( xy.x.greaterThan( this._resolution.x ) ).or( xy.y.lessThan( 0 ) ).or( xy.y.greaterThan( this._resolution.y ) ), () => {
  305. Break();
  306. } );
  307. // compute new uv, depth, viewZ and viewPosition for the new location on the ray
  308. const uvNode = xy.div( this._resolution );
  309. const d = sampleDepth( uvNode ).toVar();
  310. const vZ = getViewZ( d ).toVar();
  311. const vP = getViewPosition( uvNode, d, this._cameraProjectionMatrixInverse ).toVar();
  312. const viewReflectRayZ = float( 0 ).toVar();
  313. // normalized distance between the current position xy and the starting point d0
  314. const s = xy.sub( d0 ).length().div( totalLen );
  315. // depending on the camera type, we now compute the z-coordinate of the reflected ray at the current step in view space
  316. If( this._isPerspectiveCamera.equal( float( 1 ) ), () => {
  317. const recipVPZ = float( 1 ).div( viewPosition.z ).toVar();
  318. viewReflectRayZ.assign( float( 1 ).div( recipVPZ.add( s.mul( float( 1 ).div( d1viewPosition.z ).sub( recipVPZ ) ) ) ) );
  319. } ).Else( () => {
  320. viewReflectRayZ.assign( viewPosition.z.add( s.mul( d1viewPosition.z.sub( viewPosition.z ) ) ) );
  321. } );
  322. // if viewReflectRayZ is less or equal than the real z-coordinate at this place, it potentially intersects the geometry
  323. If( viewReflectRayZ.lessThanEqual( vZ ), () => {
  324. // compute the distance of the new location to the ray in view space
  325. // to clarify vP is the fragment's view position which is not an exact point on the ray
  326. const away = pointToLineDistance( vP, viewPosition, d1viewPosition ).toVar();
  327. // compute the minimum thickness between the current fragment and its neighbor in the x-direction.
  328. const xyNeighbor = vec2( xy.x.add( 1 ), xy.y ).toVar(); // move one pixel
  329. const uvNeighbor = xyNeighbor.div( this._resolution );
  330. const vPNeighbor = getViewPosition( uvNeighbor, d, this._cameraProjectionMatrixInverse ).toVar();
  331. const minThickness = vPNeighbor.x.sub( vP.x ).toVar();
  332. minThickness.mulAssign( 3 ); // expand a bit to avoid errors
  333. const tk = max( minThickness, this.thickness ).toVar();
  334. If( away.lessThanEqual( tk ), () => { // hit
  335. const vN = this.normalNode.sample( uvNode ).rgb.normalize().toVar();
  336. If( dot( viewReflectDir, vN ).greaterThanEqual( 0 ), () => {
  337. // the reflected ray is pointing towards the same side as the fragment's normal (current ray position),
  338. // which means it wouldn't reflect off the surface. The loop continues to the next step for the next ray sample.
  339. Continue();
  340. } );
  341. // this distance represents the depth of the intersection point between the reflected ray and the scene.
  342. const distance = pointPlaneDistance( vP, viewPosition, viewNormal ).toVar();
  343. If( distance.greaterThan( this.maxDistance ), () => {
  344. // Distance exceeding limit: The reflection is potentially too far away and
  345. // might not contribute significantly to the final color
  346. Break();
  347. } );
  348. const op = this.opacity.mul( metalness ).toVar();
  349. // distance attenuation (the reflection should fade out the farther it is away from the surface)
  350. const ratio = float( 1 ).sub( distance.div( this.maxDistance ) ).toVar();
  351. const attenuation = ratio.mul( ratio );
  352. op.mulAssign( attenuation );
  353. // fresnel (reflect more light on surfaces that are viewed at grazing angles)
  354. const fresnelCoe = div( dot( viewIncidentDir, viewReflectDir ).add( 1 ), 2 );
  355. op.mulAssign( fresnelCoe );
  356. // output
  357. const reflectColor = this.colorNode.sample( uvNode );
  358. output.assign( vec4( reflectColor.rgb, op ) );
  359. Break();
  360. } );
  361. } );
  362. } );
  363. return output;
  364. } );
  365. this._material.fragmentNode = ssr().context( builder.getSharedContext() );
  366. this._material.needsUpdate = true;
  367. //
  368. return this._textureNode;
  369. }
  370. /**
  371. * Frees internal resources. This method should be called
  372. * when the effect is no longer required.
  373. */
  374. dispose() {
  375. this._ssrRenderTarget.dispose();
  376. this._material.dispose();
  377. }
  378. }
  379. export default SSRNode;
  380. /**
  381. * TSL function for creating screen space reflections (SSR).
  382. *
  383. * @tsl
  384. * @function
  385. * @param {Node<vec4>} colorNode - The node that represents the beauty pass.
  386. * @param {Node<float>} depthNode - A node that represents the beauty pass's depth.
  387. * @param {Node<vec3>} normalNode - A node that represents the beauty pass's normals.
  388. * @param {Node<float>} metalnessNode - A node that represents the beauty pass's metalness.
  389. * @param {Camera} camera - The camera the scene is rendered with.
  390. * @returns {SSRNode}
  391. */
  392. export const ssr = ( colorNode, depthNode, normalNode, metalnessNode, camera ) => nodeObject( new SSRNode( nodeObject( colorNode ), nodeObject( depthNode ), nodeObject( normalNode ), nodeObject( metalnessNode ), camera ) );